智能交互助手 - bert训练

NLU的部分,把bert要finetune一下

用的是bert标准的classification的任务

直接引用官方的说明


bert

代码简单修改了一下,增加了rasa的task

class RasaProcessor(DataProcessor):
  """Processor for the MRPC data set (GLUE version)."""

  def get_train_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

  def get_dev_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

  def get_test_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

  def get_labels(self, data_dir):
    """See base class."""
    # 通过读取train文件获取标签的方法会出现一定的风险。
    if os.path.exists(os.path.join(data_dir, 'label_list.pkl')):
        with open(os.path.join(data_dir, 'label_list.pkl'), 'rb') as rf:
            self.labels = pickle.load(rf)
    else:
        self.labels = ["O", 'B-TIM', 'I-TIM', "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"]
    return self.labels

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
      if i == 0:
        continue
      guid = "%s-%s" % (set_type, i)
      text_a = tokenization.convert_to_unicode(line[3])
      if set_type == "test":
        label = "重来"
      else:
        label = tokenization.convert_to_unicode(line[0])
      examples.append(
          InputExample(guid=guid, text_a=text_a, label=label))
    return examples

这里的核心就是get_labels
因为rasa框架里,label是可以动态添加的,也就是对应了我们要识别的意图,因此,采用label_list文件读取。

这里就涉及到bert训练数据的生成

基于之前chatito生成的rasa的json文件,直接读取,转换成bert的classification格式的数据,同时保留一定比例的训练、测试数据

with open(sys.argv[1], "r") as input_file:
    training_data = json.loads(input_file.read())
    #print(training_data)
    nlu_data = training_data.get("rasa_nlu_data","")
    examples = nlu_data.get("common_examples","")

    ###############################################################################################
    # Generate classifier training data
    #
    print("Starting classifier data generation")
    train_file = open("./classifier_data/train.tsv", "w+")
    train_file.writelines("intent\tid1\tid2\ttext\n")
    dev_file = open("./classifier_data/dev.tsv", "w+")
    dev_file.writelines("intent\tid1\tid2\ttext\n")
    test_file = open("./classifier_data/test.tsv", "w+")
    test_file.writelines("intent\tid1\tid2\ttext\n")
    idx = 0
    label_list = set()
    for example in examples:
        idx += 1
        #print("{}, intent: {}, text: {}".format((idx%10), example["intent"], example["text"]))
        train_file.writelines("{}\t{}\t0\t{}\t\n".format(example["intent"], idx, example["text"]))
        if (idx % 10) < 6:
            #train_file.writelines("{}\t0\t0\t{}\t\n".format(example["intent"], example["text"]))
            pass
        elif (idx % 10) < 8:
            dev_file.writelines("{}\t{}\t0\t{}\t\n".format(example["intent"], idx, example["text"]))
        else:
            test_file.writelines("{}\t{}\t0\t{}\t\n".format(example["intent"], idx, example["text"]))
        label_list.add(example["intent"])
    train_file.close()
    dev_file.close()
    test_file.close()

    label_list = sorted(list(label_list))
    with open("./classifier_data/label_list.pkl", "wb") as label_file:
        pickle.dump(label_list, label_file)
    print("Finished generation, total example {}, label list: {}".format(idx, label_list))
  • 将bert的classification整合到macanv的BERT-BILSTM-CRF-NER
    在macanv的框架里,稍微调整了一下代码结构
    同时,还需要额外生成label2id.pkl文件,以供后续使用
    需要在macanv的训练代码中额外增加:

  # 保存label->index 的map
  if not os.path.exists(os.path.join(output_dir, 'label2id.pkl')):
      with open(os.path.join(output_dir, 'label2id.pkl'), 'wb') as w:
          pickle.dump(label_map, w)

      print("label map: {}".format(label_map))

详细内容请参考
https://github.com/xgzhang83/BERT-BiLSTM-CRF-NER/blob/master/bert_base/train/bert_classifier.py

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。