vqvae

Vq-VAE:向量量化VAE

VAE的本质就是通过隐变量的分布+decoder,获取目标数据分布
基础VAE的思路:对隐变量z进行各向同性标准正态分布的先验假设,训练完模型后我们就可以直接从先验假设中进行采样,将采样结果输入到decoder就可以得到目标分布中的一个样本。
VQ-VAE的思路:对隐变量的分布通过pixel cnn进行建模
假设某单通道图像集分布为p(X),x为其中一个样本,x_i是样本中的第i个像素,则样本xp(X)中出现的概率:p(x)=p(x_n|x_{n-1},x_{n-2},...,x_0)*p(x_{n-1}|x_{n-2},x_{n-3},...,x_0)...*p(x_2|x_1,x_0)*p(x_1|x_0)*p(x_0=\prod_i^n p(x_i|x_{<i})
这个过程称为自回归AutoRegressive,自回归模型由于要逐像素求解,所以对于生成大分辨率图像来说,计算量将是其一个性能瓶颈,为此我们可以在训练过程中采取这么一个策略:将图像编码到低维空间,然后再低维空间利用自回归模型进行建模,然后对低维空间进行解码求得高维空间的图像,训练结束后,我们就可以直接在通过自回归模型建模好的低维空间进行采样,然后解码得到符合目标分布的图像样本;另外对于cv领域的主导深度学习架构CNN来说,其输出值一般为连续值,而对连续值进行自回归建模几乎不可能,所以在VQ-VAE中是将连续值进行离散化,然后对离散化后的latent code进行自回归建模,具体来说就是:对encoder的输出做embedding操作(其实就是做聚类操作,embedding对应的是聚类中心),输出的每个(cx1x1)向量会对应一个embedding,然后将(cx1x1)向量用对应的embedding index替换,就得到了一个离散化的lantent code
VQ-VAE的整体流程:
输入图像\toEncoder\to z_e(x) \to最邻近搜索\to e(用e_i代替z_e(x_i))\toDecoder\to输出图像
后验假设:z对应的e的index进行建模(注意这里是首先假设index服从均匀分布:index共有0~(k-1),k个取值,然后利用pixel cnn对q(z|x)进行建模))

VQ-VAE整体流程(图源见水印)

VQ对应的就是获取离散化lantent code的过程
向量量化公式为:
VQ公式

上述公式是对lantent code进行了one-hot处理,本质是找离z_e{(x)}最近的embedding index
当我们的assumption为:q(z|x)服从0~K的均匀分布,VAE模型中的KL divergence就变成了常数.
kl散度计算公式:
D_{kl}(q(x)||p(x))=-\sum_iq(x_i)log\frac {q(x_i)}{p(x_i)}=-\sum_i q(x_i)logq(x_i)-q(x_i)logp(x_i)=-\sum_i-\frac1k=\frac 1k
其中q(x)时训练得到的分布,p(x)q(x)要拟合的分布,也就是0~K的均匀分布

VQ-VAE 的训练过程

stage1:VQ-VAE要训练的包括三部分:
encoder
decoder
embedding

损失函数总体理解:

loss = ||x-decoder(z+sg[z_q-z])||_2^2 + ||sg(z)-z_q||_2^2 + \beta ||z-sg(z_q)||_2^2

其中第一项是重构损失,用来训练encoder和decoder,需要注意的是,该项在反向传播的时候,是将embedding的梯度直接拷贝给encoder,因为该项并不用来优化embedding
第二项是固定encoder,优化embedding,sg是stop gradient的意思
第三项是固定embedding,优化encoder
具体损失函数设计细节见损失函数设计细节

VQ-VAE的相关代码:
1.整体流程:

    def __init__(self, input_dim, dim, K=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 4, 2, 1),
            ResBlock(dim),
            ResBlock(dim),
        )
        self.codebook = VQEmbedding(K, dim)
        self.decoder = nn.Sequential(
            ResBlock(dim),
            ResBlock(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
            nn.Tanh()
        )
        self.apply(weights_init)
    def encode(self, x):
        z_e_x = self.encoder(x)
        latents = self.codebook(z_e_x)#indices
        return latents
    def decode(self, latents):
        z_q_x = self.codebook.embedding(latents).permute(0, 3, 1, 2)  # (B, D, H, W)注意这里
        x_tilde = self.decoder(z_q_x)
        return x_tilde
    def forward(self, x):
        z_e_x = self.encoder(x)
        z_q_x_st, z_q_x = self.codebook.straight_through(z_e_x)
        x_tilde = self.decoder(z_q_x_st)
        return x_tilde, z_e_x, z_q_x

2.embedding部分:

class VQEmbedding(nn.Module):
    def __init__(self, K, D):
        super().__init__()
        self.embedding = nn.Embedding(K, D)
        self.embedding.weight.data.uniform_(-1./K, 1./K)
    def forward(self, z_e_x):
        z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()#B,D,H,W->B,H,W,D
        latents = vq(z_e_x_, self.embedding.weight)#indices(h,w)
        return latents
    def straight_through(self, z_e_x):
        z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()#B,D,H,W->B,H,W,D
        z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())
        z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()#B,H,W,D->B,D,H,W


        z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
            dim=0, index=indices)#indices:indices_flatten:HW;z_q_x_bar_flatten :(HW,D)
        z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)#(HW,D)->(H,W,D)
        z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()#(H,W,D)->(D,H,W)
        return z_q_x, z_q_x_bar#不优化embedding space,优化embedding space

2.1获取embedding

class VectorQuantization(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):
        with torch.no_grad():
            embedding_size = codebook.size(1)#D
            inputs_size = inputs.size()#(H,W,D)
            inputs_flatten = inputs.view(-1, embedding_size)#(HW,D)

            codebook_sqr = torch.sum(codebook ** 2, dim=1)#求每个embedding的平方和
            inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)#(HXW,D)每个分量的平方和,(HXW,1)

            # Compute the distances to the codebook 欧式距离(a^2+b^2-2ab)^0.5
            distances = torch.addmm(codebook_sqr + inputs_sqr,
                inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)

            _, indices_flatten = torch.min(distances, dim=1)
            indices = indices_flatten.view(*inputs_size[:-1])
            ctx.mark_non_differentiable(indices)

            return indices#(H,W)

    @staticmethod
    def backward(ctx, grad_output):
        raise RuntimeError('Trying to call `.grad()` on graph containing '
            '`VectorQuantization`. The function `VectorQuantization` '
            'is not differentiable. Use `VectorQuantizationStraightThrough` '
            'if you want a straight-through estimator of the gradient.')
class VectorQuantizationStraightThrough(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):
        indices = vq(inputs, codebook)
        indices_flatten = indices.view(-1)
       # 用 ctx 把该存的存起来,留着 backward 的时候用
        ctx.save_for_backward(indices_flatten, codebook)
        ctx.mark_non_differentiable(indices_flatten)

        codes_flatten = torch.index_select(codebook, dim=0,
            index=indices_flatten)#codebook:(K,D),indices_flatten:HW,codes_flatten:(HW,D)
        codes = codes_flatten.view_as(inputs)#(H,W,D)

        return (codes, indices_flatten)#embedding向量及对应的indices

    @staticmethod
    #由于 forward 有2个返回值,所以 backward需要2个参数 接收 梯度。
    def backward(ctx, grad_output, grad_indices):
        grad_inputs, grad_codebook = None, None

        if ctx.needs_input_grad[0]:
            # Straight-through estimator
            grad_inputs = grad_output.clone()#反向传播时候,将输出的梯度直接copy给输入,重构损失的反向传播
        if ctx.needs_input_grad[1]:#涉及到优化embedding
            # Gradient wrt. the codebook
            indices, codebook = ctx.saved_tensors
            embedding_size = codebook.size(1)

            grad_output_flatten = (grad_output.contiguous().view(-1, embedding_size))
            grad_codebook = torch.zeros_like(codebook)
            grad_codebook.index_add_(0, indices, grad_output_flatten)

        return (grad_inputs, grad_codebook)

自定义反向传播:
1.https://blog.csdn.net/u012436149/article/details/78829329
2.https://zhuanlan.zhihu.com/p/344802526
模型训练:

def train(data_loader, model, optimizer, args, writer):
    #pdb.set_trace()
    for images, _ in data_loader:
        images = images.to(args.device)
        optimizer.zero_grad()
        #x_tilde:解码的图像
        #z_e_x:编码器的输出(B,H,W,D)
        #z_q_x:embeding(B,H,W,D),require_grad=True
        x_tilde, z_e_x, z_q_x = model(images)

        # Reconstruction loss
        loss_recons = F.mse_loss(x_tilde, images)#x_tilde的梯度只包含encoder和decoder,反向传播时候不会优化embedding
        # Vector quantization objective
        loss_vq = F.mse_loss(z_q_x, z_e_x.detach())#固定encoder,优化embedding
        # Commitment objective
        loss_commit = F.mse_loss(z_e_x, z_q_x.detach())#固定embedding,优化encoder

        loss = loss_recons + loss_vq + args.beta * loss_commit
        loss.backward()

损失函数设计细节

重构损失函数设计

Straight-Through Estimator操作(前向传播的时候可以用想要的变量(哪怕不可导),而反向传播的时候,用自己针对一些操作设计的梯度)
该操作的目的:
一般的VAE:输入图像\toEncoder\to z \toDecoder\to输出图像
VQ-VAE:输入图像\toEncoder\to z_e(x) \to最邻近搜索\to e(用e_i代替z_e(x_i))\toDecoder\to输出图像
普通VAE用于重建的z_e(x),而VQ-VAE用于重建的是z_q(x),所以理论上重建损失应为||x-decoder(z_q(x))||_2^2,但是获取z_q(x)过程中涉及到argmin操作,该操作不可导;根据Straight-Through Estimator思想,重新设计重构损失为:||x-decoder(z+sg[z_q-z])||_2^2,这样以来,在前向计算loss的时候该项变为||x-decoder(z_q)||_2^2,在反向传播的时候,由于固定了z_q-z的梯度,所以反传时候该项变为||x-decoder(z)||_2^2,就可以用来优化encoder(具体操作的时候就是反向时将VQ的输出的梯度直接拷贝给输入,见代码注解)

embedding(编码表优化)

由于embedding有很大的自由度(embedding刚开始训练的时候,一般是随机初始化),所以我们应该让embedding去靠近z_e(x_i),而不是让z_e(x_i)去接近embedding,所以我们可以将优化embedding的损失函数||z-z_q||_2^2拆解为||z-sg(z_q)||_2^2||z_q-sg(z)||_2^2,这样以来,前向传到时,与embedding有关的损失加倍,反向传播的时候,不影响原来各项的梯度;第一项固定embedding优化encoder,第二项固定encoder优化embedding,同时我们需要z_q去接近z,所以分别给两者一权重,并且需要后者权重大于前者权重,所以总体损失函数应为:
loss = ||x-decoder(z+sg[z_q-z])||_2^2 + \lambda ||sg(z)-z_q||_2^2 + \beta ||z-sg(z_q)||_2^2
其中\beta<\lambda,原文中\lambda=1,\beta=0.25

stage2:对离散化后的lantent code利用pixel cnn建模

经过stage1的处理,我们已经可以通过encoder+vq把图片编码为k=mm的二维矩阵了,该矩阵中的元素对应z_q(x_i)的embedding index,该矩阵在一定程度上也保留了输入图像的位置信息,我们可以用自回归模型比如PixelCNN,来对编码矩阵进行拟合。通过PixelCNN得到编码分布后,就可以随机生成一个新的编码矩阵,然后通过编码表E映射为浮点数矩阵z_q,最后经过decoder得到一张图片

这部分参考苏神的:https://spaces.ac.cn/archives/6760

pixel cnn对图像的建模过程:
p(x)=p(x_n|x_{n-1},x_{n-2},...,x_0)*p(x_{n-1}|x_{n-2},x_{n-3},...,x_0)...*p(x_2|x_1,x_0)*p(x_1|x_0)*p(x_0=\prod_i^n p(x_i|x_{<i})
用神经网络拟合各条件概率p(x_i|x_{<i})
相比于PixelRNN的串行生成各个像素的方式, PixelCNN模型一次就可以将图像 x 的全部像素都并行输入,并在输出端得到与各像素相应的条件概率

image.png

PixelCNN 的实现比较简单,考虑到要用前面的像素估计后面像素的概率,因此在构建 CNN 时,需要应用一个模板,如下是一个 5 × 5 5\times 55×5 的 mask:
image.png

该模板与传统CNN filter 的 weight 逐元点积后,再做常规 convolution 操作;
在构建Loss时,采用交叉熵(cross_entropy)来衡量两个概率的差异,例如:对于Mnist数据集,我们可以将其像素值*255将其变成256各强度的取值,网络输出256个channel,然后再channel维度做softmax将其转成概率值,然后两者做交叉熵损失(注意:pytorch里面的cross_entrophy是包含softmax操作的)
PixelCNN生成样本:

    def generate(self, label, shape=(8, 8), batch_size=64):
        param = next(self.parameters())
        x = torch.zeros( (batch_size, *shape),dtype=torch.int64,device=param.device)
        for i in range(shape[0]):
            for j in range(shape[1]):
                logits = self.forward(x, label)
                #获取(i,j)位置的像素值,logits[:,:,i,j]是(i,j)位置维度为output_dim维的向量
                probs = F.softmax(logits[:, :, i, j], -1)
               #从softmax结果采样1个概率值
                x.data[:, i, j].copy_( probs.multinomial(1).squeeze().data)
        return x

https://pytorch.org/docs/master/generated/torch.multinomial.html#torch.multinomial

PixelCNN:
1.https://zhuanlan.zhihu.com/p/115257230
2.https://blog.csdn.net/StreamRock/article/details/95516065

VQ-VAE的推理过程

1.通过PixelCNN获取离散化的lantent code
2.查表获取latent code对应的embedding,将embedding (z_q)输入到decoder,解码得到目标分布中的图像样本

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,402评论 6 499
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,377评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,483评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,165评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,176评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,146评论 1 297
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,032评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,896评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,311评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,536评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,696评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,413评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,008评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,659评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,815评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,698评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,592评论 2 353

推荐阅读更多精彩内容

  • 1 为什么要对特征做归一化 特征归一化是将所有特征都统一到一个大致相同的数值区间内,通常为[0,1]。常用的特征归...
    顾子豪阅读 1,335评论 0 1
  • 1 为什么要对特征做归一化 特征归一化是将所有特征都统一到一个大致相同的数值区间内,通常为[0,1]。常用的特征归...
    顾子豪阅读 6,338评论 2 22
  • 注明:本文是对一篇整理166篇文献的综述翻译,其中对应文献地址都已附上为方便点击查看学习。查看有的文献可能需要科学...
    leon_kbl阅读 4,312评论 0 6
  • 我们都知道,牛顿说过一句名言 If I have seen further, it is by standing ...
    weizier阅读 8,353评论 5 25
  • 表情是什么,我认为表情就是表现出来的情绪。表情可以传达很多信息。高兴了当然就笑了,难过就哭了。两者是相互影响密不可...
    Persistenc_6aea阅读 124,921评论 2 7