Transformers的微调训练

数据集准备和预处理

通过dataset包加载数据集
定义Dataset.map要使用的预处理函数
定义DataCollator来用于构造训练batch

加载预训练模型

随机初始化ForSequenceClassificationHead

微调训练

Trainer是Huggingface transformers库的一个高级API,可以帮助我们快速搭建训练框架。
默认情况下,Trainer和TrainingArguments会使用:

batch size=8
epochs = 3
AdamW优化器

可以提供一个compute_metrics函数,用于输出我们希望有的一些指标。

import os
import torch
import numpy as np

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

from transformers import AdamW, AutoTokenizer, AutoModelForSequenceClassification

# Same as before
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

sequences = [
    "I've been waiting for a HuggingFace course my whole life.",
    "This course is amazing!",
]
batch = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")
batch['labels'] = torch.tensor([1, 1])  # tokenizer出来的结果是一个dictionary,所以可以直接加入新的 key-value

optimizer = AdamW(model.parameters())
loss = model(**batch).loss  #这里的 loss 是直接根据 batch 中提供的 labels 来计算的,回忆:前面章节查看 model 的输出的时候,有loss这一项
loss.backward()
optimizer.step()

from datasets import load_dataset

raw_datasets = load_dataset("glue", "mrpc")
raw_train_dataset = raw_datasets['train']

tokenized_sentences_1 = tokenizer(raw_train_dataset['sentence1'])
tokenized_sentences_2 = tokenizer(raw_train_dataset['sentence2'])

from pprint import pprint as print
inputs = tokenizer("first sentence", "second one")
print(inputs)

def tokenize_function(sample):
    # 这里可以添加多种操作,不光是tokenize
    # 这个函数处理的对象,就是Dataset这种数据类型,通过features中的字段来选择要处理的数据
    return tokenizer(sample['sentence1'], sample['sentence2'], truncation=True)

# Dataset.map节约内存
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
print(tokenized_datasets)

# 划分batch的时候再进行padding
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

samples = tokenized_datasets['train'][:5]
print(samples.keys())

# 把这里多余的几列去掉
samples = {k:v for k,v in samples.items() if k not in ["idx", "sentence1", "sentence2"]}  # 把这里多余的几列去掉
print(samples.keys())

# 打印出每个句子的长度:
print([len(x) for x in samples["input_ids"]])

# 然后我们使用data_collator来处理padding
batch = data_collator(samples)  # samples中必须包含 input_ids 字段,因为这就是collator要处理的对象
print(batch.keys())

# 再打印长度:
print([len(x) for x in batch['input_ids']])

from transformers import Trainer, TrainingArguments

from datasets import load_metric
def compute_metrics(eval_preds):
    metric = load_metric("glue", "mrpc")
    logits, labels = eval_preds.predictions, eval_preds.label_ids
    # 上一行可以直接简写成:
    # logits, labels = eval_preds  因为它相当于一个tuple
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

training_args = TrainingArguments(output_dir='test_trainer') # 指定输出文件夹,没有会自动创建

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,  # 在定义了tokenizer之后,其实这里的data_collator就不用再写了,会自动根据tokenizer创建
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# 启动训练
trainer.train()

# trainer.save_model()

# 或者加载之前的训练结果

from transformers.trainer_utils import EvalPrediction, get_last_checkpoint

#last_checkpoint = get_last_checkpoint('test_trainer')
#trainer.train(resume_from_checkpoint=last_checkpoint)

参考资料

Huggingface🤗NLP笔记7:使用Trainer API来微调模型

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,080评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,422评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,630评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,554评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,662评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,856评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,014评论 3 408
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,752评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,212评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,541评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,687评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,347评论 4 331
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,973评论 3 315
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,777评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,006评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,406评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,576评论 2 349

推荐阅读更多精彩内容