SeqGAN的概念来自AAAI 2017的SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient一文。
Motivation
如题所示,这篇文章的核心思想是将GAN与强化学习的Policy Gradient算法结合到一起——这也正是D2IA-GAN在处理Generator的优化时使用的技巧。
而该论文的出发点也是意识到了标准的GAN在处理像序列这种离散数据时会遇到的困难,主要体现在两个方面:Generator难以传递梯度更新,Discriminator难以评估非完整序列。
对于前者,作者给出的解决方案对我来说比较熟悉,即把整个GAN看作一个强化学习系统,用Policy Gradient算法更新Generator的参数;对于后者,作者则借鉴了蒙特卡洛树搜索(Monte Carlo tree search,MCTS)的思想,对任意时刻的非完整序列都可以进行评估。
问题定义
根据强化学习的设定,在时刻t,当前的状态s被定义为“已生成的序列”
,记作
,而动作a是接下来要选出的元素
,所以policy模型就是
值得一提的是,这里的policy模型是stochastic,输出的是动作的概率分布;而状态的转移则显然是deterministic,一旦动作确定了,接下来的状态也就确定了。
根据Policy Gradient算法,Generator的优化目标是令从初始状态开始的value(累积的reward期望值)最大化:
其中,
是完整序列的reward,
是action-value函数,是指“在状态s下选择动作a,此后一直遵循着policy做决策,最终得到的value”。所以对于最右边的式子我们可以这样来理解:在初始状态下,对于policy可能选出的每个y,都计算对应的value,把这些value根据policy的概率分布加权求和,就得到了初始状态的value。
action-value函数
接下来的关键是如何定义因为Discriminator充当了这个强化学习系统的environment,所以Discriminator的输出应当作为reward。但是Discriminator只能对生成的完整序列进行评估,因此目前只能对完整序列状态的value进行定义:
这是远远不够的,必须要对任意状态的value都有定义。
蒙特卡洛树搜索(MCTS)
在评估任意时刻的序列时,我们考虑的其实都是它能带来的long-term reward,就像下围棋或象棋一样,每下一步棋都要以全局为考量。在围棋和象棋的求解算法中,MCTS是一个很重要的组成部分,所以作者想到了将它应用到当前的问题。
从名字得知,这种算法属于一种蒙特卡洛方法(Monte Carlo method)——根据维基百科,也称统计模拟方法,是指使用随机数(或更常见的伪随机数)来解决很多计算问题的方法。MCTS正是这样一种基于统计模拟的启发式搜索算法,常用于游戏的决策过程。
MCTS可以无限循环,而每一次循环都由以下4个步骤构成:
- Selection:从根节点开始,连续选择子节点向下搜索,直至抵达一个叶节点。子节点的选择方法一般采用UCT(Upper Confidence Bound applied to trees)算法,根据节点的“胜利次数”和“游戏次数”来计算被选中的概率,保持了Exploitation和Exploration的平衡,是保证搜索向最优发展的关键。
- Expansion:在叶节点创建多个子节点。
- Simulation:在创建的子节点中根据roll-out policy选择一个节点进行模拟,又称为playout或者rollout。它和Selection的区别在于:Selection指的是对于搜索树中已有节点的选择,从根节点开始,有历史统计数据作为参考,使用UCT算法选择每次的子节点;Simulation是简单的模拟,从叶节点开始,用自定义的roll-out policy(可以只是简单的随机概率)来选择子节点,且模拟经过的节点并不加入树中。
- Backpropagation:根据Simulation的结果,沿着搜索树的路径向上更新节点的统计信息,包括“胜利次数”和“游戏次数”,用于Selection做决策。
在SeqGAN中,实际上只应用了上述的Simulation过程:对于非完整的序列
,以
(等同于Generator)作为roll-out policy,将剩余的T-t个元素模拟出来,这样就可以利用Discriminator进行评估了。为了减小对value估计的误差,会进行N次模拟,对这N个结果取平均值。
最终得到了完整的action-value函数:
policy gradient计算
Generator目标函数的梯度可以初步推导为:
在此基础上,可以去掉期望项,构造一个无偏估计再继续推导:
源码对loss的实现为:
111行:x是一个batch生成的所有序列,原来是一个三维数组,这里进行了reshape并转化为one-hot vector,最终得到一个二维数组,每一行以one-hot的形式代表这些生成序列的每一个元素,行数是batch size*sequence length。
-
113行:最终也是得到一个二维数组,行数与上面相同,每一行代表这些生成序列每个时刻t关于所有候选元素的log概率分布,形如
-
114行:这里的括号对应110行,运算得到这些序列每个元素被选中的log likelihood,即
116行:这些生成序列每个时刻的reward。
-
117行:括号对应于109行的结尾,括号内的运算得到了每个时刻的