6

max函数光滑逼近:一种与softmax相关的形式

 2 years ago
source link: https://allenwind.github.io/blog/11886/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client
Mr.Feng Blog

NLP、深度学习、机器学习、Python、Go

max函数光滑逼近:一种与softmax相关的形式

一个关于max函数光滑逼近,其特例居然和均值、tanh函数、Logistics函数相关!

这个逼近其实在过去的文章中也略有提及,不过今天那它来单独分析一下。

max & argmax

x=[x1,⋯,xn]

max(x)就是要找向量中元素的最大值。argmax(x)要找向量中元素的最大值的标量,就是所在的位置。假设第xj最大,那么

j=argmax(x)

有时候我们希望这个“标量”能够对齐原来向量的长度,即表达成one-hot(argmax(x)),

one-hot(argmax(x))=[0,…,1,…,0]

等式右侧第j个分量为1,其他为0​。比如圆周率的头几个数,

x=[3,1,4,1,5,9,2] 6=argmax([3,1,4,1,5,9,2])

one-hot形式为,

[0,0,0,0,0,1,0]=one-hot(argmax([3,1,4,1,5,9,2]))

于是max(x)​可以使用argmax(x)​表示,

max(x)=x⊗one-hot(argmax(x))=n∑i=1xi×one-hot[argmax(x)][i]

这里one-hot[argmax(x)][i]表示one-hot[argmax(x)]的第i个分量。

argmax的光滑逼近

一下推导使用到max(x)​的一光滑形式,但这不是我们要的最终形式,

max(x)≈1αlog(n∑i=1eαxi)

上式右边部分称为Logsumexp。这个从凸优化的角度,考虑到Logsumexp的凸性,因此有,

f(x1,x2,…,xn)=1αlog(n∑i=1eαxi)=1αlog(n×1nn∑i=1eαxi)≤log(n)α+1αlog(n∑i=11n×eαxi)≤log(n)α+1nn∑i=1xi≤log(n)α+max(x1,x2,…,xn)

另一方面,假设有数据x1,…,xn,其中xj最大,有

f(x1,x2,…,xn)=1αlog(n∑i=1eαxi)≥1αlog(eαxj)=max(x1,…,xn) max(x1,…,xn)≤1αlog(n∑i=1eαxi)≤log(n)α+max(x1,x2,…,xn)

通过夹逼定理有,

limα→+∞1αlog(n∑i=1eαxi)=max(x1,…,xn)

同样我们假设x=[x1,⋯,xn]​中xj​最大,那么one-hot形式第j​个分量为1​,其他为0​,有推导,

[0,…,1,…,0]=one-hot(argmaxi=1,⋯,nxi)=one-hot(argmaxi=1,⋯,n[xi−max(x)])=one-hot(argmaxi=1,⋯,nexp[xi−max(x)])≈one-hot(argmaxi=1,⋯,nexp[xi−1αlog(n∑i=1eαxi)])=one-hot(argmaxi=1,⋯,nexp1α[αxi−log(n∑i=1eαxi)])=one-hot(argmaxi=1,⋯,nexp[αxi−log(n∑i=1eαxi)])=one-hot(argmaxi=1,⋯,neαxin∑i=1eαxi)≈[eαx1n∑i=1eαxi,…,eαxnn∑i=1eαxi]=softmax(αx)

需要说明几点:

  • 引入[xi−max(x)]使得最大值为0,使得e0=1,对应one-hot中的1
  • 引入ex是考虑到e0=1,0<ex|x<0<1,并拉大[x1,x2,…,xn]间的距离,更好适配one-hot特点
  • max不具有光滑性,被替换为其光滑近似logsumexp,logsumexp(x1,x2,…,xn)=log(∑ni=1exi)

理解好这三点就明白上述推导过程。

max的光滑逼近

于是,接着第一部分的推导,

max(x)=x⊗one-hot(argmax(x))=n∑i=1xi×one-hot[argmax(x)][i]≈n∑i=1eαxin∑i=1eαxixi=n∑i=1xieαxin∑i=1eαxi

于是,这就是我们要找的max(x)光滑形式,

max(x)≈n∑i=1xieαxin∑i=1eαxi

有三个关键点:

  • 如果α=0,则上式右侧为均值
  • 如果α=+∞,则上式右侧为max(x)
  • 如果α=−∞,则上式右侧为min(x)

第三点使用min(x)=−max(−x)性质容易理解。

取n=2,x1=0,x2=x​有特例,

max(x)≈max{0,x}=n∑i=1xieαxin∑i=1eαxi=xeαx1+eαx=xσ(αx)

这也是relu(x)=max0,x的光滑逼近,也就是激活函数Swish。

取n=2,x1=−x,x2=x​​,有特例,

max(x)=max{−x,x}≈n∑i=1xieαxin∑i=1eαxi=x(eαx−e−αx)eαx+e−αx=xtanh(αx)

这是f(x)=|x|==max−x,x的光滑逼近。

其实这个max(x)函数的光滑形式比较优雅,这里单独拿出来说一说。

转载请包括本文地址:https://allenwind.github.io/blog/11886
更多文章请参考:https://allenwind.github.io/blog/archives/


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK