14

变分自编码器(五):VAE + BN = 更好的VAE

 4 years ago
source link: https://kexue.fm/archives/7381
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

本文我们继续之前的变分自编码器系列,分析一下如何防止NLP中的VAE模型出现“KL散度消失(KL Vanishing)”现象。本文主要的参考文献是最近的论文 《A Batch Normalized Inference Network Keeps the KL Vanishing Away》 ,并补充了一些自己的描述。

值得一提的是,本文最后得到的方案相当简洁—— 往编码输出层加入BN ——但确实很有效,因此值得正在研究相关问题的读者一试。如果按照笔者的看法,它甚至可以成为VAE模型的“标配”。

让我们简单回顾一下VAE模型,它的训练流程大概可以图示为

eeumUnv.png!web

VAE训练流程图示

写成公式就是

$$\begin{equation}\mathcal{L} = \mathbb{E}_{x\sim \tilde{p}(x)} \Big[\mathbb{E}_{z\sim p(z|x)}\big[-\ln q(x|z)\big]+KL\big(p(z|x)\big\Vert q(z)\big)\Big]

\end{equation}$$

其中第一项就是重构项,$\mathbb{E}_{z\sim p(z|x)}$是通过重参数来实现;第二项则称为KL散度项,这是它跟普通自编码器的显式差别。更详细的符号含义可以参考 《变分自编码器(二):从贝叶斯观点出发》

在NLP中,句子被编码为离散的整数ID,所以$q(x|z)$是一个离散型分布,可以用万能的“条件语言模型”来实现,因此理论上$q(x|z)$可以精确地拟合生成分布,问题就出在$q(x|z)$太强了,训练时重参数操作会来噪声,噪声一大,$z$的利用就变得困难起来,所以它干脆不要$z$了,退化为无条件语言模型(依然很强),$KL(p(z|x)\Vert q(z))$则随之下降到0,这就出现了KL散度消失现象。

这种情况下的VAE模型并没有什么价值:KL散度为0说明编码器输出的是0向量,而解码器则是一个普通的语言模型。而我们使用VAE通常来说是看中了它无监督构建编码向量的能力,所以要应用VAE的话还是得解决KL散度消失问题。事实上从2016开始,有不少工作在做这个问题,相应地也提出了很多方案,比如退火策略、换先验分布等,读者Google一下“KL Vanishing”就可以找到很多文献了,这里不一一溯源。

本文的方案则是直接针对KL散度项入手,简单有效而且没什么超参数。其思想很简单:

KL散度消失不就是KL散度项变成0吗?我调整一下编码器输出,让KL散度有一个大于零的下界,这样它不就肯定不会消失了吗?

这个简单的思想的直接结果就是:在$\mu_z$后面加入BN层,如图

RvUfuue.png!web

往VAE里加入BN

为什么会跟BN联系起来呢?我们来看KL散度项的形式:

\begin{equation}KL\big(p(z|x)\big\Vert q(z)\big) = \frac{1}{b} \sum_{i=1}^b \sum_{j=1}^d \frac{1}{2}\Big(\mu_{i,j}^2 + \sigma_{i,j}^2 - \log \sigma_{i,j}^2 - 1\Big)\end{equation}

上式是采样了$b$个样本进行计算的结果,而编码向量的维度则是$d$维。由于我们总是有$e^x \geq x + 1$,所以$\sigma_{i,j}^2 - \log \sigma_{i,j}^2 - 1 \geq 0$,因此

\begin{equation}KL\big(p(z|x)\big\Vert q(z)\big) \geq \frac{1}{b} \sum_{i=1}^b \sum_{j=1}^d \frac{1}{2}\mu_{i,j}^2 = \frac{1}{2}\sum_{j=1}^d \left(\frac{1}{b} \sum_{i=1}^b \mu_{i,j}^2\right)\end{equation}

留意到括号里边的量,其实它就是$\mu_z$在batch内的二阶矩,如果我们往$\mu_z$加入BN层,那么大体上可以保证$\mu_z$的均值为$\beta$,方差为$\gamma^2$($\beta,\gamma$是BN里边的可训练参数),这时候

\begin{equation}KL\big(p(z|x)\big\Vert q(z)\big) \geq \frac{d}{2}\left(\beta^2 + \gamma^2\right)\end{equation}

所以只要控制好$\beta,\gamma$,就可以让KL散度项有个正的下界,因此就不会出现KL散度消失现象了。

这样一来,KL散度消失现象跟BN就被巧妙地联系起来了,通过BN来“杜绝”了KL散度消失的可能性。更妙的是,加入BN层不会与原来的训练目标相冲,因为VAE中我们通常假设的先验分布就是标准高斯分布,它满足均值为0、方差为1的特性,而加入BN后$\sigma_z$的均值为$\beta$、方差为$\gamma^2$,所以直接固定$\beta=0,\gamma=1$就可以兼容原来的训练目标了。

本文简单分析了VAE在NLP中的KL散度消失现象,并介绍了通过BN层来防止KL散度消失的方案。这是一种简洁有效的方案,不单单是原论文,笔者私下也做了简单的实验,结果确实也表明了它的有效性,值得各位读者试用。

转载到请包括本文地址: https://kexue.fm/archives/7381

更详细的转载事宜请参考: 《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎/本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (2020, May 06). 《 变分自编码器(五):VAE + BN = 更好的VAE 》[Blog post]. Retrieved from https://kexue.fm/archives/7381


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK