INN实现理解——toy_example

github 地址:https://github.com/hagabbar/cINNamon

A Toy Example

1. 设置超参数:

2. 生成数据

^^ 调用函数来生成样本数据,参数 labels 用于限制生成的样本类型,labels 有三种取值:all、some、none,分别对应图 [A Toy Example] 中的三种样本分布;参数 tot_dataset_size 用于指定生成样本的个数。

pos 是一个大小为 (tot_dataset_size, 2) 的二阶矩阵,其元素都符合均值为0,方差为0.2的正态分布,表示样本点的坐标;labels 是一个 (tot_dataset_size, 3) 的矩阵,表示样本点的颜色 RGB 值。样本被均匀分为 8 堆,对每堆的样本坐标进行一定修改,使这堆样本点落在相同区域内,且有相同的颜色。

^^ 分别取 pos、labels 的前 test_split 个元素作为测试样本,画出测试数据的分布图如下:

3. 建立模型

^^ ndim_tot = max(ndim_x, ndim_y+ndim_z) + n_neurons,ndim_tot 的值对网络结构有重要影响,输入结点会将其作为维度值。如果维度 ndim_tot 相对较小,但却需要学习一个很复杂的转换,最好对网络的输入和输出都进行相同数量的 0 填充。这并不会改变输入和输出的固有维度,但使得网络内部层可以以一种更灵活的方式将数据嵌入到更大的表示空间。

^^ ReversibleGraphNet 构造函数会做四件事:
① 构造 INN 网络的正向连接,即:inp → t1 → t2 → t3 → outp。其中 t1、t2、t3 都是一个基础构建块,其结构为:

用公式表示为:

其中,s1、t1、s2、t2 都是一种线性映射关系,因此都被构造为一个有三层隐藏层的全连接神经网络。需要说明的是,隐藏层的神经元个数,被简单设置为输出层神经元个数的 2 倍。

② 确定 INN 网络的反向连接,使得可以进行反向训练。
③ 确定正向训练过程中涉及的变量及操作顺序。
④ 确定反向训练过程中涉及的变量及操作顺序。

4. 训练前准备工作

^^ 设置训练参数。

^^ 各项损失的相对权重。INN 训练过程中考虑三项损失:
① 模型输出 yi = s(xi) 与网络预测 fy(xi) 之间的偏差,损失记为 Ly(yi,fy(xi)),Ly 可以是任意有监督的损失;lamdb_predict 为 Ly 的权重;
② 模型输出 p(y = s(x)) = p(x) / |Js| 和潜在变量 p(z) 的边际分布的乘积与网络输出 q(y = fy(x),z = fz(x)) = p(x) / |Jyz| 间的偏差,记为 Lz(p(y)p(z),q(y,z));lambd_latent 为 Lz 的权重;
③ 输入端的损失 LxLx(p(x),q(x)) 表示了 p(x) 与后向预测分布 q(x) = p(y = fy(x)) p(z = fz(x)) / |Jx| 间的偏差;lambd_rev 为 Lx 的权重.

^^ 定义权重更新规则(optimizer),scheduler 对其进行封装,目的是使其学习率每隔 step_size 轮就进行一次衰减。

^^ 定义损失函数。需要定义三个函数来进行三种损失的计算,其中 Lx、Lz 是无监督损失,因此选择了 MMD_multiscale(多刻度的 MMD,MMD 常用于度量两个不同但相关的分布的距离)作为损失函数;Ly 是有监督损失,因此选择了 fit(即平方误差)。

^^ 建立测试数据装载器和训练集数据装载器。DataLoader 返回的是一个迭代器,可以使用迭代器分批获取数据,或直接使用 for 循环对其进行遍历。

^^ 初始化网络权重。这里 block 指各 INN 构建块, coeff 指 INN 构建块中的全连接神经网络。它们都是 Module 类的子类对象,因此使用了三层 for 循环对权重初始化。

^^ 从测试样本集中选取一部分,用于在模型训练完成后进行模型测试。

5. 训练模型

这个实现对 INN 训练了 2000 次。我们只看一次训练的步骤。

首先是设置此轮训练的学习率,这个一般由我们之前包装的 scheduler 根据训练轮数来进行设定。

^^ 训练网络。其中调用了 train() 函数:


核心函数 train()

首先要将训练涉及的各个模块的状态设置为 training。

^^ 设置 loss_factor,当 i_epoch 大于 300 时,其值为 1。

每轮训练只能使用数据装载器装载 n_its_per_epoch(设定为4)批的样本数据,对于每一批数据,进行如下处理。

^^ 对 x 和 yz 进行对其填充,使它们的维度和 ndim_tot 相同。在填充前,先为 y 增加随机噪声。这里也可以看出,z 服从标准正态分布

^^ 在开始训练前,需要清除已存在的梯度。

^^ 执行正向传播得到输出 output(正向计算由 PyTorch 实现),output 与输入有相同的维度。y_short 维度为 (样本数 × 4),其中,前 2 列表示 z,后两列表示 y。

^^ 计算损失 Ly,即为样本 y 和网络预测结果 y' (包括了补齐部分,但不包括 z)的均方误差。

^^ 计算损失 Lz。output_block_grad 维度为 (样本数 × 4),其中,前 2 列表示 z,后两列表示 y;其与 y_short 相对,区别是一个来源于正向网络预测结果 output,一个来源于样本 y。

^^ 这个 backward() 函数是 PyTorch 实现的,调用它是为了进行梯度计算。l 是正向过程的总损失,调用 l.backward() 计算梯度,是为了之后更新权值做准备。

^^ 这些都是反向训练需要的变量。y_rev 除了补齐部分外,包含了增加了随机噪声的、上一轮正向训练得到的 z;以及增加了随机噪声的原始样本 y。y_rev_rand 与 y_rev 大小相同,不同的地方在于,y_rev_rand 包含的是随机生成的服从标准正态分布的 z'。

^^ 对 y_rev 和 y_rev_rand 进行反向训练,得到输出结果。

^^ 计算反向训练的损失 Lx,可见其由两部分组成:一是样本 x 与反向训练结果 output_rev_rand 中的 x' 的差异;二是正向训练的输入与反向训练的输出(正向训练的输出为反向训练的输入)间的差异。

^^ l_rev 是逆向过程的总损失,调用 l_rev.backward() 计算梯度,是为了之后更新权值做准备。

^^ 将各参数的梯度值限制在 [-15.0,15.0] 区间内,然后更新网络权值。

^^ 一轮训练结束,返回训练每批数据的总损失。

train() 执行结束


在每轮训练结束后,使用测试数据对模型进行测试,画出其分布,即对训练结果可视化。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,657评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,662评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,143评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,732评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,837评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,036评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,126评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,868评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,315评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,641评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,773评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,470评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,126评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,859评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,095评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,584评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,676评论 2 351

推荐阅读更多精彩内容