智能交互助手 - bert训练(二)

bert训练之NER

这块基本就是用的macanv的BERT+CRF
训练数据,同样使用rasa,然后进行格式转换

    ###############################################################################################
    # Generate NER training data
    #
    print("Starting NER data generation")
    train_file = open("./ner_data/train.txt", "w+")
    dev_file = open("./ner_data/dev.txt", "w+")
    test_file = open("./ner_data/test.txt", "w+")
    idx = 0
    label_list = set()
    for example in examples:

        entities = example["entities"]
        text = example["text"]
        text = text.strip()
        text = text.split("。")[0]
        text = text + "。"
        labels = []
        for i in range(len(text)):
            labels.append("O")
        for entity in entities:
            start_pos = entity["start"]
            end_pos = entity["end"]
            label = entity["entity"]
            for i in range(start_pos, end_pos):
                if i == start_pos:
                    labels[i] = "B-" + label
                    label_list.add("B-"+label)
                else:
                    labels[i] = "I-" + label
                    label_list.add("I-"+label)

        def print_line(line, labels):
            ret = ""
            if len(line) != len(labels):
                print("Error: length of text and length of labels are not equal, {}".format(line))
            else:
                for i in range(len(line)):
                    ret += "{} {}\n".format(line[i], labels[i])

            return ret

        formated_text = print_line(text, labels)

        idx += 1
        #print("{}, text: {}, formated_text: {}".format((idx%10), text, formated_text))
        train_file.writelines("{}\n".format(formated_text))
        if (idx % 10) < 6:
            #train_file.writelines("{}\n".format(formated_text))
            pass
        elif (idx % 10) < 8:
            dev_file.writelines("{}\n".format(formated_text))
        else:
            test_file.writelines("{}\n".format(formated_text))
    train_file.close()
    dev_file.close()
    test_file.close()

    label_list = sorted(list(label_list))
    with open("./ner_data/label_list.pkl", "wb") as label_file:
        pickle.dump(label_list, label_file)
    print("Finished generation, total example {}, label list: {}".format(idx, label_list))

详细的信息可以参考
https://github.com/macanv/BERT-BiLSTM-CRF-NER

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容