快速掌握TensorFlow中张量运算的广播机制

相信大家在使用numpy和tensorflow的时候都会遇到如下的错误

ValueError: operands could not be broadcast together with shapes (4,3) (4,)这是由于numpy和tensorflow中的张量在进行运算的时候形状不满足广播机制的要求,不理解广播机制的同学可能会通过各种魔改代码来让代码正常运行起来,但是却不知道为什么那样改就可以。

本文将从原理上介绍张量运算中经常用到的广播机制。

广播(broadcasting)指的是不同形状的张量之间的算数运算的执行方式

通过两个例子直观了解广播

数组与标量值的乘法

import numpy as np
arr = np.arange(5)
arr #-> array([0, 1, 2, 3, 4])
arr * 4 #-> array([ 0, 4, 8, 12, 16])

在上面的乘法运算中,标量值4被广播到了其他所有元素上

通过减去列平均值的方式对数组每一列进行距平化处理

arr = np.random.randn(4,3)
arr #-> array([[ 1.83518156, 0.86096695, 0.18681254],
# [ 1.32276051, 0.97987486, 0.27828887],
# [ 0.65269467, 0.91924574, -0.71780692],
# [-0.05431312, 0.58711748, -1.21710134]])
arr.mean(axis=0) #-> array([ 0.93908091, 0.83680126, -0.36745171])

关于mean中的axis参数,可以这样理解:

numpy中,axis = 0为行轴(竖直方向),axis = 1为列轴(水平方向),指定axis表示该操作沿axis进行,得到结果将是一个shape为除去该axisarray,对于多维张量,axis=i是指运算操作沿着第i个张量下标变化的方向进行。

在上例中,arr.mean(axis=0)表示对arr沿着轴0(竖直方向)求均值。显然,第0个下标变化的方向即为竖直方向,以第一列为例,4个元素的下标分别为[(0,0),(1,0),(2,0),(3,0)]

arr的shape为(4,3),除去axis=0的shape,结果为(1,3)或者(3,),这与上面的代码运行结果相符。

广播机制的原理

如果两个数组的后缘维度(从末尾开始算起的维度)轴长度相符其中一方的长度为1,则认为它们是广播兼容的。广播会在缺失维度和(或)轴长度为1的维度上进行。

demeaned = arr - arr.mean(axis=0)
demeaned
> array([[ 0.89610065, 0.02416569, 0.55426426],
[ 0.3836796 , 0.1430736 , 0.64574058],
[-0.28638623, 0.08244448, -0.35035521],
[-0.99339402, -0.24968378, -0.84964963]])
demeaned.mean(axis=0)
> array([ -5.55111512e-17, -5.55111512e-17, 0.00000000e+00])

在上面的对arr每一列减去列平均值的例子中,arr的后缘维度为3arr.mean(0)后缘维度也是3,满足轴长度相符的条件,广播会在缺失维度进行。

这里有点奇怪的是缺失维度不是axis=1,而是axis=0,个人理解是缺失维度指的是两个arr除了轴长度匹配的维度,在上面的例子中,正好是axis=0

arr.mean(0)沿着axis=0广播,可以看作是把arr.mean(0)沿着竖直方向复制4份,即广播的时候arr.mean(0)相当于一个shape=(4,3)的数组,数组的每一行均相同,均为arr.mean(0)

各行减去行均值

row_means = arr.mean(axis=1)
row_means.shape
> (4,)
arr - row_means
> ---------------------------------------------------------------------------

ValueError Traceback (most recent call last)

<ipython-input-10-3d1314c7e700> in <module>()
----> 1 arr - row_means


ValueError: operands could not be broadcast together with shapes (4,3) (4,)

直接相减,报错,无法进行广播。

回顾上面的原则,要么满足后缘维度轴长度相等,要么满足其中一方长度为1。在这个例子中,两者均不满足,所以报错。根据广播原则,较小数组的广播维必须为1。解决方案是为较小的数组添加一个长度为1的新轴。


numpy提供了一种通过索引机制插入轴的特殊语法。通过特殊的np.newaxis属性以及“全”切片来插入新轴。

下面的例子中,我们通过插入新轴的方式实现二维数组各行减去行均值。这里将行均值沿着水平方向进行广播,广播轴为axis=1,对row_means添加一个新轴axis=1

row_means[:,np.newaxis].shape
> (4, 1)
arr - row_means[:,np.newaxis]
> array([[ 0.87419454, -0.10002007, -0.77417447],
[ 0.46245243, 0.11956678, -0.58201921],
[ 0.36798351, 0.63453458, -1.00251808],
[ 0.17378588, 0.81521647, -0.98900235]])

另一个例子

a = np.array([1,2,3])
a.shape # -> (3,)
b = np.array([[1,],[2,],[3]]) # -> (3,1)
b - a # -> array([[ 0, -1, -2],
# [ 1, 0, -1],
# [ 2, 1, 0]])

上面的例子输出为什么是一个3*3的数组? 

我们来分析一下,根据广播原则,b满足其中一方轴长度为1,那么广播会沿着长度为1的轴,及axis=1进行,对数组b沿着axis=1即水平方向进行复制,相当于b变成一个shape(3,3)且各列均为[1,2,3]的数组。

一个维度为(3,3)的数组减去一个维度为(3,)的数组,满足后缘维度轴长度相等,数组a沿着axis=0即竖直方向进行广播,相当远a变成一个shape(3,3)且个行均为[1,2,3]的数组。

b-a的时候,

 b被广播成为

[[1,1,1],
[2,2,2],
[3,3,3]]

a被广播成为

[[1,2,3],
[1,2,3],
[1,2,3]]

所以b-a的结果是

[[0,-1,-2],
[1, 0,-1],
[2, 1, 0]]

三维情况

下面的例子中,构造一个3*4*5的随机数组arr_3d,我们希望实现对arr_3d的每个元素减去其深度(axis=2)方向的均值

#构造三维数组
arr_3d = np.random.randn(3,4,5)
#求深度方向的均值,想想结果的shape是什么?原始shape是(3,4,5)
#除去axis=2后还剩(3,4)
depth_means = arr_3d.mean(axis=2)
depth_means.shape
> (3, 4)
#arr(3,4,5)和depth_means(3,4)不能直接广播,后缘维度不相符且不存在轴长度为1的轴
arr_3d_new = arr_3d - depth_means[:,:,np.newaxis] #所以我们添加广播轴
arr_3d_new.mean(axis=2) #结果应该为0,这里是接近0的浮点数,符合预期
> array([[ -5.55111512e-17, 4.44089210e-17, 4.44089210e-17, 4.44089210e-17],
[ -8.88178420e-17, -1.11022302e-16, -6.66133815e-17,
0.00000000e+00],
[ 0.00000000e+00, -7.77156117e-17, -2.22044605e-17,
-2.22044605e-17]])





最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 218,122评论 6 505
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,070评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,491评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,636评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,676评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,541评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,292评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,211评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,655评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,846评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,965评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,684评论 5 347
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,295评论 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,894评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,012评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,126评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,914评论 2 355

推荐阅读更多精彩内容

  • 先决条件 在阅读这个教程之前,你多少需要知道点python。如果你想从新回忆下,请看看Python Tutoria...
    舒map阅读 2,580评论 1 13
  • 一、numpy概述 numpy(Numerical Python)提供了python对多维数组对象的支持:ndar...
    L_steven的猫阅读 3,470评论 1 24
  • 前言 numpy是支持 Python语言的数值计算扩充库,其拥有强大的高维度数组处理与矩阵运算能力。除此之外,nu...
    TensorFlow开发者阅读 3,212评论 0 35
  • 介绍 NumPy 是一个 Python 包。 它代表 “Numeric Python”。 它是一个由多维数组对象和...
    喔蕾喔蕾喔蕾蕾蕾阅读 1,779评论 0 5
  • 我今年19岁,距离二十岁只差五个月的时间,我和大多数人都一样,没什么过人之处,读小学然后初中高中,最后是大学,然而...
    柚子黑咔阅读 90评论 0 1