去年写的文章,从notion的博客搬到这边来发一下(本来想搬到微信公众号的,但是那个格式真的反人类就作罢了),原文请到这里看mewimpetus.以后文章都会再这边先发。
引言
扩散模型是今年AI领域最热门的研究方向。由其引发的AI绘画的产业变革正在如火如荼的进行,大有淘汰一大票初中级画师的势头,目前主流的(诸如OpenAI的DALL-E 2;Google的ImageGen;以及已经商业化的MidJourney;注重二次元的NovelAI;开源引爆这波热潮的stable-diffusion)图像生成模型效果已经让人惊艳,若是再发展几年,它带来的影响将不可估量,可以说整个绘画产业正在经历着一场百年未有之大变局。而这些功能强大的绘画模型,无疑都与Denoising Diffusion Probabilistic Models 摆脱不了关系,它的原始论文由Google Brain在2020年发表。 这篇博文主要带大家一起来探究一下DDPM的工作原理和实现细节。
扩散模型的基本流程
其实扩散模型的基本思路同GAN以及VAE并无二致,都是试图从一个简单分布的随机噪声出发,经过一系列的转换,转变成类似于真实数据的数据样本。
它主要包含前向加噪声和反向去噪声两个过程:
- 从真实的数据分布中随机采样一个图片,然后通过一个固定的过程逐步往上面添加高斯随机噪声,直到图片变成一个纯粹的噪声
- 构建一个神经网络,去学习一个去噪的过程,从一个纯粹的噪声出发,逐步还原回一个真实的图像。
接下来我们用数学形式来表达上面的两个过程。
前向扩散
我们将真实数据的分布定义为,然后可以从这个分布中随机采样一个”真图“ ,于是我们就可以定义一个前向扩散的递推过程为每个时间步添加少量高斯噪声并执行步。DDPM作者将定义为这样一个条件高斯分布(其中的是一个既定的递增表):
显然,当时刻的图像为的条件下,时刻的图像服从一个均值,方差的各项同性高斯分布。我们再观察一下这个递推式,因为和都小于1,显然的均值会比更加趋向于,方差也更趋向于,因此如果设计合适的序列,最终的将趋近于标准的高斯分布。根据高斯分布的性质1:
如果且与都是实数,那么。
上述的条件高斯分布显然可以通过从标准高斯分布的线性变换得到,我们定义,那么只要让,那么第个时间步的图像。
为了更好的计算任意时刻的条件分布,我们根据上面的递推式逐步推导到,为了方便推导,我们令 , 则有了推导1:
上式中第3行到第4行的推导用到了上述的性质1,以及高斯分布的另一个性质2:
如果与 是独立统计的高斯随机变量,那么,它们的和也满足高斯分布。
由性质1可知,,而 ,再根据性质2,就可得, 再根据性质1写回到多项式的形式即得到推导的结果。
基于这个最终的推导结果,因为是事先已经定义好的,我们只需要给出初始真实分布采样,即可以计算出任何第步的样本 ,而不需要每次都从开始一步步计算。
反向去噪
有了前向的过程,我们反过来想,既然前向扩散是一个马尔可夫过程,那么它的逆过程显然也是马尔可夫过程,如果我们可以构造一个相反的条件分布,那不就可以从最终的开始一步步地去噪,从而反推回初始的了吗? 但是我们并不知道反向条件高斯分布的均值和方差。不过,在这个深度学习的时代,我们可以从真实数据集出发,通过前向过程生成一系列的 的真实扩散序列,然后设计一个神经网络从这些序列中来近似学习一个分布使其接近真实的,其中的是这个神经网络需要学习的参数,于是从变换到的概率可以表示成:
当我们前向过程所定义的足够小时,反向过程也满足高斯分布,因此我们可以假设神经网络要学习的这个分布是高斯分布,这意味着它需要去学习其均值 和方差 ,换成与上述前向过程相同的表示则有递推公式:
借助这个公式,我们就可以完成去噪过程了,接下来的任务变成了如何训练这个神经网络。
如何训练
基本思路
不知大家又没有觉得这个加噪声和去噪声的过程和VAE的编码和解码的过程十分类似,那么是否可以从VAE的训练方式中得到一些启发呢?实际上作者就是这么想的。
显然,如果直接使用与的对比误差会导致模型过拟合成AE一样的无生成能力的模型。因此,我们使用与VAE类似的变分推断的方法,希望网络输出的尽量接近由真实变化而来的的分布,即最小化似然与真实的的。于是每一个时间步骤的误差可以定义为:
而当时,因为是确定的,因此可以忽略这部分,故而,因此
于是整个去噪过程的误差就是: 。实际训练时,我们并没有使用整体的误差,而是通过均匀随机选择 ,来最小化 。
目标函数
要直接计算上面的KL散度是困难的,但是正如前面所说的, 是一个高斯分布,于是根据贝叶斯公式有:
其中 代表所有剩余与无关的项。
根据高斯分布的基本方程:
与上述的推导结果位置依次对应可得其方差和均值为:
根据上面的推导1可得 ,带入上式可得:
最小化上述的KL散度,可以转化为计算神经网络的预测的均值方差与上述均值方差的L2损失:
DDPM的论文作者在论文中说他使用一个固定的方差取得了差不多的效果,因此他的神经网络只去学习了均值,而把方差设置成了 或者是,因此我们接下来的推导也只考虑均值。后来Improved diffusion models 这篇论文将其改进后就让神经网络同时去学习均值和方差了,有兴趣的同学可以自行去了解。
观察上面的,除了,其余项均为固定值与无关,于是我们不妨将神经网络的学习目标从高斯分布的均值转变为 ,即去预测每个事件步的噪声量而非高斯分布的均值,因此我们最终的目标函数就变成了:
然后,整个训练算法便是这样一个过程:
- 从真实的复杂未知分布随机抽取一个样本
- 从到均匀采样一个时间步
- 从均值为方差为的标准高斯分布中随机采样一个
- 计算随机梯度 ,并通过随机梯度下降优化
- 重复上述过程直到收敛
采样生成
当上述的神经网络学习好 , 就可以计算出均值 ,于是我们就可以从一个随机高斯噪声 ,通过条件去噪概率 进行采样生成,逐步从 到 。
具体来说Sampling是这样一个过程:
随机采样一个
-
令 ,依次执行:
返回最终的
网络结构
虽然有了训练的方案,但是如何来设计这个神经网络才能让我们这个扩散和反扩散的过程取得较好的效果呢?DDPM的作者选择了U-Net ,并且在实验中取得了很好的效果。
这个用于学习的U-Net网络十分复杂,由一系列的诸如下采样、上采样、残差、位置Embedding、ResNet/ConvNeXT block、注意力模块、Group Normalization等组件组合而成,为了让大家了解整个网络各个组件的具体结构和连接方式,我绘制了一个详细的网络图:
根据这个图,我们可以用tensorflow或者pytorch非常轻松的实现这个网络。不过显然这个网络很大,特别是图片很大时占用的显存会很高,而且采样步骤多推理也很慢,因此后面有很多对于DDPM的改进,篇幅关系,关于对DDPM的改进我们下篇文章再讲。