这里记录以下在fairseq中微调roberta和使用bart的方法。本来想使用bart微调的,但是遇到了bug现在还没调通,因此曲线救国,使用了roberta,后面如果调通了,会补上的,以下代码和官网上https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.glue.md一致,由于简书复制少空格可能会有错误,以官网为主。
1、安装
git clone https://github.com/pytorch/fairseq
cdfairseq
pip install --editable ./
2、如果想要训练更快,安装apex
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext"--global-option="--cuda_ext"\ --global-option="--deprecated_fused_adam"--global-option="--xentropy"\ --global-option="--fast_multihead_attn"./
3、使用bart,下载预训练模型,这里选择bart.large(其他还有:bart.base bart.large
from fairseq.models.bart import BARTModel
bart = BARTModel.from_pretrained('/path/to/bart.large',checkpoint_file='model.pt')
bart.eval()# disable dropout (or leave in train mode to finetune)
4、使用Bart做句子分类任务,这里使用了在Mnli数据集上微调的模型(将句子对分三类)
bart=torch.hub.load('pytorch/fairseq','bart.large.mnli')
bart.eval()
tokens=bart.encode('BART is a seq2seq model.','BART is not sequence to sequence.')
bart.predict('mnli',tokens).argmax()
tokens=bart.encode('BART is denoising autoencoder.','BART is version of autoencoder.')
bart.predict('mnli',tokens).argmax()