理解GAN

1. 基础

生成模型

在一些任务中,我们需要用到生成模型(aka. 概率生成模型),例如:作文生成,图像生成等任务。

生成模型定义:一系列用于随机生成可观测数据的模型。

说到底,生成模型主要是尽可能地捕获到真实数据的分布,然后以这个分布随机生成与真实数据很相似的数据。

从这里,我们就能了解到生成模型具备的两个基本功能:

  • 密度估计
  • 生成样本

这里的密度估计其实就是数据分布的估计

生成模型 vs 判别模型

生成模型可以用来进行密度估计,而密度估计是典型的无监督学习问题,需要对隐变量(分布的参数)进行建模,然后用EM算法来进行求解。

图1 带隐变量的生成模型.png

但是,生成模型还可以应用于监督学习,只需要将隐变量替换成标签变量即可。

图2 带类别的生成模型.png

监督学习的目标是建模输出标签的条件密度函数p(y|\textbf{x}),在这里,我们可以将其转换成联合概率密度函数p(\textbf{x}, y)的密度估计问题。

监督学习中比较典型的生成模型有朴素贝叶斯分类器隐马尔科夫模型

和生成模型相对应的另一类监督学习模型是判别模型。

判别式模型生成模型的主要区别:

  • 判别式模型直接建模条件概率密度函数 p(y|\textbf{x})
  • 生成模型建模其联合概率密度函数 p(\textbf{x}, y)

生成模型可以得到判别模型,但由判别模型得不到生成模型

隐式密度模型 vs 显式密度模型

显式密度估计:显式构建出样本的概率密度函数p(\textbf{x}; \theta),并通过最大似然估计来求解参数。
隐式密度估计:不需要估计出概率密度函数p(\textbf{x}; \theta),只需要将拟合模型,使其能够生成符合数据分布p(\textbf{x}; \theta)的样本。

由于生成模型建模的是联合概率 p(\textbf{x}, \textbf{z}; \theta) = p(\textbf{x} | \textbf{z}; \theta) p(\textbf{z}; \theta),因此在这里,样本的概率密度函数p(\textbf{x}; \theta) 指的就是 p(\textbf{x} | \textbf{z}; \theta)

隐式密度模型的一个关键是:如何确保生成网络产生的样本一定是服从真实的数据分布。
难点:不构建显式密度函数,就无法通过最大似然估计等方法来训练,那么该怎样去训练呢?

生成样本步骤

生成样本(aka. 采样):给定一个概率密度函数为p_{\theta}(\textbf{x})的分布,生成服从这个分布的样本。

对于图1中的图模型,我们能得到两个变量的局部条件概率p_{\theta}(\textbf{z})p_{\theta}(\textbf{x} | \textbf{z})

为什么是条件概率?因为他们都需要在参数\theta这个条件下。

此时的生成过程如下:

  1. 根据隐变量的先验分布p_{\theta}(\textbf{z})进行采样,得到样本\textbf{z}
  2. 根据条件分布p_{\theta}(\textbf{x} | \textbf{z})进行采样,得到\textbf{x}

因此在生成模型中,重点是估计条件分布p(\textbf{x} | \textbf{z}; \theta)

2. 什么是GAN

GAN:英文全称是Generative Adversarial Networks,中文名是生成对抗网络

GAN是通过对抗训练的方式来是的生成网络产生的样本服从真实数据分布。

对抗训练:有两个网络一起训练,一个是生成网络,一个是判别网络判别网络的目标:尽量准确地判断一个样本是来自 于真实数据还是生成网络产生的;生成网络的目标:尽量生成判别网络 无法区分来源的样本;设计的目的:通过判别网络来辅助生成网络的训练。

GAN是隐式密度模型,注意力在生产样本上,而不是在密度估计上。它采用了对抗网络的方式来对生成模型进行训练

判别网络目标函数

判别网络目标函数是最小化交叉熵,如下:

min_{\phi} - (\mathbb{E}_\textbf{x} [y log p(y=1 | x) + (1 - y) log p(y = 0 | x) ]) \qquad (1.1) \\ = max_{\phi} (\mathbb{E}_{\textbf{x} \sim p_r(\textbf{x})} [log D(x; \phi)] + \mathbb{E}_{\textbf{x'} \sim p_{\theta}(\textbf{x'})} [log (1 - D(x'; \phi) )] ) \qquad (1.2) \\ = max_{\phi} (\mathbb{E}_{\textbf{x} \sim p_r(\textbf{x})} [log D(x; \phi)] + \mathbb{E}_{\textbf{z} \sim p(\textbf{z})} [log (1 - D(G(z;\theta); \phi) )] ) \qquad (1.3)

生成网络目标函数

生成网络的目标刚好和判别网络相反,即让判别网络将自己生成的样本判别为真实样本,如下:

max_{\theta} (\mathbb{E}_{\textbf{z} \sim p(\textbf{z})} [log D(G(z;\theta); \phi) ] ) \qquad (1.4) \\ = min_{\theta} (\mathbb{E}_{\textbf{z} \sim p(\textbf{z})} [log (1 - D(G(z;\theta); \phi)) ] ) \qquad (1.5)

但是我们在训练的时候,我们往往用公式(1.4)来进行生成模型参数的更新。

我们知道,函数log(x),x ∈ (0,1)x接近1时的梯度要比接近0时
的梯度小很多,接近饱和区间。
这样,当判别网络 D 以很高的概率认为生成网络G产生的样本是“假”样本,即 (1 - D(G(z;\theta); \phi)) \rightarrow 1,此时目标函数关于\theta的梯度很小,从而不利于优化。

GAN的目标函数

GAN的目标函数是结合了生成模型和判别模型的目标函数,如下:

min_{\theta} \ max_{\phi} \ V(D, G) = \mathbb{E}_{\textbf{x} \sim p_{data} (\textbf{x})} [log D(\textbf{x}; \phi)] + \mathbb{E}_{\textbf{z} \sim p_{z} (\textbf{z})} [log ( 1 - D(G(\textbf{z}; \theta); \phi)) ] \qquad (1.6)

其中,GD分别值的是生成模型和判别模型,\theta\phi分别是GD的参数。

解释:在判别函数有足够的能力的情况下,好的生成模型生成的样本应该能够使得数据的期望(或者是数据的熵)越小。

算法

图3 生成对抗网络算法

在训练的时候,需要平衡两个网络的能力:

  • 对于判别网络来说,一开始的判别能力不能太强,否则难以提升生成网络的能力,因为若生成网络全都分类正确,生成网络的梯度就会消失(D(G(\textbf{z}; \theta); \phi) = 0,对\theta求导也是0)
  • 判别网络也不能太弱,否则针对它训练的生成网络也不会太好,应为若判别错误,生成网络进行梯度更新就会朝错误的方向更新
  • 技巧:每次迭代时,使得判别网络比生成网络能力稍强些
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容