6

在TVM中实现Instance Normalization层

 2 years ago
source link: https://bindog.github.io/blog/2018/11/27/add-instance-norm-in-TVM/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

在TVM中实现Instance Normalization层


0x00 背景

TVM不用多解释了,深度学习领域的编译器,直接从其他框架训练好的模型编译到目标平台上的可执行代码,速度和显存的都有大幅优化,该项目目前由DMLC组维护,正在快速迭代中。

原先我组的模型线上部署都是采用Caffe2+ONNX的形式,近来想进一步控制成本,TVM无疑能够满足我们的需求。但是,由于我组最近使用的模型se_resnet_101_ibn_a中包含了一个instance_normbatch_norm融合的特殊结构(IBN结构可参考该论文),而目前TVM暂不支持instance_norm,所以只能自己来进行实现了。

考虑到instance_normbatch_norm比较类似,所以最初的计划是从batch_norm入手,找出TVM源码中所有batch_norm出现过的地方,然后模仿一下,应该就能很轻松的实现instance_norm(当然后面的事实证明这真是太天真了)

那么下面就大致按照从前端到后端,从表面到深层的顺序简单介绍我在TVM上实现instance_norm的过程

0x01 ONNX中两者参数和属性的区别

instance_normbatch_norm虽然有很多相似之处,但是在ONNX中的属性和参数上仍有较大区别,参考pytorch代码中将pth模型转换为ONNX模型的代码片段,https://github.com/pytorch/pytorch/blob/master/torch/ONNX/symbolic.py

# batch_norm
out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
            epsilon_f=eps,
            momentum_f=1 - momentum,
            outputs=1 if not training else 5)

# instance_norm
return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)

可以看到,instance_norm中是没有running_mean和running_var参数的(batch_size为1,统计这个值没有意义),也没有momentum

参考https://github.com/ONNX/ONNX/blob/88d178459973449bfab7346958281a0c06cdd661/ONNX/defs/nn/defs.cc

BatchNormalization
{
    "Attr": [spatial, epsilon, momentum] # spatial这个属性一般不用 还有is_test这个属性在早期的ONNX版本也会出现
    "Input": [X, scale, B, mean, var]
    "Output": [Y, mean, var, save_mean, save_var]  #  后4个可选
}
InstanceNormalization
{
    "Attr": [epsilon, ]
    "Input": [input, scale, B]
    "Output": [output, ]
}

ONNX当中还有instance_norm的numpy实现示例,如下

class InstanceNormalization(Base):

    @staticmethod
    def export():  # type: () -> None
        def _instancenorm_test_mode(x, s, bias, epsilon=1e-5):  # type: ignore
            dims_x = len(x.shape)
            axis = tuple(range(2, dims_x))
            mean = np.mean(x, axis=axis, keepdims=True)
            var = np.var(x, axis=axis, keepdims=True)
            dim_ones = (1,) * (dims_x - 2)
            s = s.reshape(-1, *dim_ones)
            bias = bias.reshape(-1, *dim_ones)
            return s * (x - mean) / np.sqrt(var + epsilon) + bias

其他更多细节参考ONNX的文档

0x02 从前端到后端

本节涉及到修改文件tvm/nnvm/python/nnvm/frontend/onnx.py

所谓NNVM前端,其作用就是把五花八门的模型转换到NNVM内部中一个统一的表示,然后再编译,如下所示:

sym, params = nnvm.frontend.from_ONNX(ONNX_model)
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, params=params, dtype={})

在前端的ONNX实现中InstanceNormalization是留空未实现的,仿照BatchNorm的格式将其实现,注意上一节我们提到的属性上的区别,有些东西instance_norm是没有的

class BatchNorm(OnnxOpConverter):
    
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
    # TODO(zhreshold): 'spatial' is not properly handled here.
        return AttrCvt(
            op_name='batch_norm',
            disables=['momentum'],
            ignores=['spatial', 'is_test', 'consumed_inputs'])(inputs, attr,
                                                                params)

class InstanceNorm(OnnxOpConverter):
    
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # NOTE instance norm parameters is different from batch norm
        return AttrCvt(op_name='instance_norm')(inputs, attr, params)

当然现在后端是没有instance_norm这个OP的,我们需要在NNVM后端注册instance_norm

0x03 在NNVM中注册新OP流程

本节主要涉及到改动以下两个文件:

  • tvm/nnvm/include/nnvm/top/nn.h
  • tvm/nnvm/src/top/nn/nn.cc

关于怎么在NNVM中注册一个OP可以参考这个文章MXNet 中新增 Operator,这里简单列出要点:

  1. 头文件中的Parameter Registration,定义该OP中有哪些参数
  2. Attribute Inference,包括shape和type的inference
  3. Forward & Backward Function,这一部分在MXNet中应该是必需的,而在TVM中由于我们只需要inference,这里暂时先不讨论这个
  4. Operator Registration,前面的事情完成后,最后一步就是OP的注册,这样才可以把我们定义的OP暴露给前端

其中Operator Registration中需要定义诸多信息,输入输出个数、名称、类型,FInferShape、FInferType、FCorrectLayout函数绑定。如果仔细观察 TVM源码会发现,有很多OP注册了FTVMCompute这个属性,但是也有一些OP并没有注册,这是实际上是非常重要的一个属性,通过它,NNVM才能知道这个OP该如何计算

0x04 从NNVM到TOPI

本节涉及到修改文件tvm/nnvm/python/nnvm/top/nn.py

刚才我们仅仅说了如何在NNVM中注册一个OP,在正式切入到TOPI的算法之前,我们需要让NNVM知道如何找到对应的TOPI算法,这里有以下几种实现的方式:

  1. 直接在FTVMCompute属性中注册,映射到topi::nn::xxx
  2. tvm/nnvm/python/nnvm/top/nn.py中进行注册,通过@reg.register_compute(xxx),此外注意到还有@reg.register_schedule(xxx),这是在注册Schedule(后面再详细讲)
  3. 还有一种比较有技巧性的操作,直接在tvm/nnvm/src/compiler/simplify_inference.cc里处理,batch_norm就是这么做的(享受同等待遇的OP还有dropout),直接把计算和fuse两步融为一步,算是优化到了比较极致的地步,这也直接导致模仿batch_norm实现instance_norm的计划破产……实际上相当于直接把batch_normdropout两个OP给替换掉了(只考虑inference的情况),dropout直接忽略掉,而batch_norm则替换为了若干个元OP的组合(减moving_mean,除以moving_var,乘gamma,加beta,还有broadcast之类的,所以在后面graph fuse的过程中,我们会看到这样的一个op,fuse___add_scalar___sqrt___rdiv_scalar___elemwise_mul)

所以在NNVM的OP定义中,batch_norm是不用注册FTVMCompute的,因为它被替换掉了,那能不能仿照batch_norm直接在simplify_inference.cc中写instance_norm的处理代码?个人感觉不好操作,毕竟原理不同,instance_norm涉及到对输入内容全局求均值、方差等,不像batch_norm在inference阶段的操作比较好分解。

所以最后选择在tvm/nnvm/python/nnvm/top/nn.py中进行注册,这里可以仿照lrn或者l2_normalize的写法

@reg.register_compute("instance_norm")
def compute_instance_norm(attrs, inputs, _):
    eps = attrs.get_float("epsilon")
    return topi.nn.instance_norm_inference(inputs[0], inputs[1], inputs[2], eps, False)

@reg.register_schedule("instance_norm")
def schedule_instance_norm(attrs, outs, target):
    with tvm.target.create(target):
        return topi.generic.schedule_instance_norm(outs)

reg.register_pattern("instance_norm", OpPattern.OPAQUE)

0x05 描述TOPI算法

本节涉及到修改的文件(Python和C++二选一即可)

  • tvm/topi/python/topi/nn/__init__.py
  • tvm/topi/python/topi/nn/instance_norm.py

  • tvm/topi/src/topi.cc
  • tvm/topi/include/topi/nn/instance_norm.h

TOPI的全称是TVM Operator Inventory,TVM很好的借鉴了Halide的思想,将算法(Algorithm)和计算过程(Schedule)分离,本小节先简单介绍如何用TVM描述算法,下一小节介绍Schedule

用TVM描述算法就是在理解OP计算过程的基础上,用TVM的表达式把算法写出来,以非常基础的Dense层(FC层)为例,其写法如下所示,要点我写到注释中了:

def dense_default(data, weight, bias=None):
    assert len(data.shape) == 2 and len(weight.shape) == 2, \
        "only support 2-dim dense"
    if bias is not None:
        assert len(bias.shape) == 1
    batch, in_dim = data.shape
    out_dim, _ = weight.shape
    # 定义一个IterVar
    k = tvm.reduce_axis((0, in_dim), name='k')
    # 定义计算方式,tvm.compute的第一个参数为输出的形状,第二个参数为lambda表达式
    # lambda表达式中定义了输出中每个“坐标”处的值是如何计算的,其中有用到上面定义的IterVar
    matmul = tvm.compute((batch, out_dim), \
                         lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k), \
                         tag='dense')
    if bias is not None:
        matmul = tvm.compute((batch, out_dim), \
                             lambda i, j: matmul[i, j] + bias[j], \
                             tag=tag.BROADCAST)
    return matmul

Dense算是比较简单的例子了,也有较为复杂的OP,例如目标检测中的nms,可以看下它的实现https://github.com/dmlc/tvm/blob/master/topi/python/topi/vision/nms.py

回到instance_norm,它与batch_norm最主要的区别在于inference时,batch_norm是减去moving_mean,然后除以moving_var的平方根,而instance_norm没有moving_mean和moving_var,减去的是自身每个通道的mean,除以相应的var平方根,剩下的gamma和beta操作是一样的,实现如下

def instance_norm_inference(data, gamma, beta, eps, fix_gamma):
    assert len(data.shape) == 4, "only support 4-dim instance norm"
    batch, channel, height, width = data.shape
    rh = tvm.reduce_axis((0, height), name="rh")
    rw = tvm.reduce_axis((0, width), name="rw")
    # 当前batch, channel的均值
    s_mean = tvm.compute((batch, channel), \
            lambda b, c: tvm.sum(data[b, c, rh, rw] / (height * width), axis=[rh, rw]))
    # 再次使用相同的IterVar的时候需要重新定义一下,否则直接使用会报一个越界的错误
    rh = tvm.reduce_axis((0, height), name="rh")
    rw = tvm.reduce_axis((0, width), name="rw")
    s_mean_2 = tvm.compute((batch, channel), \
            lambda b, c: tvm.sum(tvm.intrin.power(data[b, c, rh, rw], 2) / (height * width), axis=[rh, rw]))
    # 当前batch, channel的方差
    s_var = tvm.compute((batch, channel), \
            lambda b, c: s_mean_2[b, c] - tvm.intrin.power(s_mean[b, c], 2))
    if fix_gamma:
        out = tvm.compute((batch, channel, height, width), \
            lambda b, c, h, w: (data[b, c, h, w] - s_mean[b, c]) / \
            tvm.intrin.sqrt(s_var[b, c] + eps) + beta[c])
    else:
        out = tvm.compute((batch, channel, height, width), \
            lambda b, c, h, w: (data[b, c, h, w] - s_mean[b, c]) / \
            tvm.intrin.sqrt(s_var[b, c] + eps) * gamma[c] + beta[c])
    return out

以上代码仅仅描述了算法,我们还需要在对应的target平台上注册相具体的计算方式Schedule

0x06 定义Schedule

本节涉及到修改的文件(Python和C++二选一)

  • tvm/topi/python/topi/generic/nn.py
  • tvm/topi/python/topi/cuda/nn.py

  • tvm/topi/include/topi/cuda/instance_norm.h

写完了Algorithm,接下来是Schedule,Schedule描述了TOPI的计算方式,这个在不同平台显然是有区别的,为了得到最优性能,我们需要自己来定义计算方式。以CUDA为例,我们需要定义blockIdx.xthreadIdx.x等等,这个有点涉及到CUDA底层编程的知识了

Schedule总的入口文件在topi/python/topi/generic/nn.py,而对接到不同的target平台,其Schedule代码在对应的文件夹下,例如CUDA平台的Schedule代码都在tvm/topi/python/topi/cuda文件夹下,arm-cpu平台的Schedule代码在tvm/topi/python/topi/arm_cpu

定义instance_norm在CUDA平台的Schedule时,可以直接在topi/python/topi/cuda/nn.py里写,代码如下,要点写在注释中了:

@generic.schedule_instance_norm.register(["cuda"])
def schedule_instance_norm(outs):
    # 从刚才的TOPI算法的最后返回tensor中,取出涉及到tvm.compute的几个tensor
    fin = outs[0].op.input_tensors
    mean = fin[1]
    var = fin[2]
    mean_2 = var.op.input_tensors[0]
    # instance_norm的计算过程中涉及到4个tvm.compute
    # 每个tvm.compute都需要Schedule
    ts = [mean, mean_2, var, outs[0]]

    s = tvm.create_schedule([x.op for x in ts])

    # 这里采用比较简单粗暴的方式,即每个tvm.compute都采用相同的Schedule方式
    # 事实上深入思考计算过程,还有进一步优化的可能,我也在做进一步的探索和尝试
    def _schedule(Out):  # the type of Out is tensor
        num_thread = 8
        block_x = tvm.thread_axis("blockIdx.x")
        block_y = tvm.thread_axis("blockIdx.y")
        thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
        thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")

        by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread)
        bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread)
        s[Out].reorder(by, bx, ty, tx)
        s[Out].bind(ty, thread_y)
        s[Out].bind(tx, thread_x)
        s[Out].bind(by, block_y)
        s[Out].bind(bx, block_x)

    for x in ts:
        _schedule(x)

    return s

这个版本的Schedule实现是比较初级的,有很多优化并没有做到位,肯定还存在进一步优化的可能。

到这一步,instance_norm这个OP就加入到了TVM当中,当然目前只支持CUDA和LLVM两个target平台

7 Graph Fuse中出现的问题

完成了以上修改后,发现了一个非常奇怪的Bug,se_resnet_101_ibn_a模型在opt_level=0的时候正常(相当于没有做优化,速度和效率会略低),而在opt_level=3的时候会报错(做了大量优化,包括合并所有可以合并的元OP),报错信息提示在fuse_matmul_relu这个op上存在host到device内存的直接访问,查了大量资料发现,这个问题是由于Graph Fuse规则不合理或者对应平台Schedule未定义导致的。

但是很神奇的是se_resnet_101模型并没有这个问题,所以我一直在查找是不是instance_norm出了问题。但是经过排查后发现问题竟然出在SE模块上,而且是一个意想不到的地方……原因在具体的实现上,se_resnet_101_ibn_a里面的SE模块中的两个FC层是无bias的,所以在从pytorch转换到ONNX的时候,得到的图是由matmul, transpose, relu所组成的。而普通的se_resnet_101模型,默认SE模块中的FC层是有bias的,所以能够很优雅的转换为GEMM,而GEMM在Graph Fuse时是不会报错的,两者结构差异如下所示:

se_resnet_101的ONNX模型中SE结构

se_resnet_101_ibn_a的ONNX模型中SE结构

所以,解决的思路有两种

  • 直接解决TVM中的Graph Fuse规则存在的问题,但是由于我还没有研究深入到Graph Fuse那个层次(捂脸),也不太清楚底层合并的逻辑(未来准备研究研究),所以很难直接进行修复
  • 既然GEMM是能够直接被TVM优化的而不出问题的,那么直接把matmul换成GEMM不就行了?只需要把bias置为0即可。涉及到的部分在pytorch模型转换到ONNX的过程中,需要改动pytorch源码下的torch/ONNX/symbolic.py,作如下修改即可,实际结果是等价的
def matmul(g, self, other):
    # return g.op("MatMul", self, other)
    # 上一行为原代码,下面为改动后的实现,其中C就是全0的bias
    ty = self.type().scalarType().lower()
    C = g.constant(0, [other.type().sizes()[1]], ty)
    # 这里稍微解释下GEMM的计算方式,假设输入有三个矩阵A,B,C,计算公式如下(这里没考虑有转置属性的情况)
    # alpha * AB + beta * C
    return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0, broadcast_i=True)

实测使用TVM后的显存占用和速度(batch_size=8)如下所示:

  • TVM 显存429MB, 速度~37ms
  • ONNX 显存2145MB,速度~50ms

最后推荐一个模型结构可视化工具Netron,手动查看模型结构参数必备神器

仓促行文,研究的也不算深入,难免存在错误和问题,欢迎大佬们一起讨论交流~我的邮箱bindog###outlook.com

如果觉得本文对您有帮助,欢迎打赏我一杯咖啡钱~


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK