bert训练之NER
这块基本就是用的macanv的BERT+CRF
训练数据,同样使用rasa,然后进行格式转换
###############################################################################################
# Generate NER training data
#
print("Starting NER data generation")
train_file = open("./ner_data/train.txt", "w+")
dev_file = open("./ner_data/dev.txt", "w+")
test_file = open("./ner_data/test.txt", "w+")
idx = 0
label_list = set()
for example in examples:
entities = example["entities"]
text = example["text"]
text = text.strip()
text = text.split("。")[0]
text = text + "。"
labels = []
for i in range(len(text)):
labels.append("O")
for entity in entities:
start_pos = entity["start"]
end_pos = entity["end"]
label = entity["entity"]
for i in range(start_pos, end_pos):
if i == start_pos:
labels[i] = "B-" + label
label_list.add("B-"+label)
else:
labels[i] = "I-" + label
label_list.add("I-"+label)
def print_line(line, labels):
ret = ""
if len(line) != len(labels):
print("Error: length of text and length of labels are not equal, {}".format(line))
else:
for i in range(len(line)):
ret += "{} {}\n".format(line[i], labels[i])
return ret
formated_text = print_line(text, labels)
idx += 1
#print("{}, text: {}, formated_text: {}".format((idx%10), text, formated_text))
train_file.writelines("{}\n".format(formated_text))
if (idx % 10) < 6:
#train_file.writelines("{}\n".format(formated_text))
pass
elif (idx % 10) < 8:
dev_file.writelines("{}\n".format(formated_text))
else:
test_file.writelines("{}\n".format(formated_text))
train_file.close()
dev_file.close()
test_file.close()
label_list = sorted(list(label_list))
with open("./ner_data/label_list.pkl", "wb") as label_file:
pickle.dump(label_list, label_file)
print("Finished generation, total example {}, label list: {}".format(idx, label_list))