一文搞懂变分自编码器(VAE, CVAE)

变分自编码,英文是Variational AutoEncoder,简称VAE。它是包含隐变量的一种模型

变分自编码器与对抗生成网络类似,均是为了解决数据生成问题而生的。在自编码器结构中,通常需要一个输入数据,而且所生成的数据与输入数据是相同的。但是通常希望生成的数据具有一定程度的不同,这需要输入随机向量并且模型能够学习生成图像的风格化特点,因此在后续研究中以随机化向量作为输入生成特定样本的对抗生成网络结构便产生了。变分自编码器同样的以特定分布的随机样本作为输入,并且可以生成相应的图像,从此方面来看其与对抗生成网络目标是相似的。但是变分自编码器不需要判别器,而是使用编码器来估计特定分布。总体结构来看与自编码器结构类似,但是中间传递向量为特定分布的随机向量,这里需要特别区分:编码器、解码器、生成器和判别器

一. VAE原理

先假设一个隐变量Z的分布,构建一个从Z到目标数据X的模型,即构建X=g(Z),使得学出来的目标数据与真实数据的概率分布相近

VAE的结构图如下:

VAE对每一个样本X_k匹配一个高斯分布,隐变量Z就是从高斯分布中采样得到的。对K个样本来说,每个样本的高斯分布假设为\mathcal N(\mu_k,\sigma_k^2),问题就在于如何拟合这些分布。VAE构建两个神经网络来进行拟合均值与方差。即\mu_k=f_1(X_k),log\sigma_k^2=f_2(X_k),拟合log\sigma_k^2的原因是这样无需加激活函数

此外,VAE让每个高斯分布尽可能地趋于标准高斯分布\mathcal N(0,1)。这拟合过程中的误差损失则是采用KL散度作为计算,下面做详细推导:

VAE与同为生成模型的GMM(高斯混合模型)也有很相似,实际上VAE可看成是GMM的一个distributed representation的版本。GMM是有限个高斯分布的隐变量z的混合,而VAE可看成是无穷个隐变量z的混合,VAE中的z可以是高斯也可以是非高斯的

原始样本数据x的概率分布:
P(x)=\int_Z P(x)P(x|z)dz\tag{1}
假设z服从标准高斯分布,先验分布P(x|z)是高斯的,即x|z \sim N(\mu(z),\sigma(z))\mu(z)、\sigma(z)是两个函数, 分别是z对应的高斯分布的均值和方差,则P(x)就是在积分域上所有高斯分布的累加:

由于P(z)是已知的,P(x|z)未知,所以求解问题实际上就是求\mu,\sigma这两个函数。最开始的目标是求解P(x),且希望P(x)越大越好,这等价于求解关于x最大对数似然:
L=\sum_x logP(x)\tag{2}
logP(x)可变换为:
\begin{aligned} logP(x)&=\int_z q(z|x)logP(x)dz \\ &=\int_z q(z|x)log(\dfrac{P(z,x)}{P(z|x)})dz \\ &=\int_z q(z|x)log(\dfrac{P(z,x)}{q(z|x)}\dfrac{q(z|x)}{P(z|x)})dz\\ &=\int_z q(z|x)log(\dfrac{P(z,x)}{q(z|x)})dz+ \int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz\\ &=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz + \int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz \end{aligned}\tag{3}
到这里我们发现,第二项\int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz其实就是qP的KL散度,即KL(q(z|x)\;||\;P(z|x)),因为KL散度是大于等于0的,所以上式进一步可写成:
logP(x)\geq \int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz\tag{4}
这样就找到了一个下界(lower bound),也就是式子的右项,即:
L_b=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz\tag{5}
原式也可表示成:
logP(x)=L_b+KL(q(z|x)\;||\;P(z|x))
为了让logP(x)越大,目的就是要最大化它的这个下界

推到这里,可能会有个疑问:为什么要引入q(z|x),这里的q(z|x)可以是任何分布?

实际上,因为后验分布P(z|x)很难求(intractable),所以才用q(z|x)来逼近这个后验分布。在优化的过程中发现,首先q(z|x)logP(x)是完全没有关系的,logP(x)只跟P(z|x)有关,调节q(z|x)是不会影响似然也就是logP(x)的。所以,当固定住P(x|z)时,调节q(z|x)最大化下界L_b,KL则越小。当q(z|x)与不断逼近后验分布P(z|x)时,KL散度趋于为0,logP(x)就和L_b等价。所以最大化logP(x)就等价于最大化L_b

回顾L_b
\begin{aligned} L_b&=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz \\ &=\int_z q(z|x)log(\dfrac{P(z)}{q(z|x)})dz+\int_z q(z|x)logP(x|z)dz \\ &=-KL(q(z|x)\;||\;P(z)) + \int_z q(z|x)logP(x|z)dz \\ &=-KL(q(z|x)\;||\;P(z)) + E_{q(z|x)}[log(P(x|z))] \end{aligned}\tag{6}
显然,最大化L_b就是等价于最小化KL(q(z|x)\;||\;P(z))和最大化E_{q(z|x)}[log(P(x|z))]

第一项,最小化KL散度:前面已假设了P(z)是服从标准高斯分布的,且q(z∣x)是服从高斯分布\mathcal N(\mu,\sigma^2),于是代入计算可得:
\begin{aligned} KL(q(z|x)\;||\;P(z))=KL(\mathcal N(\mu,\sigma^2)\;||\;\mathcal N(0,1))=&\int\dfrac{1}{\sqrt{2\pi\sigma^2}}e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left( log\dfrac{e^{\frac{-(x-\mu)^2}{2\sigma^2}}/\sqrt{2\pi\sigma^2}}{ e^{\frac{-x^2}{2}}/\sqrt{2\pi} } \right)dx \\&...\text{化简得到} \\=&\dfrac{1}{2}\dfrac{1}{\sqrt{2\pi\sigma^2}}\int e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left(-log\sigma^2 +x^2-\dfrac{(x-\mu)^2}{\sigma^2} \right)dx \\=&\dfrac{1}{2}\int \dfrac{1}{\sqrt{2\pi\sigma^2}} e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left(-log\sigma^2 +x^2-\dfrac{(x-\mu)^2}{\sigma^2} \right)dx \end{aligned}\tag{7}
对上式中的积分进一步求解,\dfrac{1}{\sqrt{2\pi\sigma^2}}e^{\frac{-(x-\mu)^2}{2\sigma^2}}实际就是概率密度f(x),而概率密度函数的积分就是1,所以积分第一项等于-log\sigma^2;而又因为高斯分布的二阶矩就是E(X^2)=\int x^2f(x)dx=\mu^2+\sigma^2,正好对应积分第二项。又根据方差的定义可知\sigma=\int (x-\mu)dx,所以积分第三项为-1

最终化简得到的结果如下:
KL(q(z|x)\;||\;P(z))=KL(\mathcal N(\mu,\sigma^2)\;||\;\mathcal N(0,1))=\dfrac{1}{2}(-log\sigma^2+\mu^2+\sigma^2-1)\tag{8}
第二项,最大化期望。也就是表明在给定q(z|x)(编码器输出)的情况下P(x∣z)(解码器输出)的值尽可能高

  1. 第一步,利用encoder的神经网络计算出均值与方差,从中采样得到z,这一过程就对应式子中的q(z∣x)
  2. 第二步,利用decoder的N计算z的均值方差,让均值(或也考虑方差)越接近x,则产生x的几率logP(x|z)越大,对应于式子中的最大化logP(x∣z)这一部分

重参数技巧

最后模型在实现的时候,有一个重参数技巧,就是想从高斯分布\mathcal N(\mu,\sigma^2)中采样Z时,其实是相当于从\mathcal N(0,1)中采样一个\epsilon,然后再来计算 Z=\mu+\epsilon\times\sigma。这么做的原因是,采样这个操作是不可导的,而采样的结果是可导的,这样做个参数变换,Z=\mu+\epsilon\times\sigma这个就可以参与梯度下降,模型就可以训练了

class VAE(nn.Module):
    """Implementation of VAE(Variational Auto-Encoder)"""
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 200)
        self.fc2_mu = nn.Linear(200, 10)
        self.fc2_log_std = nn.Linear(200, 10)
        self.fc3 = nn.Linear(10, 200)
        self.fc4 = nn.Linear(200, 784)
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        mu = self.fc2_mu(h1)
        log_std = self.fc2_log_std(h1)
        return mu, log_std
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        recon = torch.sigmoid(self.fc4(h3))  # use sigmoid because the input image's pixel is between 0-1
        return recon
    def reparametrize(self, mu, log_std):
        std = torch.exp(log_std)
        eps = torch.randn_like(std)  # simple from standard normal distribution
        z = mu + eps * std
        return z
    def forward(self, x):
        mu, log_std = self.encode(x)
        z = self.reparametrize(mu, log_std)
        recon = self.decode(z)
        return recon, mu, log_std
    def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
        recon_loss = F.mse_loss(recon, x, reduction="sum")  # use "mean" may have a bad effect on gradients
        kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std))
        kl_loss = torch.sum(kl_loss)
        loss = recon_loss + kl_loss
        return loss

二. CVAE原理

在条件变分自编码器(CVAE)中,模型的输出就不是\mathbf{x}_j了,而是对应于输入\mathbf{x}_i的任务相关数据\mathbf{y}_i,不过套路和VAE是一样的,这次的最大似然估计变成了\log p_{\theta}(\mathbf{Y}\mid\mathbf{X}),即::
\begin{aligned} \log p_{\theta}(\mathbf{Y}\mid\mathbf{X})&=1\cdot\log p_{\theta}(\mathbf{Y}\mid\mathbf{X})\\ &=\left(\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\mathrm{d}\mathbf{z}\right)\log p_{\theta}(\mathbf{Y}\mid\mathbf{X}) \\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log p_{\theta}(\mathbf{Y}\mid\mathbf{X})\mathrm{d}\mathbf{z}\\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z}\mid\mathbf{X},\mathbf{Y})p_{\theta}(\mathbf{X})}\mathrm{d}\mathbf{z}\\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z}\mid\mathbf{X},\mathbf{Y})}\frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})p_{\theta}(\mathbf{X})}\mathrm{d}\mathbf{z}\\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z}\mid\mathbf{X},\mathbf{Y})}\mathrm{d}\mathbf{z}~+~\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})p_{\theta}(\mathbf{X})}\mathrm{d}\mathbf{z}\\ &=D_{K L}(q_{\phi}, p_{\theta}) ~+~ \ell(p_{\theta}, q_{\phi})\end{aligned} \tag{9}
ELBO(Empirical Lower Bound)\ell(p_{\theta}, q_{\phi}),进一步:
\begin{aligned} \ell(p_{\theta}, q_{\phi})&=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})p_{\theta}(\mathbf{X})}\mathrm{d}\mathbf{z}\\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{p_{\theta}(\mathbf{Y}\mid\mathbf{X},\mathbf{Z})p_{\theta}(\mathbf{Z}\mid\mathbf{X})p_{\theta}(\mathbf{X})}{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})p_{\theta}(\mathbf{X})}\mathrm{d}\mathbf{z}\\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{p_{\theta}(\mathbf{Z}\mid\mathbf{X})}{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})}\mathrm{d}\mathbf{z}~+~\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log p_{\theta}(\mathbf{Y}\mid\mathbf{X,\mathbf{Z}})\mathrm{d}\mathbf{z}\\ &=-D_{K L}(q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\mid p_{\theta}(\mathbf{Z}\mid\mathbf{X}))~+~\mathbb{E}_{q_{\phi}}[\log p_{\theta}(\mathbf{Y}\mid\mathbf{X},\mathbf{Z})] \end{aligned} \tag{10}
网络结构包含三个部分:

  • 先验网络p_{\theta}(\mathbf{z}\mid\mathbf{X}),如下图(b)所示
  • Recognition网络q_{\phi}(\mathbf{z}\mid\mathbf{X},\mathbf{Y}), 如下图(c)所示D
  • ecoder网络p_{\theta}(\mathbf{Y}\mid\mathbf{X},\mathbf{Z}),如下图(b)所示
class CVAE(nn.Module):
    """Implementation of CVAE(Conditional Variational Auto-Encoder)"""
    def __init__(self, feature_size, class_size, latent_size):
        super(CVAE, self).__init__()
        self.fc1 = nn.Linear(feature_size + class_size, 200)
        self.fc2_mu = nn.Linear(200, latent_size)
        self.fc2_log_std = nn.Linear(200, latent_size)
        self.fc3 = nn.Linear(latent_size + class_size, 200)
        self.fc4 = nn.Linear(200, feature_size)
    def encode(self, x, y):
        h1 = F.relu(self.fc1(torch.cat([x, y], dim=1)))  # concat features and labels
        mu = self.fc2_mu(h1)
        log_std = self.fc2_log_std(h1)
        return mu, log_std
    def decode(self, z, y):
        h3 = F.relu(self.fc3(torch.cat([z, y], dim=1)))  # concat latents and labels
        recon = torch.sigmoid(self.fc4(h3))  # use sigmoid because the input image's pixel is between 0-1
        return recon
    def reparametrize(self, mu, log_std):
        std = torch.exp(log_std)
        eps = torch.randn_like(std)  # simple from standard normal distribution
        z = mu + eps * std
        return z
    def forward(self, x, y):
        mu, log_std = self.encode(x, y)
        z = self.reparametrize(mu, log_std)
        recon = self.decode(z, y)
        return recon, mu, log_std
    def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
        recon_loss = F.mse_loss(recon, x, reduction="sum")  # use "mean" may have a bad effect on gradients
        kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std))
        kl_loss = torch.sum(kl_loss)
        loss = recon_loss + kl_loss
        return loss
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,417评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,921评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,850评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,945评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,069评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,188评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,239评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,994评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,409评论 1 304
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,735评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,898评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,578评论 4 336
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,205评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,916评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,156评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,722评论 2 363
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,781评论 2 351

推荐阅读更多精彩内容