一文搞懂梯度下降&反向传播

如果把神经网络模型比作一个黑箱,把模型参数比作黑箱上面一个个小旋钮,那么根据通用近似理论(universal approximation theorem),只要黑箱上的旋钮数量足够多,而且每个旋钮都被调节到合适的位置,那这个模型就可以实现近乎任意功能(可以逼近任意的数学模型)。

显然,这些旋钮(参数)不是由人工调节的,所谓的机器学习,就是通过程序来自动调节这些参数。神经网络不仅参数众多(少则十几万,多则上亿),而且网络是由线性层和非线性层交替叠加而成,上层参数的变化会对下层的输出产生非线性的影响,因此,早期的神经网络流派一度无法往多层方向发展,因为他们找不到能用于任意多层网络的、简洁的自动调节参数的方法。

直到上世纪80年代,祖师爷辛顿发明了反向传播算法,用输出误差的均方差(就是loss值)一层一层递进地反馈到各层神经网络,用梯度下降法来调节每层网络的参数。至此,神经网络才得以开始它的深度之旅。

本文用python自己动手实现梯度下降和反向传播算法。请点击这里到Github上查看源码。

梯度下降

Figure 1: Gradient Decsent

梯度下降法是一种将输出误差反馈到神经网络并自动调节参数的方法,它通过计算输出误差的loss值(J)对参数W的导数,并沿着导数的反方向来调节W,经过多次这样的操作,就能将输出误差减小到最小值,即曲线的最低点。

虽然Tensorflow、Pytorch这些框架都实现了自动求导的功能,但为了彻底理解参数调节的过程,还是有必要自己动手实现梯度下降和反向传播算法。我相信你和我一样,已经忘了之前学的微积分知识,因此,到可汗学院复习下Calculus
Multivariable Calculus是个不错的方法,或是拜读这篇关于神经网络矩阵微积分的文章

Figure 2: https://explained.ai/matrix-calculus/index.html

Figure2是求导的基本公式,其中最重要的是Chain Rule,它通过引入中间变量,将“yx求导”的过程转换为“y对中间变量u求导,再乘以ux求导”,这样就将一个复杂的函数链求导简化为多个简单函数求导。

如果你不想涉及这些求导的细节,可以跳过具体的计算,领会其思想就好。

反向传播

def forward_backward(x, y):
  z1 = x @ W1 + b1
  a1 = relu(z1)
  outp = a1 @ W2 + b2
  loss = mse(outp, y)
  
  mse_grad(outp, y)
  lin_grad(a1, outp, W2, b2)
  relu_grad(z1, a1)
  lin_grad(x, z1, W1, b1)

对于神经网络模型:Linear -> ReLu -> Linear -> MSE(Loss function)来说,反向传播就是根据链式法则对mse(linear(relu(linear(x))), y)求导,用输出误差的均方差(MSE)对模型的输出求导,并将导数传回上一层神经网络,用于它们来对wbx(上上层的输出)求导,再将x的导数传回到它的上一层神经网络,由此将输出误差的均方差通过递进的方式反馈到各神经网络层。

对于mse(loss(linear(relu(linear(x))), y)求导的第一步是为这个函数链引入中间变量:

  • u = linear(x, w1, b1)
  • g = relu(u)
  • h = linear(g, w2, b2)
  • f = mse(h, y)

接着第二步是对各中间变量求导,最后才是将这些导数乘起来。

首先,反向传播的起点是对loss function求导,即f = mse(h, y) = (h - y)^2 / N\ (N=y向量的长度)\frac{\partial mse(h, y)}{\partial h} = 2/N * (h - y) * (1 - 0) = 2(h - y) / N

def mse_grad(outp, targ):
  outp.g = 2 / outp.shape[0] * (outp.squeeze() - targ.float()).unsqueeze(-1)

mse_grad()之所以用unsqueeze(-1)给导数增加一个维度,是为了让导数的shape和tensor shape保持一致。

def lin_grad(inp, outp, w, b):
  inp.g = outp.g @ w.t()
  w.g = inp.t() @ outp.g
  b.g = outp.g.sum(0)

linear层的反向传播是对h = wx + b求导,它也是一个函数链,也要先对中间变量求导再将所有导数相乘:

  • u = w * x
  • g = \sum u
  • h = g + b

这些中间变量的导数分别是:

  • \frac{\partial (w*x)}{\partial x} = diag(w), \frac{\partial (w*x)}{\partial w} = diag(x)
  • \frac{\partial sum(u)}{\partial u} = \vec 1^T
  • \frac{\partial (g + b)}{\partial b} = 1

y对向量x求导,指的是对向量所有的标量求偏导(\partial),即:[\frac{\partial y}{\partial x_1}, \frac{\partial y}{\partial x_2}, ..., \frac{\partial y}{\partial x_n}, ],这个横向量也称为y的梯度。

这里y = w * x,是一个向量,因此,y对x求导,指的是y的所有标量(y_1, y_2, ..., y_n)对向量x求偏导,即:\frac{\partial y}{\partial x} = [ \frac{\partial y_1}{\partial x}, \frac{\partial y_2}{\partial x}, ..., \frac{\partial y_n}{\partial x}]^T
= [[\frac{\partial w_1x_1}{\partial x_1}, \frac{\partial w_1x_1}{\partial x_2}, ..., \frac{\partial w_1x_1}{\partial x_n}], [\frac{\partial w_2x_2}{\partial x_1}, \frac{\partial w_2x_2}{\partial x_2}, ..., \frac{\partial w_2x_2}{\partial x_n}], ..., [\frac{\partial w_nx_n}{\partial x_1}, \frac{\partial w_nx_n}{\partial x_2}, ..., \frac{\partial w_nx_n}{\partial x_n}]]

这个矩阵称为雅克比矩阵,它是个对角矩阵,因为i \neq j时, \frac{\partial w_ix_i}{\partial x_j} = 0,因此\frac{\partial wx}{\partial x} = diag(w)

同理,\frac{\partial \sum u}{\partial u} = \sum \frac{\partial u}{\partial u} = \sum [\frac{\partial u_1}{\partial u}, \frac{\partial u_2}{\partial u}, ..., \frac{\partial u_n}{\partial u}]^T = [1, 1, ..., 1] = \vec 1^T

因此,所有中间导数相乘的结果:

  • \frac{\partial (w*x + b)}{\partial x} = w^T
  • \frac{\partial (w*x + b)}{\partial w} = x^T
  • \frac{\partial (w*x + b)}{\partial b} = 1

lin_grad()中的inp.g、w.g和b.g分别是求x、w和b的导数,以inp.g为例,它等于w^T,且需要乘以前面各层的导数,即outp.g @ w.t(),之所以要用点积运算符(@)而不是标量相乘,是为了让它的导数shape和tensor shape保持一致。同理,w.g和b.g也是根据相同逻辑来计算的。

def relu_grad(inp, outp):
  inp.g = (inp > 0.).float() * outp.g

ReLu层的求导相对来说就简单多了,当输入 <= 0时,导数为0,当输入 > 0时,导数为1。

Testing

求导运算终于结束了,接下来就是验证我们的反向传播是否正确。验证方法是将forward_backward()计算的导数和Pytorch自动微分得到的导数相比较,如果它们相近,就认为我们的反向传播算法是正确的。

xg = x_train.g.clone()
w1g = W1.g.clone()
w2g = W2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()

x2 = x_train.clone().requires_grad_(True)
w11 = W1.clone().requires_grad_(True)
b11 = b1.clone().requires_grad_(True)
w22 = W2.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

def forward_backward2(x, y):
  l1 = x @ w11 + b11
  z1 = relu(l1)
  outp = z1 @ w22 + b22
  return mse(outp, y)

loss = forward_backward2(x_train, y_train)
loss.backward()

首先,将计算好的参数导数保存到w1g、b1g、w2g和b2g中,再用Pytorch的自动微分来求w11、b11、w22和b22的导数。

def test_near(a, b): return np.allclose(a, b)

assert test_near(w1g, w11.grad)
assert test_near(b1g, b11.grad)
assert test_near(w2g, w22.grad)
assert test_near(b2g, b22.grad)

最后,用np.allclose()来比较导数间的差异,如果有任何一个导数不相近,assert就会报错。结果证明,我们自己动手实现的算法是正确的。

END

反向传播是遵循链式法则的,它将前向传播的输出作为输入,输入作为输出,通过递进的方式将求导这个动作从后向前传递回各层。神经网络参数的求导需要进行矩阵微积分计算,根据这些导数的反方向来调节参数,就可以让模型的输出误差的优化到最小值。


欢迎关注和点赞,你的鼓励将是我创作的动力

欢迎转发至朋友圈,公众号转载请后台留言申请授权~

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,001评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,210评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,874评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,001评论 1 291
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,022评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,005评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,929评论 3 416
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,742评论 0 271
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,193评论 1 309
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,427评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,583评论 1 346
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,305评论 5 342
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,911评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,564评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,731评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,581评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,478评论 2 352

推荐阅读更多精彩内容