base-llm 2.2.4 模型的推理与优化

一、理解模型输出

实体级别的F1值是衡量模型性能的核心标准,而非简单的Token分类准确率。

1.1 Token级准确率的陷阱

Token分类准确率,也就是模型预测的正确标签数占总标签数的比例。
指标具有误导性: 因为数据不均衡, 也就是实体词占比较低的场景中,预测都为0,也能等到高的Token准确率。

模型训练到一个阶段之后,预测结果会出现大量甚至全部为‘0’。 看上去Token准确率很高,实际上模型陷入预测多数类来最小化损失,这是过拟合现象,说明模型没有真正学会实体识别。

1.2 对推理流程的启发

模型原始输出(Token标签序列)本身不是最终交付物,需要一个后处理或解码步骤,转化为用户关心的结构化的实体列表。

二、从标签到实体: 解码预测序列

2.1 解码逻辑

贪心策略: 当前是current_entity, 遇到E-则结束
标签包括BMESO
B-: 代表实体开始
M-:实体中间
E-: 实体结束
S-: 单字实体
O: 非实体

严格模式:只有符合B开头E结尾的是正常实体
宽松策略: B-M-O也可以作为实体输出,取决于对召回率和精确率的不同侧重,需要根据实际需求来决定。

# code/C8/06_predict.py

def _extract_entities(self, tokens, tags):
   entities = []
   current_entity = None
   for i, tag in enumerate(tags):
       if tag.startswith('B-'):
           # 如果前一个实体未正确结束,则放弃
           if current_entity:
               pass # 或者可以根据业务逻辑决定是否保存不完整的实体
           current_entity = {"text": tokens[i], "type": tag[2:], "start": i}
       elif tag.startswith('M-'):
           # M 标签必须跟在 B- 或 M- 之后
           if current_entity and current_entity["type"] == tag[2:]:
               current_entity["text"] += tokens[i]
           else:
               # 非法 M 标签,重置当前实体
               current_entity = None
       elif tag.startswith('E-'):
           # E 标签必须跟在 B- 或 M- 之后
           if current_entity and current_entity["type"] == tag[2:]:
               current_entity["text"] += tokens[i]
               current_entity["end"] = i + 1
               entities.append(current_entity)
           # 实体已结束,重置
           current_entity = None
       elif tag.startswith('S-'):
           # S 标签表示单个字符的实体
           # 如果有未结束的实体,则放弃
           current_entity = None
           entities.append({"text": tokens[i], "type": tag[2:], "start": i, "end": i + 1})
       else: # 'O' 标签
           # O 标签意味着没有实体,或者实体已经结束
           # 如果有未结束的实体,则放弃
           current_entity = None
   
   # 循环结束后,不再处理任何未闭合的实体
   return entities

三、封装推理器

设计思想:

  1. 易于初始化
  2. 接口简洁: 提供简单的predict(text)方法,接收原始文本字符串,返回结构化实体列表。
  3. 与训练解耦

3.2 NerPredictor核心流程

3.2.1 初始化init

  1. 加载配置: 从模型目录加载config.json, 获取模型超参数和相关文件路径
  2. 加载词汇表和标签映射: vocabulary.json和tags.json,并构建id2tag映射
  3. 加载分词器: 初始化CharTokenizer
  4. 初始化模型并加载权重:
    • 根据配置初始化BiGruNerNetWork模型
    • 从模型目录加载best_model.pth模型权重,这里需要使用map_location=self.device来去顶模型可以被加载到指定的设备上
    • 调用model.to(self.device)将模型移动至指定设备
    • 调用model.eval() 模型切换到评估模式,关闭Dropout和BatchNorm等只在训练时使用的层,确包预测结果的确定性

3.2.2 预测predict

  1. 预处理: 文本-> token_ids-> Tensor[batch_size,seq_len]
    创建 attention_mask。将所有张量移动到 self.device。
  2. 模型预测:
    使用 with torch.no_grad(): 临时禁用梯度计算,减少内存消耗并加速推理过程。
    将 token_ids 和 attention_mask 送入模型,得到 logits。
  3. 后处理:
    对 logits 在最后一个维度上执行 argmax,得到预测的 label_ids 序列。
    使用 id2tag 映射,将 label_ids 转换为 tags 字符串列表。
    调用 _extract_entities 方法,完成最终的解码,返回实体列表。

四、自定义损失函数

因为数据不均衡,大多数都是非实体,可以在损失函数上,主动引导模型去关注实体样本。

4.1 核心策略

4.1.1. 加权交叉熵损失

给数量稀少的实体标签(BMES)一个更高的权重,非实体标签较低的权重。

4.1.2 硬负样本挖掘

采样,在大量非实体样本中,大多数都是模型可以轻易正确预测的简单样本,对损失贡献很小,反复学习易于不大。真正有价值的是哪些模型容易搞错的硬负样本,例如一个模型倾向于预测为实体的非实体Token。

在计算非实体部分的损失时, 不计算所有实体Token的平均损失,只选择损失值最大(Top-K)的一部分进行计算和反向传播, 相当于从海量的"多数派"中,筛选出了最右价值的疑难样本进行学习,提升了训练的效率和效果。

4.2.6 解读验证集损失

使用自定义损失函数会发现: 验证集上的F1分数在稳步提升,但loss值确停滞不前甚至上升,这是正常且符合预期的现象。
因为Trainer评估节点使用自定义、加权的损失函数来计算验证集,loss主要反应的是训练目标的优化情况,不是一个标准的评估指标。

  • 权重影响:实体部分损失赋予很高权重,少数几个实体相关的错误会导致loss值大幅波动或居高不下
  • 硬负样本挖掘影响: 策略会动态聚焦于模型最容易搞错的那些非实体标签,随着训练进行,简单敷衍本损失会降低,但模型会转而面对更棘手的硬样本,倒是计算出的non_ner_loss可能不会持续下降。

因此,当使用这些高级损失策略时,验证集loss不再是衡量模型好坏的主要标准。应将注意力放在最终目标的指标上,对于NER任务,是实体级别的F1分数,这也是Trainer用F1作为保存最佳模型的一句的原因。

五、优化训练工作流

可视化日志、提前停止和断点续训,让训练过程更加可控、高效和可靠。

5.1 训练过程可视化

日志记录功能模块化。使用tensorboard。
纯文本的训练日志虽然直接,但难以洞察模型训练全局动态, 例如观察损失是否平稳下降、验证集F1是否持续提升,以及模型是否出现过拟合现象,可以集成TensorBoard来实现可视化;同时为提高客服现行,建议在训练开始前固定随机数种子。

为什么要固定随机数种子?
在深度学习中,许多操作依赖随机数:
模型参数初始化
Dropout层的随机丢弃
数据打乱(shuffle)
随机数据增强
随机采样

5.2 早停实现

在每个epoch结束时被Trainer调用来检查是否需要早停。

提前停止是一种简单而高效的正则化策略。其核心思想是在训练过程中持续监控模型在验证集上的性能。如果验证集准确率(或损失)连续 N 个轮次(N 称为“耐心值” patience)没有超过历史最佳水平,就认为模型已经达到了最佳点或开始过拟合,此时应提前终止训练。我们在 Trainer 类的基础上创建一个子类,重写 train 方法以实现该逻辑。

5.3 实现断点续训

对需要数小时甚至数天的长时间训练任务,意外中断时常见风险,从头训练会造成巨大的时间浪费。断点续训机制允许保存训练过程中完整状态(包括模型权重、优化器状态和当前轮数),应在需要时从中恢复,继续训练。

六、小结

命名实体项目流程

  1. 数据处理与准备: 解析原有CmeEE数据集,构建全局统一的BMES标签映射,字符集词汇表,并最终封装成一个高效的、可复用的DataLoader。
  2. 模型构建与训练框架: 设计并实现基于Bi-GRU的序列标注模型,并围绕它打造了一个结构清晰、组件化的训练框架。通过模型、数据加载器、分词器、评估指标等核心功能解耦,构造了一个易于维护和扩展的Trainer类。
  3. 推理与工作流优化: 实现从模型输出到结构化实体的解码逻辑,并将其封装成一个开箱即用的NerPredictor推力器。同时,为了提升训练框架的健壮性和实用性,害继承了自定义损失函数来应对数据不均衡问题,并引入了TensorBoard可视化日志、提前停止和断点续训等功能。
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容