Pyro简介:产生式模型实现库(二),推断

我们仍然以一个例子来说明Pyro的推断功能。首先,我们引入头文件。

import matplotlib.pyplot as plt
import numpy as np
import torch

import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

pyro.set_rng_seed(101)

例子:测量物体的重量

假如我们要测量物体的重量,而秤并不怎样精确,每每测量的结果存在稍许差异。为了补偿秤本身的误差,我们要把过程的“噪声”(即造成误差的不明因素)积分处理。下面的过程描述了数据产生的原理:
weight|guess \sim \cal{N}(guess, 1)
measurement|guess, weight \sim \cal{N}(weight, 0.75)
实现代码:

def scale(guess):
    weight = pyro.sample('weight', dist.Normal(guess, 1.))
    return pyro.sample('measurement', dist.Normal(weight, .75))

条件概率

上述是“正向”的数据产生过程,建模过程是很容易的。在实际的场景里,我们却只能通过观察数据,来“反推”数据的产生过程。Pyro中,产生数据的状态,是用sample()来实现的。
考虑scale,假如我们给定guess = 8.5作为输入,并且观察到measurement==9.5这样的样本,我们希望了解weight的分布范围是多少,即
(weight|guess, measurement==9.5) \sim ?
Pyro提供了pyro.condition来限制采样的状态。pyro.condition是一个“高阶函数”,即输入一个模型函数和一份观察值的字典,返回一个基于观察的新模型函数。

conditioned_scale = pyro.condition(scale, data={'measurement': 9.5})

这和Python的其他函数是一样的。我们可以用lambdadef的方法重写上面的句子:

def deferred_conditioned_scale(measurement, guess):
    return pyro.condition(scale, data={'measurement':  measurement})(guess)

还有一种更省事的写法,用obs这一关键字来提示pyro.condition观察值的情况。

def scale_obs(guess): # 该函数与 conditioned_scale是等价的
    weight = pyro.sample('weight', dist.Normal(guess, 1.))
    # 条件为给定观察值9.5
    return pyro.sample('measurement', dist.Normal(weight, 1.), obs=9.5)

多提一句,Pyro在pyro.condion中也集成了朱迪亚贝尔的“执行”命令pyro.do

用guide函数,灵活地推断

conditioned_scale函数里,我们在给定guessmeasurement==data的条件下,对weight进行推断。
推断算法在Pyro框架里,如pyro.infer.SVI,被定义在pyro.infer类中。对于被推断的任何随机函数,我们称其为guideguides,用来表示后验分布的近似结果。guide函数需要满足两个条件:

  1. 所有的独立变量(它们不依赖于其他随机变量),在model中出现的,也必在guide中出现。
  2. guide与model具有相同的参数(argument)。

guide在多种场景下发挥作用,如重要采样、拒绝采样、序列蒙特卡洛采样、MCMC、独立Metropolis-Hastings采样、变分推断、推断网络,等等。现在已经在Pyro完成封装的,有重要采样、MCMC、变分推断。在未来其余场景也将陆续完成。
虽然在不同场景下,guide可以灵活规定,原则上我们需要在guide中涵盖独立变量的完整采样过程。
scale中,给定guessmeasurement后,其后验概率为\cal{N}(9.14, .6)。由于这个例子比较简单,我们可以手算其后验概率的形式。(感兴趣的读者请参阅:http://www.stat.cmu.edu/~brian/463-663/week09/Chapter%2003.pdf
\mu_{update} = \frac{\sigma^2M + \tau^2nx}{n\tau^2+\sigma^2}
\sigma^2_{update} = \frac{\sigma^2\tau^2}{n\tau^2+\sigma^2}

def perfect_guide(guess):
    # sigma=0.75,tau=1,n=1,x=9.5,M=guess=8.5
    loc = (.75 ** 2 * guess + 9.5) / (1 + .75 ** 2) # 9.14
    scale = np.sqrt(.75 ** 2 / (1 + .75 ** 2)) # 0.6
    return pyro.sample('weight', dist.Normal(loc, scale))

从参数化的随机函数,到变分推断

上面的例子中,我们计算出了精确的后验概率分布。这是一种极为幸运的情况,而非通例。哪怕仍旧用scale这个简单的例子,如果weight经过某种非线性操作,后验分布就不再具有精确解了。

def scale(guess):
    weight = pyro.sample('weight', dist.Normal(guess, 1.))
    return pyro.sample('measurement', dist.Normal(some_nonlinear_function(weight), .75))

这时,我们需要重新估计一个函数,它的采样结果能最大程度地符合观察结果,或使某一损失函数最小化,这一过程叫做变分推断。在Pyro中,我们利用pyro.param来具体化guides函数的可选范围。
pyro.param是Pyro的键值对组成的容器。和pyro.sample一样,pyro.param通过第一个参数来命名。第一次声明pyro.sample的名字,容器中就会存储这个参数的名字和值,在以后再次调用时返回它的值。这个过程就像下面的sample_param_store.setdefault一样。

simple_param_store = {}
a = simple_param_store('a', torch.randn(1))

举个例子,我们要在scale_posterior_guide中,参数化ab,而非人工实例化它们:

def scale_parmeterized_guide(guess):
    a = pyro.param('a', torch.tensor(guess))
    b = pyro.param('b', torch.tensor(1.))
    return pyro.sample('weight', dist.Normal(a, torch.abs(b)))

插句题外话,上面的b加上了torch.abs函数,是因为正态分布的标准差必须是非负数。我们也可以通过Pytorch的constraint module来明确规定这一限制。

from torch.distributions import constraints

def scale_parameterized_guide_constrained(guess):
    a = pyro.param('a', torch.tensor(guess))
    b = pyro.param('b', torch.tensor(1.), constraint=constrains.positive)
    return pyro.sample('weight', dist.Normal(a, b)) # 不再需要 torch.abs

话说回来。Pyro这个代码库的最直接目的,就是执行随机变分推断(SVI)。这类操作包含下面三个特点:

  1. 参数都是实值张量
  2. 通过model和guide的执行历史,采样并计算得到损失函数的蒙特卡洛估计
  3. 通过梯度下降法,搜索最优的参数值

结合Pytorch的GPU加速和自动求导机制,Pyro能够在高维参数空间高效完成变分推断。在后面的教程中,我们会详细介绍。这里给出一个简单的例子:

guess = 8.5
pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_scale, 
                     guide=scale_parameterized_guide, 
                     optim=pyro.optim.SGD({'lr':0.001, 'momentum':0.1}),
                     loss=pyro.infer.Trace_ELBO())
losses, a, b = [], [], []
num_steps = 2500
for t in range(num_steps):
    losses.append(svi.step(guess))
    a.append(pyro.sample('a').item())
    b.append(pyro.sample('b').item())

plt.plot(losses)
plt.title('ELBO')
plt.xlabel('step')
plt.ylabel('loss')
print('a = ', pyro.sample('a').item())
print('b = ', pyro.sample('b').item())

a = 9.107474327087402
b = 0.6285384893417358

plt.subplot(1, 2, 1)
plt.plot([0, num_steps], [9.14, 9.14], 'k:')
plt.plot(a)
plt.ylabel('a')

plt.subplot(1,2,2)
plt.plot([0, num_steps], [0.6, 0.6], 'k:')
plt.plot(b)
plt.ylabel('b')
plt.tight_layout()

由图可见,SVI的推断值,与真值是相当接近的。这正是我们所希望的。
应该注意的是,guide的参数优化过程,被存放在参数容器中。当我们需要做后验采样时,我们可以直接从guide中采样,为下游的任务所利用。

接下来的教程,我们将使用神经网络来构建scale函数,并用随机变分推断的方法构建图像的产生式模型,敬请期待。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,125评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,293评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,054评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,077评论 1 291
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,096评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,062评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,988评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,817评论 0 273
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,266评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,486评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,646评论 1 347
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,375评论 5 342
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,974评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,621评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,796评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,642评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,538评论 2 352

推荐阅读更多精彩内容