太长不看版
- 模型在学习或调试过程中,设置
pyro.enable_validation(True)
; - 张量的“广播”,维度对齐自右向左:
torch.ones(3,4,5) + torch.ones(5)
; - 分布的尺寸
.sample().shape == batch_shape + event_shape
; - 分布的尺寸
.log_prob(x).shape == batch_shape
(没有event_shape
); - 使用
expand()
从Pyro中采样一批数据,或使用plate
机制自动扩展; - 使用
my_dist.to_event(1)
声明维度为依赖(dependent),或说不独立; - 使用
with pyro.plate('name', size):
声明条件独立; - 所有维度要么是依赖的,要么是条件独立的;
- 支持维度最左方的批处理,启动Pyro的并行处理;
- 使用负号指标,如
x.sum(-1)
,而不是x.sum(2)
; - 使用省略号,如
pixel = image[...,i, j]
; - 如果要枚举
i,j
,使用Vindex,如pixel = Vindex(image)[...,i, j]
;
- 使用负号指标,如
- 在调试过程中,使用Trace.format_shapes检查维度定义。
内容列表
- 概率分布的形状
-
plate
声明条件独立 - 在plate中部分采样
- 并行地枚举,张量的广播
文件头如下
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam
smoke_test = ('CI' in os.environ)
pyro.enable_validation(True) #这句话最好加上
# 我们借助这个函数,检查模型是否正确
def test_model(model, guide, loss):
pyro.clear_param_store()
loss.loss(model, guide)
概率分布的尺寸:batch_shape
和event_shape
Pytorch的张量Tensor
只有一个尺寸.shape
,但是Distributions
有两个尺寸.batch_shape
和.event_shape
,分别表示条件独立的随机变量的大小和不独立的随机变量的大小。这两部分构成了一个样本的尺寸。
x = d.sample()
assert x.shape == d.batch_shape + d.event_shape
由于计算对数似然只牵涉不独立的变量,所以.log_prob()
方法后,event_shape
就被缩并了,只剩下batch_shape
。
assert d.log_prob(x) == d.batch_shape
Distributions.sample()
方法可以输入一个参数sample_shape
,作为独立同分布(iid)的随机变量,所以指定样本大小的采样,具有三个尺寸。
x2 = d.sample(sample_shape)
assert x2.shape == sample_shape + batch_shape + event_shape
总结来说
| iid | independent | dependent
------+--------------+-------------+------------
shape = sample_shape + batch_shape + event_shape
由上可推论,单变量随机分布的event_shape
为0,因为每次采样值是一个实数,所以没有不独立的维度。像MultivariateNormal
多元高斯分布这样的概率分布,具有len(event_shape) == 1
,因为每个采样是一个向量,向量内部是彼此依赖的(这里假定方差矩阵不是对角阵)。而InverseWishart
逆威沙特分布具有len(event_shape) == 2
,等等。
关于概率分布尺寸的举例
从单变量随机分布开始。
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
# x是一个Pytorch张量,没有batch_shape和event_shape
assert x.shape == ()
assert d.log_prob(x).shape == ()
通过传入批参数,概率分布数据可以分成批。
d = Bernoulli(0.5 * torch.ones(3, 4))
assert d.batch_shape == (3,4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)
另一种成批的方法,是通过expand()
。不过只在参数的最左侧维度独立时才可使用。
d = Bernoulli(torch.tensor([.1, .2, .3, .4])).expand([3, 4])
# 注意expand的参数写在一个列表中
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)
多元高斯分布具有非空的event_shape
维度。对于这些分布来说,.sample()
和.log_prob()
的维度是不同的。
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3, )
x = d.sample()
assert x.shape == (3, ) # == batch_shape + event_shape
assert d.log_prob(x).shape == () # == batch_shape
改变分布的维度独立性
使用关键字.to_event(n)改变不独立维度的情况,其中n
表示从右数第n维度开始,声明为不独立维度。
d = Bernoulli(0.5 * torch.ones(3, 4)).to_event(1)
assert d.batch_shape == (3, )
assert d.event_shape == (4, )
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, )
用户必须小心地设置.to_event(n)
将batch_shape
缩减到合适的水平上,或者用pyro.plate
声明维度的独立性。采样仍旧会保留batch_shape+event_shape
的尺寸,然而log_prob(x)
只剩下batch_shape
。
声明为不独立,通常是安全的做法
在Pyro中,我们常常会声明维度是不独立的,哪怕它们实际上是独立的。请看这个例子:
x = pyro.sample('x', dist.Normal(0, 1).expand([10]).to_event(1))
assert x.shape == (10,)
上面的例子很容易就可以换成MultivariateNormal
分布。它将下面的写法简化了:
with pyro.plate('x_plate', 10):
x = pyro.sample('x', dist.Normal(0, 1)) #不需要expand,系统自动补全
assert x.shape == (10,)
实际上,这两份代码存在一点小小的差别。上面的代码中,Pyro默认x之间是不独立的,而下面的x则是条件独立的。声明为不独立通常是安全的,这与图论中的d-separation基于同一个原理:在不同节点之间多连一条边,即便节点之间不存在互相依赖关系,随着优化该边的权重将越来越低,并不影响最终结果;而本就存在依赖的节点少连了一条边,任优化策略多么高明,都无法弥补这一错误。这种错误常见于平均场假设的模型中。不过,在实际执行时,Pyro的SVI模块在估算Normal
分布时,两份代码的梯度估计值是一样的。
通过plate
声明维度为独立
Pyro的上下文管理器pyro.plate能够声明特定的维度为独立维度。推断算法可以利用这一独立性做一些算法优化,例如构造低方差的梯度估计器,再如求解推断问题不在指数空间而在线性空间采样。下面的例子中,我们将声明同一批次中的数据之间是互相独立的。
最简单的方法,是不声明独立维度,系统将缺省值-1——即最右边的维度,作为独立维度。
with pyro.plate('my_plate'):
# 在该上下文中,维度-1将作为独立维度
虽然效果是一样的,不过我们仍提倡用户写出来,以帮助用户调试代码:
with pyro.plate('my_plate', len(data)):
# 在该上下文中,维度-1将作为独立维度
从Pyro 0.2版本开始,plate语句可以嵌套使用。比如声明图像的每个像素都是独立的:
with pyro.plate('x_axis', 320):
# 在该上下文中,维度-1将作为独立维度
with pyro.plate('y_axis', 200):
# 在该上下文中,维度-2和-1将作为独立维度
我们习惯上总从右向左声明独立维度,所以指标是负的,如-1,-2,等等。
有时情况会更复杂一些,比如我们希望声明一些噪声依赖x
,另一些噪声依赖y
,还有一些噪声依赖二者。这时Pyro允许用户声明多重独立,为了清楚地标明独立维度,必须指定dim
这一参数,如下面的例子:
x_axis = pyro.plate('x_axis', dim = -2)
y_axis = pyro.plate('y_axis', dim = -3)
with x_axis:
# 在该上下文中,维度-2将作为独立维度
with y_axis:
# 在该上下文中,维度-3将作为独立维度
with x_axis, y_axis:
# 在该上下文中,维度-2和-3将作为独立维度
让我们举更多例子,来展示plate
的用法。
def model1():
a = pyro.sample('a', Normal(0, 1))
b = pyro.sample('b', Normal(torch.zeros(2), 1).to_event(1))
with pyro.plate('c_plate', 2):
c = pyro.sample('c', Normal(torch.zeros(2), 1))
with pyro.plate('d_plate', 3):
d = pyro.sample('d', Normal(torch.zeros(3, 4, 5), 1).to_event(2))
assert a.shape == () # batch_shape == (), event_shape == ()
assert b.shape == (2,) # batch_shape == (), event_shape == (2,)
assert c.shape == (2,) # batch_shape == (2,), event_shape == ()
assert d.shape == (3, 4, 5) # batch_shape == (3), event_shape == (4, 5)
##
x_axis = pyro.plate('x_axis', 3, dim=-2)
y_axis = pyro.plate('y_axis', 2, dim=-3)
with x_axis:
x = pyro.sample('x', Normal(0, 1))
with y_axis:
y = pyro.sample('y', Normal(0, 1))
with x_axis, y_axis:
xy = pyro.sample('xy', Normal(0, 1))
z = pyro.sample('z', Normal(0, 1).expand([5]).to_event(1))
assert x.shape == (3, 1) # batch_shape == (3, 1), event_shape==()
assert y.shape == (2, 1, 1) # batch_shape == (2, 1, 1), event_shape==()
assert xy.shape == (2, 3, 1) # batch_shape == (2, 3, 1), event_shape==()
assert z.shape == (2, 3, 1, 5) # batch_shape == (2, 3, 1), event_shape==(5,)
test_model(model1, model1, Trace_ELBO())
可视化如下:
batch dims | event dims
-----------+-----------
| a = sample("a", Normal(0, 1))
|2 b = sample("b", Normal(zeros(2), 1)
| .to_event(1))
| with plate("c", 2):
2| c = sample("c", Normal(zeros(2), 1))
| with plate("d", 3):
3|4 5 d = sample("d", Normal(zeros(3,4,5), 1)
| .to_event(2))
|
| x_axis = plate("x", 3, dim=-2)
| y_axis = plate("y", 2, dim=-3)
| with x_axis:
3 1| x = sample("x", Normal(0, 1))
| with y_axis:
2 1 1| y = sample("y", Normal(0, 1))
| with x_axis, y_axis:
2 3 1| xy = sample("xy", Normal(0, 1))
2 3 1|5 z = sample("z", Normal(0, 1).expand([5])
| .to_event(1))
为了在调试代码时方便地查看随机变量的形状,Pyro提供了Trace.format_shapes()
方法,在采样点上打印分布的形状(包含site['fn'].batch_shape
和site['fn'].event_shape
)、变量的形状(site['value'].shape
)、如果计算对数似然概率时log_prob
的形状(site['log_prob'].shape
)。
trace = poutine.trace(model1).get_trace()
trace.compute_log_prob() # 可选的,这句话可以打印log_prob的形状
print(trace.format_shapes())
打印结果:
Trace Shapes:
Param Sites:
Sample Sites:
a dist |
value |
log_prob |
b dist | 2
value | 2
log_prob |
c_plate dist |
value 2 |
log_prob |
c dist 2 |
value 2 |
log_prob 2 |
d_plate dist |
value 3 |
log_prob |
d dist 3 | 4 5
value 3 | 4 5
log_prob 3 |
x_axis dist |
value 3 |
log_prob |
y_axis dist |
value 2 |
log_prob |
x dist 3 1 |
value 3 1 |
log_prob 3 1 |
y dist 2 1 1 |
value 2 1 1 |
log_prob 2 1 1 |
xy dist 2 3 1 |
value 2 3 1 |
log_prob 2 3 1 |
z dist 2 3 1 | 5
value 2 3 1 | 5
log_prob 2 3 1 |
在plate
句块中采样部分张量
plate最重要的功能之一就是部分采样,plate
句块中的随机变量都是条件独立的。如果样本量为总样本的一半,那么样本损失的值将被认为是总损失的一半。
在实现部分时,用户需要通知Pyro采样量和样本总量的值,Pyro就会随机产生一定量的数据指标作为样本。
data = torch.arange(100.)
def model2():
mean = pyro.param('mean', torch.zeros(len(data)))
with pyro.plate('data', len(data), subsample_size=10) as ind:
assert len(ind) == 10
batch = data[ind]
mean_batch = mean[ind]
# 在batch中做一些计算
x = pyro.sample('x', Normal(mean_batch, 1), obs=batch)
assert x.shape == (10,)
test_model(model2, guide=lambda: None, loss=Trace_ELBO())
广播功能,实现数据的并行枚举
Pyro 0.2后的版本都支持离散随机变量的并行枚举功能。这一功能可以极大地减少计算变分推断时梯度估计的方差,确保优化的稳定性。
为了实现枚举,Pyro需要用户指定哪些维度是不独立的,哪些是独立的,只有不独立的维度才允许枚举。自然地,这一指定需要用到plate
语句,我们需要声明最大数量的枚举范围,这一关键字为max_plate_nesting
,它是SVI
类的一个参数(而且通过TraceEnum_ELBO传入)。通常来说,Pyro可以自动地指定枚举范围(只要运行一次model
和guide
,系统将了解枚举范围),不过在动态变化的模型中,用户需要人工地指定max_plate_nesting
的数值。
为了弄清楚max_plate_nesting
的作用机制,我们重新回顾model1()
,这一次我们关心三种维度的形状:最左边的枚举维度,中间的批维度,最右边的不独立维度。而max_plate_nesting
规定了中间的批维度。
max_plate_nesting = 3
|<--->|
enumeration|batch|event
-----------+-----+-----
|. . .| a = sample("a", Normal(0, 1))
|. . .|2 b = sample("b", Normal(zeros(2), 1)
| | .to_event(1))
| | with plate("c", 2):
|. . 2| c = sample("c", Normal(zeros(2), 1))
| | with plate("d", 3):
|. . 3|4 5 d = sample("d", Normal(zeros(3,4,5), 1)
| | .to_event(2))
| |
| | x_axis = plate("x", 3, dim=-2)
| | y_axis = plate("y", 2, dim=-3)
| | with x_axis:
|. 3 1| x = sample("x", Normal(0, 1))
| | with y_axis:
|2 1 1| y = sample("y", Normal(0, 1))
| | with x_axis, y_axis:
|2 3 1| xy = sample("xy", Normal(0, 1))
|2 3 1|5 z = sample("z", Normal(0, 1).expand([5]))
| | .to_event(1))
上面的例子中,如果我们声明(过度)充裕的max_plate_nesting=4
也是可以的,但不能声明例如max_plate_nesting=2
,因为2<3,这时系统将会报错。
我们再举一个例子:
@config_enumerate
#该修饰符表示枚举类型,不能省略!!
def model3():
p = pyro.param('p', torch.arange(6) / 6.)
locs = pyro.param('locs', torch.tensor([-1., 1.]))
# locs in [-1, 1]
# a in [0, 1, 2, 3, 4, 5]
a = pyro.sample('a', Categorical(torch.ones(6) / 6.))
# p[a] in [0, 1/6, 2/6, 3/6, 4/6, 5/6]
b = pyro.sample('b', Bernoulli(p[a])) # 声明b依赖于a
# b in [0, 1]
with pyro.plate('c_plate', 4):
c = pyro.sample('c', Bernoulli(0.4))
# c in [0, 1]
with pyro.plate('d_plate', 5):
d = pyro.sample('d', Bernoulli(0.3))
# d in [0, 1]
e_loc = locs[d.long()].unsqueeze(-1)
# e_loc in [-1, 1]
e_scale = torch.arange(1., 8.)
# e_scale in [1, 2, ..., 7]
e = pyro.sample('e', Normal(e_loc, e_scale).to_event(1)) # 依赖于d
# 枚举维度|批维度(独立维度)|不独立维度
assert a.shape == ( 6, 1,1 ) # 多类别分布的维度大小为6
assert b.shape == ( 2,1, 1,1 ) # 枚举伯努利分布,非扩增
assert c.shape == ( 2,1,1, 1,1 ) # 伯努利分布,非扩增
assert d.shape == ( 2,1,1,1, 1,1 ) # 伯努利分布,非扩增
assert e.shape == ( 2,1,1,1, 5,4, 7) # e是采样出来的,依赖于d
#
assert e_loc.shape == ( 2,1,1,1, 1,1, 1,) # 最后的逗号可以省略
assert e_scale.shape == ( 7,) # 注意逗号不能省略!!
test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))
我们重新来可视化一下:
max_plate_nesting = 2
|<->|
enumeration batch event
------------|---|-----
6|1 1| a = pyro.sample("a", Categorical(torch.ones(6) / 6))
2 1|1 1| b = pyro.sample("b", Bernoulli(p[a]))
| | with pyro.plate("c_plate", 4):
2 1 1|1 1| c = pyro.sample("c", Bernoulli(0.3))
| | with pyro.plate("d_plate", 5):
2 1 1 1|1 1| d = pyro.sample("d", Bernoulli(0.4))
2 1 1 1|1 1|1 e_loc = locs[d.long()].unsqueeze(-1)
| |7 e_scale = torch.arange(1., 8.)
2 1 1 1|5 4|7 e = pyro.sample("e", Normal(e_loc, e_scale)
| | .to_event(1))
我们分析一下这些维度。我们为Pyro指定了枚举的维度max_plate_nesting
:Pyro给a
赋予枚举维度-3,给b
赋予枚举维度-4,给c
赋予枚举维度-5,给d
赋予枚举维度-6。当用户不指定维度扩展后的数值时,新维度被默认为1,这方便计算。我们还可以观察到,log_prob
的形状广播的范围是枚举维度和独立维度,比如trace.nodes['d']['log_prob'].shape == (2,1,1,1,5,4)
使用Pyro的自带工具Trace.format_shapes():
trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace()
trace.compute_log_prob() # 可选
print(trace.format_shapes())
结果:
Trace Shapes:
Param Sites:
p 6
locs 2
Sample Sites:
a dist |
value 6 1 1 |
log_prob 6 1 1 |
b dist 6 1 1 |
value 2 1 1 1 |
log_prob 2 6 1 1 |
c_plate dist |
value 4 |
log_prob |
c dist 4 |
value 2 1 1 1 1 |
log_prob 2 1 1 1 4 |
d_plate dist |
value 5 |
log_prob |
d dist 5 4 |
value 2 1 1 1 1 1 |
log_prob 2 1 1 1 5 4 |
e dist 2 1 1 1 5 4 | 7
value 2 1 1 1 5 4 | 7
log_prob 2 1 1 1 5 4 |
编写并行代码
在Pyro中,我们需要掌握两个取巧的技术,来实现并行采样:广播 、 椭圆分片。我们通过下面的例子来分别介绍枚举情形和非枚举情形下的用法。
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumeration = None # 设为True或False
def fun(observe):
p_x = pyro.param('p_x', torch.tensor(0.1), constraint=constraints.unit_interval)
p_y = pyro.param('p_y', torch.tensor(0.1), constraint=constraints.unit_interval)
x_axis = pyro.plate('x_axis', width, dim=-2)
y_axis = pyro.plate('y_axis', height, dim=-1)
# 在这些样本点上,分布形状取决于Pyro是否枚举
with x_axis:
x_active = pyro.sample('x_active', Bernoulli(p_x))
with y_axis:
y_active = pyro.sample('y_active', Bernoulli(p_y))
if enumerated:
assert x_active.shape == (2, 1, 1) # max_plate_nesting==2
assert y_active.shape == (2, 1, 1, 1)
else:
assert x_active.shape == (width, 1)
assert y_active.shape == (height, )
# 第一个trick:广播,broadcast。枚举和非枚举都可使用。
p = 0.1 + 0.5 * x_active * y_active
if enumerated:
assert p.shape == (2, 2, 1, 1)
else:
assert p.shape == (width, height)
dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
# 第二个trick:椭圆分片。Pyro可以在左方任意增加维度。
for x, y in sparse_pixels:
dense_pixels[..., x, y] = 1
if enumerated:
assert dense_pixels.shape == (2, 2, width, height)
else:
assert dense_pixels.shape == (width, height)
#
with x_axis, y_axis:
if observe:
pyro.sample('pixels', Bernoulli(p), obs=dense_pixels)
def model4():
fun(observe=True)
def guide4():
fun(observe=False)
# Test: 非枚举
enumerated = False
test_model(model4, guide4, Trace_ELBO())
# Test: 枚举。注意目标函数为TraceEnum_ELBO
enumerated = True
test_model(model4, config_enumerate(guide4, 'parallel'), TraceEnum_ELBO(max_plate_nesting=2))
在pyro.plate内部实现自动广播
在以上所有model/plate的实现中,我们都使用了pyro.plate的自动扩增功能,使变量满足pyro.sample
规定的形状。这一广播方式等价于.expand()
。
我们稍许更改上面的代码作为例子,注意几点区别:
- 我们仅考虑并行枚举的情况,但对于串行的、非枚举的情况也适用;
- 我们将采样函数分离出来,model代码使用常规的形式,这样做有利于代码的维护;
-
pyro.plate
使用ELBO的num_particles参数,将上下文中最远的内容打包。
# 规定采样的样本量
num_particals = 100
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample('x_active', Bernoulli(p_x).expand([num_particals, width, 1]))
with y_axis:
y_active = pyro.sample('y_active', Bernoulli(p_y).expand([num_particals, 1, height]))
return x_active, y_active
def sample_pixel_locations_full_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample('x_active', Bernoulli(p_x))
with y_axis:
y_active = pyro.sample('y_acitve', Bernoulli(p_y))
return x_active, y_active
def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample('x_active', Bernoulli(p_x).expand([width, 1]))
with y_axis:
y_active = pyro.sample('y_active', Bernoulli(p_y).expand([height]))
return x_acitve, y_active
def fun(observe, sample_fn):
p_x = pyro.param('p_x', torch.tensor(0.1), constraint=constraints.unit_interval)
p_y = pyro.param('p_y', torch.tensor(0.1), constraint=constraints.unit_interval)
x_axis = pyro.plate('x_axis', width, dim=-2)
y_axis = pyro.plate('y_axis', height, dim=-1)
#
with pyro.plate('num_particals', 100, dim=-3):
x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
## 并行枚举指标被扩增在“num_particals”最左边
assert x_active.shape == (2, 1, 1, 1)
assert y_active.shape == (2, 1, 1, 1, 1)
p = 0.1 + 0.5 * x_active * y_active
assert p.shape == (2, 2, 1, 1, 1)
dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
for x, y in sparse_pixels:
dense_pixels[..., x, y] = 1
assert dense_pixels.shape == (2, 2, 1, width, height)
#
with x_axis, y_axis:
if observe:
pyro.sample('pixels', Bernoulli(p), obs=dense_pixels)
def test_model_with_sample_fn(sample_fn):
def model():
fun(observe=True, sample_fn=sample_fn)
#
@config_enumerate
def guide():
fun(observe=False, sample_fn=sample_fn)
test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_full_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting)
在第一个采样函数中,我们像账房先生那样,仔细规定了Bernoulli
分布的的形状。请仔细观察num_particles
, width
和height
传入sample_pixel_locations
函数的方式。这一方式有些笨拙。
对于第二个采样函数,我们需要注意pyro.plate
的参数必须要提供,这样系统才能猜出批维度的形状。
我们可以看到,对于张量操作,使用pyro.plate
实现并行是多么容易!
pyro.plate
还具有将代码模块化的效果。