50行代码实现GAN

人生苦短我用GAN

首先声明一下,本教程面向入门吃瓜群众,大牛可以绕道,闲话不多说,先方一波广告。(高级GAN玩法),怎么说,我越来越感觉到人工智能正在迎来生成模型的时代,以前海量数据训练模型的办法有点揠苗助长,看似效果很好,实际上机器什么卵都没有学到(至少从迁移性上看缺少一点味道,不过就图片领域来说另当别论,在CV领域监督学习还是相当成功)。
但是问题来了,GAN这么屌这么牛逼,我怎么搞?怎么入门?谁带我?慌了!

莫慌,50行代码你就可以成为无监督学习大牛

我最讨厌那些,嘴里一堆算法,算法实现不出来的人。因为我喜欢看到结果啊!尤其是一些教程,就是将论文,鸡巴论文奖那么多有什么用?你码代码给我看啊,我不知道数据是什么,不知道输入维度是什么,输出什么,里面到底发生了什么变化我怎么学?这就有点像,典型的在沙漠里教你钓鱼,在我看来,论文应该是最后才去看的东西。但是问题在于,你要有一个入门的教程啊。我想这是一个鸿沟,科研里面,理论和动手的鸿沟。
这篇教程就是引路人了。欢迎加入生成模型队伍。这个教程会一直保持更新,因为科技每天变幻莫测,同时我还会加入很多新内容,改进一些在以后看来是错误的说法。

首先,我们废话不多说了,直接show you the code:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
from scipy import stats


def generate_real_data_distribution(n_dim, num_samples):
    all_data = []
    for i in range(num_samples):
        x = np.random.uniform(0, 8, n_dim)
        y = stats.lognorm.pdf(x, 0.6)
        all_data.append(y)
    all_data = np.array(all_data)
    print('generated data shape: ', all_data.shape)
    return all_data


def batch_inputs(all_data, batch_size=6):
    assert isinstance(all_data, np.ndarray), 'all_data must be numpy array'
    batch_x = all_data[np.random.randint(all_data.shape[0], size=batch_size)]
    return Variable(torch.from_numpy(batch_x).float())


def main():
    # 给generator的噪音维数
    n_noise_dim = 30
    # 真实数据的维度
    n_real_data_dim = 256
    num_samples = 666
    lr_g = 0.001
    lr_d = 0.03
    batch_size = 6
    epochs = 1000

    real_data = generate_real_data_distribution(n_real_data_dim, num_samples=num_samples)
    print('sample from real data: \n', real_data[: 10])

    g_net = nn.Sequential(
        nn.Linear(n_noise_dim, 128),
        nn.ReLU(),
        nn.Linear(128, n_real_data_dim)
    )

    d_net = nn.Sequential(
        nn.Linear(n_real_data_dim, 128),
        nn.ReLU(),
        nn.Linear(128, 1),
        nn.Sigmoid()
    )

    opt_d = torch.optim.Adam(d_net.parameters(), lr=lr_d)
    opt_g = torch.optim.Adam(g_net.parameters(), lr=lr_g)

    for epoch in range(epochs):
        for i in range(num_samples // batch_size):
            batch_x = batch_inputs(real_data, batch_size)
            batch_noise = Variable(torch.randn(batch_size, n_noise_dim))

            g_data = g_net(batch_noise)

            # 用G判断两个输出分别多大概率是来自真正的画家
            prob_fake = d_net(g_data)
            prob_real = d_net(batch_x)

            # 很显然,mean里面的这部分是一个负值,如果想整体loss变小,必须要变成正直,加一个负号,否则会越来越大
            d_loss = -torch.mean(torch.log(prob_real) + torch.log(1 - prob_fake))
            # 而g的loss要使得discriminator的prob_fake尽可能小,这样才能骗过它,因此也要加一个负号
            g_loss = -torch.mean(torch.log(prob_fake))

            opt_d.zero_grad()
            d_loss.backward(retain_variables=True)
            opt_d.step()

            opt_g.zero_grad()
            g_loss.backward(retain_variables=True)
            opt_g.step()

            print('Epoch: {}, batch: {}, d_loss: {}, g_loss: {}'.format(epoch, i, d_loss.data.numpy()[0],
                                                                        g_loss.data.numpy()[0]))

if __name__ == '__main__':
    main()

这些代码,总共,也就是90行,核心代码50行,基本上,比你写一个其他程序都端,什么红黑算法,什么排序之类的。我个人比较喜欢简约,我很多时候不喜欢太鸡巴隆昌的代码。

直接开始训练吧

这个GAN很简单,三部分:

  • real data生成,这个real data我们怎么去模拟呢?注意这里用的数据是二维的,不是图片,图片是三维的,二维你可以看成是csv,或者是序列,在这里面我们每一行,也就是一个样本,是sample自某个分布的数据,这里用的分布式lognorm;
  • d_net 和 g_net,这里两个net都是非常小,小到爆炸,这如果要是用tensorflow写就有点蛋疼了,我选择PyTorch,一目了然;
  • loss,loss在GAN中非常重要,是接下来的重点。

OK,一阵复制粘贴,你就可以训练一个GAN,这个GAN用来做什么?就是你随机输入一个噪音,生成模型将会生成一个和lognorm分布一样的数据。也就是说,生成模型学到了lognrom分布。这能说明什么?神经网络学到了概率!用到图片里面就是,他知道哪个颜色快可能是什么东西,这也是现在的CycleGAN, DiscoGAN的原理。

我吃饭去了

未完待续...

来了

继续刚才的,好像我写的文章没有人看啊,伤感。自己写自己看吧,哎,我骚味改了一下代码,loss函数部分,之前的写错了,我偷一张图把。



这个是公式,原始GAN论文里面给的公式,但是毫无疑问,正如很多人说的那样,GAN很容易漂移:

Epoch: 47, batch: 66, d_loss: 0.7026655673980713, g_loss: 2.0336945056915283
Epoch: 47, batch: 67, d_loss: 0.41225430369377136, g_loss: 2.1994106769561768
Epoch: 47, batch: 68, d_loss: 0.674636960029602, g_loss: 1.5774009227752686
Epoch: 47, batch: 69, d_loss: 0.5779278874397278, g_loss: 2.2797725200653076
Epoch: 47, batch: 70, d_loss: 0.4029145836830139, g_loss: 2.200833559036255
Epoch: 47, batch: 71, d_loss: 0.7264774441719055, g_loss: 1.5658557415008545
Epoch: 47, batch: 72, d_loss: 0.46858924627304077, g_loss: 2.355680227279663
Epoch: 47, batch: 73, d_loss: 0.6716371774673462, g_loss: 1.7127293348312378
Epoch: 47, batch: 74, d_loss: 0.7237206101417542, g_loss: 1.4458404779434204
Epoch: 47, batch: 75, d_loss: 0.9684935212135315, g_loss: 1.943861961364746
Epoch: 47, batch: 76, d_loss: 0.4705852270126343, g_loss: 2.439894199371338
Epoch: 47, batch: 77, d_loss: 0.4989328980445862, g_loss: 1.5290288925170898
Epoch: 47, batch: 78, d_loss: 0.44530192017555237, g_loss: 2.9254989624023438
Epoch: 47, batch: 79, d_loss: 0.6329593658447266, g_loss: 1.7527830600738525
Epoch: 47, batch: 80, d_loss: 0.42348209023475647, g_loss: 1.856258749961853
Epoch: 47, batch: 81, d_loss: 0.5396828651428223, g_loss: 2.268836498260498
Epoch: 47, batch: 82, d_loss: 0.9727945923805237, g_loss: 1.0528483390808105
Epoch: 47, batch: 83, d_loss: 0.7551510334014893, g_loss: 1.508225917816162
Epoch: 47, batch: 84, d_loss: 2.4204068183898926, g_loss: 1.5375216007232666
Epoch: 47, batch: 85, d_loss: 1.517686128616333, g_loss: 0.6334291100502014
Epoch: 47, batch: 86, d_loss: inf, g_loss: 0.7990849614143372
Epoch: 47, batch: 87, d_loss: nan, g_loss: nan
Epoch: 47, batch: 88, d_loss: nan, g_loss: nan
Epoch: 47, batch: 89, d_loss: nan, g_loss: nan
Epoch: 47, batch: 90, d_loss: nan, g_loss: nan
Epoch: 47, batch: 91, d_loss: nan, g_loss: nan

你如果train一下的话会发现,到一定程度就会nan,这个nan我就无法理解了,按道理来说,从loss来看我们定义的来自以log,如果为无穷那么应该是log(0)了,但是我们的discriminator出来的函数是sigmoid啊,sigmoid不可能为0,只看是0-1且不包括闭区间。这个问题比较玄学。

既然nan的话,我也不深究是因为啥了,总之这个重点在于loss,因为后面GAN的变种基本上都是在loss的训练形式上。

GAN 生成mnist

我们现在玩一下mnist把。

交流

我见了一个GAN群,加我微信让我拉进来。jintianiloveu, 顺便下载一个我做的app吧,内侧中,专门用来看美女图片的,你懂得。。传送门

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,417评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,921评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,850评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,945评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,069评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,188评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,239评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,994评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,409评论 1 304
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,735评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,898评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,578评论 4 336
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,205评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,916评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,156评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,722评论 2 363
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,781评论 2 351

推荐阅读更多精彩内容