TensorFlow从0到1 - 6 - 解锁梯度下降算法

TensorFlow从0到1系列回顾

上一篇 5 TF轻松搞定线性回归,我们知道了模型参数训练的方向是由梯度下降算法指导的,并使用TF的封装tf.train.GradientDescentOptimizer(0.01)(学习率为0.01)完成了机器自学习的过程。本篇开启梯度下降算法的黑盒一探究竟,并解锁几个TF API常用参数的真正含义:

  • learning rate;
  • steps;
  • epoch;
  • mini batch。
雪山速降

一般函数的最小值问题

4 第一个机器学习问题引入了损失函数的定义,即待训模型参数为自变量,计算模型输出与预期(label)的均方误差(MSE)。如下所示。

B-O-F-1 损失函数

所获得的这个新函数C(a,b)的最小值处的(a, b)值,就是我们所寻找的理想模型参数。就这样,一个回归问题变成了更加具体的求函数极值的问题。

更进一步,本节将之前损失函数自变量a和b一般化表示为v1,v2,把求解损失函数的最小化问题,转换为更一般的函数C(v1,v2)最小化问题,C(v1,v2)具有任意的函数形式。如果找到一般的函数最小值求解方法,那么具有特殊形式的损失函数最小值求解自不在话下。

对于C是一个或者少数几个变量的函数,可以通过函数极值点处的导数特性来获得多元方程组,直接求解极值点。但是我们准备放弃这种尝试,因为对于一个真实世界的机器学习问题,其模型的复杂程度通常会远远的高于线性模型,参数的个数也远不止两个,损失函数的形式会变成:C(v1, v2 ... vn),如果n数以亿计,用微积分的方法简直就是噩梦。

雪山速降的启发

把函数曲面的某个局部,想象成前面图中的雪山,如果想速降(以最快的速度下山),那么直觉上的最佳路径就是沿着雪山最陡峭的方向下山。

再打个比方,考虑有两个自变量的二次函数C(v1, v2),在三维视图中,它是一个曲面。假设有个小球靠自身重力滚落到曲面的底部,可以想象其路径也是沿着最陡峭的方向的。

梯度下降

如果我们不能直接看出函数的最小值,或者通过直接求解的方式得到函数最小值,那么利用雪山速降、小球滚落的启发,总是沿着最陡峭的下降方向移动,就会最快到达最小值点。

那么,“最陡峭”方向在数学上该怎么表达呢?

梯度的定义

微积分告诉我们,当把v1, v2, ... , vn各个自变量移动一个很小的值,C将有如下变化:

B-C-F-1 微积分

梯度定义有:

B-C-F-2 梯度

v的变化量为∆v ≡ (∆v1, ∆v2, ..., ∆vn)T,则C的变化量可重写为梯度向量▽C与v的变化向量∆v的点乘:

B-C-F-3 C的增量

梯度下降算法

直觉上,如果v朝某个方向上移动,导致C的增量是个负数,那么可以肯定C在“下降”。

开下脑洞,直接令∆v = -η▽C,其中η是一个正数,代入公式B-C-F-3有:

∆C ≈ -η▽C·▽C = -η‖▽C‖2 ≤ 0,此时∆C一定小于等于0,C在下降。

幸运的是,数学上可以证明对于一个非常小的步长∆v,令∆v = -η▽C可以使C的减小最大化。

总结起来就是:

  • -η▽C正是我们期望的∆v——移动方向是▽C的反方向,移动的幅度是η‖▽C‖
  • v移动∆v所造成的C的∆C,是-η‖▽C‖2

上面这个η就叫做学习率learning rate

回头再来看“最陡峭的一小步”的数学解释,那就是沿着梯度的反方向上走一小步。只要一小步一小步朝着正确的方向移动,迟早可以走到C(v1, v2, ..., vn)的最小值处。“梯度下降”,名副其实。

梯度下降的具体操作方法如下:

  1. 随机选取自变量的初始位置v(以后会专门讨论初始化的技巧);
  2. v → v' = v - η▽Cv(v移动到v',▽Cv是v处的梯度,η保持不变);
  3. v' → v'' = v' - η▽Cv'(v'移动到v'',▽Cv'是v'处的梯度,η保持不变);
  4. ...

v移动的次数,即训练的步数steps

v是各个自变量(v1, v2, ..., vn)的向量表示,那具体到每个自变量该如何移动呢?以v1,v2为例:

B-O-F-3 梯度下降

随机梯度下降算法

到此,梯度下降算法解决了如何寻求一般函数C(v1, v2, ..., vn)的最小值问题,再应用到机器学习之前,先别急,还差一小步。

B-O-F-2 损失函数

回到损失函数,再仔细看看其形式,发现它有个特别之处,即函数表达式与训练样本的数量密切相关,它是多个样本方差的累加,最后再求均值。一个样本集的样本数动辄成千上万,为了“梯度下降”一小步中要用到的▽C,这么多样本都要参与计算吗?

并不需要,实践中有巧妙的方法:

B-O-F-4 样本梯度均值

首先,损失函数的梯度▽C,实践中一般是通过样本集中单个样本梯度值▽Cx的均值得到。如果你对这个公式持怀疑态度,这不奇怪,一个简单的消除疑虑的做法就是用之前的线性模型和损失函数,用两个样本值分别计算一下等式两边,看是否相等即可。

对于样本集成千上万个样本,对每个样本x都求其▽Cx,计算量似乎更大了。先别急,往下看。可以用一个小批量样本,通过其中每个样本▽Cx的均值,来近似为▽C:

B-O-F-5 样本梯度均值的近似

这就是实践中采用的方法,被称为随机梯度下降法。那个小批量样本就是一个mini batch

把全部样本集分成一批批的小样本集,每全部遍历使用过1次,就称为1次epoch

据此,每个自变量更新的公式如下:

B-O-F-6 分量的增量

上一篇 5 TF轻松搞定线性回归
下一篇 7 TF线性回归参数溢出之谜


共享协议:署名-非商业性使用-禁止演绎(CC BY-NC-ND 3.0 CN)
转载请注明:作者黑猿大叔(简书)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,928评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,192评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,468评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,186评论 1 286
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,295评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,374评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,403评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,186评论 0 269
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,610评论 1 306
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,906评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,075评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,755评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,393评论 3 320
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,079评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,313评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,934评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,963评论 2 351

推荐阅读更多精彩内容

  • 现在我们设计了一个神经网络,但是它怎样通过学习来识别手写数字呢?首先需要的是被称为训练数据集的数据集合。我们将使用...
    魔法炼金术阅读 676评论 0 2
  • 去年Alaph GO击败李世石九段,社会掀起了机器学习技术讨论的热潮,不过很多人对机器学习并不了解,本文借由手写数...
    Sunhaorong阅读 2,548评论 0 4
  • 第二个Topic讲深度学习,承接前面的《浅谈机器学习基础》。 深度学习简介 前面也提到过,机器学习的本质就是寻找最...
    我偏笑_NSNirvana阅读 15,597评论 7 49
  • 昨天室友走了,宿舍只剩下我自己一个人,本来也没什么的,一直也都是一个人在做着自己的事情,上自习,去图书馆,本来也就...
    一个爱幻想的girl阅读 265评论 0 0
  • 目录本次给大家介绍的是我收集以及自己个人保存一些.NET面试题简介1.C# 值类型和引用类型的区别2.如何使得一个...
    寒剑飘零阅读 4,809评论 0 30