3

对比PyTorch、TensorFlow、JAX、Theano,我发现都在关注两大问题

 1 year ago
source link: https://blog.csdn.net/OneFlow_Official/article/details/128391980
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

对比PyTorch、TensorFlow、JAX、Theano,我发现都在关注两大问题

42210f8d7eccacde159ea0643b6ab2ae.jpeg

作者|王益

OneFlow社区编译

翻译|杨婷

最近,我在处理 PyTorch 分布式和 TorchRec 相关的工作,为此,我开始学习 PyTorch 2.0。在业余时间,我也在跟着Alpa作者学习JAX和XLA。如今回顾这些技术,我发现它们的关注点似乎都是如下两个问题:

  1. 包含自动求导和并行在内的函数转换,例如 vmap, pmap 和 pjit 等;

  2. 异构计算,CPU 负责控制流,GPU/TPU 负责张量计算和集合通信。

本文档中的所有例子都支持在 Colab 中运行:

Theano/Aesara

https://colab.research.google.com/drive/1eg7C5WMNokhXgXQ46pNA30dXUCklquPz

TensorFlow 1.x

https://colab.research.google.com/drive/1jc0ePg2AAXBihevtoZM_33mmhC70rzqz?usp=sharing

TensorFlow 2.x

https://colab.research.google.com/drive/1PbftzJ9E2_FyIiuozTpExMvlFky_G2nv

PyTorch 1.x

https://colab.research.google.com/drive/1v4hENL-IJ-C6VT5H9W1NC2te85D8VdJK

https://colab.research.google.com/drive/1PlFijLIzAttIBd3tBjiEbSgPXvq9lVlg

functorch/PyTorch 2.x

https://colab.research.google.com/drive/1o-yJ-5g1V084RDaiRw2PqfAjOG7Ty951

“函数转换”意为将一个程序转变成另一个程序,最常见的例子是自动求导(autograd)。自动求导采用用户编写的前向过程并创建后向过程,对于用户来说,编写自动求导通常都太过复杂。函数转换的主要难点在于:在编写函数转换算法时以何种方式表示输入和输出过程。

Theano:显式地构建 IR

Theano是最早的深度学习工具之一,也就是如今为人们所熟知的Aesara项目。Theano有一个允许用户在内存中将IR构建为数据结构的API,因此Theano可实现自动求导,并将结果输出为 Python 函数。

TensorFlow 1.x:用于运行 IR 的虚拟机

TensorFlow 1.x明确保留了构建IR的想法。若在TensorFlow中运行上述示例,结果不会有什么差别;但倘若在TensorFlow 1.x中来运行,最大的差别在于:我们不会将后向 IR 转换为 Python 函数,并使用 Python 解释器来运行。相反,我们会在TensorFlow runtime中来运行。

newCodeMoreWhite.png

PyTorch 1.x:没有前向IR

PyTorch不会像Theano或TensorFlow那样将前向传播转换为IR。反之,PyTorch 使用 Python 解释器来运行前向传播。这样做的弊端在于会在运行期间生成表示后向传播的 IR,我们称之为Eager模式(动态图模式)。

TensorFlow 2.x: 梯度带

TensorFlow 2.x增加了一个像PyTorch API的Eager模式API。此 API 追踪前向传播如何运行名为梯度带(GradientTape)的 IR 。TensorFlow 2.x可以从这个跟踪中找出后向传播。

JAX 不会向用户公开诸如梯度带等方面的低级别细节。简单说来,JAX的思维方式为:将输入和输出都用Python函数来表示。

对于想要自己编写的函数转换的高级用户,他们可以调用make_jaxpr等低级 API 来访问 IR,称为 JAXPR。

FuncTorch

FuncTorch和JAX类似,都是基于PyTorch的函数转换。

JAX的make_jaxpr类似于functorch的make_fx。

TensorFlow 2.x、JAX 和 functorch 都为前向传递构建了一个 IR,但 PyTorch Eager模式没有。IR 不仅可用于自动求导,还可用于其他类型的函数转换。在下列例子中,functorch.compile.aot_function调用了回调函数print_compile_fn两次,分别用于前向和后向传播。

2
高阶导数

PyTorch

TensorFlow 2.x

JAX

3
动态控制流

动态控制流(dynamic control flows)有两个层级:在 CPU 上运行的粗粒度级别和在 GPU /TPU 上运行的细粒度级别。本部分主要介绍在 CPU 上运行的粗粒度级别的动态控制流。下面我们将用(if/else)条件语句作为例子检验深度学习工具。

TensorFlow 1.x

在 TensorFlow 1.x 中,我们需要将条件语句显式构建到 IR 中。此时条件语句是一个特殊的运算符 tf.cond。

TensorFlow 2.x

TensorFlow 2.x 支持使用 tf.cond 和 tf.while_loop 显式构建控制流。此外,实验项目google/tangent中有AutoGraph功能,它可以将Python控制流转换为tf.cond或tf.while_loop。此功能利用了 Python 解释器支持的函数和函数源代码。例如下面的g函数调用了 Python 的标准库将源代码解析为 AST,然后调用 SSA 表单来理解控制流。

由于部分Python语法很复杂,所以通过解析源代码来理解控制流就显得很困难,这就导致AutoGraph经常出错。但如果这种方法很简单,那么Python开发者社区也不会在构建Python编译器时失败这么多次了。正是由于有这种挑战的存在,必须要明确地将控制流构建到 IR 中。为此,JAX 提供了 jax.lax.cond 和 jax.lax.for_loop函数。

jax.lax.cond(a < b, lambda : a*17, lambda: b+23)

考虑到这一点,你可能会觉得我们可以使用递归算法。但是下面用于计算阶乘的递归无法用JAX跟踪。

可能你还想调用factorial来计算 3!=6。但这会让递归深度超过最大值,因为递归不仅依赖于条件,还依赖于函数定义和调用。

PyTorch

PyTorch最初是Python-native。正如前文所说,由于多功能调度机制,grad 和 vamp 的函数转换都是即时的。值得注意的是:

  1. 相比Theano 和 TensorFlow构建IR后的函数转换,即时函数转换效率更高。

  2. 在进行grad和vmap 时,JAX也是即时函数转换。然而像pamp和pjit等更复杂的函数转换需要对整个计算过程进行概述,在这个过程中IR是必不可少的。

由于IR在pmap 和 pjit中的必要性,PyTorch社区最近添加了torch.condpytorch/pytorch#83154

4
分布式计算

根据执行代码或 IR 的不同方式,在使用 Python 解释器或runtime时,有两种分布式计算方法。

Python-Native

Theano和PyTorch采用了Python-native分布式计算方式。这种分布式训练工作包含多个Python解释器进程。这导致出现了以下结果。

  1. 打包和运行(Pack and run)。由于这些 Python 进程在不同的host上运行,因此我们需要打包用户程序和依赖项,并将它们发送到这些host上去运行。一直以来TorchX负责了这个打包过程。它支持例如Docker和torch.package等各种打包格式,并且可以与各种集群管理器配合使用,如Kubernetes和SLURM。

  2. 单程序多数据(SPMD)。由于将用户程序发送到各种host上要依赖于打包,与其他权重较轻的方式(如通过 RPC 发送代码)相比,这种方式不太灵活,因此,我们通常只发送一个程序。当所有这些进程运行同一程序时,这个作业就变成了单程序多数据(SPMD)作业。

Python-native SPMD

下面是一个简单的SPMD PyTorch程序,我们可以在相同或不同的host上使用进程运行这个程序。在这个过程中,我们只需要调用all_gather。真正的分布式训练程序会调用更高级别的API,例如torch.nn.parallel.DistributedDataParallel 和 torchrec.DistributedModelParallel, 然后再调用低级 API,例如 all_gather 和 all_reduce。

newCodeMoreWhite.png

Python-native Non-SPMD

PyTorch 不仅限于 SPMD 式的分布式训练。它还通过torch.distributed.pipeline.sync.Pipe和PiPPy project提供流水并行,其中流水并行的各个阶段在不同的设备上运行不同的程序。这些阶段常通过 torch.rpc 包来沟通。

分布式运行时机制

分布式 TensorFlow 作业由运行 TensorFlow runtime 程序的进程组成,而不是由 Python 解释器组成。此分布式运行时作业执行 TensorFlow graph (IR),它是由执行用户程序的 Python 解释器生成。

用户程序可以使用低级API(如 tf.device)去指定作业要运行什么操作、在哪台设备和主机上运行等等。因为API有runtime,所以可以做到这一点。

与PyTorch一样,TensorFlow也为分布式训练提供了高级API tf.distributed.strategy,Keras和DTensor。

分布式运行时极大地方便了训练服务的维护,因为我们不再将用户程序打包到集群上运行。相反,我们打包运行时程序,因为相比用户程序,运行时程序更加统一。

JAX 支持 Python-native 和分布式运行时。

JAX 提供例如vmap、pmap 和 pjit的函数转换,这可以将 Python 函数转换为分布式程序。

(本文经授权后由OneFlow社区编译,译文转载请联系获得授权。原文:https://quip.com/Y8qtAyV4EXRg)

其他人都在看

欢迎Star、试用OneFlow最新版本:GitHub - Oneflow-Inc/oneflow: OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient.OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient. - GitHub - Oneflow-Inc/oneflow: OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient.fluidicon.pnghttps://github.com/Oneflow-Inc/oneflow/

文章知识点与官方知识档案匹配,可进一步学习相关知识

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK