NLU的部分,把bert要finetune一下
用的是bert标准的classification的任务
直接引用官方的说明
bert
代码简单修改了一下,增加了rasa的task
class RasaProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self, data_dir):
"""See base class."""
# 通过读取train文件获取标签的方法会出现一定的风险。
if os.path.exists(os.path.join(data_dir, 'label_list.pkl')):
with open(os.path.join(data_dir, 'label_list.pkl'), 'rb') as rf:
self.labels = pickle.load(rf)
else:
self.labels = ["O", 'B-TIM', 'I-TIM', "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"]
return self.labels
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
if set_type == "test":
label = "重来"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, label=label))
return examples
这里的核心就是get_labels
因为rasa框架里,label是可以动态添加的,也就是对应了我们要识别的意图,因此,采用label_list文件读取。
这里就涉及到bert训练数据的生成
基于之前chatito生成的rasa的json文件,直接读取,转换成bert的classification格式的数据,同时保留一定比例的训练、测试数据
with open(sys.argv[1], "r") as input_file:
training_data = json.loads(input_file.read())
#print(training_data)
nlu_data = training_data.get("rasa_nlu_data","")
examples = nlu_data.get("common_examples","")
###############################################################################################
# Generate classifier training data
#
print("Starting classifier data generation")
train_file = open("./classifier_data/train.tsv", "w+")
train_file.writelines("intent\tid1\tid2\ttext\n")
dev_file = open("./classifier_data/dev.tsv", "w+")
dev_file.writelines("intent\tid1\tid2\ttext\n")
test_file = open("./classifier_data/test.tsv", "w+")
test_file.writelines("intent\tid1\tid2\ttext\n")
idx = 0
label_list = set()
for example in examples:
idx += 1
#print("{}, intent: {}, text: {}".format((idx%10), example["intent"], example["text"]))
train_file.writelines("{}\t{}\t0\t{}\t\n".format(example["intent"], idx, example["text"]))
if (idx % 10) < 6:
#train_file.writelines("{}\t0\t0\t{}\t\n".format(example["intent"], example["text"]))
pass
elif (idx % 10) < 8:
dev_file.writelines("{}\t{}\t0\t{}\t\n".format(example["intent"], idx, example["text"]))
else:
test_file.writelines("{}\t{}\t0\t{}\t\n".format(example["intent"], idx, example["text"]))
label_list.add(example["intent"])
train_file.close()
dev_file.close()
test_file.close()
label_list = sorted(list(label_list))
with open("./classifier_data/label_list.pkl", "wb") as label_file:
pickle.dump(label_list, label_file)
print("Finished generation, total example {}, label list: {}".format(idx, label_list))
- 将bert的classification整合到macanv的BERT-BILSTM-CRF-NER
在macanv的框架里,稍微调整了一下代码结构
同时,还需要额外生成label2id.pkl文件,以供后续使用
需要在macanv的训练代码中额外增加:
# 保存label->index 的map
if not os.path.exists(os.path.join(output_dir, 'label2id.pkl')):
with open(os.path.join(output_dir, 'label2id.pkl'), 'wb') as w:
pickle.dump(label_map, w)
print("label map: {}".format(label_map))
详细内容请参考
https://github.com/xgzhang83/BERT-BiLSTM-CRF-NER/blob/master/bert_base/train/bert_classifier.py