import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
data_mean = 4
data_standard_deviation = 1.25
g_input_size = 1 # Random noise dimension coming into generator, per output vector
g_hidden_size = 50 # Generator complexity
g_output_size = 1 # size of generated output vector
d_input_size = 100 # Minibatch size - cardinality of distributions
d_hidden_size = 50 # Discriminator complexity
d_output_size = 1 # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size # use batch gradient descent
d_lr = 2e-4
g_lr = 2e-4
optim_betas = (0.9, 0.999)
n_epoches = 10000
print_interval = 1400
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.map1(x))
x = F.sigmoid(self.map2(x))
x = self.map3(x)
return x
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.map1(x))
x = F.relu(self.map2(x))
x = F.sigmoid(self.map3(x))
return x
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_size, hidden_size=d_hidden_size, output_size=d_output_size)
# input_size = 100^2 hidden_size = 50 output_size = 1
print(G)
print(D)
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr = d_lr, betas = optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr = g_lr, betas = optim_betas)
def stats(d):
return [np.mean(d), np.std(d)]
def extract(v):
return v.data.storage().tolist()
for epoch in range(n_epoches):
d_real_input = torch.Tensor(np.random.normal(data_mean, data_standard_deviation,(1, d_input_size)))
g_input = torch.rand(d_input_size,g_input_size) # 100x1
D.zero_grad()
d_real_output = D(Variable(d_real_input))
d_real_error = criterion(d_real_output, Variable(torch.ones(1)))
d_real_error.backward()
d_fake_input = G(Variable(g_input)).detach()#100x1=>100x1
d_fake_output = D(d_fake_input.t())#100x1=>1
#???????????
d_fake_error = criterion(d_fake_output, Variable(torch.zeros(1)))
d_fake_error.backward()
d_optimizer.step()
# Only optimizes D's parameters; changes based on stored gradients from backward()
# ???????????????????????
G.zero_grad()
g_output = G(Variable(g_input))# it's the fake data G generated
g_error = criterion(D(g_output.t()), Variable(torch.ones(1)))
g_error.backward()
g_optimizer.step()
if epoch % print_interval == 0:
print(epoch)
print(stats(extract(Variable(d_real_input))))
print(stats(extract(d_fake_input)))
MyGAN
最后编辑于 :
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
推荐阅读更多精彩内容
- 序 最近失业待家学习ROR开发,生活真不容易啊!闲来扯淡不如言归正传。 IDE:RubyMine Ruby版本:2...
- iPhone诞生10周年纪念日,苹果在乔布斯剧场如期召开了2017年秋季新品发布会,相信大家已经对本次发布的5款硬...