变分自编码,英文是Variational AutoEncoder,简称VAE。它是包含隐变量的一种模型
变分自编码器与对抗生成网络类似,均是为了解决数据生成问题而生的。在自编码器结构中,通常需要一个输入数据,而且所生成的数据与输入数据是相同的。但是通常希望生成的数据具有一定程度的不同,这需要输入随机向量并且模型能够学习生成图像的风格化特点,因此在后续研究中以随机化向量作为输入生成特定样本的对抗生成网络结构便产生了。变分自编码器同样的以特定分布的随机样本作为输入,并且可以生成相应的图像,从此方面来看其与对抗生成网络目标是相似的。但是变分自编码器不需要判别器,而是使用编码器来估计特定分布。总体结构来看与自编码器结构类似,但是中间传递向量为特定分布的随机向量,这里需要特别区分:编码器、解码器、生成器和判别器
一. VAE原理
先假设一个隐变量Z的分布,构建一个从Z到目标数据X的模型,即构建,使得学出来的目标数据与真实数据的概率分布相近
VAE的结构图如下:
VAE对每一个样本匹配一个高斯分布,隐变量就是从高斯分布中采样得到的。对个样本来说,每个样本的高斯分布假设为,问题就在于如何拟合这些分布。VAE构建两个神经网络来进行拟合均值与方差。即,拟合的原因是这样无需加激活函数
此外,VAE让每个高斯分布尽可能地趋于标准高斯分布。这拟合过程中的误差损失则是采用KL散度作为计算,下面做详细推导:
VAE与同为生成模型的GMM(高斯混合模型)也有很相似,实际上VAE可看成是GMM的一个distributed representation
的版本。GMM是有限个高斯分布的隐变量的混合,而VAE可看成是无穷个隐变量的混合,VAE中的可以是高斯也可以是非高斯的
原始样本数据的概率分布:
假设服从标准高斯分布,先验分布是高斯的,即。是两个函数, 分别是对应的高斯分布的均值和方差,则就是在积分域上所有高斯分布的累加:
由于是已知的,未知,所以求解问题实际上就是求这两个函数。最开始的目标是求解,且希望越大越好,这等价于求解关于最大对数似然:
而可变换为:
到这里我们发现,第二项其实就是和的KL散度,即,因为KL散度是大于等于0的,所以上式进一步可写成:
这样就找到了一个下界(lower bound),也就是式子的右项,即:
原式也可表示成:
为了让越大,目的就是要最大化它的这个下界
推到这里,可能会有个疑问:为什么要引入,这里的可以是任何分布?
实际上,因为后验分布很难求(intractable),所以才用来逼近这个后验分布。在优化的过程中发现,首先跟是完全没有关系的,只跟有关,调节是不会影响似然也就是的。所以,当固定住时,调节最大化下界,KL则越小。当与不断逼近后验分布时,KL散度趋于为0,就和等价。所以最大化就等价于最大化
回顾:
显然,最大化就是等价于最小化和最大化。
第一项,最小化KL散度:前面已假设了是服从标准高斯分布的,且是服从高斯分布,于是代入计算可得:
对上式中的积分进一步求解,实际就是概率密度,而概率密度函数的积分就是1,所以积分第一项等于;而又因为高斯分布的二阶矩就是,正好对应积分第二项。又根据方差的定义可知,所以积分第三项为-1
最终化简得到的结果如下:
第二项,最大化期望。也就是表明在给定(编码器输出)的情况下(解码器输出)的值尽可能高
- 第一步,利用encoder的神经网络计算出均值与方差,从中采样得到,这一过程就对应式子中的
- 第二步,利用decoder的计算的均值方差,让均值(或也考虑方差)越接近,则产生的几率越大,对应于式子中的最大化这一部分
重参数技巧:
最后模型在实现的时候,有一个重参数技巧,就是想从高斯分布中采样时,其实是相当于从中采样一个,然后再来计算 。这么做的原因是,采样这个操作是不可导的,而采样的结果是可导的,这样做个参数变换,这个就可以参与梯度下降,模型就可以训练了
class VAE(nn.Module):
"""Implementation of VAE(Variational Auto-Encoder)"""
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 200)
self.fc2_mu = nn.Linear(200, 10)
self.fc2_log_std = nn.Linear(200, 10)
self.fc3 = nn.Linear(10, 200)
self.fc4 = nn.Linear(200, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
mu = self.fc2_mu(h1)
log_std = self.fc2_log_std(h1)
return mu, log_std
def decode(self, z):
h3 = F.relu(self.fc3(z))
recon = torch.sigmoid(self.fc4(h3)) # use sigmoid because the input image's pixel is between 0-1
return recon
def reparametrize(self, mu, log_std):
std = torch.exp(log_std)
eps = torch.randn_like(std) # simple from standard normal distribution
z = mu + eps * std
return z
def forward(self, x):
mu, log_std = self.encode(x)
z = self.reparametrize(mu, log_std)
recon = self.decode(z)
return recon, mu, log_std
def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
recon_loss = F.mse_loss(recon, x, reduction="sum") # use "mean" may have a bad effect on gradients
kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std))
kl_loss = torch.sum(kl_loss)
loss = recon_loss + kl_loss
return loss
二. CVAE原理
在条件变分自编码器(CVAE)中,模型的输出就不是了,而是对应于输入的任务相关数据,不过套路和VAE是一样的,这次的最大似然估计变成了,即::
则ELBO(Empirical Lower Bound)
为,进一步:
网络结构包含三个部分:
- 先验网络,如下图(b)所示
- Recognition网络, 如下图(c)所示D
- ecoder网络,如下图(b)所示
class CVAE(nn.Module):
"""Implementation of CVAE(Conditional Variational Auto-Encoder)"""
def __init__(self, feature_size, class_size, latent_size):
super(CVAE, self).__init__()
self.fc1 = nn.Linear(feature_size + class_size, 200)
self.fc2_mu = nn.Linear(200, latent_size)
self.fc2_log_std = nn.Linear(200, latent_size)
self.fc3 = nn.Linear(latent_size + class_size, 200)
self.fc4 = nn.Linear(200, feature_size)
def encode(self, x, y):
h1 = F.relu(self.fc1(torch.cat([x, y], dim=1))) # concat features and labels
mu = self.fc2_mu(h1)
log_std = self.fc2_log_std(h1)
return mu, log_std
def decode(self, z, y):
h3 = F.relu(self.fc3(torch.cat([z, y], dim=1))) # concat latents and labels
recon = torch.sigmoid(self.fc4(h3)) # use sigmoid because the input image's pixel is between 0-1
return recon
def reparametrize(self, mu, log_std):
std = torch.exp(log_std)
eps = torch.randn_like(std) # simple from standard normal distribution
z = mu + eps * std
return z
def forward(self, x, y):
mu, log_std = self.encode(x, y)
z = self.reparametrize(mu, log_std)
recon = self.decode(z, y)
return recon, mu, log_std
def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
recon_loss = F.mse_loss(recon, x, reduction="sum") # use "mean" may have a bad effect on gradients
kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std))
kl_loss = torch.sum(kl_loss)
loss = recon_loss + kl_loss
return loss