GPT系列:有监督微调GPT-2预训练模型,自动续写电视剧本

关键词:GPT预训练模型

前言

在前文GPT系列:GPT-2模型结构简述和实践中介绍了GPT-2的网络结构和minGPT项目的源码实现,并且以电视剧《狂飙》的其中一小段剧本作为输入,从头开始训练了一个小型的gpt-mini。本节介绍GPT-2中文预训练模型的使用,以及基于《狂飙》剧本对GPT-2进行有监督微调。


内容摘要

  • GPT-2预训练模型快速开始
  • GPT2LMHeadModel模型实现简述
  • 微调任务的输入和目标说明
  • GPT-2微调代码实践

GPT-2预训练模型快速开始

本篇使用的预训练模型为腾讯团队推出的gpt2-chinese-cluecorpussmall,它是基于CLUECorpusSmall数据训练,包含14GB的中文文本,由互联网上的新闻,问答,百科,评论等文本数据组成。
结合HuggingFace的GPT2LMHeadModel模型类,用Python可以很方便地的对该预训练模型进行调用。

from transformers import BertTokenizer, GPT2LMHeadModel
# 分词器
tokenizer = BertTokenizer.from_pretrained("./gpt2-chinese-cluecorpussmall")
# gpt-2预训练模型
model = GPT2LMHeadModel.from_pretrained("./gpt2-chinese-cluecorpussmall")
# 文本生成pipeline
text_generator = TextGenerationPipeline(model, tokenizer)

通过管道定义文本生成所需的模型和分词器,给到prompt即可实现文本自动生成。

>>> text = '明天降温了'
>>> res = text_generator(text, max_length=30, do_sample=False)
>>> print(res)
[{'generated_text': '明天降温了 , 但 是 , 我 们 还 是 要 注 意 防 寒 保 暖 , 不 要 让 寒 冷 刺 激 到 你'}]

其中max_length代表算上prompt的文本之后最大生成30个词,do_sample为False代表使用Greedy Search进行文本生成,每一步选择最大概率得分的词。
可以添加更多文本生成参数,包括top_k,温度系数,采用多项式分布采样,输出多个候选文本,比如

>>> res = text_generator(text, max_length=30, do_sample=True, top_k=5, temperature=0.8, num_return_sequences=3)
>>> for i in res:
        print(i)
{'generated_text': '明天降温了 , 但 是 我 们 还 是 要 把 握 好 时 机 , 不 要 错 过 了 。 我 们 的 投 资'}
{'generated_text': '明天降温了 , 你 就 会 明 白 , 我 们 是 多 么 的 幸 运 。 3 : 一 个 男 人 对 一'}
{'generated_text': '明天降温了 , 我 们 可 以 穿 上 一 件 秋 衣 。 王 女 士 说 。 记 者 了 解 到 ,'}

通过采样生成的3条文本增加了多样性,但是在连贯性和语义表达的清晰程度上都不如Greedy Search不采样的结果。


GPT2LMHeadModel模型实现简述

GPT2LMHeadModel类实现了GPT-2的训练和推理,以及调用generate方法实现文本生成。


GPT2LMHeadModel前向传播结构

GPT2LMHeadModel类包含两个子模块,分别是负责实现堆叠Decoder的GPT2Model,以及负责映射到单词得分的线性层Linear

class GPT2LMHeadModel(GPT2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        ...

GPT2LMHeadModel的主要输入为input_ids,past_key_values,labels,分别代表输入文本,模型维护的Q,V上下文信息,目标预测文本

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        labels: Optional[torch.LongTensor] = None,
        ...
    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
  • input_ids:输入GPT的文本token id,对于gpt2-chinese-cluecorpussmall最大长度不超过1024,在训练阶段输入全部文本,在预测阶段只需要获取下一个单词的embedding信息,因此只需要输入当前最后一个单词的token id,结合past_key_values即可完成子回归推理
  • past_key_values:维护各层各个注意力头在各个步长下的Q,V上下文信息。在推理阶段的自注意力模块,K只要输入一个词的token id即可,但是Q和V是截止到当前词之前所有词的信息,如果在模型的逐位推理过程中不记录past_key_values,则每次都需要将全部文本整体输入才能得到Q和V,导致推理效率低下。在模型首次推理的时候past_key_values设置为None。
  • labels:训练阶段的预测目标token id。若设置labels则模型为训练阶段,会计算loss。需要注意的是shifted right移位的操作已经在该模型内部实现,不需要将文本手动shifted right传入labels,换句话说直接设置labels和input_ids相同即可。另外的可以对labels某些位置设置为-100,代表该位置的预测结果不会计入loss计算。
推理阶段input_ids(Q),past_key_values(K,V)工作示意图

在前向传播阶段GPT2LMHeadModel先进行Decoder层的运算拿到last_hidden_state,再传入线性层拿到在词表中每个词的概率得分lm_logits

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            ...
        )
        # TODO BaseModelOutputWithPastAndCrossAttentions[0] => last_hidden_state
        hidden_states = transformer_outputs[0]
        ...
        lm_logits = self.lm_head(hidden_states)

输入的概率得分和labels进行比较计算交叉熵,作者通过切片将shifted right在模型内部实现,无需在模型外面额外处理

        loss = None
        if labels is not None:
            # TODO shifted right
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

最终返回CausalLMOutputWithCrossAttentions类,如果是训练阶段从中可以拿到loss,如果是推理阶段可以拿到下一个词的预测得分分布,以及past_key_values上下文向量信息用于下一次预测的输入。

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )

past_key_values模型推理实现

有两种方法实现模型推理,分别是

  • 传入全部文本,past_key_values为None,取输出的一组token中最后一个token embedding
  • 传入当前词,以及截止到当前的past_key_values,取输出的唯一token embedding

以“明天降温了”作为prompt,采用贪婪搜索的方式生成20个单词

>>> tokenizer = BertTokenizer.from_pretrained("gpt2-chinese-cluecorpussmall")
>>> text = '明天降温了'

第一种形式生成如下,past_key_values直接为None

>>> text_list = list(text)
>>> for i in range(20):
        token = torch.tensor([tokenizer.convert_tokens_to_ids(text_list)])
        next_one_emb = model(token, past_key_values=None)
        next_one = torch.argmax(next_one_emb.logits[..., -1, :], dim=-1, keepdim=True)
        token = next_one
        text_list.append(tokenizer.convert_ids_to_tokens(next_one)[0])
>>> print(" ".join(text_list))
>>> 明 天 降 温 了 , 但 是 , 我 们 还 是 要 注 意 防 寒 保 暖 , 不 要 让 寒

第二种形式生成如下,需要不断更新past_key_values

>>> past_key_values = None
>>> res = []
>>> token = torch.tensor([tokenizer.convert_tokens_to_ids(list(text))])
>>> for i in range(20):
        next_one_emb = model(token, past_key_values=past_key_values)
        past_key_values = next_one_emb.past_key_values
        next_one = torch.argmax(next_one_emb.logits[..., -1, :], dim=-1, keepdim=True)
        token = next_one
        res.append(tokenizer.convert_ids_to_tokens(next_one)[0])
>>> print(text + " ".join(res2))
>>> 明天降温了, 但 是 , 我 们 还 是 要 注 意 防 寒 保 暖 , 不 要 让 寒

两种方法生成的结果完全一样,显然从性能角度考虑第二种方式更优。


generate文本生成实现

在快速开始环节介绍了TextGenerationPipeline这种简单快捷的方式来生成文本,通过模型自身的generate方法是另一种更通用的方式

>>> res = model.generate(input_ids=torch.LongTensor([tokenizer.convert_tokens_to_ids(list("明天降温了"))]),
                      max_length=25,  # 生成序列的最大长度
                      do_sample=False,  # 是否开启采样,默认是 False,即贪婪找最大条件概率的词
                      top_k=20,  # top-k-filtering 算法保留多少个 最高概率的词 作为候选,默认50
                      repetition_penalty=1.0,  # 重复词惩罚
                      temperature=1.0)  # 温度系数

>>> generated_texts = tokenizer.batch_decode(res, skip_special_tokens=True)
>>> print(generated_texts)
>>> ['明 天 降 温 了 , 但 是 , 我 们 还 是 要 注 意 防 寒 保 暖 , 不 要 让 寒']

微调任务的输入和目标说明

通过引入领域独有数据,在预训练GPT-2上继续训练,预测下一个词作为任务目标,完成对GPT-2的微调,使得生成的内容更加适配该领域的知识, 本节领域数据继续使用电视剧《狂飙》的部分电视剧本。

狂飙电视剧

GPT-2微调代码实践

微调的输入为上下文窗口最大128的文本的token id,由于GPT2LMHeadModel内部已经实现了shifted-right,因此预测目标和输入等同,数据处理过程如下

import torch.cuda
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup

MODEL = GPT2LMHeadModel.from_pretrained("./gpt2-chinese-cluecorpussmall")
TOKENIZER = BertTokenizer.from_pretrained("./gpt2-chinese-cluecorpussmall")

class Data(Dataset):
    def __init__(self, block_size):
        self.text = open("./data/text.txt", encoding="utf8").read()
        self.block_size = block_size

    def __len__(self):
        return len(self.text) - self.block_size + 1

    def __getitem__(self, item):
        block = self.text[item: item + self.block_size]
        return block

def collate_fn(batch_block):
    input_ids = []
    for i in batch_block:
        token = TOKENIZER.convert_tokens_to_ids(list(i))
        input_ids.append(token)

    return torch.LongTensor(input_ids).to(DEVICE)

data = Data(128)
data_loader = DataLoader(data, batch_size=48, collate_fn=collate_fn, shuffle=True, drop_last=False)

接下来直接调用GPT2LMHeadModel拿到loss,同时每训练50步让模型自动基于给定的prompt“高启强被捕之后”进行剧本续写,目的是查看随着微调损失的收敛,文本的生成是否更加贴合剧本信息

epochs = 20
optimizer = AdamW(MODEL.parameters(), lr=2e-4)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=50, num_training_steps=epochs * len(data_loader))
criterion = torch.nn.CrossEntropyLoss()
text = "高启强被捕之后"
print("--------------不做微调")
generate(text)

for epoch in range(epochs):
    for step, input_ids in enumerate(data_loader):
        MODEL.to(DEVICE).train()
        optimizer.zero_grad()
        forward = MODEL(input_ids=input_ids, labels=input_ids)
        loss = forward.loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        if (step + 1) % 50 == 0:
            print("--------------step: {}, loss: {}".format(step + 1, loss.item()))
            generate(text)

文本生产采用多项式采样,设置一定的top-k以及温度系数,最大推理50个文本长度

def generate(text):
    encode = torch.LongTensor([TOKENIZER.convert_tokens_to_ids(list(text))]).to(DEVICE)
    output = MODEL.generate(input_ids=encode, max_length=50, do_sample=True, top_k=5, repetition_penalty=1.0,
                            temperature=0.5)
    res = TOKENIZER.batch_decode(output, skip_special_tokens=True)
    print("".join(res[0].split(" ")))

微调模型的训练日志如下

迭代次数 loss 生成内容
0 - 高启强被捕之后,他们的一些人被捕后,他们的一些人被捕后被释放。这些人被送到了一个叫做的地方。他们
100 0.949 高启强被捕之后,不断在反省中寻找着自己的线索。直至今天,他仍然坚持着自己的立场,绝不允许任何人将他交
200 0.249 高启强被捕之后,高启强心里咯噔一下,表面仍不动声色,冲着营业员做个手势,示意盯一下,自己捂着手机出了
300 0.164 高启强被捕之后,高启强被铐在审讯室的椅子上,跷着二郎腿。对面电视里正在放春晚的小品,笑声一阵响似一阵
400 0.139 高启强被捕之后,高启兰哇的一声哭了出来:我是他妹妹,警察大哥,我哥绝对是好人!安欣看着兄妹二

前两条结果为模型在微调初始阶段生成的内容,带有比较重的公开数据的味道,语句勉强通顺但是语义表达不清晰,第三条结果有明显改善,语句通顺且表达出了完整的意义,从第四条结果开始已经很贴切《狂飙》的剧本情节,第四条很类似剧本中曾经出现的原文,而第五条结果也成功生成出了《狂飙》中的人物高启兰,并且成功地表达出了两者的兄妹关系,完全区分不出是人写的还是GPT-2自动生成的,也映证了通过领域数据微调GPT-2预训练模型的有效性。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 217,734评论 6 505
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,931评论 3 394
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,133评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,532评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,585评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,462评论 1 302
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,262评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,153评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,587评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,792评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,919评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,635评论 5 345
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,237评论 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,855评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,983评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,048评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,864评论 2 354

推荐阅读更多精彩内容