26

生成时间序列的VAE——VRNN与SRNN模型浅析

 3 years ago
source link: https://zhuanlan.zhihu.com/p/272106709
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

时间序列作为一种很常见的数据结构,它的生成、表征以及预测等问题在金融、语音、语言等领域是非常重要的,但是由于时间序列具有依赖性、不确定性等特点,一些传统的生成模型可能并不适合这类数据。而变分自编码器(Variational Auto-Encoders,VAE)作为一种基于 变分贝叶斯推断 的生成式网络,它通过 潜在随机变量 (latent random variables)来实现样本的生成,从而有更好的鲁棒性。这篇文章将会介绍一类专门针对于时间序列生成的VAE模型——VRNN与SRNN,它们将原始的VAE拓展到了时间序列上,实现了对于时间序列的表征与生成。

原始VAE

关于VAE的原理这里就不再多说了,不熟悉的同学可以看下面几篇文章:

PaperWeekly:变分自编码器VAE:原来是这么一回事 | 附开源代码 zhuanlan.zhihu.com neQBVjA.jpg!mobile ALme:变分推断与变分自编码器(VAE) zhuanlan.zhihu.com nmUjYza.jpg!mobile 苏一:一文理解变分自编码器(VAE) zhuanlan.zhihu.com Z3ymYjY.jpg!mobile

直接上公式,VAE优化的目标是:

YreM3yN.png!mobile

其中 是样本, 是潜变量, UjAvQrI.png!mobile 是以 为参数的推断网络得到的近似后验分布(approximate posterior), BZZVbu.png!mobile 是以 为参数的生成网络得到条件生成分布, yqayEbE.png!mobile 是 的先验分布。

简单来看,VAE的流程是:首先输入样本 ,然后利用推断网络 UjAvQrI.png!mobile 得到潜变量 的 近似后验分布 ,然后从该分布中采样出多个 ,通过生成网络 BZZVbu.png!mobile 来计算 条件生成分布 ;对于给定的样本 ,它在条件生成分布中的对数似然 ZzaEJbm.png!mobile 越大表明模型重构的效果越好,同时为了能够让模型生成新样本,我们还要保证近似后验分布 UjAvQrI.png!mobile 与 一个已知的 先验分布 yqayEbE.png!mobile 尽可能地接近,这样在模型训练好了之后,我们就可以直接从先验分布里采样出 ,然后直接通过生成网络得到新样本。

Variational RNN

我们知道其实RNN模型是可以实现时间序列的生成和表征的,因此在介绍VRNN之前,这里先回顾一下传统RNN的建模过程:

对于一个时间序列 MZv67zU.png!mobile ,在时间步 ,RNN读取输入 Bzyuqef.png!mobile ,然后更新隐层状态 UryaiuA.png!mobile ,其中 rYNvUzi.png!mobile 是以 为参数的神经网络(例如LSTM、GRU等);由于RNN是按照从前往后的顺序依次计算的,那么序列 aErYbmQ.png!mobile 的联合概率分布可以写为: 2IFJZzB.png!mobile

其中 yUZRvmF.png!mobile ,即从 历史隐层状态 6BnYR3q.png!mobile 中得到 Bzyuqef.png!mobile 的概率分布, fyAfAvV.png!mobile 是以 为参数的函数。为了描述概率分布,通常把 fyAfAvV.png!mobile 分为两个过程,先通过一个神经网络 zU32umZ.png!mobile 得到一个 参数集 e6N3maz.png!mobile ,即 uEBjqu2.png!mobile ,然后得到一个用 e6N3maz.png!mobile 描述的分布 eiURRfj.png!mobile ,例如先利用神经网络得到高斯分布的均值 fmUNNzJ.png!mobile 和标准差 ,然后得到用 fmUNNzJ.png!mobile 和 描述的高斯分布 jMVrYjE.png!mobile

可以看出,在传统的RNN建模当中, 对于序列不确定性的建模是通过最后的输出函数 fyAfAvV.png!mobile 实现的 ,这样简单的方式对于波动大、变化多的序列数据来说是不够的,可能无法很好地分辨信号与噪声,因此VRNN这篇论文就把VAE的方法拓展到了RNN当中, 通过潜在随机变量来实现对多维时间序列的表征 ,下面就来具体介绍。

生成过程

我们先回忆一下原始VAE的生成过程:从一个先验分布 yqayEbE.png!mobile 里面采样出一个潜变量 ,然后通过条件生成网络 BZZVbu.png!mobile 得到重构样本的分布概率。

现在VRNN想要生成的是一个序列,那么应该在每一个时间步 中都有一个VAE,逐时间步地生成重构样本 367FZjI.png!mobile ;另外考虑到时间序列的依赖性,那么这里的先验分布 yqayEbE.png!mobile 不应该是完全独立的,而是应该 依赖于历史信息 ,我们这里把 6BnYR3q.png!mobile 当做历史信息(看做是对历史序列 Nn6JfyR.png!mobile 的编码),因此这里的先验分布形式为:

其中 EB7zUrq.png!mobile 是先验分布的均值和方差,它们依赖于前一时间步的隐层状态 6BnYR3q.png!mobile

类似地,条件概率分布 BZZVbu.png!mobile 同样也要依赖于 6BnYR3q.png!mobile ,其形式为:

而隐层状态的更新公式为:

Ub2Iju7.png!mobile

其中 yU3iAb7.png!mobile 是条件生成分布的均值和方差, jy6byqU.png!mobile 两个网络可以看做是 bAbQJfB.png!mobile 的特征提取器。

简单分析一下上面的公式,由于 6BnYR3q.png!mobile 依赖于 Nn6JfyR.png!mobilezENni2F.png!mobile ,因此关于 UNFjMfv.png!mobile6jYJjqU.png!mobile 的分布可以写做: eAbmq2y.png!mobile3M7FbuJ.png!mobile 注意 mUZvUbb.png!mobile ),那么整个序列的联合概率分布就可以写为: Y36RfeQ.png!mobile

推断过程

还是类似的思想,加入历史依赖后,后验分布的形式变为:

那么联合分布可以写为:

iYZNBb.png!mobile

模型的各个流程可以总结为下图:

AVZrmye.jpg!mobile

目标函数

这样我们就可以写出VRNN的目标函数,还是采用VAE变分下界的形式:

该式与原始VAE的目标函数基本一致,在保证近似后验分布 qeqqIz.png!mobile 与先验分布 3M7FbuJ.png!mobile 尽可能接近的同时,令条件生成分布的似然 67bmMn.png!mobile 尽可能大。

实验

论文在两种任务上对VRNN与其它生成模型做了对比,1)语音波形建模;2)手写体生成。下面是在多个数据集上的实验结果:

Qv2Inuf.jpg!mobile

baseline模型采用的是标准的RNN,-Gauss代表最终的输出函数是服从高斯分布(即网络输出 maUJBzf.png!mobile ),-GMM代表输出函数是服从混合高斯分布(即网络除了输出 maUJBzf.png!mobile 以外,还有权重 3MF3Ejq.png!mobile ),-I 代表先验分布不依赖历史信息(即与隐层状态 6BnYR3q.png!mobile 无关)

下图模型在语音训练集上生成的样本示例,前3行是全局波形,后3行是局部波形。可以看出相对于RNN-GMM,VRNN-Gauss重建的序列会包含的噪声更小。

iyqUziJ.jpg!mobile

下图是在手写体任务中的样本示例对比

2iYzyee.jpg!mobile

Stochastic RNN

接下来介绍SRNN,该模型在VRNN的基础上做了进一步的推广。概括来说,SRNN将RNN与state space model(SSM)结合到了一起,RNN与SSM都是时间序列建模中的常用方法,假设时间序列 VZJvI3z.png!mobile 可能依赖于 vAN7B3U.png!mobile ,下图展示了RNN与SSM的流程。

RfI32y2.jpg!mobile

可以看出RNN与SSM有类似的地方,比如说序列 q6nQbyE.png!mobile 的信息都会被整合到潜状态 eyYFFzN.png!mobileZfIZR3.png!mobile 中。对比来看:RNN具有很好的非线性拟合能力,但是其隐状态是确定性的;SSM的随机状态转移更适合不确定性建模,但是推断过程通常比较简单。那么我们能不能令两者的优势互补呢?这就是VRNN的动机,它 使SSM的随机状态转移变为非线性的,同时还保持了RNN的门控激活机制 ,结构如下图所示:

BBNBBjR.jpg!mobile

生成模型

先来看生成模型,如上图a所示,其实很直观,就是把RNN与SSM的隐状态结合了起来,这里需要注意的是 随机状态 ZfIZR3.png!mobile 依赖于确定性状态 eyYFFzN.png!mobile ,而反过来就没有依赖性了,这样是为了保证 将确定性状态与随机状态分隔开,使其不会受到随机性的影响

这样生成模型的联合概率分布就可以写为

其中 MNVvYzQ.png!mobilebQVry2B.png!mobileIriMvaI.png!mobile6VZrYvy.png!mobile 也类似; FbQBBjY.png!mobile 就是一个RNN cell的更新公式。

推断网络

推断网络的结构如上图b所示,我们要做的就是从整个样本序列中得到潜变量 EzyMFnq.png!mobile 的近似后验分布,因此推断网络的概率分布可以写成:

考虑到确定性状态 eyYFFzN.png!mobile 只依赖于 Q3QNv2u.png!mobilebAZnQvj.png!mobile ,则有

a6rA733.png!mobile

现在来看 AnIJru7.png!mobile ,我们希望通过所有信息( 包括未来信息 )来估计 JjERBnA.png!mobile 的后验分布,由概率图的知识可知

也就是说, 6jYJjqU.png!mobile 的后验分布 只依赖于 RnY7jij.png!mobile ,为了让 6jYJjqU.png!mobile 依赖于未来的信息 Qnema2.png!mobile ,我们在推断过程中再加入一个 反向传播 的状态 NRJNzu.png!mobile ,令其更新公式为 aEBVZff.png!mobile ,其实就是一个 反向的RNN ,这样 NRJNzu.png!mobile 就可以代表未来信息了,我们有:

其中 MvQfAz.png!mobile 代表两个向量的拼接, RFnMFbb.png!mobile 还是用均值和方差网络来表示成一个高斯分布,即 VJj6jm2.png!mobile

通过这样的推断网络,我们就可以通过全部信息来推断 6jYJjqU.png!mobile 的后验分布。

目标函数

下面为了讨论方便,认为 77ryE3u.png!mobile ,即 v6fQF33.png!mobile 是一个“随机变量”,其概率分布是一个以 a2QbAfI.png!mobile 为中心的delta函数。

模型需要在每一时间步上计算loss,因此目标函数可以写做

看起来稍微有点复杂,不过其实 还是原始VAE的那两项 ,其中 6Zraiuj.png!mobileZ3AZneb.png!mobile 的边缘概率分布,其计算公式为:

近似后验的改进

另外,论文还在实验中发现了一个问题: 在模型训练的过程中,KL散度项 MnQRVre.png!mobile 会很快接近于0。

其实也不难看出原因,只要推断网络中状态 NRJNzu.png!mobile 的更新公式 aEBVZff.png!mobile 只考虑 ZfqEZ3j.png!mobile 项,那么近似后验 RRzimyF.png!mobile 就可以很容易地模仿先验 mYVz2mN.png!mobile 的行为,从而使两者的差别变得非常小。

论文针对于这个问题提出了一个针对于近似后验的改进方法,我们不再直接计算

bu6bUjy.png!mobile

而是改为如下的残差形式:

qYjYby.png!mobile

其中

改进后的推断网络的算法流程如下图所示

rYBvuuV.jpg!mobile

简单来说,我们从包含了未来信息的 6Zraiuj.png!mobile 中采样 Z3AZneb.png!mobile ,然后再通过生成网络 bEJbAj.png!mobile 得到 6jYJjqU.png!mobile 的均值 UFne6fb.png!mobile ,我们认为这样的 UFne6fb.png!mobile 已经是生成网络对于 6jYJjqU.png!mobile 均值的 最好估计 了(因为用到的 Z3AZneb.png!mobile 是根据未来信息推断出来的);我们再让推断网络 nMVBbe3.png!mobile 在这个最好的 UFne6fb.png!mobile 的基础上进行修正后,最后得到近似后验的均值 nqI7ZzM.png!mobile 。另外,由于新方法在方差上起到效果,因此方差网络保持原来的计算方式不变。

通过这样的改进方式,目标函数中的KL项就不再依赖于 UFne6fb.png!mobile 了。

这里可能有点不好理解,具体解释一下:先来看KL项

ZZJRZbM.png!mobile

它衡量的是两个条件分布之间的差别,其实也就是两个 FBZ3Iz6.png!mobile 网络之间的差别,即对于同样的输入(条件参数) FfQvYb.png!mobile ,两个网络输出分布的差别。由于两个网络输出的都是一个高斯分布,我们知道两个高斯分布的KL散度可以写为(可以参考这篇文章)

其中 ZFBv6z2.png!mobile 是数据的个数,可以看出两分布均值 VNn2ue2.png!mobile 在KL散度中出现的形式是 3aYFRji.png!mobile ,由于两个网络输入的是相同的 Z3AZneb.png!mobile ,因此这里 bauummF.png!mobile ,那么根据近似后验均值的残差形式,可知 EFRfy2u.png!mobile ,这样就证明了KL项就不依赖于 UFne6fb.png!mobile

可以看出,改进以后的推断网络 nMVBbe3.png!mobile 学的只是 UFne6fb.png!mobilenqI7ZzM.png!mobile 之间的残差,它只需要学习如何通过使用未来信息去“ 纠正 ”先验分布就行了,这就使得在模型训练期间,推断网络更容易“ 跟踪 ”生成模型的变化。

实验

这篇论文在语音和音乐数据集上对SRNN以及一些baseline模型做了序列重构任务,实验对比结果如下,具体细节可以看论文原文。

J7jeauF.jpg!mobilejIfAnaj.jpg!mobile

其中filt表示推断过程只用当前时间步 的信息,而不用未来信息,即 Bb2eeyu.png!mobile ,而smooth就是之前设定的 JFf6Fre.png!mobile rAnMZfi.png!mobile 表示采用上文那个改进的残差近似后验;注意到baseline模型中采用了我们前一篇论文中的VRNN模型,符号含义与上文一致。

总结

这篇文章介绍了一类VAE在时间序列上的变体模型——VRNN和SRNN。总的来看,两生成模型的思路是一致的,都是 将VAE融入了RNN当中,利用RNN的隐层状态来建立时间序列的依赖关系,并在此基础上用VAE来对序列建模 。SRNN可以看做是VRNN的改进版本,改进的地方有:

1.生成过程

两模型的在生成过程的结构如下(只看第t步),可以看出,唯一的区别就是随机状态 A3ueQb3.png!mobile 对确定状态 eyYFFzN.png!mobile 的作用,VRNN将在这里做了改进,使得随机状态不会影响到确定状态。

aQzAJnN.jpg!mobile 左:VRNN 右:SRNN

2.推断过程

其实VRNN相当于是SRNN的filt版本,只用到了当前时间步 的信息,而SRNN的smooth版本在推断过程中又加了一个反向的RNN来捕捉未来信息,从实验证明smooth的效果比filt更好。

3.近似后验分布

SRNN没有直接用一个网络来得到近似后验分布,而是采用了残差的方式,从而避免了训练过程中为了减小KL项的loss,近似后验网络刻意地去“模仿”先验网络,使得推理网络更容易“跟踪”生成模型的变化。

参考资料

[1] A recurrent latent variable model for sequential data

[2] Sequential neural models with stochastic layers

[3] PaperWeekly:变分自编码器VAE:原来是这么一回事 | 附开源代码

[4] ALme:变分推断与变分自编码器(VAE)

[5] 苏一:一文理解变分自编码器(VAE)

[6] 小明:两个多变量高斯分布之间的KL散度


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK