parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on')
parser.add_argument('--n_future', type=int, default=10, help='number of frames to predict during training')
parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict during eval')
n_past个frame作为参考,预测之后n_future个frame。
LSTM Model
import models.lstm as lstm_models
if opt.model_dir != '':
frame_predictor = saved_model['frame_predictor']
posterior = saved_model['posterior']
prior = saved_model['prior']
else:
frame_predictor = lstm_models.lstm(opt.g_dim+opt.z_dim, opt.g_dim, opt.rnn_size, opt.rnn_layers, opt.batch_size)
posterior = lstm_models.gaussian_lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.rnn_layers, opt.batch_size)
prior = lstm_models.gaussian_lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.rnn_layers, opt.batch_size)
frame_predictor.apply(utils.init_weights)
posterior.apply(utils.init_weights)
prior.apply(utils.init_weights)
其中lstm.py如下:
class lstm(nn.Module):
def __init__(self, input_size, output_size, hidden_size, n_layers, batch_size):
super(lstm, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.batch_size = batch_size
self.n_layers = n_layers
self.embed = nn.Linear(input_size, hidden_size)
self.lstm = nn.ModuleList([nn.LSTMCell(hidden_size, hidden_size) for i in range(self.n_layers)])
self.output = nn.Sequential(
nn.Linear(hidden_size, output_size),
#nn.BatchNorm1d(output_size),
nn.Tanh())
self.hidden = self.init_hidden()
def init_hidden(self):
hidden = []
for i in range(self.n_layers):
hidden.append((Variable(torch.zeros(self.batch_size, self.hidden_size).cuda()),
Variable(torch.zeros(self.batch_size, self.hidden_size).cuda())))
return hidden
def forward(self, input):
embedded = self.embed(input.view(-1, self.input_size))
h_in = embedded
for i in range(self.n_layers):
self.hidden[i] = self.lstm[i](h_in, self.hidden[i])
h_in = self.hidden[i][0]
return self.output(h_in)
继承了nn.Module的类在赋值时提供大小参数(初始化),在调用时提供输入参数(forward计算)。里面的nn.Linear(insize, outsize)、nn.LSTMCell(insize, outsize)等等也是一样的工作原理。
lstm与gaussian_lstm区别在于:
class gaussian_lstm(nn.Module):
def __init__(self, input_size, output_size, hidden_size, n_layers, batch_size):
...
self.mu_net = nn.Linear(hidden_size, output_size)
self.logvar_net = nn.Linear(hidden_size, output_size)
...
def reparameterize(self, mu, logvar):
logvar = logvar.mul(0.5).exp_()
eps = Variable(logvar.data.new(logvar.size()).normal_())
return eps.mul(logvar).add_(mu)
def forward(self, input):
embedded = self.embed(input.view(-1, self.input_size))
h_in = embedded
for i in range(self.n_layers):
self.hidden[i] = self.lstm[i](h_in, self.hidden[i])
h_in = self.hidden[i][0]
mu = self.mu_net(h_in)
logvar = self.logvar_net(h_in)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
输出的是正态采样的z,以及均值和对数标准差。
Encoder / Decoder
以dcgan_64为例
if opt.model == 'dcgan':
if opt.image_width == 64:
import models.dcgan_64 as model
dcgan_64.py如下:
import torch
import torch.nn as nn
class dcgan_conv(nn.Module):
def __init__(self, nin, nout):
super(dcgan_conv, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(nin, nout, 4, 2, 1),
nn.BatchNorm2d(nout),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, input):
return self.main(input)
class dcgan_upconv(nn.Module):
def __init__(self, nin, nout):
super(dcgan_upconv, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(nin, nout, 4, 2, 1),
nn.BatchNorm2d(nout),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, input):
return self.main(input)
class encoder(nn.Module):
def __init__(self, dim, nc=1):
super(encoder, self).__init__()
self.dim = dim
nf = 64
# input is (nc) x 64 x 64
self.c1 = dcgan_conv(nc, nf)
# state size. (nf) x 32 x 32
self.c2 = dcgan_conv(nf, nf * 2)
# state size. (nf*2) x 16 x 16
self.c3 = dcgan_conv(nf * 2, nf * 4)
# state size. (nf*4) x 8 x 8
self.c4 = dcgan_conv(nf * 4, nf * 8)
# state size. (nf*8) x 4 x 4
self.c5 = nn.Sequential(
nn.Conv2d(nf * 8, dim, 4, 1, 0),
nn.BatchNorm2d(dim),
nn.Tanh()
)
def forward(self, input):
h1 = self.c1(input)
h2 = self.c2(h1)
h3 = self.c3(h2)
h4 = self.c4(h3)
h5 = self.c5(h4)
return h5.view(-1, self.dim), [h1, h2, h3, h4]
encoder将每一层都输出了
class decoder(nn.Module):
def __init__(self, dim, nc=1):
super(decoder, self).__init__()
self.dim = dim
nf = 64
self.upc1 = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
nn.BatchNorm2d(nf * 8),
nn.LeakyReLU(0.2, inplace=True)
)
# state size. (nf*8) x 4 x 4
self.upc2 = dcgan_upconv(nf * 8 * 2, nf * 4)
# state size. (nf*4) x 8 x 8
self.upc3 = dcgan_upconv(nf * 4 * 2, nf * 2)
# state size. (nf*2) x 16 x 16
self.upc4 = dcgan_upconv(nf * 2 * 2, nf)
# state size. (nf) x 32 x 32
self.upc5 = nn.Sequential(
nn.ConvTranspose2d(nf * 2, nc, 4, 2, 1),
nn.Sigmoid()
# state size. (nc) x 64 x 64
)
def forward(self, input):
vec, skip = input
d1 = self.upc1(vec.view(-1, self.dim, 1, 1))
d2 = self.upc2(torch.cat([d1, skip[3]], 1))
d3 = self.upc3(torch.cat([d2, skip[2]], 1))
d4 = self.upc4(torch.cat([d3, skip[1]], 1))
output = self.upc5(torch.cat([d4, skip[0]], 1))
return output
decoder的每一层和encoder的每一层连接起来。
Training Functions
def train(x):
...
mse = 0
kld = 0
for i in range(1, opt.n_past+opt.n_future):
h = encoder(x[i-1])
h_target = encoder(x[i])[0]
if opt.last_frame_skip or i < opt.n_past:
h, skip = h
else:
h = h[0]
z_t, mu, logvar = posterior(h_target)
_, mu_p, logvar_p = prior(h)
h_pred = frame_predictor(torch.cat([h, z_t], 1))
x_pred = decoder([h_pred, skip])
mse += mse_criterion(x_pred, x[i])
kld += kl_criterion(mu, logvar, mu_p, logvar_p)
loss = mse + kld*opt.beta
loss.backward()
frame_predictor_optimizer.step()
posterior_optimizer.step()
prior_optimizer.step()
encoder_optimizer.step()
decoder_optimizer.step()
return mse.data.cpu().numpy()/(opt.n_past+opt.n_future), kld.data.cpu().numpy()/(opt.n_future+opt.n_past)
在n_past之前每次更新skip。
从第n_past开始,将i的输出作为i+1的输入;并且每次均对目标帧进行解码采样。
for epoch in range(opt.niter):
frame_predictor.train()
posterior.train()
prior.train()
encoder.train()
decoder.train()
epoch_mse = 0
epoch_kld = 0
progress = progressbar.ProgressBar(max_value=opt.epoch_size).start()
for i in range(opt.epoch_size):
progress.update(i+1)
x = next(training_batch_generator)
# train frame_predictor
mse, kld = train(x)
epoch_mse += mse
epoch_kld += kld
progress.finish()
.train(): Sets the module in training mode.
Predict
def make_gifs(x, idx):
for s in range(nsample):
progress.update(s+1)
gen_seq = []
gt_seq = []
frame_predictor.hidden = frame_predictor.init_hidden()
posterior.hidden = posterior.init_hidden()
prior.hidden = prior.init_hidden()
x_in = x[0]
all_gen.append([])
all_gen[s].append(x_in)
for i in range(1, opt.n_eval):
h = encoder(x_in)
if opt.last_frame_skip or i < opt.n_past:
h, skip = h
else:
h, _ = h
h = h.detach()
if i + 1 < opt.n_past:
h_target = encoder(x[i])[0].detach()
z_t, _, _ = posterior(h_target)
else:
z_t, _, _ = prior(h)
if i < opt.n_past:
frame_predictor(torch.cat([h, z_t], 1))
x_in = x[i]
all_gen[s].append(x_in)
else:
h = frame_predictor(torch.cat([h, z_t], 1)).detach()
x_in = decoder([h, skip]).detach()
gen_seq.append(x_in.data.cpu().numpy())
gt_seq.append(x[i].data.cpu().numpy())
all_gen[s].append(x_in)
_, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)
progress.finish()
n_past及以后的图片都是未知的,因此从第n_past开始,将i的输出作为i+1的输入;从第n_past-1开始,从prior分布中根据输入采样。