15

winograd int8实现技巧

 3 years ago
source link: https://zhuanlan.zhihu.com/p/266040003
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

随着深度学习的应用普及,推理框架在深度学习模型的落地产品化过程中扮演了越来越重要的角色。可用于加速卷积计算且历史悠久的的winograd算法重新粉墨登场,又一次开始了它的神奇表演。

winograd算法的原理网上有很多文章介绍这里不再赘述,笔者这里给出一张计算流程图和计算公式和转换矩阵:

JbeQfau.png!mobile

eIVJbqQ.jpg!mobile winograd计算流程图

emUZJfm.png!mobile

mmYVreZ.png!mobile , myQf2af.png!mobile , Iriu6nm.png!mobile

2uaYRr7.png!mobile

nm6n6jb.png!mobile , 32uQZ3q.png!mobile , VZNvUr2.png!mobile

eMNrAnb.png!mobile

,

Bf6FzaM.png!mobile , mYjuayQ.png!mobile

winograd在float32浮点运算下表现确实不错,在3x3 stride=1的卷积层时取得了很好的加速效果,但对于int8量化卷积来说,加速效果不甚理想,整型可没有浮点这么大的动态范围,如何保证运算过程整型不会溢出将是个令人头疼的问题。

假设将网络量化成int8, int8的权重和int8的输入,先看 M7RbAzv.png!mobile ,

J7re6v7.png!mobile

AzE77bZ.png!mobile

2UVFFji.png!mobile

为了保证计算不溢出, iUnQjmY.png!mobile 需要额外的2个bit, 而 YZZvieR.png!mobile 需要额外的1个bit,换言之,为了保证安全的int8*int8, 权重和输入分别需要量化到int6和int7.

由上分析推广到 v6zIbeM.png!mobileeMvQJfN.png!mobile 需要额外的4个bit, 7JR7JnU.png!mobile 需要额外的2个bit,换言之,为了保证安全的int8*int8, M7RbAzv.png!mobile 权重和输入实际需要分别量化到int4和int6才可以!

由上面 M7RbAzv.png!mobile 的分析方法推广到 FZjYvaV.png!mobile , 为了保证安全的int8*int8, eMvQJfN.png!mobile 需要额外的10个bit, 7JR7JnU.png!mobile 需要额外的7个bit……换言之,压根没法保证int8*int8的安全计算,这时候只能将int8扩展到int16,执行int16*int16的计算。

为了保证计算结果的安全性并且不损失输入量化数据的动态范围,实际当中int8 winograd执行的基本都是int16*int16计算。

对于 M7RbAzv.png!mobile 来说,只有权值变换矩阵G含有小数,在int8计算时只需要把G'=2*G作为新的变换矩阵即可,在int8输出前requant(int8->int8)/dequant(int8->fp32)时修改w_ scale值w_ scale' = w_scale * 0.25f即可,考虑int16表示的动态范围是-32768~32767,可以简单验证下这个变换的过程不会产生溢出。

但对于 FZjYvaV.png!mobile 来说就不同了:

BNVrmuj.png!mobile ,

nMBFBvj.png!mobile ,

注:ok表示不会溢出, overflow表示有可能溢出,实际上右下角数据= nE73q2e.png!mobile,实际当中非常有可能溢出( yQV3Ezm.png!mobile即可溢出), 实际int8量化后的取值大部分选择去除掉-128这个点,实际范围选择在-127~127.

这里我们采取一个trick来避免int8计算weight_transform带来的数据溢出问题,利用提取公共因子法,最后一列和最后一行转int16输出前除24,中间计算结果改用int32来存储以保持权值矩阵转换过程计算的正确性,

在后面计算batched gemm计算时,修改原来使用的int16*int16的计算kernel函数接口加入alpha系数:

gemm_int16(M, N, K, int16* A, int16* B, int32* C, int alpha)

在batched gemm kernel函数内部, 把int32的计算结果store出去之前,先mul alpha;

在kernel函数外部调用的地方对index进行判断alpha分别传1或者24, 把在weight_ transform过程中除掉的系数在gemm计算时再乘回来,下面给出weights_ transform示意代码:

void weight_trans_4x4_3x3_int8(int16_t* dest, const int8_t* din, int ch_in,
                           int ch_out, void* workspace) {
  const int32_t coeff[6][3] = {{6, 0, 0},
                                {-4, -4, -4},
                                {-4, 4, -4},
                                {1, 2, 4},
                                {1, -2, 4},
                                {0, 0, 24}};
  int32_t* ptr_out = static_cast<int32_t*>(workspace);
  for (int i = 0; i < ch_out; i++) {
    for (int j = 0; j < ch_in; j++) {
      const int8_t* kernel0 =
          static_cast<const int8_t*>(din) + (i * ch_in + j) * 9;

      const int8_t* k0 = kernel0;
      const int8_t* k1 = kernel0 + 3;
      const int8_t* k2 = kernel0 + 6;

      int32_t tmp[6][3];
      for (int i = 0; i < 6; i++) {
        tmp[i][0] =
            static_cast<int32_t>(k0[0]) * coeff[i][0] +
            static_cast<int32_t>(k0[1]) * coeff[i][1] +
            static_cast<int32_t>(k0[2]) * coeff[i][2];
        tmp[i][1] =
            static_cast<int32_t>(k1[0]) * coeff[i][0] +
            static_cast<int32_t>(k1[1]) * coeff[i][1] +
            static_cast<int32_t>(k1[2]) * coeff[i][2];
        tmp[i][2] =
            static_cast<int32_t>(k2[0]) * coeff[i][0] +
            static_cast<int32_t>(k2[1]) * coeff[i][1] +
            static_cast<int32_t>(k2[2]) * coeff[i][2];
      }

      for (int j = 0; j < 6; j++) {
        int32_t* tmpp = &tmp[j][0];
        for (int i = 0; i < 6; i++) {
          ptr_channel[j * 6 + i] = tmpp[0] * coeff[i][0] +
                                   tmpp[1] * coeff[i][1] +
                                   tmpp[2] * coeff[i][2];
          if (i == 5 || j == 5)
            ptr_channel[j * 6 + i] /= 24;
        }
      }
    }
  }
... ... 
}

batched gemm接口调用示意代码:

for (int gi = 0; gi < 36; ++gi) {
     ... ...
     ... ...
     int col_idx = gi / 6;
     int row_idx = gi % 6;
     if (col_idx == 5 || row_idx == 5) {
        gemm_int16_alpha(
            M, N, K, A, B, C, 24);
     } else {
        gemm_int16_alpha(
            M, N, K, A, B, C, 1);
     }
}

scale转换部分代码:

... ...
for (auto& ws : w_scale_) {
      ws /= 576;
}
... ...

这里给出了一个优化winograd int8 4x4_3x3实现的一个trick, 笔者在华为P30 ARM A76大核单线程上测试与之前的winograd int8 2x2_3x3实现相比,有明显的算法性能增益,这里就不给出具体的测试性能数据了.

容易掉进的坑:在计算weight_transform时采用float32形式,最后把计算的结果static_cast到int16,这样会丢弃float32计算结果的小数部分,直接导致最后计算结果的错误。

对于 EnYzAj6.png!mobile ,情况要更复杂一些,除了权值转换矩阵G之外, 输入转换矩阵B和输出转换矩阵A也含有小数,感兴趣的同学可以推导一下能否采取相同的trick保证input_ transform和weight_ transform转换之后的结果落在int16有效的表示范围内(-32768~32767),在batched gemm时进行系数纠正.

天色已晚,晚安。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK