LLM面面观之RLHF平替算法DPO

1. 背景

最近本qiang~老看到一些关于大语言模型的DPO、RLHF算法,但都有些云里雾里,因此静下心来收集资料、研读论文,并执行了下开源代码,以便加深印象。

此文是本qiang~针对大语言模型的DPO算法的整理,包括原理、流程及部分源码

2. DPO vs RLHF

RLHF vs DPO

上图左边是RLHF算法,右边为DPO算法,两图的差异对比即可体现出DPO的改进之处。

1. RLHF算法包含奖励模型(reward

model)和策略模型(policy model,也称为演员模型,actor model),基于偏好数据以及强化学习不断迭代优化策略模型的过程。

2. DPO算法不包含奖励模型和强化学习过程,直接通过偏好数据进行微调,将强化学习过程直接转换为SFT过程,因此整个训练过程简单、高效,主要的改进之处体现在于损失函数。

PS:

1. 偏好数据,可以表示为三元组(提示语prompt, 良好回答chosen, 一般回答rejected)。论文中的chosen表示为下标w(即win),rejected表示为下标l(即lose)

2. RLHF常使用PPO作为基础算法,整体流程包含了4个模型,且通常训练过程中需要针对训练的actor model进行采样,因此训练起来,稳定性、效率、效果不易控制。

1) actor model/policy

model: 待训练的模型,通常是SFT训练后的模型作为初始化

2) reference

model: 参考模型,也是经SFT训练后的模型进行初始化,且通常与actor model是同一个模型,且模型冻结,不参与训练,其作用是在强化学习过程中,保障actor model与reference model的分布差异不宜过大。

3) reward model:奖励模型,用于提供每个状态或状态动作对的即时奖励信号。

4) Critic model:作用是估计状态或状态动作对的长期价值,也称为状态值函数或动作值函数。

3. DPO算法仅包含RLHF中的两个模型,即演员模型(actor

model)以及参考(reference model),且训练过程中不需要进行数据采样。

4. RLHF可以参考附件中的引文

3. DPO的损失函数

DPO的损失函数

如何将RLHF的Reward model过程简化为上式,作者花了大量篇幅进行了推导,感兴趣的读者可以参考附件DPO的论文。

DPO算法的目的是最大化奖励模型(此处的奖励模型即为训练的策略),使得奖励模型对chosen和rejected数据的差值最大,进而学到人类偏好。

上式的后半部分通过对数函数运算规则,可以进行如下转化。

Loss公式转化

转化后的公式和源代码中的计算函数中的公式是一致的。

其中左半部分是训练的policy模型选择chosen优先于rejected,右半部分是冻结的reference模型选择chosen优先于rejected,二者的差值可类似于KL散度,保障actor模型的分布与reference模型的分布不会有较大的差异。

4. 微调流程

DPO微调流程

上图展示了DPO微调的大致流程,其中Trained

LM即为策略模型,Frozen LM即为参考模型,二者均是先进行SFT微调得到的模型进行初始化,其中Trained LM需要进行训练,Frozen LM不参与训练。

两个模型分别针对chosen和rejected进行预测获取对应的得分,再通过DPO的损失函数进行损失计算,进而不断的迭代优化。

5. 源码

源码参考代码:https://github.com/eric-mitchell/direct-preference-optimization

5.1 DPO损失函数


def preference_loss(policy_chosen_logps: torch.FloatTensor,


  policy_rejected_logps: torch.FloatTensor,


  reference_chosen_logps: torch.FloatTensor,


  reference_rejected_logps: torch.FloatTensor,

                    beta:  float,


  label_smoothing: float = 0.0,

                    ipo: bool  = False,


  reference_free: bool = False) -> Tuple[torch.FloatTensor,  torch.FloatTensor, torch.FloatTensor]:

    # policy_chosen_logps:训练模型对于chosen经过log后logits

         #  policy_rejected_logps:训练模型对于rejected经过log后logits

         #  reference_chosen_logps:训练模型对于chosen经过log后logits

         #  reference_rejected_logps:训练模型对于rejected经过log后logits

         # beta: policy和reference的差异性控制参数


         # actor模型选择chosen优先于rejected

    pi_logratios =  policy_chosen_logps - policy_rejected_logps

         # reference模型选择chosen优先于rejected

    ref_logratios =  reference_chosen_logps - reference_rejected_logps


    if reference_free:

        ref_logratios = 0


         #差值可类似于KL散度,保障actor模型的分布与reference模型的分布不会有较大的差异

    logits = pi_logratios -  ref_logratios  # also known as  h_{\pi_\theta}^{y_w,y_l}


    if ipo:

        losses = (logits -  1/(2 * beta)) ** 2  # Eq. 17 of  https://arxiv.org/pdf/2310.12036v2.pdf

    else:

        # Eq. 3  https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7  of https://arxiv.org/pdf/2305.18290.pdf)

                  #  label_smoothing为0,对应的DPO论文的算法

        losses =  -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta *  logits) * label_smoothing


         # chosen和rejected的奖励

    chosen_rewards = beta *  (policy_chosen_logps - reference_chosen_logps).detach()

    rejected_rewards = beta  * (policy_rejected_logps - reference_rejected_logps).detach()


    return losses,  chosen_rewards, rejected_rewards


5.2 批次训练过程


def get_batch_metrics(self, batch: Dict[str, Union[List,  torch.LongTensor]], loss_config: DictConfig, train=True):

         """Compute  the SFT or DPO loss and other metrics for the given batch of  inputs."""


         if loss_config.name  in {'dpo', 'ipo'}:

                  # policy模型针对chosen和rejected进行预测

                  policy_chosen_logps,  policy_rejected_logps = self.concatenated_forward(self.policy, batch)

                  with  torch.no_grad():

                          #  reference模型针对chosen和rejected进行预测

                          reference_chosen_logps,  reference_rejected_logps = self.concatenated_forward(self.reference_model,  batch)


                  if  loss_config.name == 'dpo':

                          loss_kwargs  = {'beta': loss_config.beta, 'reference_free': loss_config.reference_free,  'label_smoothing': loss_config.label_smoothing, 'ipo': False}

                  elif  loss_config.name == 'ipo':

                          loss_kwargs  = {'beta': loss_config.beta, 'ipo': True}

                  else:

                          raise  ValueError(f'unknown loss {loss_config.name}')

                  #损失计算

                  losses,  chosen_rewards, rejected_rewards = preference_loss(

                          policy_chosen_logps,  policy_rejected_logps, reference_chosen_logps, reference_rejected_logps,  **loss_kwargs)


                  reward_accuracies  = (chosen_rewards > rejected_rewards).float()


         elif  loss_config.name == 'sft':

                  policy_chosen_logits  = self.policy(batch['chosen_input_ids'],  attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)

                  policy_chosen_logps  = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'],  average_log_prob=False)


                  losses =  -policy_chosen_logps


         return losses.mean()


5.3 LM的交叉熵计算


def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor,  average_log_prob: bool = False) -> torch.FloatTensor:

    #经模型后的logits进行批量计算logps


    assert logits.shape[:-1]  == labels.shape


         #基于先前的token预测下一个token

    labels = labels[:,  1:].clone()

    logits = logits[:, :-1,  :]

    loss_mask = (labels !=  -100)


    # dummy token; we'll  ignore the losses on these tokens later

    labels[labels == -100] =  0


         #交叉熵函数

    per_token_logps =  torch.gather(logits.log_softmax(-1), dim=2,  index=labels.unsqueeze(2)).squeeze(2)


    if average_log_prob:

        return  (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

    else:

        return  (per_token_logps * loss_mask).sum(-1)


5.4 其他注意

1. hugging face设置代理

源码会从hugging face中下载英文语料和模型,由于网络限制,因此设置代理映射,将HF_ENDPOINT设置为https://hf-mirror.com,即设置: os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

2. 如果仅想要熟悉DPO整体流程,可以下载较小的生成式模型,如BLOOM 560M,GPT2等

6. 总结

一句话足矣~

本文主要针对大语言模型的DPO算法的整理,包括原理、流程及部分源码。

此外,建议大家可以针对源码进行运行,源码的欢迎大家一块交流。

7. 参考

(1) RLHF:https://blog.csdn.net/v_JULY_v/article/details/128579457

(2) DPO论文: https://arxiv.org/pdf/2305.18290v2.pdf

(3) DPO代码: https://github.com/eric-mitchell/direct-preference-optimization

(4) DPO理解1:https://medium.com/@joaolages/direct-preference-optimization-dpo-622fc1f18707

(5) DPO理解2: https://zhuanlan.zhihu.com/p/669825918

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,588评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,456评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,146评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,387评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,481评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,510评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,522评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,296评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,745评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,039评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,202评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,901评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,538评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,165评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,415评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,081评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,085评论 2 352

推荐阅读更多精彩内容