Vq-VAE:向量量化VAE
VAE的本质就是通过隐变量的分布+decoder,获取目标数据分布
基础VAE的思路:对隐变量进行各向同性标准正态分布的先验假设,训练完模型后我们就可以直接从先验假设中进行采样,将采样结果输入到decoder就可以得到目标分布中的一个样本。
VQ-VAE的思路:对隐变量的分布通过pixel cnn进行建模
假设某单通道图像集分布为,为其中一个样本,是样本中的第i个像素,则样本在中出现的概率:
这个过程称为自回归AutoRegressive,自回归模型由于要逐像素求解,所以对于生成大分辨率图像来说,计算量将是其一个性能瓶颈,为此我们可以在训练过程中采取这么一个策略:将图像编码到低维空间,然后再低维空间利用自回归模型进行建模,然后对低维空间进行解码求得高维空间的图像,训练结束后,我们就可以直接在通过自回归模型建模好的低维空间进行采样,然后解码得到符合目标分布的图像样本;另外对于cv领域的主导深度学习架构CNN来说,其输出值一般为连续值,而对连续值进行自回归建模几乎不可能,所以在VQ-VAE中是将连续值进行离散化,然后对离散化后的latent code进行自回归建模,具体来说就是:对encoder的输出做embedding操作(其实就是做聚类操作,embedding对应的是聚类中心),输出的每个(cx1x1)向量会对应一个embedding,然后将(cx1x1)向量用对应的embedding index替换,就得到了一个离散化的lantent code
VQ-VAE的整体流程:
输入图像Encoder 最邻近搜索 (用代替)Decoder输出图像
后验假设:对对应的的index进行建模(注意这里是首先假设index服从均匀分布:index共有0~(k-1),k个取值,然后利用pixel cnn对进行建模))
VQ对应的就是获取离散化lantent code的过程
向量量化公式为:
上述公式是对lantent code进行了one-hot处理,本质是找离最近的embedding index
当我们的assumption为:服从0~K的均匀分布,VAE模型中的KL divergence就变成了常数.
kl散度计算公式:
其中时训练得到的分布,是要拟合的分布,也就是0~K的均匀分布
VQ-VAE 的训练过程
stage1:VQ-VAE要训练的包括三部分:
encoder
decoder
embedding
损失函数总体理解:
其中第一项是重构损失,用来训练encoder和decoder,需要注意的是,该项在反向传播的时候,是将embedding的梯度直接拷贝给encoder,因为该项并不用来优化embedding
第二项是固定encoder,优化embedding,是stop gradient的意思
第三项是固定embedding,优化encoder
具体损失函数设计细节见损失函数设计细节
VQ-VAE的相关代码:
1.整体流程:
def __init__(self, input_dim, dim, K=512):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(input_dim, dim, 4, 2, 1),
nn.BatchNorm2d(dim),
nn.ReLU(True),
nn.Conv2d(dim, dim, 4, 2, 1),
ResBlock(dim),
ResBlock(dim),
)
self.codebook = VQEmbedding(K, dim)
self.decoder = nn.Sequential(
ResBlock(dim),
ResBlock(dim),
nn.ReLU(True),
nn.ConvTranspose2d(dim, dim, 4, 2, 1),
nn.BatchNorm2d(dim),
nn.ReLU(True),
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
nn.Tanh()
)
self.apply(weights_init)
def encode(self, x):
z_e_x = self.encoder(x)
latents = self.codebook(z_e_x)#indices
return latents
def decode(self, latents):
z_q_x = self.codebook.embedding(latents).permute(0, 3, 1, 2) # (B, D, H, W)注意这里
x_tilde = self.decoder(z_q_x)
return x_tilde
def forward(self, x):
z_e_x = self.encoder(x)
z_q_x_st, z_q_x = self.codebook.straight_through(z_e_x)
x_tilde = self.decoder(z_q_x_st)
return x_tilde, z_e_x, z_q_x
2.embedding部分:
class VQEmbedding(nn.Module):
def __init__(self, K, D):
super().__init__()
self.embedding = nn.Embedding(K, D)
self.embedding.weight.data.uniform_(-1./K, 1./K)
def forward(self, z_e_x):
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()#B,D,H,W->B,H,W,D
latents = vq(z_e_x_, self.embedding.weight)#indices(h,w)
return latents
def straight_through(self, z_e_x):
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()#B,D,H,W->B,H,W,D
z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())
z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()#B,H,W,D->B,D,H,W
z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
dim=0, index=indices)#indices:indices_flatten:HW;z_q_x_bar_flatten :(HW,D)
z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)#(HW,D)->(H,W,D)
z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()#(H,W,D)->(D,H,W)
return z_q_x, z_q_x_bar#不优化embedding space,优化embedding space
2.1获取embedding
class VectorQuantization(Function):
@staticmethod
def forward(ctx, inputs, codebook):
with torch.no_grad():
embedding_size = codebook.size(1)#D
inputs_size = inputs.size()#(H,W,D)
inputs_flatten = inputs.view(-1, embedding_size)#(HW,D)
codebook_sqr = torch.sum(codebook ** 2, dim=1)#求每个embedding的平方和
inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)#(HXW,D)每个分量的平方和,(HXW,1)
# Compute the distances to the codebook 欧式距离(a^2+b^2-2ab)^0.5
distances = torch.addmm(codebook_sqr + inputs_sqr,
inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)
_, indices_flatten = torch.min(distances, dim=1)
indices = indices_flatten.view(*inputs_size[:-1])
ctx.mark_non_differentiable(indices)
return indices#(H,W)
@staticmethod
def backward(ctx, grad_output):
raise RuntimeError('Trying to call `.grad()` on graph containing '
'`VectorQuantization`. The function `VectorQuantization` '
'is not differentiable. Use `VectorQuantizationStraightThrough` '
'if you want a straight-through estimator of the gradient.')
class VectorQuantizationStraightThrough(Function):
@staticmethod
def forward(ctx, inputs, codebook):
indices = vq(inputs, codebook)
indices_flatten = indices.view(-1)
# 用 ctx 把该存的存起来,留着 backward 的时候用
ctx.save_for_backward(indices_flatten, codebook)
ctx.mark_non_differentiable(indices_flatten)
codes_flatten = torch.index_select(codebook, dim=0,
index=indices_flatten)#codebook:(K,D),indices_flatten:HW,codes_flatten:(HW,D)
codes = codes_flatten.view_as(inputs)#(H,W,D)
return (codes, indices_flatten)#embedding向量及对应的indices
@staticmethod
#由于 forward 有2个返回值,所以 backward需要2个参数 接收 梯度。
def backward(ctx, grad_output, grad_indices):
grad_inputs, grad_codebook = None, None
if ctx.needs_input_grad[0]:
# Straight-through estimator
grad_inputs = grad_output.clone()#反向传播时候,将输出的梯度直接copy给输入,重构损失的反向传播
if ctx.needs_input_grad[1]:#涉及到优化embedding
# Gradient wrt. the codebook
indices, codebook = ctx.saved_tensors
embedding_size = codebook.size(1)
grad_output_flatten = (grad_output.contiguous().view(-1, embedding_size))
grad_codebook = torch.zeros_like(codebook)
grad_codebook.index_add_(0, indices, grad_output_flatten)
return (grad_inputs, grad_codebook)
自定义反向传播:
1.https://blog.csdn.net/u012436149/article/details/78829329
2.https://zhuanlan.zhihu.com/p/344802526
模型训练:
def train(data_loader, model, optimizer, args, writer):
#pdb.set_trace()
for images, _ in data_loader:
images = images.to(args.device)
optimizer.zero_grad()
#x_tilde:解码的图像
#z_e_x:编码器的输出(B,H,W,D)
#z_q_x:embeding(B,H,W,D),require_grad=True
x_tilde, z_e_x, z_q_x = model(images)
# Reconstruction loss
loss_recons = F.mse_loss(x_tilde, images)#x_tilde的梯度只包含encoder和decoder,反向传播时候不会优化embedding
# Vector quantization objective
loss_vq = F.mse_loss(z_q_x, z_e_x.detach())#固定encoder,优化embedding
# Commitment objective
loss_commit = F.mse_loss(z_e_x, z_q_x.detach())#固定embedding,优化encoder
loss = loss_recons + loss_vq + args.beta * loss_commit
loss.backward()
损失函数设计细节
重构损失函数设计
Straight-Through Estimator操作(前向传播的时候可以用想要的变量(哪怕不可导),而反向传播的时候,用自己针对一些操作设计的梯度)
该操作的目的:
一般的VAE:输入图像Encoder Decoder输出图像
VQ-VAE:输入图像Encoder 最邻近搜索 (用代替)Decoder输出图像
普通VAE用于重建的,而VQ-VAE用于重建的是,所以理论上重建损失应为,但是获取过程中涉及到argmin操作,该操作不可导;根据Straight-Through Estimator思想,重新设计重构损失为:,这样以来,在前向计算loss的时候该项变为,在反向传播的时候,由于固定了的梯度,所以反传时候该项变为,就可以用来优化encoder(具体操作的时候就是反向时将VQ的输出的梯度直接拷贝给输入,见代码注解)
embedding(编码表优化)
由于embedding有很大的自由度(embedding刚开始训练的时候,一般是随机初始化),所以我们应该让embedding去靠近,而不是让去接近embedding,所以我们可以将优化embedding的损失函数拆解为和,这样以来,前向传到时,与embedding有关的损失加倍,反向传播的时候,不影响原来各项的梯度;第一项固定embedding优化encoder,第二项固定encoder优化embedding,同时我们需要去接近,所以分别给两者一权重,并且需要后者权重大于前者权重,所以总体损失函数应为:
其中,原文中
stage2:对离散化后的lantent code利用pixel cnn建模
经过stage1的处理,我们已经可以通过encoder+vq把图片编码为的二维矩阵了,该矩阵中的元素对应的embedding index,该矩阵在一定程度上也保留了输入图像的位置信息,我们可以用自回归模型比如PixelCNN,来对编码矩阵进行拟合。通过PixelCNN得到编码分布后,就可以随机生成一个新的编码矩阵,然后通过编码表E映射为浮点数矩阵,最后经过decoder得到一张图片
这部分参考苏神的:https://spaces.ac.cn/archives/6760
pixel cnn对图像的建模过程:
用神经网络拟合各条件概率
相比于PixelRNN的串行生成各个像素的方式, PixelCNN模型一次就可以将图像 x 的全部像素都并行输入,并在输出端得到与各像素相应的条件概率
PixelCNN 的实现比较简单,考虑到要用前面的像素估计后面像素的概率,因此在构建 CNN 时,需要应用一个模板,如下是一个 5 × 5 5\times 55×5 的 mask:
该模板与传统CNN filter 的 weight 逐元点积后,再做常规 convolution 操作;
在构建Loss时,采用交叉熵(cross_entropy)来衡量两个概率的差异,例如:对于Mnist数据集,我们可以将其像素值*255将其变成256各强度的取值,网络输出256个channel,然后再channel维度做softmax将其转成概率值,然后两者做交叉熵损失(注意:pytorch里面的cross_entrophy是包含softmax操作的)
PixelCNN生成样本:
def generate(self, label, shape=(8, 8), batch_size=64):
param = next(self.parameters())
x = torch.zeros( (batch_size, *shape),dtype=torch.int64,device=param.device)
for i in range(shape[0]):
for j in range(shape[1]):
logits = self.forward(x, label)
#获取(i,j)位置的像素值,logits[:,:,i,j]是(i,j)位置维度为output_dim维的向量
probs = F.softmax(logits[:, :, i, j], -1)
#从softmax结果采样1个概率值
x.data[:, i, j].copy_( probs.multinomial(1).squeeze().data)
return x
https://pytorch.org/docs/master/generated/torch.multinomial.html#torch.multinomial
PixelCNN:
1.https://zhuanlan.zhihu.com/p/115257230
2.https://blog.csdn.net/StreamRock/article/details/95516065
VQ-VAE的推理过程
1.通过PixelCNN获取离散化的lantent code
2.查表获取latent code对应的embedding,将embedding ()输入到decoder,解码得到目标分布中的图像样本