44

Bert与模型蒸馏: PKD和DistillBert

 3 years ago
source link: http://mp.weixin.qq.com/s?__biz=MjM5ODkzMzMwMQ%3D%3D&%3Bmid=2650416201&%3Bidx=2&%3Bsn=c32df5b205278df3dedffb066b56721f
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Znui6v.jpg!mobile

最近要开始使用Transformer去做一些事情了,特地把与此相关的知识点记录下来,构建相关的、完整的知识结构体系,

以下是要写的文章,本文是这个系列的第五篇:

Overall

Bert模型虽然很火,但是模型太大,在device上基本不可用,要想更好的使用的话需要让模型变小。

而目前模型变小的技术大概有四种:

  • 模型量化:即把float值变为int8,可以直接将模型降为原来的四分之一。速度也会有提高。

  • 矩阵分解:大矩阵分解为小矩阵的乘积去拟合,可以显著降低size。

  • 模型结构改动:比如更多的参数共享,更高效的层次计算等。

  • 知识蒸馏:从大模型中去学习小模型。

本文介绍的就是知识蒸馏这一部分。主要介绍的是参考论文[1]中的PKD方法,参考论文[2]中的DistillBert本身较为简单,本文也会简要介绍一下。

知识蒸馏

首先,我们先来了解一下知识蒸馏,知识蒸馏是让一个小模型去学习一个大模型,所以首先会有一个预训练好的大模型,称之为Teacher模型,小模型被称为Student模型。知识蒸馏的方法就是会让Student模型去尽量拟合。这个的动机就在于跟ground truth的one-hot编码相比,Teacher模型的输出概率分布包含着更多的信息,从Teacher模型的概率分布中学习,能让Student模型充分去模拟Teacher模型的行为。

在具体的学习Teacher模型概率分布这个过程中,知识蒸馏还引入了温度的概念,即Teacher和Student的logits都先除以一个参数T,然后再去做softmax,得到的概率值再去做交叉熵。温度T控制着Student模型学习的程度,当T>1时,Student模型学习的是一个更加平滑的概率分布,当T<1时,则是更加陡峭的分布。因此,学习过程中,T一般是一个逐渐变小的过程。Teacher模型经过温度T之后的输出被称之为soft labels。

Soft labels的计算过程如下,P t 中的t代表的是Teacher模型,同样的,Student模型的输出也要经过类似的计算。

MbMN3eq.png!mobile

得到Teacher模型和Student的模型后,就可以计算知识蒸馏这部分的损失,得到L DS

RnQVr2I.png!mobile

然后student模型除了向teacher模型学习外,还需要向ground truth学习,得到L CE

A3IFJzn.png!mobile

然后再让两个损失去做加权平均。得到:

u2ERVfB.png!mobile

以上就是知识蒸馏的思想,这方法可以直接应用在Bert上,但是作为一个深层模型,在中间层次上的信息也很丰富,如何利用这部分的信息呢?这就有了PKD中提出的方法。

多层蒸馏

首先,PKD论文中先做了对比,减少模型宽度和减少模型深度,得到的结论是减少宽度带来的efficiency提高不如减少深度来的更大,因此,PKD主要关注减少模型深度。即Student模型比Teacher模型要浅。

论文所提出的多层蒸馏,即Student模型除了学习Teacher模型的概率输出之外,还要学习一些中间层的输出。论文提出了两种方法,第一种是Skip模式,即每隔几层去学习一个中间层,第二种是Last模式,即学习teacher模型的最后几层。如下图所示:

3Uj6Fra.png!mobile

如果是完全的去学习中间层的话,那么计算量很大。为了避免这个问题,我们注意到Bert模型中有个特殊字段[CLS],在蒸馏过程中,让student模型去学习[CLS]的中间的输出。

而对于中间层的学习,使用的损失函数是均方差函数:

Ajq6Bjm.png!mobile

最后的损失函数如下:

初始化的话就采用Teacher模型的前几层来做初始化。

PKD实验结果

在GLUE上的实验结果如下:

2QniIvE.png!mobile

可以看到,多层的方法在除了MRPC之外的任务上都能达到比较最好的效果。而究其原因,可能是因为MRPC的数据较少,从而导致了过拟合。

而Last和Skip模式的对比如下:

Vr6nIbb.png!mobile

可以看到,Skip一般会比Last模式要好。这是因为,Skip方式下,层次之间的距离较远,从而让student学习到各种层次的信息。

那么Student模型的计算量和参数数目是如何的呢?如下图所示:

zq6NVrq.png!mobile

由于需要每层计算损失,Student和Teacher模型每层的宽度都是一样(这其实是一个限制),Student模型的参数量主要少在层次少。

而更好的teacher模型会带来增长么?答案是不会的,可以看上图,把12层的Bert模型换成了24层的Bert模型,反而导致效果变差。究其原因,可能是因为在实验中,我们使用Teacher模型的前N层来初始化Student模型,对于24层模型来说,前N层更容易导致不匹配。而更好的方法则是Student模型先训练好,再去学Teacher模型。

DistillBert

DistillBert的做法就比较简单直接,同样的,DistillBert还是保证模型的宽度不变,模型深度减为一半。主要在初始化和损失函数上下了功夫:

  • 损失函数:采用知识蒸馏损失、Masked Language Model损失和cosine embedding损失加起来的值。

  • 初始化:用Teacher模型的参数进行初始化,不过是从每两层中找一层出来。

具体结果就不赘述了,可以去参考原始论文。

参考文献

  • [1]. Sun S, Cheng Y, Gan Z, et al. Patient knowledge distillation for bert model compression[J]. arXiv preprint arXiv:1908.09355, 2019.

  • [2]. Sanh V, Debut L, Chaumond J, et al. DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter[J]. arXiv preprint arXiv:1910.01108, 2019.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK