Notes on PyTorch Internals II

原版英文链接:Edward Z. Yang's PyTorch internals : Inside 245-5D

Autograd

PyTorch与Python、Numpy等最大的区别在于自动求导机制,这种机制为神经网络训练过程中的梯度计算而特别设计。为了说明梯度计算,先给出如下一段计算进行展示:

i2h = torch.mm(W_x, x.t())
h2h = torch.mm(W_h, prev_h.t())
next_h = i2h + h2h
next_h2 = next_h.tanh()
loss = next_h2.sum()

通过PyTorch的自动求导机制,系统底层会进行类似如下的求导计算:

grad_loss = torch.tensor(1, dtype=loss.dtype()) 
grad_next_h2 += grad_loss.expand(next_h2.size())
grad_next_h += tanh_backward(grad_next_h2, next_h2)
grad_i2h, grad_h2h += grad_next_h, grad_next_h
W_h.grad += mm_mat1_backward(grad_h2h, prev_h.t(), W_h, 1)
W_x.grad += mm_mat1_backward(grad_i2h, x_t(), W_x, 1)

一个正向计算,有其对应的反向梯度计算。正向计算的输出通常也会作为输入,参与反向梯度计算。例如,tanh的计算转换为tanh_backward计算,tanh计算的输出next_h2成为tanh_backward计算的输入。

上述计算过程实际就是基于链式规则的反向梯度计算:
grad\_loss = d(loss) / d(loss)
grad\_loss取值恒为1
grad\_next\_h2 = \partial(loss) / \partial(next\_h2)
对每个分量求偏导数,求和计算的求导输出与next\_2h维度一致的全1向量,对应的函数实现为grad\_loss.expand(next\_h2.size())

为了支持自动求导计算,需要存储更多的信息。PyTorch增加Variable类型,封装(wrapping)Tensor并增加额外的元数据(Autograd Meta)用于反向梯度计算。此外也需要更新调度(dispatches)流程,在调度不同的设别实现之前(见Tensor介绍),需要对Variable进行拆封(unwrapping),增加额外需要进行梯度计算的信息后再一次封装(rewrapping)。

说明:以下内容原博客没有详细展开,非原文作者内容。由本文作者基于ppt和参考文献制作。

Autograd Meta

从PyTorch 0.4版本开始,支持梯度计算的Variable已经和Tensor合并。Tensor本身即包含了支持自动梯度计算所需的属性,主要包括:

data: Tensor类型,描述实际参与计算的数据内容。

requires_grad: bool类型,描述当前变量是否需要参与梯度计算。True表示参与梯度计算,PyTorch底层会开始跟踪所有该变量参与的计算路径。

grad: Tensor类型,存储当前变量的梯度值。如果requires_grad=False,该值为None;当requires_grad=True,如果计算路径上的输出节点没有调用backward()执行求导计算,取值仍为None;只有两个条件都满足时,grad才有取值。对于变量x,及其路径上的输出节点out(x),调用out.backward()后,x.grad=\partial{out}/\partial{x}

grad_fun: 梯度计算函数,用于求取输出节点相对当前变量的梯度值。通常,对于一个正向计算函数,有对应的梯度计算函数。正向计算的输出也会参与梯度的计算过程。

is_leaf: bool类型,描述当前变量在动态计算图中是否为叶子节点。一个节点是叶子节点的三种情况为:
1.通过显式声明的变量,例如x=torch.tensor(1.0)
2.通过全部为requires_grad=False的输出变量而创建的输出变量。
3.通过调用detach()得到的变量

只有requires_gradis_leaf都为True的节点,才会进行梯度的计算,得到的梯度值更新到grad属性中。

Jacobian Vector Product

backward()函数用来触发梯度的计算。该函数支持外部Tensor参数调用,参数的维度需要与调用backward()函数的变量一致。考虑到神经网络训练过程中,大部分情况是计算损失(标量)对高维权重的梯度,backward()参数的默认值为tensor(1.0),维度与损失标量一致,out.backward()调用其实是 out.backward(torch.tensor(1.0))

如果out是非标量变量,则需要传入维度一致的参数。例如:

x = torch.rand(1,3, requires_grad=True)
y = torch.rand(1,3, requires_grad=True)
out = x*y
out.backward(torch.FloatTensor([1.0, 1.0, 1.0]))

out1x3维的向量,调用backward()函数需要传入同等维度的参数。如果直接调用out.backward(),会触发RuntimeError: grad can be implicitly created only for scalar outputs错误。

节点间梯度的计算实际是Jacobian Vector乘积的过程。以正向计算的输出向量为输入向量,乘以输出向量相对于求导变量的Jacobian矩阵,得到当前变量的反向梯度值。

假设向量X=[x1, x2, ..., xn],通过某个函数映射一组输出向量f(X)=[f1(X), f2(X), ..., fm(X)],则f(x)的Jacobian矩阵为:

Jacobian Matrix (source: Wikipedia)

应用到PyTorch的反向梯度计算中,针对如下伪代码计算流程:

# pseudo-code
x = [x1, x2, ..., xn]
y = f(x)=[y1(x), y2(x), ..., ym(x)]
l = loss_fun(y)
l.backward()

实际底层通过Jacobian Vector乘积的流程为:

  1. 计算损失l相对向量y的梯度v=(\partial{l}/\partial{y_1}, \partial{l}/\partial{y_2}, ..., \partial{l}/\partial{y_m})^T
    该向量称为梯度张量(grad_tensor)。

  2. 计算损失l相对输入向量x的梯度。通过链式法则\partial{l}/\partial{x}=\partial{l}/\partial{y}\cdot\partial{y}/\partial{x}\partial{y}/\partial{x}为向量对向量的求导,需要用到Jacobian矩阵。结合第一步计算的梯度张量v

    Jacobian Vector Product [2]

通过Jacobian Vector乘积方式,实现了计算步骤中,任意维度输出相对任意维度输入的梯度计算。通过链式连接计算路径上各步骤,得到最终的梯度值。

那么计算路径如何描述呢?

Dynamic Computational Graph

在PyTorch中,计算路径通过动态计算图(Dynamic Computational Graph, DCG)描述。DGC由变量和函数组成,包含两层含义:

计算图:通过图的方式描述数学计算过程。在计算图中,节点由输入和计算函数组成,数据通过边流向下一个节点。PyTorch的计算图是有向无环图(Directed Acyclic Graph, DAG),叶子节点由输入变量组成,根节点由输出变量组成,中间节点既可以是算术操作符,也可以中间结果变量。

动态图: 在Caffe以及早期版本的TensorFlow中,计算图是静态的,即预先定义好计算的顺序。运行过程中只是数据的流通,计算图是不可变动的。在PyTorch中计算图是动态的,边构建边定义(define-by-run)。动态图支持构建过程中随时查看中间结果,而静态图则需要定义图之后,运行整个图才能查看结果。此外,动态图天然支持运行过程中结构动态变化的神经网络架构。当然,由于每次迭代都会构建一个全新的计算图,会影响计算速度。

反向梯度计算过程直接依赖于正向构建的计算图。计算从调用backward()函数的根节点开始,沿计算图反向追踪到所有可以触达的requires_grad=True的叶子节点(is_leaf=True)。只有requires_grad=True的叶子节点才需要计算梯度值。

针对如下的计算片段:

a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0)
c = a * b
d = torch.tensor(4.0, requires_grad=True)
e = c * d
e.backward()

通过计算图描述为:


动态计算图(参考[3]制作)

计算图由变量节点和函数节点组成。

变量节点存储与计算相关的属性(见上文描述),函数节点分为正向计算节点和梯度计算节点。如图中的mul以及MulBackward节点,分别表示正向的乘积运算和乘积的反向梯度运算。并且,MulBackward节点的引用保存在正向计算输出的变量节点(如图中的ec)的grad_fun属性。

函数节点保存了当前计算的上下文(context),正向计算节点和反向梯度计算节点通过上下文对象共享中间存储结果。在PyTorch中,参与梯度计算的函数需继承自torch.autograd.Function类,并重载其两个静态函数forward(ctx, input)backward(ctx, input_grad)实现底层的前向计算和反向计算逻辑。在正向计算中调用 ctx.save_for_backward(...)存储反向梯度计算所需的中间变量(通常包含输入);在反向计算中通过ctx.saved_tensors来获取正向计算传递过来的中间变量。fowardbackward函数组成了关于某个特定计算的正向计算和反向梯度计算节点对。

变量节点的grad_fun指向变量对应的反向梯度计算节点,而计算节点的next_functions属性以列表方式,存储参与计算变量节点的其它输入节点所对应的反向梯度计算节点。例如,存储在节点e.grad_func的乘积反向梯度节点MultBackward,其属性next_functions=[(MultiBackward, 0), (AccumulateGrad, 0)]。当前变量节点为e,参与计算该节点的输入变量节点有cd。第一个元素MultBackward即为c节点的反向梯度计算节点(c是通过乘积计算得到,梯度计算函数为乘积反向梯度计算类型),其引用存储在c.grad_fun中;d节点是参与梯度计算的叶子节点,即反向梯度计算的终止点,系统会自动附上AccumulateGrad梯度累积计算节点,更新该节点的当前梯度值。此即为next_functions属性列表中第二项。

某些正向计算节点可能输出多个变量,例如:

xyz = torch.tensor[(1.0, 2.0, 3.0), requires_grad=True]
x, y, z = xyz.unbind()
out = x*z
out.backward()

unbind()会输出三个变量。为了确定反向梯度计算针对哪个输出变量,next_functions每个元素包含一个索引值,描述求导变量在原始变量中的位置。例如,在计算xz相对xyz的梯度时,需要记录xz在原始变量xyz中的位置(索引位置影响梯度的计算,例如\partial{x}/\partial(xyz)=[1.0, 0, 0]\partial{z}/\partial(xyz)=[0, 0, 1.0])。存储方式为out.grad_fun.next_functions=[(UnbindBackward, 0), (UnbindBackward, 2)],其中索引02对应xz位置。

在上述动态计算图中,黑线代表正向计算的数据流,蓝线代表反向梯度计算的数据流。由图可知,反向梯度计算时,并不是所有节点会接受反向梯度计算的更新,例如节点bc。回顾上文,只有requires_grad=True的叶子节点(is_leaf=True)才需要计算梯度值。由于变量brequires_grad=False,而c节点不是叶子节点,因此无需计算其梯度计算。

反向梯度计算始于e.backward()的调用。变量e为标量,使用默认值torch.tensor(1.0)作为参数即可。根据e.grad_fun的存储引用,调用MultBackward梯度计算函数,并将默认值1.0传递给该计算节点。该函数分别计算cd的梯度值为4.0(=d)6.0(=c),与输入值1.0乘积后,分别传递给节点cd的梯度计算函数。两个节点的梯度计算函数引用存储在e.next_functions列表中,其中第一项对应c节点的梯度函数,第二项对应节点d的梯度函数。

节点c非叶子节点,无需计算其梯度值。通过引用梯度函数c.grad_fun分别计算c对输入节点(ba)的梯度,将梯度流进一步往下传递。其中输入节点brequires_grad=False,因此无需计算梯度(对应梯度函数为None);对节点a的梯度值为3.0(=b),通过链式法则,乘以上一步传入的梯度值4.0,累积输出梯度值12.0a为叶子节点并且requires_grad=True,需计算该节点梯度值。通过缺省AccumlateGrad梯度计算函数,将累积梯度值12.0更新到节点a的当前值('0.0'),得到a节点在计算图上本次完整迭代后的最终值12.0

节点drequires_grad=True的叶子节点,需计算其梯度值。通过缺省的梯度累加函数AccumulateGrad,将上一节点传递进来的梯度值6.0,叠加到当前的值(0.0)并更新到 d.grad中。

至此,完成梯度在计算图上的一次完整迭代计算。所有需要计算梯度的节点(ab)都将得到更新的梯度值。多次迭代运行计算图,节点的梯度值会进行累加计算。为了降低存储量,PyTorch在每次调用backward()后,中间计算的buffer已经清除(网络训练中每次min-batch迭代都会基于新数据构建一个新的计算图,因此一次计算完成更新权重梯度后,默认清除中间结果,节省存储空间),在同一批节点上无法进行第二次backward()调用。通过设置retain_graph=True可以保留计算图中间结果,实现多次调用。例如:

a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0)
c = a * b
d = torch.tensor(4.0, requires_grad=True)
e = c * d
e.backward(retain_graph=True)    # retain computational graph buffers
print(a.grad)        # a.grad = 12.0
e.backward()        
print(a.grad)        # a.grad = 24.0

每一次迭代,都会执行\partial{e}/\partial{a}=12.0计算,该梯度流通过累加函数AccumulateGrad增加到当前梯度值上(a.grad+=12.0)。

Notes on PyTorch Internals系列文章

Notes on PyTorch Internals I
Notes on PyTorch Internals II
Notes on PyTorch Internals III

参考文献

[1] Adam Paszke, Sam Gross, etc. Automatic differentiation in PyTorch
[2] Vaibhav Kumar. PyTorch Autograd
[3] Elliot Waite. PyTorch Autograd Explained — In-depth Tutorial

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,547评论 6 477
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,399评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,428评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,599评论 1 274
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,612评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,577评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,941评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,603评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,852评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,605评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,693评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,375评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,955评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,936评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,172评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 43,970评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,414评论 2 342