生成式大模型的RLHF技术(一):基础

一、概述

大语言模型(LLMs)在预训练的过程中通常会捕捉数据的特征,而这些训练数据通常既包含高质量的也包含低质量的,因此模型有时会产生不被期望的行为,如编造事实,生成有偏见或有毒的文本,甚至对人类有害的内容。因此,将LLMs与人类价值观(如helpful, honest, 和harmless, 即3H)对齐是非常重要的,目前采用的主流的技术即是基于人类反馈的强化学习技术(RLHF)。

通常来说,RLHF包括三个步骤:
①supervised fine-tuning (SFT):对LLMs进行微调,LLMs通过模仿人类标注的对话示例来学习通用的的类似人类的对话。
②reward model (RM) training:对于模型对同一个prompt的多个回复,利用人类标注来进行排序以获取人类偏好,然后单独使用另一个语言模型作为reward model,在这个reward model上使用标注的数据进行训练(类似排序任务)。
③proximal policy optimization (PPO):以训练得到的reward model作为reward function来继续训练优化LLMs,促进其与人类偏好的对齐。

RLHF

RLHF的整个过程如上图所示,也可参考OpenAI的InstructGPT的做法:InstructGPT:语言模型的人类反馈指令对齐。InstructGPT使用的数据规模如下表所示:

数据规模

本文会介绍基于PPO(proximal policy optimization)的RLHF的技术细节,包括Reward Modeling、Policy Gradient Methods、Advantage Actor-Critic (A2C)等。本文主要介绍A2C框架,其中包括四个模型的调度:
①Policy Model/Actor Model:由SFT之后的模型初始化而来。作为策略(policy)模型,用于接收上文,做出动作,预测下一个字符。学习完毕之后,我们最终使用的就是这个模型。
②Reference Model:和Actor Model同样初始化自SFT Model,训练过程中冻结参数,用于和Actor Model做对比,保证模型不要偏离原始SFT Model太多。
③Reward Model:作为环境(env),训练过程中冻结参数,针对每一个状态,给出奖励分数。
④Critic Model:由Reward Model初始化而来,用于近似价值函数,输入为当前的状态,估计当前状态的价值。

二、Reward Modeling

Reward model可以使用移除了最后一个unembedding层的预训练语言模型来作为基础架构,通常就是将最后一个token最终的embedding输入给一个线性层,然后得到一个标量值,即是reward的值。在InstructGPT中,作者尝试用了1.3B、6B和175B的GPT-3来做实验,最终综合考虑只用6B的模型来训练reward model。在训练reward model时,对于同一个输入promptx,有一个更偏好的输出y_w和一个相对不偏好的输出y_l。每一对偏好和不偏好的回复的损失为:

\mathcal{L}(\psi )=log\, \sigma (r(x,y_{w})-r(x,y_{l}))

这里的\sigma是sigmoid函数,r代表reward model,其参数为\psir(x,y)是一个reward model为xy预测的标量得分。另外,也可以额外加上一个模仿学习(imitation learning)的损失,也就是一个语言模型的预训练损失,来让模型模仿句子对中更偏好的那一个:

\mathcal{L}(\psi )=-\lambda \mathbb{E}_{(x,y_w,y_l)\backsim \mathcal{D}_{rm}}[log\, \sigma (r(x,y_{w})-r(x,y_{l}))]+\beta _{rm}\mathbb{E}_{(x,y_w)\backsim \mathcal{D}_{rm}}[log(r^{\prime}(x,y_w))]

这里的\lambda\beta _{rm}都是超参数,\mathcal{D}_{rm}是训练集的经验分布,r^{\prime}r除了顶部线性层不同以外是同一个模型(r^{\prime}线性层的维度为词典的大小),r^{\prime}(x,y_w)是给定promptx和偏好回复y_w后的似然。

在PPO阶段,使用训练得到的reward model来作为reward function训练policy模型时,还可以为reward function添加一个基于当前policy模型\pi _{\phi }^{RL}和SFT模型\pi ^{SFT}之间KL散度的惩罚项,此时reward function为:

r_{total}=r(x,y)-\eta KL(\pi _{\phi }^{RL}(y|x),\pi ^{SFT}(y|x))

这里的\eta是一个超参数。这个KL散度的惩罚项,通常来说有两个作用:
①作为一个entropy bonus,促进模型在policy空间中的探索,防止policy model过早地收敛到一个单一的模式。
②可以确保强化学习policy model的输出不会与reward model在训练阶段遇到的样本严重偏离。

三、Reinforcement Learning

将强化学习应用于对话生成是一种艰难的挑战,这是因为其状态-动作空间(state-action space)是非常巨大和广阔的。在这样的场景中,我们将人类的互动作为环境(environment)。在每个时间步t,agent(也就是LLMs)将从环境(即对话历史)中接收一个状态s_t,其中包括所有的对话历史文本(也就是prompt和已经生成的回复)。接着,基于policy\pi,agent的动作a_t为生成下一个token。环境相应的也会反馈一个rewardr(s_t,a_t),这个reward来自于一个reward functionrr是从人类偏好数据中训练得到的,也就是前文的reward model。此后,agent将转换到下一个状态s_{t+1}。整个过程将得到一个轨迹(trajectory)\tau =\left \{s_1,a_1,s_2,a_2,\dots ,s_T,a_T\right \}。对于LLMs的一个输入x和输出y来说,s_1=xa_i=y_i,可采取的动作的所有选择即是模型词典中的所有token。强化学习的目的即是最大化一个轨迹的累积reward(也就是回报,return)。一种有限期无折扣回报( finite-horizon undiscounted return)为R(\tau )=\textstyle\sum_{t=1}^{T^{\prime}}r(s_t,a_t),即有限步数的累积reward的加和。另一种无限期折扣回报( infinite-horizon discounted return)为R(\tau )=\textstyle\sum_{t=0}^{\infty }\gamma ^{t}r(s_t,a_t),这种计算方式考虑了整个轨迹上所获得的的所有回报,其中\gamma \in (0,1)为折扣率。

  1. Policy Gradient Methods

在介绍最常用的A2C方法之前,先介绍下更基础一些的policy gradient方法。Policy gradient方法是一种强化学习技术,其直接优化agent的policy,也就是从state到action的映射。Policy gradient方法的核心思想是使用梯度上升方法直接优化policy。从本质上讲,这类方法调整policy model的参数,使其朝着最大限度地提高预期return的方向优化。Policy\pi通常由\theta参数化,我们将其表示为\pi (a|s,\theta),即在状态s下采取动作a的概率。Policy gradient的参数更新方式为:

\theta \gets \theta +\alpha \nabla _{\theta }J(\theta )

这里的\alpha是学习率,J(\theta)表示当采用policy\pi _{\theta }时的期望return,其梯度\nabla _{\theta }J(\theta )被称为policy gradient。一个policy gradient的通用形式为:

\nabla _{\theta }J(\theta )=\mathbb{E}_{\tau \sim \pi _{\theta }}\left [\sum_{t=0}^{T}\nabla _{\theta }log\; \pi _{\theta }(a_{t},s_{t})\Phi _{t}\right ]

这里的\Phi _{t}可以是\Phi _{t}=R(\tau), \Phi _{t}=\textstyle\sum_{t^{\prime}=t}^{T}R(s_{t^{\prime}},a_{t^{\prime}}),\Phi _{t}=\textstyle\sum_{t^{\prime}=t}^{T}R(s_{t^{\prime}},a_{t^{\prime}})-b(s_{t})b是一个baseline)中的任意一个。这些不同的形式会得到policy gradient相同的期望值,但会得到不同的方差。

Return通过蒙特卡洛采样的方式来计算。如果得到的return是有利的,则会增大生成这些动作的概率。这种方法的优势在于其是无偏的,因为其只依赖于实际的return,而非需要估计它。然而这种方法具有非常高的方差,因为同一个prompt产生的不同的轨迹会计算得到不同的return值,这是由于环境(一个episode中的随机事件)和policy本身的随机性决定的。

为了降低这里的高方差,一种常用的策略是使用优势函数(advantage function)估计来替代原始return,即\Phi _{t}=A(s_t,a_t)。优势函数A(s_t,a_t)表示在状态s_t时采取动作a_t相比于同一状态下所有动作的平均质量来说是否会更好。这里需要引入两个概念:
①动作的价值Q(s_t,a_t):在t时刻,给定当前状态s_t,采取动作a_t,可以获得的return的期望,具体的,Q(s_t,a_t)=\mathbb{E}\left [\textstyle\sum _{t^{\prime}=t}^{T}r(s_{t^{\prime}},a_{t^{\prime}})|s_t,a_t\right ]
②状态的价值V(s_t):在t时刻,给定当前状态s_t,可以获得的return的期望,具体的,V(s_t)=\mathbb{E}_{a\in \mathcal{A}}\left [Q(s_t,a)\right ]=\textstyle\sum _{a}\pi _{\theta }(a|s_t)Q(s_t,a),其中\mathcal{A}是所有动作的集合。这里就是在上文s_t状态下,求和所有下一个token出现概率与token对应价值的乘积。

有了这两个概念即可得到优势函数A(s_{t},a_{t})=Q(s_{t},a_{t})-V(s_{t})。优势函数的意义在于只有当前动作的return比平均水平更高时才能获得正的优势值,从而被增强,如果低于平均水平就会被抑制。

使用优势函数的policy gradient方法是强化学习邻域的重要支柱。优势函数的估计方法是多种多样的,其中有一种广泛采用的估计方法为广义优势估计(Generalized Advantage Estimation, GAE),下一节将着重介绍。

  1. Generalized Advantage Estimation

优势函数A定义为Q函数与价值函数的差值。Q考虑一个具体的动作,而价值函数是所有可能的动作的平均。而在实践中,我们使用实际episode的return来估计Q函数,也就是蒙特卡洛采样的return,这会引入非常高的方差,因为未来的reward包含大量的噪声。一种减少这种噪声的方法是使用价值函数来估计未来(时间步t之后)的return,也就是Temporal Difference (TD)的方法,这种方法通常有较大的偏差。本节要介绍的GAE算法是一种介于使用one-step TD return(高偏差)和完全蒙特卡洛return(高方差)之间的算法,可以平衡偏差和方差。接下来的内容即是对GAE算法的推导。

首先,我们用\hat{R}_{t}^{k}来表示TD-k return,这是一种实际的reward和估计的return的集合:

\hat{R}_{t}^{k}=r_{t}+\gamma r_{t+1}+\cdots +\gamma ^{(k-1)}r_{t+k-1}+\gamma ^kV(s_{t+k})

这里的\gamma是折扣率。折扣率的意义可以这样理解:每一步虽然都有一个即时的reward,但是每一步对后面的可能状态都是有影响的,即后面的动作获取的reward都能累计到前面的动作的贡献。不过直接加上去可能不好,毕竟不是前面的动作直接获取的reward,但是可以打个折扣再加上去,即乘个小于1的\gamma

使用TD-k return的优势称为k-step优势,定义为:

\hat{A}_t^k=\hat{R}_t^k-V(s_t)=-V(s_t)+r_t+\gamma r_{t+1}+\cdots +\gamma ^{(k-1)}r_{t+k-1}+\gamma ^kV(s_{t+k})=\sum_{l=1}^{k}\gamma ^{l}\delta _{t+l}

这里的\delta _{t}=r_t+\gamma V(s_{t+1})-V(s_t),叫做TD error。在k-step优势中,如果k比较小,偏差会很高,因为优势估计只基于很少的步数,所以非常依赖价值函数的准确性。而在k比较大时,方差会非常高,因为优势估计会把很多噪声reward加进来。

为了平衡偏差和方差,GAE定义优势函数为k-step优势的指数移动平均,权重为(1-\lambda )\lambda ^{(k-1)}

\begin{align} \hat{A}_{t}^{GAE(\gamma ,\lambda )}&=(1-\lambda )(\hat{A}_{t}^{(1)}+\lambda \hat{A}_{t}^{(2)}+\lambda ^{2}\hat{A}_{t}^{(3)}+\cdots )\\ &=(1-\lambda )(\delta _{t}+\lambda (\delta _{t}+\gamma \delta _{t+1})+\lambda ^2(\delta _{t}+\gamma \delta _{t+1}+\gamma ^2\delta _{t+2})+\cdots )\\ &=(1-\lambda )(\delta _{t}(1+\lambda +\lambda ^2+\cdots )+\gamma \delta _{t+1}(\lambda +\lambda ^2+\lambda ^3+\cdots )+\gamma ^2\delta _{t+2}(\lambda ^2+\lambda ^3+\lambda ^4+\cdots )+\cdots )\\ &=(1-\lambda )(\delta _{t}(\frac{1}{1-\lambda })+\gamma \delta _{t+1}(\frac{\lambda }{1-\lambda })+\gamma ^2\delta _{t+2}(\frac{\lambda ^2}{1-\lambda })+\cdots )\\ &=\sum_{l=0}^{\infty }(\gamma \lambda )^l\delta _{t+l} \end{align}

GAE可以平滑地平衡高偏差(\lambda =0)和高方差(\lambda =1):

GAE(\gamma ,0):\hat{A}_t=\delta _t=r_t+\gamma V(s_{t+1})-V(s_t)\\ GAE(\gamma ,1):\hat{A}_t=\sum_{l=0}^{\infty }\gamma ^l\delta _{t+l}=\sum_{l=0}^{\infty }\gamma ^lr _{t+l}-V(s_t)

通过GAE,我们可以准确地得到优势函数A(s_t,a_t)的估计值\hat{A}_t。这一估计值在进行policy gradient估计时将扮演重要的角色:

\nabla _{\theta }\hat{J}(\theta )=\frac{1}{|\mathcal{D}|}\sum_{\tau \in \mathcal{D}}\displaystyle\sum_{t=1}^{T}\nabla _{\theta }log\; \pi _{\theta }(a_t|s_t)\hat{A}_t

这里的\mathcal{D}是有限批量的样本。后面我们将用\hat{\mathbb{E}}_t来表示\frac{1}{|\mathcal{D}|}\textstyle\sum_{\tau \in \mathcal{D}}\textstyle\sum_{t=1}^{T}

  1. Proximal Policy Optimization

PPO和TRPO是RL中的两种关键技术,旨在有效地训练policy而不损害其稳定性。这些方法的遵循“小而稳定的步骤”的思想,即轻微地进行policy的优化,而不是强制进行可能破坏整个学习过程的激进更新。

在传统的强化学习中,policy gradient的原则要求新旧policy在参数空间中保持接近。然而,参数空间中的这种接近并不一定等同于相似的性能,参数的轻微变化可能会极大地影响策略的有效性。这一部分原因要归结于神经网络是过参数化的,类似微调语言模型的LoRA方法,无需微调所有的模型参数即可将模型适配到特定的下游任务上。此外,如果不加限制地大步更新,就可能导致policy表现崩溃,这种情况通常被描述为“掉下悬崖(falling off the cliff)”。这种固有的风险是原始的policy gradient中样本效率的限制因素。

  • TRPO

TRPO的方法没有受到参数接近性(parameter closeness)的限制,而是对policy更新引入了一种不同的约束。其通过确保新旧policy model的KL散度在一个可接受的限制范围内来正则化policy的更新:

\max_{\theta } \hat{\mathbb{E}}_t\left [\frac{\pi _{\theta }(a_t|s_t)}{\pi _{\theta _{old}}(a_t|s_t)}\hat{A}_t\right ],\\ \mathrm{subject\; to}\; \hat{\mathbb{E}}_t[KL(\pi _{\theta _{old}}(\cdot |s_t),\pi _{\theta }(\cdot |s_t))]\le \delta

这里的\theta _{old}是更新之前旧的policy参数。

  • PPO-Penalty

PPO的方法有两个主要的变种:PPO-Penalty和PPO-Clip。TRPO是一种KL散度的硬性限制,而PPO-Penalty通过采用基于惩罚的方法而不是约束来以无约束优化问题的方式优化policy:

\mathcal{L}_{\mathrm{ppo-penalty}}(\theta )=\hat{\mathbb{E}}_t\left [\frac{\pi _{\theta }(a_t|s_t)}{\pi _{\theta _{old}}(a_t|s_t)}\hat{A}_t\right ]-\beta KL(\pi _{\theta _{old}}(\cdot |s_t),\pi _{\theta }(\cdot |s_t))

这里的\beta是惩罚因子。

  • PPO-Clip

PPO-Clip试图保持新policy接近旧policy,但不像TRPO那样对KL散度施加约束,而是在其目标函数中根据新旧policy的预测概率的比值进行截断。目标函数可以表示为:

\mathcal{L}_{\mathrm{ppo-clip}}(\theta )=\hat{\mathbb{E}}_t\left [\mathrm{min}\left (\frac{\pi _{\theta }(a_t|s_t)}{\pi _{\theta _{old}}(a_t|s_t)}\hat{A}_t,\mathrm{clip}\left (\frac{\pi _{\theta }(a_t|s_t)}{\pi _{\theta _{old}}(a_t|s_t)},1-\epsilon ,1+\epsilon \right )\hat{A}_t\right )\right ]

这里的\frac{\pi _{\theta }(a_t|s_t)}{\pi _{\theta _{old}}(a_t|s_t)}是新policy与旧policy的概率的比例,\epsilon是一个超参数,用于控制新policy能够偏离旧policy多远。\mathrm{clip}函数将概率的比值限制在(1-\epsilon ,1+\epsilon )之间。\mathrm{clip}函数充当正则化器,限制policy从一个迭代到下一个迭代的急剧变化的程度。防止过大的policy更新,确保了学习过程的鲁棒性,同时保持了比普通policy gradient方法更高效的样本学习。

  • 价值函数估计

在PPO中,文章开篇提到的critic model通常用来作为价值函数,来评估每个状态的期望return。其学习目标为最小化其预测值与真实return之间的差异,目标函数通常采用MSE损失,即:

\mathcal{L}_{\mathrm{critic}}(\phi )=\hat{\mathbb{E}}_t\left [\left \|V_\phi (s_t)-\hat{R}_t\right \|^2\right ]

这里的V_\phi (s_t)代表critic model(参数为\phi)在s_t状态的预测值,\hat{R}_t代表实际状态s_t的return值,通常估计为:\hat{R}_t=\textstyle\sum_{l=0}^{\infty }\gamma ^{l}r_{t+l}

  • 混合预训练梯度

为了缓解PPO训练后的模型在通用语言能力上的退化和灾难性遗忘问题,可以在强化学习训练过程中加入预训练数据,这种方法通常称为PPO-ptx,其损失函数为:

\mathcal{L}_{\mathrm{ppo-ptx}}(\theta )=\mathcal{L}_{\mathrm{ppo-clip}}(\theta )+\lambda _{ptx}\mathbb{E}_{x\backsim \mathcal{D}_{pretrain}}[log(\pi _{\theta }^{RL}(x))]

这里的\lambda _{ptx}是一个超参数,\mathcal{D}_{pretrain}是预训练数据的分布。

四、总结

综合前面的描述,整个PPO训练过程的算法可以表示为:

算法

另外也可参考整个流程的框架图:

框架

参考资料

Secrets of RLHF in Large Language Models Part I: PPO

大模型RLHF理论详细讲解

拆解大语言模型RLHF中的PPO

RLHF实践

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,588评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,456评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,146评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,387评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,481评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,510评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,522评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,296评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,745评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,039评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,202评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,901评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,538评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,165评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,415评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,081评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,085评论 2 352

推荐阅读更多精彩内容

  • 0. 引言 最近跟着 OpenAI 的 Spinning Up 教学文档 学习了一遍 Deep RL,对这个领域有...
    OurNote阅读 5,883评论 0 8
  • 简介 2022年11月,OpenAI推出了一款AI聊天机器人程序,其强大的问答能力瞬间引爆全网关注度。 组成部分:...
    臻甄阅读 1,724评论 0 0
  • 大家好,上期我们讲到研发人员正在研究解决语言模型中的一致性问题。ChatGPT 使用了人类反馈来指导学习过程,对其...
    城北楠哥阅读 363评论 0 0
  • 在公司看文档,对用到的一些知识做简单梳理;大部分idea来源于DeepMind或OpenAI PPO的目标函数 P...
    YukiRain阅读 630评论 0 0
  • 1. 背景 最近本qiang~老看到一些关于大语言模型的DPO、RLHF算法,但都有些云里雾里,因此静下心来收集资...
    mengrennwpu阅读 362评论 0 0