推荐 :谷歌JAX 助力科学计算

百家 作者:数据分析 2022-09-26 08:12:03
作者:王可汗; 审校:陈之炎
本文约3500字,建议阅读9分钟
本文为你介绍使用谷歌JAX助力科学计算。
谷歌最新推出的JAX,官方定义为CPU、GPU和TPU上的NumPy。它具有出色的自动微分(differentiation)功能,是可用于高性能机器学习研究的python库。Numpy在科学计算领域十分普及,但是在深度学习领域,由于它不支持自动微分和GPU加速,所以更多的是使用Tensorflow或Pytorch这样的深度学习框架。然而谷歌之前推出的Tensorflow API有一些比较混乱的情况,在1.x的迭代中,就存在如原子op、layers等不同层次的API。面对不同类型的用户,使用粒度不同的多层API本身并不是什么问题。但同层次的API也有多种竞品,如slim和layers等实则提高了学习成本和迁移成本。而JAX使用 XLA 在诸如GPU和TPU的加速器上编译和运行NumPy。它与 NumPy API 非常相似, numpy 完成的事情几乎都可以用 jax.numpy 完成,从而避免了直接定义API这件事。

下面简要介绍JAX的几个特性,并同时给出一些示例让读者能够快速入门上手。最后我们将结合科学计算的实例,展现google JAX在科学计算方面的巨大威力。

1.JAX特性

1)自动微分:

在深度学习领域,网络参数的优化是通过基于梯度的反向传播算法实现的。因此能够实现任意数值函数的微分对于机器学习有着十分重要的意义。下面结合官方文档的例子简要介绍这一特性。

首先介绍最简单的grad求一阶微分:可以直接通过grad函数求某一函数在某位置的梯度值

import jax.numpy as jnpfrom jax import grad, jit, vmapgrad_tanh = grad(jnp.tanh)print(grad_tanh(2.0))[OUT]:0.070650816

当然如果想对双切正弦函数继续求二阶,三阶导数,也可以这样做:

print(grad(grad(jnp.tanh))(2.0))print(grad(grad(grad(jnp.tanh)))(2.0))[OUT]:-0.136218680.25265405

除此之外,还可以利用hessian、jacfwd 和 jacrev 等方法实现函数转换,它们的功能分别是求解海森矩阵,以及利用前向或反向模式求解雅克比矩阵。Jacfwd和jacrev可以得到一样的结果,但是在不同的情形下求解效率不同,这是因为两者背后对应的微分几何中的push forward和pull back方法。而前面提到的grad则是基于反向模式。

在一些拟牛顿法的优化算法中,常常需要利用二阶的海森矩阵。为了实现海森矩阵的求解。为了实现这一目标,我们可以使用jacfwd(jacrev(f))或者jacrev(jacfwd(f))。但是前者的效率更高,因为内层的雅克比矩阵计算是通过类似于一个1维损失函数对n维向量的求导,明显使用反向模式更为合适。外层则通常是n维函数对n维向量的求导,正向模式更有优势。

2)向量化

无论是科学计算或者机器学习的研究中,我们都会将定义的优化目标函数应用到大量数据中,例如在神经网络中我们去计算每一个批次的损失函数值。JAX 通过 vmap 转换实现自动向量化,简化了这种形式的编程。

下面结合几个例子,说明这一用法:

vmap有3个最重要的参数:

  • fun: 代表需要进行向量化操作的具体函数;
  • in_axes:输入格式为元组,代表fun中每个输入参数中,使用哪一个维度进行向量化;
  • out_axes: 经过fun计算后,每组输出在哪个维度输出。

我们先来看二维情况下的一些例子:

import jax.numpy as jnpimport numpy as npimport jax

(1)先定义a,b两个二维数组(array)

a = np.array(([1,3],[23, 5]))print(a)[out]: [[ 1 3][23 5]]b = np.array(([11,7],[19,13]))print(b)[OUT]: [[11 7][19 13]]
(2)正常的两个矩阵element-wise的相加

print(jnp.add(a,b))#[[1+11, 3+7]]# [[23+19, 5+13]][OUT]: [[12 10][42 18]]

(3)矩阵a的行 + 矩阵b的行,然后根据out_axes=0输出,0表示行输出

print(jax.vmap(jnp.add, in_axes=(0,0), out_axes=0)(a,b))#[[1+11, 3+7]]#[[23+19, 5+13]][OUT]: [[12 10][42 18]]

(4)矩阵a的行 + 矩阵b的行,然后根据out_axes=1输出,1表示列输出

print(jax.vmap(jnp.add, in_axes=(0,0), out_axes=1)(a,b))# [[1+11, 3+7]]#[[23+19, 5+13]] 再以列转置输出[OUT]: [[12 42][10 18]]

理解了上面的例子之后,现在开始增加难度,换成三维的例子:

from jax.numpy import jnpA, B, C, D = 2, 3, 4, 5def foo(tree_arg):x, (y, z) = tree_argreturn jnp.dot(x, jnp.dot(y, z))from jax import vmapK = 6 # batch sizex = jnp.ones((K, A, B)) # batch axis in different locationsy = jnp.ones((B, K, C))z = jnp.ones((C, D, K))tree = (x, (y, z))vfoo = vmap(foo, in_axes=((0, (1, 2)),))print(vfoo(tree).shape)

你能够计算最后的输出吗?

让我们一起来分析一下。在这段代码中分别定义了三个全1矩阵x,y,z,他们的维度分别是6*2*3,3*6*4,4*5*6。而tree则控制了foo函数中矩阵连续点积的顺序。根据in_axes可知,y和z的点积最后结果为6个3*5的子矩阵,这是由于y和z此时相当于6个y的子矩阵(3*4维)和6个z的子矩阵(4*5维)点积。再与x点积,得到的最终结果为(6,2,5)。

3)JIT编译


XLA是TensorFlow底层做JIT编译优化的工具,XLA可以对计算图做算子Fusion,将多个GPU Kernel合并成少量的GPU Kernel,用以减少调用次数,可以大量节省GPU Memory IO时间。Jax本身并没有重新做执行引擎层面的东西,而是直接复用TensorFlow中的XLA Backend进行静态编译,以此实现加速。

jit的基本使用方法非常简单,直接调用jax.jit()或使用@jax.jit装饰函数即可:

import jax.numpy as jnpfrom jax import jitdef slow_f(x):# Element-wise ops see a large benefit from fusionreturn x * x + x * 2.0x = jnp.ones((5000, 5000))fast_f = jax.jit(slow_f) # 静态编译slow_f;%timeit -n10 -r3 fast_f(x)%timeit -n10 -r3 slow_f(x)10 loops, best of 3: 24.2 ms per loop10 loops, best of 3: 82.8 ms per loop

运行时间结果:fast_f(x)是slow_f(x) 在CPU上运行速度的3.5倍!静态编译大大加速了程序的运行速度。如图1 所示。



图 1  tensorflow和JAX中的XLA backend

2.JAX在科学计算中的应用

分子动力学是现代计算凝聚态物理的重要力量。它经常用于模拟材料。下面的实例将展现JAX在以分子动力学为代表的科学计算领域的巨大潜力。

首先简单介绍一下分子动力学。分子动力学的基本任务就是获得研究对象在不同时刻的位置和速度,然后基于统计力学的知识获取想得到的物理量,解释对象的行为和性质。

它的主要步骤包括:

第一步,设置研究对象组成粒子的初始位置和速度;
第二步,基于粒子的位置计算每个粒子的合力,并基于牛顿第二定计算粒子的加速度。(这里可能有小伙伴会问,如何计算?我们下文的势函数将为大家解释);
第三步,基于加速度算下一时刻粒子速度,根据速度计算下一时刻位置。

不断循环2-3步,得到粒子的运动轨迹。

如需要获得所有粒子的轨迹,根据牛顿运动方程,需要知道粒子的初始位置和速度,质量以及受力。粒子的受力是势能函数的负梯度,所以在分子动力学模拟中,必须确定所有原子之间的势能函数,即势能关于两个原子之间相对位置的函数,这个势函数我们也称之为力场。

在分子动力学中,复杂力场的优化是一类重要的问题。ReaxFF就是其中的代表。相比于传统力场基于静态化学键以及不随化学环境改变的静态电荷假设,ReaxFF引入键级势的概念,这允许键在整个模拟过程里形成和断开,并动态地为原子分配电荷。也正是由于这些特性的存在,反应力场的形式明显比经典力场更为复杂。这使得我们将其计算的能量等值与密度泛函或者实验值对比得到的损失函数进行反馈优化时更为困难,如图2 所示。



图2 反应力场的参数构成

各种全局优化方法,例如遗传算法,模拟退火算法,进化算法以及粒子群优化算法等等往往没有利用任何梯度信息,这使得这些搜索成本可能会非常昂贵。而JAX的出现为这一问题的解决带来了可能。

JAX-REAXFF:

1)流程


图3  Jax-ReaxFF流程


图3是Jax-ReaxFF的任务流概述,可以将其大致分为两个阶段:聚类和主优化循环。而主优化循环则分别包括利用梯度信息的能量最小化和力场参数优化。

聚类只要是根据相互作用列表进行聚类,在内存中正确对齐,以确保有效的单指令多数据(SIMD)并行化提高效率。

而主优化循环中能量最小化的过程是寻找能量最低最稳定几何构型的过程。它的具体做法是利用JAX求体系势能对原子坐标的梯度,进行优化。力场参数的优化在原文中则分别使用了两种拟牛顿优化方法——L-BFGS和SLSQP。这通scipy.optimize.minimize函数实现,其中向该函数直接传入JAX求解梯度的方法以提高效率。能量最小化和力场参数优化迭代循环。


图4 JAX-ReaxFF主循环优化

Github地址:
https://github.com/cagrikymk/JAX-ReaxFF

2)效果

作者在多个数据集上分别实现了参数的优化,可以看到相比于其他算法,利用JAX梯度信息的优化具有明显的速度优势。

图5   金属钴数据集结果


参考文献:
https://pubs.acs.org/doi/pdf/10.1021/acs.jctc.2c00363
https://jax.readthedocs.io/en/latest/faq.html
https://zhuanlan.zhihu.com/p/474724292
https://arxiv.org/abs/2010.09063
https://mp.weixin.qq.com/s/AoygUZK886RClDBnp1v3jw
END
自:数据派THU;

版权声明:本号内容部分来自互联网,转载请注明原文链接和作者,如有侵权或出处有误请和我们联系。

合作请加QQ:365242293  
数据分析(ID : ecshujufenxi )互联网科技与数据圈自己的微信,也是WeMedia自媒体联盟成员之一,WeMedia联盟覆盖5000万人群。

关注公众号:拾黑(shiheibook)了解更多

[广告]赞助链接:

四季很好,只要有你,文娱排行榜:https://www.yaopaiming.com/
让资讯触达的更精准有趣:https://www.0xu.cn/

公众号 关注网络尖刀微信公众号
随时掌握互联网精彩
赞助链接