一、数据集介绍与加载
本次使用的数据集来自 Hugging Face 平台,是一个中文商品评价数据集,包含正向(positive)和负向(negative)两类标签,适用于文本分类任务。数据集通过人工标注完成,标注过程中对模糊语句进行了单独处理与审核,确保数据质量。
数据集包含三个部分:
- 训练集:9600 条文本/标签对
- 验证集:1200 条
- 测试集:1200 条
标签为 0(负向)和 1(正向),例如:
- 负向示例:“入住的时候刚装修味道很浓,房间内的床垫坐上去好像有点不稳。” → 标签 0
- 正向示例:“赶上五一促销入手,做工还行,接口够用,价格能接受,电池强劲…” → 标签 1
第一步:在线加载
from datasets import load_dataset
ds = load_dataset("lansinuote/ChnSentiCorp", cache_dir=cache_dir)
这种方式会默认从 Hugging Face 平台下载数据集到本地,且必须联网,即使本地已经有有数据集,也需要联网,但不会重复下载。数据集格式为 Arrow(平台自定义的加密格式),无法直接查看,可通过代码访问。
第二步:本地加载(推荐)
在第一步中已把数据集已下载到本地,然后可使用 Dataset.from_file 加载,并使用 PyTorch 创建自定义数据集:
- PyTorch 中自定 Dataset 类必须实现
__init__(self)、__getitem__(self, index)以及__len__(self)三个方法; -
__init__(self)方法中首先加载父类的方法,然后通过Dataset.from_file加载数据集; -
__getitem__(self, index)返回指定索引的数据,这里返回评论的文本和标签。
from datasets import Dataset
from torch.utils import data
class MyDataset(data.Dataset):
def __init__(self):
super().__init__()
self.train_dataset = Dataset.from_file(r'/Users/Desktop/huggingface/data/lansinuote___chn_senti_corp/default/0.0.0/b0c4c119c3fb33b8e735969202ef9ad13d717e5a/chn_senti_corp-train.arrow')
self.validation_dataset = Dataset.from_file(r'/Users/Desktop/huggingface/data/lansinuote___chn_senti_corp/default/0.0.0/b0c4c119c3fb33b8e735969202ef9ad13d717e5a/chn_senti_corp-validation.arrow')
self.test_dataset = Dataset.from_file(r'/Users/Desktop/huggingface/data/lansinuote___chn_senti_corp/default/0.0.0/b0c4c119c3fb33b8e735969202ef9ad13d717e5a/chn_senti_corp-test.arrow')
def __len__(self):
return len(self.train_dataset)
def __getitem__(self, item):
text = self.train_dataset[item]["text"]
label = self.train_dataset[item]["label"]
return text, label
if __name__ == '__main__':
data = MyDataset()
for d in data:
print(d)
注意:
- 路径应为包含
dataset_info.json的根目录。 - 使用绝对路径,并建议在路径字符串前加
r防止转义。
加载后可查看数据集信息:
{'text': '酒店的位置不错,附近都靠近购物中心和写字楼区。以前来大连一直都住,但感觉比较陈旧了。住的期间,酒店在进行装修,翻新和升级房间设备。好是好,希望到时房价别涨太多了。', 'label': 1}
{'text': '位置不很方便,周围乱哄哄的,卫生条件也不如其他如家的店。以后绝不会再住在这里。', 'label': 0}
{'text': '抱着很大兴趣买的,买来粗粗一翻排版很不错,姐姐还说快看吧,如果好我也买一本。可是真的看了,实在不怎么样。就是中文里夹英文单词说话,才翻了2页实在不想勉强自己了。我想说的是,练习英文单词,靠这本书肯定没有效果,其它好的方法比这强多了。', 'label': 0}
{'text': '东西不错,不过有人不太喜欢镜面的,我个人比较喜欢,总之还算满意。', 'label': 1}
{'text': '房间不错,只是上网速度慢得无法忍受,打开一个网页要等半小时,连邮件都无法收。另前台工作人员服务态度是很好,只是效率有得改善。', 'label': 1}
{'text': '挺失望的,还不如买一本张爱玲文集呢,以<色戒>命名,可这篇文章仅仅10多页,且无头无尾的,完全比不上里面的任意一篇其它文章.', 'label': 0}
二、模型和分词器的本地加载
2.1 模型来源
同样通过 Hugging Face 平台获取 bert-base-chinese,这是一个中文文本分类模型。
下载模型的同时需要下载对应的分词器,原因在于模型不能直接识别文字,分词器的作用是把每个文字转为模型可识别的数字,再把数字输入给模型,不同的模型有不同的分词器,因此模型与分词器必须匹配。
from transformers import AutoModel, BertTokenizerFast
model_name = 'ckiplab/bert-base-chinese'
cache_dir = 'model'
# 下载分词器
BertTokenizerFast.from_pretrained('bert-base-chinese', cache_dir=cache_dir)
print("tokenizer done")
# 下载模型
AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
print("model done")
平台上的模型大多存储在 Google 云盘,下载时需联网,但可缓存到本地供后续离线使用。
from transformers import BertModel, BertTokenizerFast
# 从本地加载模型和分词器
tokenizer = BertTokenizerFast.from_pretrained(r'/Users/Desktop/huggingface/model/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea')
model = BertModel.from_pretrained(r'/Users/Desktop/huggingface/model/models--ckiplab--bert-base-chinese/snapshots/efe27bb4a9373384e0120ffe1cf327714ceb61bf')
print(model)
2.2 模型结构简介
打印 BERT 模型,其结构主要分为三部分:
- Embedding 层:将输入的位置编码(即分词后的数字序列)转换为 768 维的词向量。
- Encoder 层:由多层 Transformer 编码器组成,用于提取特征。BERT-base 包含 12 层 Encoder,这是 Transformer 模型有效的最低层数要求。
- Pooler 层:由一个全联接层和一个激活函数组成,其中全联接层的输出维度为 768,这个很重要,关系到后面模型的搭建。
三、模型设计
定义模型时,只需要定义全连接层,再追加到 BERT 模型的输出。因此整体流程是:BERT —> 自定义全连接层 —> softmax。
- BERT 模型:作为通用语言理解基座,提供高质量的文本特征表示
- 自定义全连接层:作为专用任务头,将通用特征映射到具体业务分类
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
# in_features=768 是BERT的输出维度,out_features=2 是分类数
self.fc = torch.nn.Linear(in_features=768, out_features=2)
def forward(self, input_ids, attention_mask, token_type_ids):
# BERT 模型不需要训练
with torch.no_grad():
out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
# 取[CLS]位置的隐藏状态
out = self.fc(out.last_hidden_state[:, 0])
out = out.softmax(dim=1)
return out
BERT 的输入参数包含:
- input_ids: 将文本转换成的token索引序列。例如“我喜欢自然语言处理” → [101, 2769, 4263, 3315, 4476, 1566, 5543, 102]
- attention_mask: 标识哪些token是有效内容,哪些是填充(padding)。因为 BERT 要求输入是固定长度的,所以对于短于最大长度的序列,用 0 进行填充,会忽略这些填充词。
- token_type_ids: 区分句子对中的第一个句子和第二个句子
BERT 的这些输入将在 Dataloader 中,使用分词器提取。
BERT 的输出参数包含:
- last_hidden_state:最后一层的隐藏状态 [batch_size, seq_len, hidden_size]
- pooler_output:[CLS] token 经过线性层和 tanh 激活后的表示 [batch_size, hidden_size]
取 [CLS] 位置的隐藏状态,是因为 [CLS] 包含了整个句子的信息,并且预训练任务中已经训练它做类似的任务。其他 token 的表示则更多关注局部信息,不适合直接用于整个句子的分类。
四、创建自定义 Dataloader
Dataloader 需要去从 Dataset 读取数据,它提供了一种简便的方式来迭代数据集。
这里需要设置批处理数据读取,有一个很重要的操作就是要先定义一个 collate_fn 函数,在这个函数中我们将 Dataset 原始文本通过分词器(tokenizer)进行词向量的转换,转换为模型可理解的参数输入给模型。
collate_fn 是一个可选参数,允许用户自定义如何将多个数据样本合并成一个 batch,由于 bert 模型需要三个入参(input_ids, attention_mask, token_type_ids),因此需要把数据集通过分词器转化为这三个参数。
import torch
from my_dataset import MyDataset
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
# 加载分词器
token_path = r'/Users/Desktop/huggingface/model/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea'
tokenizer = BertTokenizerFast.from_pretrained(token_path)
def collate_fn(data):
sentes = [i[0] for i in data]
label = [i[1] for i in data]
data = tokenizer.batch_encode_plus(
batch_text_or_text_pairs=sentes,
truncation=True,
padding='max_length',
max_length=512,
return_tensors='pt', # 输出数据将作为 PyTorch 张量返回,而不是 NumPy 数组或其他格式
return_length=True
)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
label = torch.LongTensor(label)
return input_ids, attention_mask, token_type_ids, label
# 初始化数据集
train_data = MyDataset()
# 数据加载
train_loader = DataLoader(
dataset=train_data,
batch_size=32,
shuffle=True,
drop_last=True,
collate_fn=collate_fn
)
五、模型训练
定义模型、优化器、损失函数,开启训练模式,在训练批次中,从 train_loader 加载数据,并把数据(input_ids, attention_mask, token_type_ids)传入给模型。
model = Model().to('cpu')
optimizer = AdamW(model.parameters(), lr=0.001)
loss_func = torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(10):
for i, (input_ids, attention_mask, token_type_ids, label) in enumerate(train_loader):
input_ids, attention_mask, token_type_ids, label = input_ids.to('cpu'), attention_mask.to('cpu'), token_type_ids.to('cpu'), label.to('cpu')
out = model(input_ids, attention_mask, token_type_ids)
loss = loss_func(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 5 == 0:
out = out.argmax(dim=1)
acc = (out == label).sum().item() / len(label)
print(f'epoch={epoch}, i={i}, loss={loss.item()}, acc={acc}')
torch.save(model.state_dict(), f'param/{epoch}bert.pt')
print(epoch, 'save success')
开始训练,打印准确率:
epoch = 0, i=0, loss=0.7204890847206116, acc=0.4375
epoch = 0, i=5, loss=0.6873999834060669, acc=0.65625
epoch = 0, i=10, loss=0.5797467827796936, acc=0.78125
epoch = 0, i=15, loss=0.6594133377075195, acc=0.625
epoch = 0, i=20, loss=0.6064494252204895, acc=0.71875
epoch = 0, i=25, loss=0.649083137512207, acc=0.625
epoch = 0, i=30, loss=0.5833480358123779, acc=0.75
epoch = 0, i=35, loss=0.5145970582962036, acc=0.875
……
六、模型预测
预测的流程与训练流程类似,也需要把输入的文字先通过分词器转为模型能理解的数字,再输入给模型。
def collate_fn2(data):
sentes = []
sentes.append(data)
data = tokenizer.batch_encode_plus(
batch_text_or_text_pairs=sentes,
truncation=True,
padding='max_length',
max_length=512,
return_tensors='pt',
return_length=True
)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
return input_ids, attention_mask, token_type_ids
def test():
name = ["差评", "好评"]
model.load_state_dict(torch.load("param/1bert.pt"))
model.eval()
while True:
data = input("input content: ")
input_ids, attention_mask, token_type_ids = collate_fn2(data)
input_ids, attention_mask, token_type_ids = input_ids.to('cpu'), attention_mask.to('cpu'), token_type_ids.to('cpu')
with torch.no_grad():
out = model(input_ids, attention_mask, token_type_ids)
out = out.argmax(dim=1)
print(f"result={name[out]}\n")
input content: 5月8日付款成功,当当网显示5月10日发货,可是至今还没看到货物,也没收到任何通知,简不知怎么说好!!!
result=差评
input content: 下次不会再买了,不喜欢
result=差评
input content: 质量挺不错的,下次再来
result=好评