关键词:提示学习
,P-Tuning
,BERT
,GPT2
前言
P-Tuning是清华团队提出的一种使用提示学习微调大模型的方法,它提出自适应学习的连续提示模板,来解决人工自然语言模板的不稳定性,本文对该方法进行简要介绍和实践。
内容摘要
- P-Tuning理论方法简介
- P-Tuning微调BERT实践
- P-Tuning微调GPT-2实践
- P-Tuning、PET、Fine-Tuning效果对比
P-Tuning理论方法简介
前文所介绍的《提示学习系列:prompt自然语言模板微调BERT/GPT2实现文本分类》中,指出用自然语言来诱导预训练模型完成NLU任务,例如在文本分类任务中,通过自然语言配合BERT的MLM完型填空过程来对要预测的分类做填空,而GPT-2也是构造自然语言让其进行续写得出分类类型,提示学习不同于额外增加分类层的fine-tuning,做到了训练和预测表达形式的统一,自然语言模板的提示学习示意图如下。
上图中要分类的文本是“欧联杯双马对决”,“下面是一篇关于MM的新闻”是人工构造的prompt,MM是完型填空需要预测的目标。在该类提示学习方法中,人工构造prompt的内容,以及拼接到原文的位置,都会影响模型的训练效果,往往改一个字都可能导致提示学习的性能大幅下降,难以求得最优的提示文本,因此人工模板的提示学习效果不稳定。
在论文中,作者举例通过提示求得城市X所在国家Y,不同的提示文本对模型影响巨大,体现在预测结果精确度的高方差,从最低19.8最高51.1,而引入P-Tuning不仅能降低预测方差,还能提升整体准确性。
P-Tuning的思想是,与其绞尽脑汁构造和搜索出最优的prompt文本,不如引入一部分可训练的embedding和人工模板组合,一齐作为prompt的表征,让其具备一定的自适应能力,从而来适配各种下游NLU任务,增强模型训练的稳定性,具体的,采用新的未知token来构成prompt的可训练部分,和部分人工模板拼接,P-Tuning的提示学习示意图如下。
其中u1,u2,u3代表未知的token,在BERT词表中对应[unused1]~[unused3],“新闻主题分类”是人工模板,让模型往有利于预测出M位置的方向上迭代token的表征,其中引入token的数量和拼接位置可自行调整设置。
P-Tuning同样适合冻结预训练模型参数和放开全参微调两种方式,当标注样本较少时采用冻结模型参数,只优化prompt token embeddng的方式,当标注样本充足时建议全参微调以达到最优的模型效果。
在原论文中,作者为了增强prompt部分的表征能力,引入LSTM+MLP来刻画token之间的前后依赖关系,使得其更加贴近自然语言。
P-Tuning微调BERT实践
本篇采用和提示学习系列:prompt自然语言模板微调BERT/GPT2实现文本分类中一样的数据样本,通过PyTorch快速实现p-tuning文本多分类,样本对新闻文本做15分类预测,对于每个预测类别都当一个完整的新token加入词表进行预测和损失计算,词表拓充如下。
MODEL_PATH = "./model_hub/chinese-roberta-wwm-ext"
PRE_TRAIN = BertForMaskedLM.from_pretrained(MODEL_PATH).to(DEVICE)
PRE_TRAIN_CONFIG = BertConfig.from_pretrained(MODEL_PATH)
TOKENIZER = BertTokenizer.from_pretrained(MODEL_PATH)
# TODO 加入新词,用于标记prompt占位符
TOKENIZER.add_special_tokens({'additional_special_tokens': ["[PROMPT]"]})
PROMPT_TOKEN_ID = TOKENIZER.get_vocab()["[PROMPT]"]
CONFIG = BertConfig.from_pretrained(MODEL_PATH)
LABELS = ['文化', '娱乐', '体育', '财经', '房产', '汽车', '教育', '科技', '军事', '旅游', '国际', '证券', '农业', '电竞', '民生']
TOKENIZER.add_tokens(LABELS)
PRE_TRAIN.resize_token_embeddings(len(TOKENIZER))
PRE_TRAIN.tie_weights()
TOKENIZER.save_pretrained("./test_add_word_p_tuning")
样本构造,将可学习的token和人工合并,拼接到原文的前面,完整的样本样式为[CLS] + [used1] + ... + [used3] + [MASK] + [used4] + ... + [used_n] + [token1] + .. + [token_n] + sample + [SEP]的形式,其中used为可学习token embedding,token为人工自然语言模板,MASK为预测目标,将非MASK位置的token预测label改为-100不记入损失,sample为原文。
PROMPT_LEN = (4, 4)
def collate_fn(data):
prompts, attention_mask, labels, label_no = [], [], [], []
for d in data:
token = TOKENIZER.convert_tokens_to_ids(list(d["text"]))
discrete_token = TOKENIZER.convert_tokens_to_ids(["新", "闻", "主", "题", "分", "类", "。"])
cls = [TOKENIZER.cls_token_id]
sep = [TOKENIZER.sep_token_id]
first_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[0]
mask = [TOKENIZER.mask_token_id] * 1
second_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[1]
# third_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[2]
# TODO [CLS] + [PROMPT] + [MASK] + [PROMPT] + token + [PROMPT] + [SEP]
# prompt = cls + first_prompt + mask + second_prompt + token + third_prompt + sep
prompt = cls + first_prompt + mask + second_prompt + discrete_token + token + sep
prompts.append(prompt)
attention_mask.append([1] * len(prompt))
labels.append(TOKENIZER.convert_tokens_to_ids(d["label_name"]))
label_no.append(d["label"])
# TODO 对输入进行padding
batch_max_length = max([len(x) for x in prompts])
for i in range(len(prompts)):
one_length = len(prompts[i])
if one_length < batch_max_length:
prompts[i] = prompts[i] + [0] * (batch_max_length - one_length)
attention_mask[i] = attention_mask[i] + [0] * (batch_max_length - one_length)
prompts = torch.LongTensor(prompts).to(DEVICE)
attention_mask = torch.LongTensor(attention_mask).to(DEVICE)
labels = torch.LongTensor(labels).to(DEVICE)
# TODO 对labels进行进行处理,[MASK]位置为label,其他位置为-100
label_ids = torch.empty_like(prompts).fill_(-100).long()
label_mask = (prompts == TOKENIZER.mask_token_id).nonzero()[:, 1].reshape(prompts.shape[0], 1)
# TODO 将MASK位置打上真实的词,其他置为-100
label_ids = label_ids.scatter_(1, label_mask, labels.unsqueeze(1))
return prompts, attention_mask, label_ids, label_no
train = Data("./short_news/train.json")
train_loader = DataLoader(train, collate_fn=collate_fn, batch_size=128, shuffle=True, drop_last=False)
单独对prompt设置网络模块,设置可学习的token embedding,根据预设的位置随机初始化embedding,经过LSTM和MLP得到最终token表征。
class PromptEncoder(nn.Module):
def __init__(self, prompt_num, embedding_size):
super(PromptEncoder, self).__init__()
self.input_ids = torch.arange(0, prompt_num).long().to(DEVICE)
self.embedding = nn.Embedding(prompt_num, embedding_size)
self.lstm = nn.LSTM(input_size=embedding_size, hidden_size=embedding_size // 2, batch_first=True,
bidirectional=True, num_layers=2)
self.mlp = nn.Sequential(nn.Linear(embedding_size, embedding_size), nn.ReLU(),
nn.Linear(embedding_size, embedding_size))
self.init_weight()
def init_weight(self):
nn.init.xavier_normal_(self.embedding.weight.data)
for name, weight in self.lstm.named_parameters():
if name.startswith("weight"):
nn.init.xavier_normal_(weight.data)
for layer in self.mlp:
if isinstance(layer, nn.Linear):
nn.init.xavier_normal_(layer.weight.data)
def forward(self):
embedding = self.embedding(self.input_ids).unsqueeze(0)
out = self.mlp(self.lstm(embedding)[0]).squeeze()
return out
在BERT网络中仅需要对输入层做改造,将used位置的token替换为PromptEncoder层输出即可,LM_FINE_TUNING参数决定是否冻结BERT预训练参数。
LM_FINE_TUNING = True
class PTuningBert(nn.Module):
def __init__(self):
super(PTuningBert, self).__init__()
self.pre_train = PRE_TRAIN
# TODO 如果仅微调prompt则冻结预训练模型
for param in self.pre_train.parameters():
param.requires_grad = LM_FINE_TUNING
self.embedding = self.pre_train.bert.get_input_embeddings() # TODO 单独拿到embedding层
self.prompt_encoder = PromptEncoder(sum(PROMPT_LEN), PRE_TRAIN_CONFIG.hidden_size)
def replace_embedding(self, prompt_embedding, raw_embedding, block_indices):
# TODO 矩阵每一行进行替换
for ids in range(block_indices.size()[0]):
for i in range(sum(PROMPT_LEN)):
# TODO 将PROMPT位置的embedding 一条样本一条样本,一个位置一个位置的 替换为随机初始化+LSTM+MLP的
# TODO block_indices[ids, i]:text中实际的PROMPT位置, i: PROMPT emb表中每个位置的id
raw_embedding[ids, block_indices[ids, i], :] = prompt_embedding[i, :]
return raw_embedding
def forward(self, input_ids, attention_mask, label_ids=None):
queries_for_embedding = input_ids.clone()
queries_for_embedding[(input_ids == PROMPT_TOKEN_ID)] = TOKENIZER.unk_token_id
raw_embeds = self.embedding(queries_for_embedding)
# TODO 拿到每个text中PROMPT的位置索引 [[p1,p2,p3], [], []...]
blocked_indices = (input_ids == PROMPT_TOKEN_ID).nonzero().reshape((input_ids.size()[0], sum(PROMPT_LEN), 2))[:, :, 1]
prompt_embeds = self.prompt_encoder()
# TODO 将原始raw_emb中的PROMPT(unk)位置替换掉
input_embedding = self.replace_embedding(prompt_embeds, raw_embeds, blocked_indices)
output = self.pre_train(inputs_embeds=input_embedding.to(attention_mask.device),
attention_mask=attention_mask,
labels=label_ids)
loss, logits = output.loss, output.logits
return loss, logits
P-Tuning微调GPT-2实践
同理,GPT-2的预测目标是最后一个token,将之前的所有token labe设置为-100不参与loss计算,网络部分和BERT实现一致,仅需要替换预训练模型即可。
def collate_fn(data):
prompts, attention_mask, label_ids, label_no = [], [], [], []
for d in data:
token = TOKENIZER.encode(d["text"])[1:-1]
discrete_token = TOKENIZER.convert_tokens_to_ids(["新", "闻", "主", "题", "分", "类", "。"])
first_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[0]
target_token = [TOKENIZER.convert_tokens_to_ids(d["label_name"])]
second_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[1]
# TODO [PROMPT] + token + [PROMPT] + target
prompt = first_prompt + discrete_token + token + second_prompt + target_token
label_id = [-100] * (len(prompt) - 1) + target_token
prompts.append(prompt)
attention_mask.append([1] * len(prompt))
label_ids.append(label_id)
label_no.append(d["label"])
....
P-Tuning、PET、Fine-Tuning效果对比
训练样本数量分别取1000,5000,20000,设置最大验证集10次早停,采用chinese-bert-wwm-ext和gpt2-chinese-cluecorpussmall作为预训练模型,全部采用全参微调的方式,对比Fine-Tuning,人工模板PET,P-Tuning的测试集F1效果。
模型策略 | 1000 | 5000 | 20000 |
---|---|---|---|
BERT + fine_tuning | 0.8324 | 0.852 | 0.8623 |
BERT + pet | 0.8283 | 0.8511 | 0.8565 |
GPT-2 + pet | 0.7796 | 0.8329 | 0.8383 |
BERT + p_tuning | 0.7858 | 0.82675 | 0.84995 |
GPT-2 + p_tuning | 0.8134 | 0.83485 | 0.85545 |
统计结果可视化如下
结果显示,在小尺寸的BERT和GPT-2上,不论PET还是P-Tuning这些提示学习微调的方法,微调结果都不如Fine-Tuning,至少有1个百分点的差距。在BERT上,P-Tuning的似乎明显不如PET,该结论和作者论文的结论相悖。在GPT-2上,P-Tuning明显提升了PET的效果,并且接近了BERT Fine-Tuning的效果,这点和作者的论文题目《GPT Understands, Too》这一结论一致。