5

图解自动微分的正向模式和逆向模式

 1 year ago
source link: https://seanwangjs.github.io/2022/12/06/autograd-forward-and-reverse.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

图解自动微分的正向模式和逆向模式

自动微分建立在复合函数求导的链式规则之上,考虑以下复合函数

f(x)=a(b(c(x)))

则 f 对 x 的导数为

dfdx=dfdadadbdbdcdcdx

显然,上述公式存在两种计算顺序,第一种先对高阶函数求导,我们称之为逆向模式

dfdx=((dfdadadb)dbdc)dcdx

第二种先对低阶函数求导,我们称之为正向模式

dfdx=dfda(dadb(dbdcdcdx))

为了更直观地说明两种模式的区别,我们考虑一个简单的例子

f(x,y,z)=log(xy)+sin(z)

它的计算图如下所示

autograd-computation_graph.png

我们最终需要计算偏导数 ∂f∂x,∂f∂y,∂f∂z。下面我们首先推导正向模式自动微分:

  1. 计算 v1=xy,得到 ∂v1∂x,∂v1∂y
autograd-forward_1.png
  1. 计算 v2=log(v1),得到 ∂v2∂v1,∂v2∂x,∂v2∂y
autograd-forward_2.png
  1. 计算 v3=sin(z),得到 ∂v3∂z
autograd-forward_3.png
  1. 计算 f=v2+v3,得到 ∂f∂v2,∂f∂v3,∂f∂x,∂f∂y,∂f∂z
autograd-forward_4.png

可以看到,随着计算图的向前推进,我们也不断地得到最新值对输入的梯度,所以这种计算方式也称为自动微分的正向模式。

接下来我们再描述另一种计算模式:

  1. 计算 v1=xy,并构造另一张计算图,形状与原计算图类似,但描述的是梯度计算关系,这里插入乘法的梯度计算节点
autograd-backward_1.png
  1. 计算 v2=log(v1),插入log运算的梯度计算节点
autograd-backward_2.png
  1. 计算 v3=sin(z),插入sin运算的梯度计算节点
autograd-backward_3.png
  1. 计算 f=v2+v3,插入加法的梯度计算节点
autograd-backward_4.png

最后再以原计算图相反的方向遍历梯度计算图,从而得到 f 对 x,y,z 的梯度。这就是自动微分的逆向模式。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK