BERT实战(下)-问答任务

上半部分介绍了如何从BERT模型提取嵌入,下半部分介绍如何针对下游任务进行微调,分为四个任务。

下游任务:

  1. 情感分类任务
  2. 自然语言推理任务
  3. 问答任务
  4. 命名实体识别任务

微调方式:

  1. 分类器层与BERT模型一起更新权重(通常情况且效果更好
  2. 仅更新分类器层的权重而不更新BERT模型的权重。BERT模型仅作为特征提取器

1 问答任务

1.1 任务说明

问答任务分为两种:

  1. 抽取式:从给定的上下文中抽取回答。
  2. 摘要式:从给定的上下文中生成对于问题正确的回答。

这里介绍抽取式:输入是一个问题和一个含有答案的段落,然后段落中提取答案。本质上讲是返回包含答案的文本段。模型实际上预测的是答案在段落中的起始位置和结束位置的索引。

步骤:

  1. 引入起始向量S和结束向量E
  2. 将输入的问题和含有答案的段落输入BERT获得每个标记的特征
  3. 段落标记的特征向量与S和E分别点积,然后段落标记之间进行softmax,获得每个标记作为起始位置或结束位置的概率:
    P_i=\frac{e^{S\cdot R_i}}{\sum_je^{S\cdot R_j}} 和 P_i=\frac{e^{E\cdot R_i}}{\sum_je^{E\cdot R_j}}
    针对问答任务微调预训练BERT模型

1.2 代码

1.2.1 QA任务代码

from transformers import AutoModelForQuestionAnswering, AutoTokenizer, Trainer, TrainingArguments
from transformers import DefaultDataCollator
from nlp import load_dataset
from huggingface_hub import notebook_login
from transformers import pipeline

notebook_login()


def train():
    # squad: https://huggingface.co/datasets/squad
    squad = load_dataset("squad", split="train[:5000]")
    print("squad: {}".format(squad))

    # train_test_split=0.2
    dataset = squad.train_test_split(test_size=0.2)
    train_set = dataset["train"]
    test_set = dataset["test"]

    print("train_set[0]: {}".format(train_set[0]))
    print("test_set[0]: {}".format(test_set[0]))
    print("train_set: {}".format(train_set))
    print("test_set: {}".format(test_set))

    # 加载模型: distilbert
    model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased")

    # 词元分析器
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

    # 预测的是: answer在question和context拼接, tokenize成token后的起始token索引和结束token索引
    def preprocess_function(examples):
        questions = [q.strip() for q in examples["question"]]
        # Huggingface Transformers各类库介绍(Tokenizer、Pipeline): https://blog.csdn.net/weixin_42475060/article/details/128105633
        inputs = tokenizer(
            questions,
            examples["context"],
            max_length=384,  # 限制最大长度
            truncation="only_second",  # 截第二个输入, 即context
            return_offsets_mapping=True,  # 返回每个token在输入字符串中的偏移量
            padding="max_length",  # 按最大长度补齐
        )

        # 从inputs中提前移除了offset_mapping, 因为不需要这个输入
        offset_mapping = inputs.pop("offset_mapping")
        answers = examples["answers"]
        start_positions = []
        end_positions = []

        for i, offset in enumerate(offset_mapping):
            answer = answers[i]
            start_char = answer["answer_start"][0]  # 答案在context字符串中的起始索引
            end_char = answer["answer_start"][0] + len(answer["text"][0])  # 答案在context字符串中的结束索引
            sequence_ids = inputs.sequence_ids(i)

            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1

            # If the answer is not fully inside the context, label it (0, 0)
            if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
                start_positions.append(0)
                end_positions.append(0)
            else:
                # Otherwise it's the start and end token positions
                idx = context_start
                while idx <= context_end and offset[idx][0] <= start_char:
                    idx += 1
                start_positions.append(idx - 1)

                idx = context_end
                while idx >= context_start and offset[idx][1] >= end_char:
                    idx -= 1
                end_positions.append(idx + 1)

        inputs["start_positions"] = start_positions
        inputs["end_positions"] = end_positions
        return inputs

    # 把没有用到的原始输入抛弃: remove_columns=squad.column_names
    train_set = train_set.map(preprocess_function, batched=True, remove_columns=squad.column_names)
    test_set = test_set.map(preprocess_function, batched=True, remove_columns=squad.column_names)

    # 使用默认的DefaultDataCollator, 将输入数据转化为pytorch的tensor
    data_collator = DefaultDataCollator()

    training_args = TrainingArguments(
        output_dir="~/Documents/huggingface_local_hub/llm/task_qa_distilbert",
        hub_model_id="smile367/task_qa_distilbert",
        evaluation_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=3,
        weight_decay=0.01,
        push_to_hub=True,
    )

    # qa任务使用交叉熵损失训练
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_set,
        eval_dataset=test_set,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    trainer.train()
    # https://git-lfs.com/
    trainer.push_to_hub()
    print("ok")


def inference():
    test_data = load_dataset("squad", split="validation[:1]")
    print("question: {}".format(test_data["question"]))
    print("context: {}".format(test_data["context"]))
    print("answers: {}".format(test_data["answers"]))
    question_answerer = pipeline("question-answering", model="smile367/task_qa_distilbert")
    result = question_answerer(question=test_data["question"], context=test_data["context"])
    print("result: {}".format(result))
    print("ok")

1.2.2 squad数据集预处理过程解析

  • 可以单步调试观察每一步的处理过程
from transformers import AutoTokenizer
from nlp import load_dataset

if __name__ == '__main__':
    dataset = load_dataset("squad", split="train[:2]")
    question = dataset["question"]
    context = dataset["context"]
    answers = dataset["answers"]

    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    inputs = tokenizer(question, context, max_length=384, truncation="only_second", return_offsets_mapping=True, padding="max_length")
    start_positions = []
    end_positions = []
    offset_mapping = inputs.pop("offset_mapping")

    for i, offset in enumerate(offset_mapping):
        print()
        print("question: {}".format(question[i]))
        print("context: {}".format(context[i]))
        print("answers: {}".format(answers[i]["text"]))
        print("inputs.tokens: {}".format(inputs.tokens(i)))
        print("inputs.sequence_ids: {}".format(inputs.sequence_ids(i)))
        print("offset: {}".format(offset))

        # 答案在context字符串中的起始索引和结束索引(注意不是tokens索引)
        answer = answers[i]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        print("start_char: {}, end_char: {}, context[start_char: end_char]: {}".format(start_char, end_char, context[i][start_char: end_char]))

        # 找到context的在tokens中起始索引和结束索引
        sequence_ids = inputs.sequence_ids(i)
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1
        print("context_start: {}, context_end: {}, tokens[context_start: context_end]: {}".format(context_start, context_end, inputs.tokens(i)[context_start: context_end]))

        # If the answer is not fully inside the context, label it (0, 0)
        # 答案在context字符串的索引不在context的索引之内
        print("offset[context_start][0]: {}, offset[context_end][1]: {}".format(offset[context_start][0], offset[context_end][1]))
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start  # 从context起始token的索引向结束token的索引扫描, 找到start_char的token的索引
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end  # 从context结束token的索引向开始token的索引扫描, 找到end_char的token的索引
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)
            print("start_position: {}, end_position: {}, inputs.tokens[start_position: end_position + 1]: {}".format(start_positions[i], end_positions[i], inputs.tokens(i)[start_positions[i]: end_positions[i] + 1]))

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    print("inputs: {}".format(inputs))
    print("ok")

参考资料

[1]. BERT基础教程Transformer大模型实战
[2]. huggingface官网QA任务指南:https://huggingface.co/docs/transformers/tasks/question_answering
[3]. squad数据集文档: https://huggingface.co/datasets/squad
[4]. huggingface hub使用说明:https://www.jianshu.com/p/5337a01f1cae?v=1683977910521
[5]. Huggingface Transformers各类库介绍(Tokenizer、Pipeline):https://blog.csdn.net/weixin_42475060/article/details/128105633

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,384评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,845评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,148评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,640评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,731评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,712评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,703评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,473评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,915评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,227评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,384评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,063评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,706评论 3 324
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,302评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,531评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,321评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,248评论 2 352

推荐阅读更多精彩内容