图解自动微分的正向模式和逆向模式
source link: https://seanwangjs.github.io/2022/12/06/autograd-forward-and-reverse.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
图解自动微分的正向模式和逆向模式
自动微分建立在复合函数求导的链式规则之上,考虑以下复合函数
f(x)=a(b(c(x)))
则 f 对 x 的导数为
dfdx=dfdadadbdbdcdcdx
显然,上述公式存在两种计算顺序,第一种先对高阶函数求导,我们称之为逆向模式
dfdx=((dfdadadb)dbdc)dcdx
第二种先对低阶函数求导,我们称之为正向模式
dfdx=dfda(dadb(dbdcdcdx))
为了更直观地说明两种模式的区别,我们考虑一个简单的例子
f(x,y,z)=log(xy)+sin(z)
它的计算图如下所示
我们最终需要计算偏导数 ∂f∂x,∂f∂y,∂f∂z。下面我们首先推导正向模式自动微分:
- 计算 v1=xy,得到 ∂v1∂x,∂v1∂y
- 计算 v2=log(v1),得到 ∂v2∂v1,∂v2∂x,∂v2∂y
- 计算 v3=sin(z),得到 ∂v3∂z
- 计算 f=v2+v3,得到 ∂f∂v2,∂f∂v3,∂f∂x,∂f∂y,∂f∂z
可以看到,随着计算图的向前推进,我们也不断地得到最新值对输入的梯度,所以这种计算方式也称为自动微分的正向模式。
接下来我们再描述另一种计算模式:
- 计算 v1=xy,并构造另一张计算图,形状与原计算图类似,但描述的是梯度计算关系,这里插入乘法的梯度计算节点
- 计算 v2=log(v1),插入
log
运算的梯度计算节点
- 计算 v3=sin(z),插入
sin
运算的梯度计算节点
- 计算 f=v2+v3,插入加法的梯度计算节点
最后再以原计算图相反的方向遍历梯度计算图,从而得到 f 对 x,y,z 的梯度。这就是自动微分的逆向模式。
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK