原文章链接:https://www.jianshu.com/p/3d0bb34c488a
谷歌推出的Bert,最近有多火,估计做自然语言处理的都知道。据称在SQuAD等11项任务当中达到了state of the art。bert的原理可参考论文,或者网上其他人翻译的资料。谷歌已经在github上开源了代码,相信每一个从事NLP的都应该和我一样摩拳擦掌,迫不及待地想要学习它了吧。
就我个人而言学习一个开源项目,第一步是安装,第二步是跑下demo,第三步才是阅读源码。安装bert简单,直接github上拉下来就可以了,跑demo其实也不难,参照README.md一步步操作就行了,但是经我实操过后,发现里面有个小坑,所以用这篇文章记录下来,供读者参考。
闲言少叙,书归正传。本次介绍的demo只有两个,一个是基于MRPC(Microsoft Research Paraphrase Corpus )的句子对分类任务,一个是基于SQuAD语料的阅读理解任务。run demo分为以下几步:
1、下载bert源码
这没什么好说的,直接clone
git clone https://github.com/google-research/bert.git
2、下载预训练模型
为什么选择BERT-Base, Uncased
这个模型呢?原因有三:1、训练语料为英文,所以不选择中文或者多语种;2、设备条件有限,如果您的显卡内存小于16个G,那就请乖乖选择base,不要折腾large了;3、cased表示区分大小写,uncased表示不区分大小写。除非你明确知道你的任务对大小写敏感(比如命名实体识别、词性标注等)那么通常情况下uncased效果更好。
3、下载训练数据:
(1)下载MRPC语料:
官网上指定的方式是通过跑脚本download_glue_data.py来下载 GLUE data 。指定数据存放地址为:glue_data, 下载任务为:MRPC,执行(本篇中所有python3的命令同样适用于python):
python3 download_glue_data.py --data_dir glue_data --tasks MRPC
原始文章中是使用python3的,我机器上的python版本是python2,所以直接使用其代码会有问题,我下面给出我改动后的download_glue_data.py
。所以其运行命令就变成了:
python download_glue_data.py --data_dir glue_data --tasks MRPC
# download_glue_data.py 代码
import os
import sys
import shutil
import argparse
import tempfile
import urllib
import zipfile
import codecs
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
"MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
"QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
"STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
"MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
"SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
"QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
"RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
"WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
"diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
def download_and_extract(task, data_dir):
print("Downloading and extracting %s..." % task)
data_file = "%s.zip" % task
urllib.urlretrieve(TASK2PATH[task], data_file)
with zipfile.ZipFile(data_file) as zip_ref:
zip_ref.extractall(data_dir)
os.remove(data_file)
print("\tCompleted!")
def format_mrpc(data_dir, path_to_data):
print("Processing MRPC...")
mrpc_dir = os.path.join(data_dir, "MRPC")
if not os.path.isdir(mrpc_dir):
os.mkdir(mrpc_dir)
if path_to_data:
mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
else:
print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
urllib.urlretrieve(MRPC_TRAIN, mrpc_train_file)
urllib.urlretrieve(MRPC_TEST, mrpc_test_file)
assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
urllib.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
dev_ids = []
file = os.path.join(mrpc_dir, "dev_ids.tsv")
with codecs.open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf-8") as ids_fh:
for row in ids_fh:
dev_ids.append(row.strip().split('\t'))
with codecs.open(mrpc_train_file, encoding="utf-8") as data_fh, \
codecs.open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf-8") as train_fh, \
codecs.open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf-8") as dev_fh:
header = data_fh.readline()
train_fh.write(header)
dev_fh.write(header)
for row in data_fh:
label, id1, id2, s1, s2 = row.strip().split('\t')
if [id1, id2] in dev_ids:
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
else:
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
with codecs.open(mrpc_test_file, encoding="utf-8") as data_fh, \
codecs.open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf-8") as test_fh:
header = data_fh.readline()
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
for idx, row in enumerate(data_fh):
label, id1, id2, s1, s2 = row.strip().split('\t')
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
print("\tCompleted!")
def download_diagnostic(data_dir):
print("Downloading and extracting diagnostic...")
if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
os.mkdir(os.path.join(data_dir, "diagnostic"))
data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
urllib.urlretrieve(TASK2PATH["diagnostic"], data_file)
print("\tCompleted!")
return
def get_tasks(task_names):
task_names = task_names.split(',')
if "all" in task_names:
tasks = TASKS
else:
tasks = []
for task_name in task_names:
assert task_name in TASKS, "Task %s not found!" % task_name
tasks.append(task_name)
return tasks
def main(arguments):
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
type=str, default='all')
parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
type=str, default='')
args = parser.parse_args(arguments)
if not os.path.isdir(args.data_dir):
os.mkdir(args.data_dir)
tasks = get_tasks(args.tasks)
for task in tasks:
if task == 'MRPC':
format_mrpc(args.data_dir, args.path_to_mrpc)
elif task == 'diagnostic':
download_diagnostic(args.data_dir)
else:
download_and_extract(task, args.data_dir)
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
如果上述方法不行我找到了网友百度云的分享:
链接:https://pan.baidu.com/s/1-b4I3ocYhiuhu3bpSmCJ_Q
提取码:z6mk
(2)下载SQuAD语料:
基本上没什么波折,可以使用下面三个链接直接下载,放置于$SQUAD_DIR路径下
-
train-v1.1.json 直接浏览器打开然后
command+s
自动保存 -
dev-v1.1.json 直接浏览器打开然后
command+s
自动保存 - evaluate-v1.1.py
4、run demo
(1) 基于MRPC语料的句子对分类任务
训练:
设置环境变量,指定预训练模型文件和语料地址
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue_data
在bert源码文件里执行run_classifier.py,基于预训练模型进行fine-tune
python run_classifier.py \
--task_name=MRPC \
--do_train=true \
--do_eval=true \
--data_dir=$GLUE_DIR/MRPC \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=3.0 \
--output_dir=/tmp/mrpc_output/
模型保存在output_dir, 验证结果为:
# 在单机上面跑我跑了大概3个小时。。。囧,有GPU还是用GPU
INFO:tensorflow:***** Eval results *****
INFO:tensorflow: eval_accuracy = 0.86519605
INFO:tensorflow: eval_loss = 0.40176657
INFO:tensorflow: global_step = 343
INFO:tensorflow: loss = 0.40176657
预测:
指定fine-tune之后模型文件所在地址
export TRAINED_CLASSIFIER=/path/to/fine/tuned/classifier
执行以下语句完成预测任务,预测结果输出在output_dir文件夹中
python run_classifier.py \
--task_name=MRPC \
--do_predict=true \
--data_dir=$GLUE_DIR/MRPC \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$TRAINED_CLASSIFIER \
--max_seq_length=128 \
--output_dir=/tmp/mrpc_output/
(2)基于SQuAD语料的阅读理解任务
设置为语料所在文件夹为$SQUAD_DIR
python run_squad.py \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--do_train=True \
--train_file=$SQUAD_DIR/train-v1.1.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v1.1.json \
--train_batch_size=12 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=/tmp/squad_base/
在output_dir文件夹下会输出一个predictions.json文件,执行:
python3 $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json predictions.json
看到以下结果,说明执行无误:
{"f1": 88.41249612335034, "exact_match": 81.2488174077578}
5、总结:
本篇内容主要解决了以下两个问题:
(1) 基于MRPC语料的句子对分类任务和基于SQuAD语料的阅读理解任务的demo执行,主要是翻译源码中README.md的部分内容;
(2) 对于部分语料无法下载的情况,提供了其他的搜集方式。系列后续将对bert源码进行解读,敬请关注
Reference
1.https://github.com/google-research/bert