提示学习系列:prompt自然语言模板微调BERT/GPT2实现文本分类

关键词:提示学习PromptBERTGPT2

前言

提示学习(Prompt-Based Learning)不同于传统的监督学习,它直接利用在大量原始语料上训练的到的预训练模型,配合一个提示函数,即可完成小样本甚至零样本学习,是NLP领域的新范式,本文介绍基于人工设计提示模板(Pattern-Exploiting Training)在BERT/GPT2上做文本多分类学习的实践案例。


内容摘要

  • 人工设计提示模板PET理论简介
  • BERT完型填空任务说明
  • PET微调BERT代码实践
  • GPT2 shift right错位预测任务说明
  • PET微调GPT2代码实践
  • 提示学习和有监督学习效果对比

人工设计提示模板PET理论简介

人工设计提示模板是指,以自然语言的形式,人工地设计出一个带有为填充词槽的句子模板,将它加入原始的输入中,一齐输入给语言模型,让语言模型以概率的形式填充词槽,从而完成任务。
我们以新闻主题分类为例,人工设计提示模型的示意图如下

人工提示模型+BERT

原始输入为“欧联杯双马对决”,人工设计的提示模板为“下面是一篇关于MM的新闻”,其中M代表[MASK],留下了两个词槽,输入给BERT,期望BERT用完型填空的方式预测出两个[MASK]位置分别为“体育”二字,即体育的条件联合概率最大。
PET的方式将预训练任务和下游任务统一起来,使得文本分类的下游任务转化为了预训练过程中的MLM,相比于有监督方式需要额外加一层全连接,PET的方式和预训练模型本身的能力更加契合,表达更加自然。
基于HuggingFace下的BERT预训练模型和BERT完型填空API BertForMaskedLM可以完成对该方案的快速验证,测试在大规模语聊上训练的BERT配合模板的使用,是否具备直接推断出新闻主题的能力,代码如下

import numpy as np
import torch
from transformers import BertForMaskedLM, BertTokenizer, BertConfig

MODEL_PATH = "/home/model_hub/bert-base-chinese"
PRE_TRAIN = BertForMaskedLM.from_pretrained(MODEL_PATH).to("cuda:0")
TOKENIZER = BertTokenizer.from_pretrained(MODEL_PATH)
CONFIG = BertConfig.from_pretrained(MODEL_PATH)

labels = [
    u'文化', u'娱乐', u'体育', u'财经', u'房产', u'汽车', u'教育', u'科技', u'军事', u'旅游', u'国际',
    u'证券', u'农业', u'电竞', u'民生'
]

label_ids = np.array([TOKENIZER.encode(l)[1:-1] for l in labels])
prompt = "下面是一篇关于[MASK][MASK]的新闻。"

def get_one_res_clp(prompt, text):
    text = prompt + text
    token = TOKENIZER.encode_plus(text, return_tensors="pt", max_length=CONFIG.max_position_embeddings, truncation=True,
                                  padding=True).to("cuda:0")
    out = PRE_TRAIN(**token)["logits"]
    label_score = out[0, 8, label_ids[:, 0]] * out[0, 9, label_ids[:, 1]]
    max_index = label_score.argmax(dim=0)
    return labels[max_index]

>>> get_one_res_clp(prompt, "欧联杯双马对决")
'军事'
>>> get_one_res_clp(prompt, "这两年钢材、水泥、砖头等建材价格还会不会降到三年前水平?")
'房产'
>>> get_one_res_clp(prompt, "KUKA机器人与橙子自动化全球战略合作发布")
'科技'
>>> get_one_res_clp(prompt, "该如何在希腊选择合适房源")
'房产'

首先定义了分类池labels作为搜索池,get_one_res_clp函数抽取了两个字相乘的条件联合概率作为依据,取搜索池中概率最大的类别作为结果,从预测的4条结果来看,首句预测分类错误,其他正确,因此得到结论BERT预训练模型不做任何微调,只需要加以提示词引导,即可具备一定的文本分类能力。


BERT完型填空任务说明

本节说明下BERT模型在预训练过程中的完型填空训练方式,为构造提示学习样本做前期准备。
BERT在预训练过程中会对预料进行随机MASK,具体的每个位置的token会以15%的几率进行特殊处理,特殊处理下又有80%的几率直接MASK,10%的几率维持原样,10%的几率随机改变为语料中的其他一个token,BERT的任务就是预测这些被遮蔽的词,通过学会完型填空任务来理解语义。
从模型结构来看,训练过程中的MLM单独具有BertOnlyMLMHead结构,其按照次序包含一层768*768的Dense,一层GELU,一层LayerNorm,一层768×21128的Dense,其中最后一层的Dense采用tie embedding的形式,和BERT输入层的word Embedding共享,示意图如下

BERT MLM网络结构

一般BERT在有监督使用的场景不需要BertOnlyMLMHead层,而是添加了BertPooler层取[CLS]位置的池化输出完成下游微调任务。


PET微调BERT代码实践

有了前文的铺垫,只需要设计一个提示模板,和原始文本拼接,构造成MLM的采样MASK格式即可对BERT完型填空任务进行微调,微调之后再预测词槽位置即可完成分类任务。
对于新闻主题分类标签位置的MASK,采用一个新的token代替,因此避免了标签长度大小不一样导致的条件概率无法比较的问题。
首先拓展词表,将标签label加入进来

import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import BertForMaskedLM, BertTokenizer, BertConfig, AdamW, get_linear_schedule_with_warmup

from sklearn.metrics import classification_report

DEVICE = "cuda:0"
MODEL_PATH = "/home/model_hub/chinese-roberta-wwm-ext"
PRE_TRAIN = BertForMaskedLM.from_pretrained(MODEL_PATH).to(DEVICE)
TOKENIZER = BertTokenizer.from_pretrained(MODEL_PATH)
CONFIG = BertConfig.from_pretrained(MODEL_PATH)

PROMPT = "下面是一篇关于[MASK]的新闻。"
LABELS = ['文化', '娱乐', '体育', '财经', '房产', '汽车', '教育', '科技', '军事', '旅游', '国际', '证券', '农业', '电竞', '民生']
TOKENIZER.add_tokens(LABELS)
PRE_TRAIN.resize_token_embeddings(len(TOKENIZER))
PRE_TRAIN.tie_weights()
# TOKENIZER.save_pretrained("./test_add_word")
LABELS_IDS = [TOKENIZER.convert_tokens_to_ids(x) for x in LABELS]

采用add_tokens将类别加入词表,resize_token_embeddings重新初始化词表的维度,手动调用tie_weights使得权重共享,样本处理核心代码如下

class Data(Dataset):
    def __init__(self, path, sample_length=20000):
        super(Data, self).__init__()
        self.data = [eval(line.strip()) for line in open(path, encoding="utf8").readlines()[:sample_length]]

    def __getitem__(self, item):
        return self.data[item]

    def __len__(self):
        return len(self.data)


def label_mask_collate_fn(batch_data):
    texts, labels = [], []
    for b in batch_data:
        texts.append(PROMPT + b["text"])
        labels.append(b["label_name"])
    token = TOKENIZER.batch_encode_plus(texts, padding=True, max_length=CONFIG.max_position_embeddings, truncation=True)
    input_ids = token["input_ids"]
    source, target = [], []
    for input_id, label in zip(input_ids, labels):
        s, t = random_mask(input_id, label)
        source.append(s)
        target.append(t)
    return torch.tensor(source).to(DEVICE), torch.tensor(token["token_type_ids"]).to(DEVICE), torch.tensor(
        token["attention_mask"]).to(DEVICE), torch.tensor(target).to(DEVICE)


def random_mask(token_ids, label=None):
    rand = np.random.random(len(token_ids))
    # TODO input_ids 带有[MASK]的输入
    # TODO target 原始输入,并且把padding位置和非[MASK]置为-100
    input_ids, target = [], []
    for r, t in zip(rand, token_ids):
        if t == 0:
            input_ids.append(t)
            target.append(-100)
        else:
            if r < 0.15 * 0.8:
                input_ids.append(TOKENIZER.convert_tokens_to_ids("[MASK]"))
                target.append(t)
            elif r < 0.15 * 0.9:
                input_ids.append(t)
                target.append(t)
            elif r < 0.15:
                input_ids.append(np.random.choice(TOKENIZER.vocab_size - 1) + 1)
                target.append(t)
            else:
                input_ids.append(t)
                target.append(-100)
    if label:
        token = TOKENIZER.convert_tokens_to_ids(label)
        input_ids[8] = TOKENIZER.mask_token_id
        target[8] = token
    return input_ids, target


def prompt_collate_fn(batch_data):
    texts, labels = [], []
    for b in batch_data:
        texts.append(PROMPT + b["text"])
        labels.append(b["label"])
    token = TOKENIZER.batch_encode_plus(texts, return_tensors="pt", max_length=CONFIG.max_position_embeddings,
                                        truncation=True, padding=True)
    return token.to(DEVICE), torch.tensor(labels).to(DEVICE)

label_mask_collate_fn是训练数据的处理流程,它将PROMPT拼接在原始输入的前面,并且采用random_mask进行整体全文的采样遮蔽,额外的对[MASK]位置进行遮蔽,target给到正确的token id。在计算loss的时候非MASK位置被填充为-100被交叉熵损失自动忽略。
prompt_collate_fn是验集合和测试集的处理流程,只需要拼接到前面进行编码,因为只需要对标签位置进行预测,无法进行额外的MASK。
模型部分核心代码如下

def eval_metrics(model, loader):
    res = []
    model.eval()
    cnt, hit = 0, 0
    with torch.no_grad():
        for token, label in tqdm(loader):
            logits = model(**token).logits  # [batch_size, seq_len, emb_size]
            logits = logits[:, 8, LABELS_IDS]
            max_index = logits.argmax(dim=-1)
            acc = (max_index == label).sum(0).item()
            cnt += label.size(0)
            hit += acc
            res.append(max_index)
    return hit / cnt, res


if __name__ == '__main__':
    train = Data("./short_news/train.json")
    val = Data("./short_news/val.json")
    test = Data("./short_news/test.json")
    # TODO 控制有监督的样本数量
    train = train
    train_loader = DataLoader(train, collate_fn=label_mask_collate_fn, batch_size=16, shuffle=True)
    val_loader = DataLoader(val, collate_fn=prompt_collate_fn, batch_size=16, shuffle=False)
    test_loader = DataLoader(test, collate_fn=prompt_collate_fn, batch_size=16, shuffle=False)
    # acc, pred = eval_metrics(PRE_TRAIN, test_loader)
    # print("[test] acc: {}".format(acc))

    epochs = 200
    pre_train_lr = 2e-5
    warm_step, total_step = 100, epochs * len(train_loader)
    handler = train_handler(PRE_TRAIN, model_path="./pet_add_word/pytorch_model.bin", acc_name="acc")
    optimizer = AdamW(PRE_TRAIN.parameters(), lr=pre_train_lr)
    schedule = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=warm_step,
                                               num_training_steps=total_step)
    early_stop_flag = False

    for epoch in range(epochs):
        for step, (source, token_type_ids, attention_mask, target) in enumerate(train_loader):
            PRE_TRAIN.train()
            optimizer.zero_grad()
            loss = PRE_TRAIN(input_ids=source, attention_mask=attention_mask, token_type_ids=token_type_ids,
                             labels=target).loss
            loss.backward()
            optimizer.step()
            schedule.step()
            step += 1
            print("epoch: {}, step: {}, loss: {}".format(epoch + 1, step, loss.item()))
            if step % 200 == 0 or step == len(train_loader):
                acc, pred = eval_metrics(PRE_TRAIN, val_loader)
                print("[evaluation] acc: {}".format(acc))
                handler.get_step_metrics(acc)
                if handler.early_stop():
                    early_stop_flag = True
                    print("early stop...")
                    break
        if early_stop_flag:
            break

    model2 = torch.load("./pet_add_word/pytorch_model.bin").to(DEVICE)
    acc, pred = eval_metrics(model2, test_loader)
    print("[test] acc: {}".format(acc))

模型采用BertForMaskedLM,该类能够直接输出正向传播的loss,最小化该loss直到早停收敛。


GPT2 shift right错位预测任务说明

BERT采用MLM任务完成预测出分类,同样的GPT2也可以通过文本补全来预测出分类,GPT2采用shift right的方式,当前词的输出来作为下一个token的预测依据,在本案例场景中示意图如下

提示模板+GPT-2的方式

在HuggingFace的GPT2LMHeadModel中自带了shift right,同样将非MASK位置改为-100即可忽略损失。


PET微调GPT2代码实践

同样将所有新闻类别标签新设置为一个token,核心样本构造代码如下

def prompt_collate_fn(batch_data):
    texts, target_ids, label_ids, label_no = [], [], [], []
    for b in batch_data:
        texts.append((b["text"], PROMPT))
        label_ids.append(TOKENIZER.convert_tokens_to_ids(b["label_name"]))
        label_no.append(b["label"])
    token = TOKENIZER.batch_encode_plus(batch_text_or_text_pairs=texts, return_tensors="pt",
                                        max_length=CONFIG.max_position_embeddings,
                                        truncation=True, padding=True)
    input_ids, token_type_ids, attention_mask = token["input_ids"], token["token_type_ids"], token["attention_mask"]
    target_ids = torch.empty(input_ids.shape).long().fill_(-100)
    mask_ids = (token["input_ids"] == TOKENIZER.mask_token_id).nonzero()[:, 1].unsqueeze(1)
    target_ids = target_ids.scatter_(1, mask_ids, torch.LongTensor(label_ids).unsqueeze(1))
    return input_ids.to(DEVICE), token_type_ids.to(DEVICE), attention_mask.to(DEVICE), target_ids.to(
        DEVICE), mask_ids.to(DEVICE), torch.LongTensor(label_no).to(DEVICE)

其中将prompt拼接原始输入的后面,初始化一个全是-100的向量,通过torch.scatter将MASK位置改为对应的token id。
训练过程代码如下

def eval_metrics(model, loader):
    model.eval()
    cnt, hit = 0, 0
    with torch.no_grad():
        for input_ids, token_type_ids, attention_mask, target_ids, mask_ids, label_no in tqdm(loader):
            res = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
                           labels=target_ids)  # [batch_size, seq_len, emb_size]
            logits, loss = res["logits"], res["loss"]
            batch_size = input_ids.shape[0]
            cnt += batch_size
            for i in range(batch_size):
                logit = logits[i, mask_ids[i, 0], LABELS_IDS]
                max_index = logit.argmax(dim=-1)
                acc = (max_index == label_no[i]).sum(0).item()
                hit += acc
    return hit / cnt, loss


if __name__ == '__main__':
    train = Data("./short_news/train.json")
    val = Data("./short_news/val.json")
    test = Data("./short_news/test.json")
    # TODO 控制有监督的样本数量
    train = train
    train_loader = DataLoader(train, collate_fn=prompt_collate_fn, batch_size=16, shuffle=True)
    val_loader = DataLoader(val, collate_fn=prompt_collate_fn, batch_size=16, shuffle=False)
    test_loader = DataLoader(test, collate_fn=prompt_collate_fn, batch_size=16, shuffle=False)

    epochs = 200
    pre_train_lr = 2e-5
    warm_step, total_step = 100, epochs * len(train_loader)
    handler = train_handler(PRE_TRAIN, model_path="./pet_gpt_add_word/pytorch_model.bin", acc_name="acc")
    optimizer = AdamW(PRE_TRAIN.parameters(), lr=pre_train_lr)
    schedule = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=warm_step,
                                               num_training_steps=total_step)
    early_stop_flag = False

    for epoch in range(epochs):
        for step, (source, token_type_ids, attention_mask, target, mask_ids, label_no) in enumerate(train_loader):
            PRE_TRAIN.train()
            optimizer.zero_grad()
            loss = PRE_TRAIN(input_ids=source, attention_mask=attention_mask, token_type_ids=token_type_ids,
                             labels=target).loss
            loss.backward()
            optimizer.step()
            schedule.step()
            step += 1
            print("epoch: {}, step: {}, loss: {}".format(epoch + 1, step, loss.item()))
            if step % 200 == 0 or step == len(train_loader):
                acc, loss = eval_metrics(PRE_TRAIN, val_loader)
                print("[evaluation] acc: {}".format(acc))
                handler.get_step_metrics(acc)
                if handler.early_stop():
                    early_stop_flag = True
                    print("early stop...")
                    break
        if early_stop_flag:
            break

    model2 = torch.load("./pet_gpt_add_word/pytorch_model.bin").to(DEVICE)
    acc, pred = eval_metrics(model2, test_loader)
    print("[test] acc: {}".format(acc))

提示学习和有监督学习效果对比

由于标签是新定义了一个token,embedding是随机初始化的,没有复用预训练模型本身的embedding,相当于没有用预训练模型本身的知识,因此这种模式对样本的数量有一定要求,要重新训练标签的token embedding。分别取样本量为1000,5000,20000,分别对比BERT+全连接有监督微调,PET微调BERT,PET微调GPT-2三种策略的准确率,其中样本数为0代表直接使用预训练模型不微调,相当于零样本学习。对比结果如下

样本数 BERT+FC有监督 BERT+PET模板 GPT-2+PET模板
1000 0.8324 0.8283 0.7796
5000 0.852 0.8511 0.8329
20000 0.8623 0.8565 0.8383

从结果来看,BERT+PET的效果略低于BERT+FC,GPT-2+PET效果较BERT有明显下降,基于PET的提示学习准确率逼近BERT有监督微调。

多个策略下文本分类准确率

一个基于BERT+PET预测方法代码如下

import torch
from transformers import BertTokenizer, BertConfig


class Predictor:
    def __init__(self):
        self.device = "cuda:0"
        self.prompt = "下面是一篇关于[MASK]的新闻。"
        self.model = torch.load("./pet_add_word/pytorch_model.bin")
        self.model_config = BertConfig.from_pretrained("./pet_add_word")
        self.tokenizer = BertTokenizer.from_pretrained("./pet_add_word")
        self.labels = ['文化', '娱乐', '体育', '财经', '房产', '汽车', '教育', '科技', '军事', '旅游', '国际', '证券', '农业', '电竞', '民生']
        self.labels_ids = [self.tokenizer.convert_tokens_to_ids(x) for x in self.labels]
        self.model.to(self.device)

    def preprocess(self, text):
        texts = self.prompt + text
        token = self.tokenizer.batch_encode_plus([texts], return_tensors="pt",
                                                 max_length=self.model_config.max_position_embeddings,
                                                 truncation=True, padding=True)
        return token.to(self.device)

    def infer(self, text):
        token = self.preprocess(text)
        with torch.no_grad():
            logits = self.model(**token).logits
            logits = logits[:, 8, self.labels_ids]
            print(logits)
            max_index = logits.argmax(dim=-1)
        return self.labels[max_index]


if __name__ == '__main__':
    predictor = Predictor()
    res = predictor.infer("你认为《诗经》中最美的句子是什么?")
    print(res)

>>> 文化

全文完毕。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容