vae

VAE-变分自编码器

变分:loss部分有kl divergence,kl散度是一个泛函数,泛函数求极值要用到变分法,VAE 的名字中“变分”,是因为它的推导过程用到了 KL 散度及其性质。
——https://zhuanlan.zhihu.com/p/34998569
(泛函是一种函数,其输入经常为函数,输出经常为实数)
——https://www.zhihu.com/question/21938224
###算法目标
构建一个latent code:z到目标图像target:x的一个生成模型

本质

两种分布之间的映射关系
得到目标x所服从的分布,通过采样得到目标分布中的样本
分布1:输入z所服从的分布,一般先验假设为高斯分布或者均匀分布
分布2:目标x所服从的分布
模型最终目的是学习一个x=g(z)的映射关系

解决方案

用已知分布(先验假设)去估计目标分布
分布1:输入z所服从的分布,一般先验假设为高斯分布或者均匀分布
分布2:目标x所服从的分布
VAE模型最终目的:学习一个x=g(z)的映射关系,将先验分布中
的样本通过映射关系就可以得到目标分布中的样本

难点

衡量模型生成数据所服从的分布p(g(z_i))与目标分布p(x_i)之间的差距,因为我们只有sample,没办法得到两者分布的表达式
(如果我们知道目标分布,直接在目标分布中进行采样就可以得到想要的数据了,就没必要探讨生成模型了)

VAE的解决方法

目标数据样本:x:{x_1,x_2,...,x_n...}
目标数据的分布:p(x)
进一步:p(x)=\sum_ip(z_i)p(x_i|z_i)(不严谨,意思理解即可)
这里的p(x_i|z_i)就描述了一个由z_i生成x_i的模型,此时再对p(z_i)做一个先验假设:p(z_i)服从标准正太分布,那么就可以从先验假设中进行采样,然后经过p(x_i|z_i)就可以得到目标分布中的样本

整体网络结构:

输入\mapstoencoder\mapsto均值 vs 方差(target:标准正态分布)\mapstoresample\mapstodecoder\mapsto输出(target:输入)

VAE整体结构(图源见水印)

上述结构图容易会让人困惑:例如,从图中的正太分布中采样得到的latent code:Z_1输入到生成器中得到的生成样本\hat X_1是否与真实样本X_1对应,只有对应才可以进行对比操作;而实际中,VAE的编码器是针对每个样本X_i通过编码器求得一组专属的(\mu_i,\sigma_i),进一步得到一个均值向量为\mu_i方差向量为\sigma_i的各向同性的正太分布,并令其向标准正太分布看齐,然后从均值为\mu_i方差为\sigma_i的正太分布中进行resample得到一个采样变量,进一步输入到生成器中解码得到与输入X_i对应的\hat X_i
对应的代码:

def forward(self, x):
        #pdb.set_trace()
        mu, logvar = self.encoder(x).chunk(2, dim=1)

        q_z_x = Normal(mu, logvar.mul(.5).exp())
        p_z = Normal(torch.zeros_like(mu), torch.ones_like(logvar))
        kl_div = kl_divergence(q_z_x, p_z).sum(1).mean()

        x_tilde = self.decoder(q_z_x.rsample())
        return x_tilde, kl_div

VAE的整体结构,更具体一步应该如下图所示:

VAE结构图(图源见水印)

如果只用"对比"的方式训练VAE,那么就很容易使得编码器输出的方差为0,(这里的方差相当于引入了"噪声"),VAE就退化成了一个普通的AE,就没有图像生成的能力了,只有图像编码和解码能力;为了让其有图像生成能力,VAE进一步约束正太分布向标准正太分布对齐,这样就可以再模型训练完毕之后扔掉编码器,直接从标准正太分布中采样输入到生成器就可以生成图像。
VAE结构图(图源见水印)

图像生成部分代码:

def generate_samples():
    model.eval()
    z_e_x = torch.randn(64, Z_DIM, 1, 1).cuda()#Z_DIM,1,1是Z_i的维度
    x_tilde = model.decoder(z_e_x)

    images = (x_tilde.cpu().data + 1) / 2

    save_image(
        images,
        'samples/vae_samples_{}.png'.format(DATASET),
        nrow=8
    )

利用FashionMnist训练VAE,生成样本结果:


vae_samples_0.png

vae_samples_1.png

vae_samples_2.png

git地址:https://github.com/ritheshkumar95/pytorch-vqvae.git
关于先验假设的另一种解释(假设p(z)服从标准正太分布本质上是假设后验分布p(z|x)服从标准正态分布)https://zhuanlan.zhihu.com/p/34998569

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容