从 0 开始的 TorchScript
source link: https://muyuuuu.github.io/2022/10/03/torch-jit-1/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
从 0 开始的 TorchScript
2022-10-03
12 13k 12 分钟
上一次正儿八经写博客是今年 2 月,5 月做了个比赛总结,其余的博客竟然都是刷题和算法,实属无聊。艰难的日子已经过去,准备学点模型部署相关的东西以及参与一个实际的开源项目,争取数据、算法和工程全链路打通。众所周知,对于一个不是很常用的东西,学完就忘,如 spark, Go
等学过的但很少用的东西,已经被我抛到九霄云外了。所以,这次学完模型的 trace
之后,尝试部署一些能实际运行的软件。
TorchScript
是 PyTorch
的 JIT
实现。JIT
全程是 Just In Time Compilation,也就是即使编译。在深度学习中 JIT
的思想更是随处可见,最明显的例子就是 Keras
框架的 model.compile 创建的静态图。
- 静态图需要先构建再运行,优势是在运行前可以对图结构进行优化,比如常数折叠、算子融合等,可以获得更快的前向运算速度。缺点也很明显,就是只有在计算图运行起来之后,才能看到变量的值,像
TensorFlow1.x
中的session.run
那样。 - 动态图是一边运行一边构建,优势是可以在搭建网络的时候看见变量的值,便于检查。缺点是前向运算不好优化,因为根本不知道下一步运算要算什么。动态图模型通过牺牲一些高级特性来换取易用性。
那么那到底 JIT
有哪些特性,使得 torch
这样的动态图框架也要走 JIT
这条路呢?或者说在什么情况下不得不用到 JIT
呢?下面主要通过介绍 TorchScript
来分析 JIT
到底带来了哪些好处。
JIT
是 Python
和 C++
的桥梁,我们可以使用 Python
训练模型,然后通过 JIT
将模型转为语言无关的模块,从而让 C++
可以非常方便得调用,从此「使用 Python
训练模型,使用 C++
将模型部署到生产环境」对 PyTorch
来说成为了一件很容易的事。而因为使用了 C++
,我们现在几乎可以把 PyTorch
模型部署到任意平台和设备上:树莓派、iOS、Android 等等。不然每次都要通过 python
调用模型,性能会大打折扣。
既然是为部署生产所提供的特性,那免不了在性能上面做了极大的优化,如果推断的场景对性能要求高,则可以考虑将模型(torch.nn.Module
)转换为 TorchScript Module
,再进行推断。有两种方式可以转换:
- 使用
TorchScript Module
的更简单的办法是使用Tracing
,Tracing
可以直接将PyTorch
模型(torch.nn.Module
)转换成TorchScript Module
。「trace
」顾名思义,就是需要提供一个「输入」来让模型forward
一遍,以通过该输入的流转路径,获得图的结构。这种方式对于forward
逻辑简单的模型来说非常实用,但如果forward
里面本身夹杂了很多流程控制语句,就会存在问题,因为同一个输入不可能遍历到所有的逻辑分枝。而没有被经过的分支就不会被trace
。 - 可以直接使用
TorchScript Language
来定义一个PyTorch JIT Module
,然后用torch.jit.script
来将他转换成TorchScript Module
并保存成文件。而TorchScript Language
本身也是Python
代码,所以可以直接写在Python
文件中。对于TensorFlow
我们知道不能直接使用Python
中的if
等语句来做条件控制,而是需要用tf.cond
,但对于TorchScript
我们依然能够直接使用if
和for
等条件控制语句,所以即使是在静态图上,PyTorch
依然秉承了「易用」的特性。
trace 方法
首先定义一个简单的模型:
import torch
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
# 分支判断
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.dg = MyDecisionGate()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
y = torch.tanh(self.dg(self.linear(x)) + h)
return y
my_cell = MyCell()
print(my_cell)
x, h = torch.rand(1, 4), torch.rand(1, 4)
print(my_cell(x, h))
我们可以绑定输入对模型进行 trace
:
import torch
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.dg = MyDecisionGate()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
y = torch.tanh(self.dg(self.linear(x)) + h)
return y
my_cell = MyCell()
x, h = torch.rand(1, 4), torch.rand(1, 4)
trace_model = torch.jit.trace(my_cell, (x, h))
print(trace_model(x, h))
print(trace_model.code)
# def forward(self,
# x: Tensor,
# h: Tensor) -> Tensor:
# dg = self.dg
# linear = self.linear
# _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
# return torch.tanh(_0)
可以看到没有出现 if-else
的分支, trace
做的是:运行代码,记录出现的运算,构建 ScriptModule
,但是控制流就丢失了。然后流程丢失并不是好事,在 trace
只会对一个输入进行处理的情况下,对不同的输入得到的结果是不一样的,因为输入只会满足一个分支,因此 trace
的程序也只包含一个分支。
import torch
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
my_cell = MyDecisionGate()
x = torch.tensor([-0.1, 0.05]) # 这两个 x trace 到的代码是不一样的
# x = torch.tensor([0.1, -0.05])
trace_model = torch.jit.trace(my_cell, (x))
print(trace_model(x))
print(trace_model.code)
因此,我们认为这样的 trace
没有泛化能力。而这种现象普遍发生在动态控制流中,即:具体执行哪个算子取决于输入的数据。
if x[0] == 4: x += 1
是动态控制流model: nn.Sequential = ... [m(x) for x in model]
不是class A(nn.Module):
backbone: nn.Module
head: Optiona[nn.Module]
def forward(self, x):
x = self.backbone(x)
if self.head is not None:
x = self.head(x)
return x
在之后的文章中,会介绍如何使 trace
具备泛化能力。
script 方法
script
方法直接分析 python
代码进行转换:使用他们提供的 script
编译器,将 python
的代码进行语法分析,并重新解释为 TorchScript
。
import torch
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
scripted_gate = torch.jit.script(MyDecisionGate())
print(scripted_gate.code) # 含有流程控制
my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)
TorchScript
代码可以被它自己的解释器(一个受限的Python
解释器)调用。这个解释器不需要获得全局解释锁GIL,这样很多请求可以同时处理。- 这个格式可以让我们保存模型到硬盘上,在另一个环境中加载,例如服务器,也可以使用非
python
的语言。 TorchScript
提供的表示可以做编译器优化,做到更有效地执行。TorchScript
可以与其他后端/设备运行时进行对接,他们只需要处理整个项目,无需关心细节运算。
Trace 和 Script 谁更好?
通过上文我们可以了解到:
trace
只记录走过的tensor
和对tensor
的操作,不会记录任何控制流信息,如if
条件句和循环。因为没有记录控制流的另外的路,也没办法对其进行优化。好处是trace
深度嵌入python
语言,复用了所有python
的语法,在计算流中记录数据流。script
会去理解所有的code
,真正像一个编译器一样去进行词法分析语法分析句法分析,形成AST
树,最后再将AST
树线性化。script
相当于一个嵌入在Python/Pytorch
的DSL
,其语法只是Pytorch
语法的子集,这意味着存在一些op
和语法script
不支持,这样在编译的时候就会遇到问题。此外,script
的编译优化方式更像是CPU
上的传统编译优化,重点对于图进行硬件无关优化,并对if
、loop
进行优化。
在大模型的部署上 trace
更好,因为可以有效的优化复杂的计算图,如下所示:
class A(nn.Module):
def forward(self, x1, x2, x3):
z = [0, 1, 2]
xs = [x1, x2, x3]
for k in z: x1 += xs[k]
return x1
model = A()
print(torch.jit.script(model).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
# z = [0, 1, 2]
# xs = [x1, x2, x3]
# x10 = x1
# for _0 in range(torch.len(z)):
# k = z[_0]
# x10 = torch.add_(x10, xs[k])
# return x10
print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
# x10 = torch.add_(x1, x1)
# x11 = torch.add_(x10, x2)
# return torch.add_(x11, x3)
因为 script
试图忠实地表示 Python
代码,所以即使其中一些是不必要的。例如:并不能对 Python
代码中的某些循环或数据结构进行优化。如上例,所以它实际上有变通方法,或者循环可能会在以后的优化过程中得到优化。但关键是:这个编译器并不总是足够聪明。对于复杂的模型, script
可能会生成一个具有不必要复杂性且难以优化的计算图。
tracing
有许多优点,事实上,在 Facebook/Meta
部署的分割和检测模型中,tracing
是默认的选择,仅当必要的时候使用 scripting
。因为 trace
不会破坏代码质量,可以结合 script
来避免一些限制。
python
是一个很大很动态的语言,编译器最多只能支持其语法功能和内置函数的子集,同理,script
也不例外。这个编译器支持 Python
的哪个子集?一个粗略的答案是:编译器对最基本的语法有很好的支持,但对任何更复杂的东西(类、内置函数、动态类型等)的支持度很低或者不支持。但并没有明确的答案:即使是编译器的开发者,通常也需要运行代码,看看能不能编译去判断是否支持。
所以不完整的 Python
编译器限制了用户编写代码的方式。尽管没有明确的约束列表,但可以从经验中看出它们对大型项目的影响:script
的问题会影响代码质量。很多项目只停留在了代码能 script
成功这一层面,使用基础语法,没有自定义类型,没有继承,没有内置函数,没有 lambda
等等的高级特性。因为这些高级的功能编译器并不支持或者部分支持,就会导致在某些情况下成功,但在其他情况下失败。而且由于没有明确的规范哪些是被支持的,因此用户无法推理或解决故障。因此,最终用户会仅仅停留在代码成功搬移,而不考虑可维护性和性能问题,会导致开发者因为害怕报错而停止进一步的探索高级特性。
如此下去,代码质量可能会严重恶化:垃圾代码开始积累,因为优良的代码有时无法编译。此外,由于编译器的语法限制,无法轻松进行抽象以清理垃圾代码。该项目的可维护状况逐渐走下坡路。如果认为 script
似乎适用于我的项目,基于过去在一些支持 script
的项目中的经验,我可能会出于以下原因建议不要这样做:
- 编译成功可能比你想象的更脆弱(除非将自己限制在基本语法上):你的代码现在可能恰好可以编译,但是有一天你会在模型中添加一些更改,并发现编译失败;
- 基本语法是不够的:即使目前你的项目似乎不需要更复杂的抽象和继承,但如果预计项目会增长,未来将需要更多的语言特性。
以多任务检测器为例:
- 可能有 10 个输入,因此最好使用一些结构/类。
- 检测器有许多架构选择,这使得继承很有用。
- 大型、不断增长的项目肯定需要不断发展的抽象来保持可维护性。
因此,这个问题的现状是:script
迫使你编写垃圾的代码,因此我们仅在必要时使用它。
Trace 细节
trace
让模型的 trace
更清楚,对代码质量有很少的影响。
如果模型不是以 Pytorch
格式表示的计算图,则 script
和 trace
都不起作用。例如,如果模型具有 DataParallel
子模块,或者如果模型将张量转换为 numpy
数组并调用 OpenCV
函数等,则必须对其进行重构。除了这个明显的限制之外,对 trace
只有两个额外的要求:
输入/输出格式是
Tensor
类型时才能被trace
。但是,这里的格式约束不适用于子模块:子模块可以使用任何输入/输出格式:类、kwargs
以及Python
支持的任何内容。格式要求仅适用于最外层的模型,因此很容易解决。如果模型使用更丰富的格式,只需围绕它创建一个简单的包装器,它可以与Tuple[Tensor]
相互转换。shape
。tensor.size(0)
是eager
模式下的整数,但它是tracing mode
下的tensor
。这个差异在trace
时是必要的,shape
的计算可以被捕获为计算图中的算子。由于不同的返回类型,如果返回的一部分是shape
是整数则无法trace
,这通常可以简单的解决。此外,一个有用的函数是torch.jit.is_tracing
,它检查代码是否在trace
模式下执行。
我们来看个例子:
>>> a, b = torch.rand(1), torch.rand(2)
>>> def f1(x): return torch.arange(x.shape[0])
>>> def f2(x): return torch.arange(len(x))
>>> # See if the two traces generalize from a to b:
>>> torch.jit.trace(f1, a)(b)
tensor([0, 1])
>>> torch.jit.trace(f2, a)(b)
tensor([0]) # WRONG!
>>> # Why f2 does not generalize? Let's compare their code:
>>> print(torch.jit.trace(f1, a).code, torch.jit.trace(f2, a).code)
def f1(x: Tensor) -> Tensor:
_0 = ops.prim.NumToTensor(torch.size(x, 0))
_1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
return _1
def f2(x: Tensor) -> Tensor:
_0 = torch.arange(1, dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
return _0
在 trace f2
函数时,lex(x)
是一个定值而非 tensor
,这样在传入其他长度的数据时就回报错。除了 len()
,这个问题也可能出现在:
.item()
将张量转换为int/float
。- 将
Torch
类型转换为numpy/python
原语的任何其他代码。
tensor.size()
在 trace
期间返回 Tensor
,以便在图中捕获形状计算。用户应避免意外将张量形状转换为常量。使用 tensor.size(0)
而不是 len(tensor)
,因为后者是一个 int
。这个函数对于将大小转换为张量很有用,在 trace
和 eager
模式下都可以使用。对于自定义类,实现 .size()
方法或使用 .__len__()
而不是 len()
,不要通过 int()
转换大小,因为它们会捕获常量。
这就是 trace
所需要的一切。最重要的是,模型实现中允许使用任何 Python
语法,因为 trace
根本不关心语法。
Trace 的泛化问题
Trace 和 Script 混合
>>> def f(x):
... return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
>>> m = torch.jit.trace(f, torch.tensor(3))
>>> print(m.code)
def f(x: Tensor) -> Tensor:
return torch.sqrt(x)
注意这种代码在 trace
时不会报错,只有 warning
的输出,因此我们要特别关注。trace
和 script
都有各自的问题,最好的方法是混合使用他们。避免影响代码质量,主要的部分进行 trace
,必要时进行 script
。如果有一个 module
里面有很多选择,但是我们不希望在 TorchScript
里出现,那么应该使用 tracing
而不是 scripting
,这个时候,trace
将内联 script
模块的代码。
import torch
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
class MyRNNLoop(torch.nn.Module):
def __init__(self, scripted_gate, x, h):
super(MyRNNLoop, self).__init__()
# 对控制流进行 trace
self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
def forward(self, xs):
h, y = torch.zeros(3, 4), torch.zeros(3, 4)
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y, h
x, h = torch.rand(2, 4), torch.rand(2, 4)
scripted_gate = torch.jit.script(MyDecisionGate())
rnn_loop = torch.jit.script(MyRNNLoop(scripted_gate, x, h))
print(rnn_loop.code)
print(rnn_loop.cell.code)
我们简化一下:
model.submodule = torch.jit.script(model.submodule)
torch.jit.trace(model, inputs)
对于不能正确 trace
的子模块,可以进行 script
处理。但是并不推荐,更建议使用 @script_if_tracing
,因为这样修改 script
仅限于子模块的内部,而不影响模块的接口。使用 @script_if_tracing
装饰器,在 torch.jit.trace
时,@script_if_tracing
装饰器可以通过 script
编译。通常,这只需要对前向逻辑进行少量重构,以分离需要编译的部分(具有控制流的部分):
def forward(self, ...):
# ... some forward logic
@torch.jit.script_if_tracing
def _inner_impl(x, y, z, flag: bool):
# use control flow, etc.
return ...
output = _inner_impl(x, y, z, flag)
# ... other forward logic
只 script
需要的部分,代码质量相对于全部 script
被破坏的很少,被 @script_if_tracing
装饰的函数必须是不包含 tensor
模块运算的纯函数。因此,有时需要进行更多重构:
# Before:
if x.numel() > 0: # This branch cannot be compiled by @script_if_tracing because it refers to `self.layers`
x = preprocess(x)
output = self.layers(x)
else:
output = torch.zeros(...) # Create empty outputs
# After:
if x.numel() > 0: # This branch can now be compiled by @script_if_tracing
x = preprocess(x)
else:
x = torch.zeros(...) # Create empty inputs
# Needs to make sure self.layers accept empty inputs.
# If necessary, add such condition branch into self.layers as well.
output = self.layers(x)
同样的,我们可以在 script
中嵌套 trace
:
model.submodule = torch.jit.trace(model.submodule, submodule_inputs)
torch.jit.script(model)
这里的子模块是 trace
,但是实际中并不常用,因为会影响子模块的推理(当且仅当子模块的输入和输出都是 tensor
时才适用),这是很大的限制。但是 trace
作为子模块的时候也有很试用的场景:
class A(nn.Module):
def forward(self, x):
# Dispatch to different submodules based on a dynamic, data-dependent condition:
return self.submodule1(x) if x.sum() > 0 else self.submodule2(x)
@script_if_tracing
不能处理这样的控制流,因为它只支持纯函数。如果子模块很复杂不能被 script
,使用 trace
trace
子模块是很好的选择,这里就是 self.submodule2
和 self.submodule1
,类 A
还是要 script
的。
Script 优势
事实上,对于大多数视觉模型,动态控制流仅在少数易于编写 script
的子模块中需要。script
相对于 trace
,有两个有点:
- 一个数据有很多属性的控制流,
trace
无法处理 trace
只支持forward
方法,script
支持更多的方法
实际上,上述两个功能都在做同样的事情:它们允许以不同的方式使用导出的模型,即根据调用者的请求执行不同的运算符序列。下面是一个这样的特性很有用的示例场景:如果 Detector
是 script
化,调用者可以改变它的 do_keypoint
属性来控制它的行为,或者如果需要直接调用 predict_keypoint
方法。
class Detector(nn.Module):
do_keypoint: bool
def forward(self, img):
box = self.predict_boxes(img)
if self.do_keypoint:
kpts = self.predict_keypoint(img, box)
def predict_boxes(self, img): pass
def predict_keypoint(self, img, box): pass
这种要求并不常见。但是如果需要,如何在 trace
中实现这一点?我有一个不是很优雅的解决方案:Tracing
只能捕获一个序列的算子,所以自然的方式是对模型进行两次 Tracing
:
det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
然后我们可以为它们的模型设置别名(以不重复存储),并将两个 trace
合并到一个模块中以编写 script
:
det2.submodule.weight = det1.submodule.weight
class Wrapper(nn.ModuleList):
def forward(self, img, do_keypoint: bool):
if do_keypoint:
return self[0](img)
else:
return self[1](img)
exported = torch.jit.script(Wrapper([det1, det2]))
还可以使用单元测试来判断 trace
是否成功:
assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
此外,还可以通过优化程序,避免掉不必要的特殊情况:
if x.numel() > 0:
output = self.layers(x)
else:
output = torch.zeros((0, C, H, W)) # Create empty outputs
此外还需要注意设备问题,在 trace
期间会记录使用的设备,而 trace
不会对不同的设备进行泛化,但是部署时都会有固定的设备,这个问题不用担心。
>>> def f(x):
... return torch.arange(x.shape[0], device=x.device)
>>> m = torch.jit.trace(f, torch.tensor([3]))
>>> print(m.code)
def f(x: Tensor) -> Tensor:
_0 = ops.prim.NumToTensor(torch.size(x, 0))
_1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
return _1
>>> m(torch.tensor([3]).cuda()).device
device(type='cpu') # WRONG!
trace
有明显的局限性:本文大部分时间都在讨论 trace
的局限性以及如何解决它们。我实际上认为这是 trace
的优点:它有明确的限制和解决方案,所以你可以推断它是否有效。相反, script
更像是一个黑匣子:在尝试之前没有人知道它是否有效。
trace
具有较小的代码破坏范围: trace
和 script
都会影响代码的编写方式,但 trace
的代码破坏范围要小得多,并且造成的损害要小得多:
- 它限制了输入/输出格式,但仅限于最外层的模块。
- 在
trace
中混合script
,但可以只更改受影响模块的内部实现,而不是它们的接口。
另一方面, script
对以下方面有影响:
- 涉及的每个模块和子模块的接口,接口需要高级语法特性,针对接口编程时,千万别在接口设计上妥协。
- 这也可能最终影响训练,因为接口通常在训练和推理之间共享。
这也是为什么 script
会对代码质量造成很大损害的原因。Detectron2
支持 script
,但不推荐其他大型项目以可 script
且不丢失抽象为目标,因为这实在有点难度,除非它们也能像阿里巴巴那样得到 PyTorch
团队的支持。
PyTorch
深受用户喜爱,最重要的是编写 Python
控制流。但是 Python
的其他语法也很重要。如果能够编写 Python
控制流( 使用 script
)意味着失去其他优秀的语法,我宁愿放弃编写 Python
控制流的能力。事实上,如果 PyTorch
对 Python
控制流不那么执着,并且像这样(类似于 tf.cond
的 API
)为我提供了诸如 torch.cond
之类的符号控制流:
def f(x):
return torch.cond(x.sum() > 0, lambda: torch.sqrt(x), lambda: torch.square(x))
然后 f
可以正确 trace
,不再需要担心 script
。
保存和加载模型
traced.save('wrapped_rnn.pt')
loaded = torch.jit.load('wrapped_rnn.pt')
print(loaded)
print(loaded.code)
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK