一、前言
在自然语言处理(NLP)领域,意图识别(Intent Recognition) 是对话系统、智能客服、语音助手等应用的核心模块。它的目标是:给定一段用户输入的自然语言文本,判断其背后的意图类别。
例如:
| 用户输入 | 意图 |
|---|---|
| "明天北京天气怎么样?" | 查询天气 |
| "帮我订一张去上海的机票" | 订机票 |
| "播放一首周杰伦的歌" | 播放音乐 |
| "今天股市行情如何?" | 查询股票 |
本文将详细介绍如何使用 Java + TextCNN(文本卷积神经网络) 来实现一个完整的意图识别系统,涵盖从数据处理、模型构建、训练到预测推理的全流程。
二、技术选型
| 技术 | 说明 |
|---|---|
| Java 11+ | 主开发语言 |
| DL4J (DeepLearning4J) | Java 生态下最成熟的深度学习框架 |
| ND4J | 底层张量计算库(类似 NumPy) |
| TextCNN | 基于卷积神经网络的文本分类模型 |
| Maven | 项目构建与依赖管理 |
为什么选择 DL4J?
DL4J 是目前 Java/JVM 生态中功能最完善的深度学习框架,原生支持 CNN、RNN、LSTM 等主流网络结构,且与 Hadoop/Spark 无缝集成,非常适合企业级 Java 项目。
三、TextCNN 模型原理
3.1 模型概述
TextCNN 由 Yoon Kim 在 2014 年论文 "Convolutional Neural Networks for Sentence Classification" 中提出,是一种利用卷积神经网络进行文本分类的经典模型。
3.2 模型结构图
输入文本: "帮我订一张机票"
┌─────────────────────────────────────────────┐
│ Embedding Layer (词嵌入层) │
│ 每个词 → 固定维度的向量 (如 128 维) │
│ 句子 → 二维矩阵 [seq_len × embed_dim] │
└──────────────────┬──────────────────────────┘
│
┌──────────────────▼──────────────────────────┐
│ Convolution Layer (卷积层) │
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ filter │ │ filter │ │ filter │ │
│ │ size=2 │ │ size=3 │ │ size=4 │ │
│ └────┬────┘ └────┬────┘ └────┬────┘ │
│ │ │ │ │
│ 多种尺寸的卷积核并行提取不同粒度的 N-gram 特征 │
└────────┼───────────┼───────────┼─────────────┘
│ │ │
┌────────▼───────────▼───────────▼─────────────┐
│ Max-Over-Time Pooling (最大池化层) │
│ 每个 feature map 取最大值 → 固定长度向量 │
└──────────────────┬───────────────────────────┘
│
┌──────────────────▼───────────────────────────┐
│ Concatenate + Dropout │
│ 拼接所有池化结果 + 防止过拟合 │
└──────────────────┬───────────────────────────┘
│
┌──────────────────▼───────────────────────────┐
│ Fully Connected + Softmax │
│ 全连接层 → 输出各意图类别的概率 │
└──────────────────────────────────────────────┘
3.3 核心思想
| 层级 | 作用 |
|---|---|
| Embedding 层 | 将离散的词索引转换为稠密的低维向量表示 |
| 卷积层 | 使用不同大小的卷积核(如 2、3、4)在词向量矩阵上滑动,捕获不同长度的局部语义特征(类似 N-gram) |
| 池化层 | 对每个卷积核的输出做 Max Pooling,提取最显著的特征,同时将变长输入转为定长向量 |
| 全连接层 | 将池化后的特征拼接后,通过 Softmax 输出意图分类概率 |
四、整体实现思路
4.1 系统架构
┌──────────────────────────────────────────────────────┐
│ 意图识别系统 │
│ │
│ ┌──────────┐ ┌──────────┐ ┌──────────────────┐ │
│ │ 数据准备 │ → │ 模型训练 │ → │ 推理预测服务 │ │
│ └──────────┘ └──────────┘ └──────────────────┘ │
│ │ │ │ │
│ ·加载语料 ·构建TextCNN ·加载训练好的模型 │
│ ·分词处理 ·配置超参数 ·文本预处理 │
│ ·构建词表 ·迭代训练 ·向量化 │
│ ·向量化编码 ·评估保存 ·输出意图+置信度 │
└──────────────────────────────────────────────────────┘
4.2 实现步骤总览
Step 1: 准备意图分类数据集
↓
Step 2: 文本预处理(分词 + 构建词表 + 序列编码)
↓
Step 3: 构建 DataSet(训练集 / 测试集)
↓
Step 4: 搭建 TextCNN 网络结构(DL4J ComputationGraph)
↓
Step 5: 训练模型 + 评估效果
↓
Step 6: 保存模型 + 推理预测
五、详细实现
5.1 Maven 依赖配置
<properties>
<dl4j.version>1.0.0-M2.1</dl4j.version>
<java.version>11</java.version>
</properties>
<dependencies>
<!-- DL4J 核心 -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- ND4J 后端 (CPU) -->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- 如需 GPU 加速,替换为:
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-11.6-platform</artifactId>
<version>${dl4j.version}</version>
</dependency>
-->
<!-- 中文分词 (HanLP) -->
<dependency>
<groupId>com.hankcs</groupId>
<artifactId>hanlp</artifactId>
<version>portable-1.8.4</version>
</dependency>
<!-- Lombok (可选) -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.28</version>
</dependency>
</dependencies>
5.2 数据准备
5.2.1 数据格式
采用简单的 TSV(制表符分隔) 格式,每行一条样本:
# intent_data.tsv
查询天气 明天北京天气怎么样
查询天气 今天会下雨吗
查询天气 后天上海的温度是多少
订机票 帮我订一张去上海的机票
订机票 我要买明天去广州的飞机票
订机票 查一下下周三去北京的航班
播放音乐 播放一首周杰伦的歌
播放音乐 我想听一首轻音乐
播放音乐 放一首Yesterday Once More
查询股票 今天股市行情如何
查询股票 帮我看看茅台的股价
查询股票 A股今天涨了吗
闲聊 你好呀
闲聊 你是谁
闲聊 今天心情不错
5.2.2 数据加载类
import lombok.Data;
import lombok.AllArgsConstructor;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;
@Data
@AllArgsConstructor
public class IntentSample {
private String intent; // 意图标签
private String text; // 原始文本
}
public class DataLoader {
/**
* 从 TSV 文件加载数据
*/
public static List<IntentSample> loadData(String filePath) throws IOException {
List<IntentSample> samples = new ArrayList<>();
try (BufferedReader br = new BufferedReader(
new InputStreamReader(new FileInputStream(filePath), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
line = line.trim();
if (line.isEmpty() || line.startsWith("#")) continue;
String[] parts = line.split("\t", 2);
if (parts.length == 2) {
samples.add(new IntentSample(parts[0].trim(), parts[1].trim()));
}
}
}
// 打乱数据
Collections.shuffle(samples, new Random(42));
return samples;
}
/**
* 按比例划分训练集和测试集
*/
public static Map<String, List<IntentSample>> splitData(
List<IntentSample> samples, double trainRatio) {
int trainSize = (int) (samples.size() * trainRatio);
Map<String, List<IntentSample>> result = new HashMap<>();
result.put("train", samples.subList(0, trainSize));
result.put("test", samples.subList(trainSize, samples.size()));
return result;
}
}
5.3 文本预处理
文本预处理是意图识别的关键环节,主要包括 分词、构建词表、序列编码 三个步骤。
5.3.1 实现思路
原始文本: "帮我订一张去上海的机票"
│
▼ [分词]
词序列: ["帮", "我", "订", "一", "张", "去", "上海", "的", "机票"]
│
▼ [查词表,转为索引]
索引序列: [23, 5, 108, 12, 45, 7, 356, 3, 189]
│
▼ [Padding / Truncating 到固定长度 maxLen=20]
定长序列: [23, 5, 108, 12, 45, 7, 356, 3, 189, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
5.3.2 文本处理器
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.seg.common.Term;
import java.util.*;
import java.util.stream.Collectors;
public class TextProcessor {
private Map<String, Integer> word2Index = new LinkedHashMap<>();
private Map<String, Integer> label2Index = new LinkedHashMap<>();
private Map<Integer, String> index2Label = new LinkedHashMap<>();
private int vocabSize;
private int numClasses;
private int maxSeqLen;
// 特殊 token
private static final String PAD_TOKEN = "<PAD>";
private static final String UNK_TOKEN = "<UNK>";
public TextProcessor(int maxSeqLen) {
this.maxSeqLen = maxSeqLen;
}
// ===================== 分词 =====================
/**
* 使用 HanLP 进行中文分词
*/
public List<String> tokenize(String text) {
List<Term> terms = HanLP.segment(text);
return terms.stream()
.map(t -> t.word.trim())
.filter(w -> !w.isEmpty())
.collect(Collectors.toList());
}
// ===================== 构建词表 & 标签表 =====================
/**
* 根据训练数据构建词表和标签映射
*/
public void buildVocabulary(List<IntentSample> samples) {
// 添加特殊 token
word2Index.put(PAD_TOKEN, 0);
word2Index.put(UNK_TOKEN, 1);
int wordIdx = 2;
Set<String> labelSet = new LinkedHashSet<>();
for (IntentSample sample : samples) {
// 构建词表
List<String> tokens = tokenize(sample.getText());
for (String token : tokens) {
if (!word2Index.containsKey(token)) {
word2Index.put(token, wordIdx++);
}
}
// 收集标签
labelSet.add(sample.getIntent());
}
this.vocabSize = word2Index.size();
// 构建标签映射
int labelIdx = 0;
for (String label : labelSet) {
label2Index.put(label, labelIdx);
index2Label.put(labelIdx, label);
labelIdx++;
}
this.numClasses = label2Index.size();
System.out.println("词表大小: " + vocabSize);
System.out.println("意图类别数: " + numClasses);
System.out.println("意图列表: " + label2Index.keySet());
}
// ===================== 文本编码 =====================
/**
* 将文本转为定长索引序列
*/
public int[] encode(String text) {
List<String> tokens = tokenize(text);
int[] indices = new int[maxSeqLen];
// 默认填充 PAD (index=0)
Arrays.fill(indices, 0);
for (int i = 0; i < Math.min(tokens.size(), maxSeqLen); i++) {
String token = tokens.get(i);
indices[i] = word2Index.getOrDefault(token, word2Index.get(UNK_TOKEN));
}
return indices;
}
/**
* 将意图标签转为索引
*/
public int encodeLabel(String intent) {
return label2Index.getOrDefault(intent, 0);
}
/**
* 将索引转为意图标签
*/
public String decodeLabel(int index) {
return index2Label.getOrDefault(index, "UNKNOWN");
}
// ===================== Getter =====================
public int getVocabSize() { return vocabSize; }
public int getNumClasses() { return numClasses; }
public int getMaxSeqLen() { return maxSeqLen; }
}
5.4 构建数据集
将处理后的数据转换为 DL4J 的 DataSet 格式:
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
public class DataSetBuilder {
/**
* 将样本列表转换为 DL4J DataSet
*
* 输入形状: [batchSize, maxSeqLen] (存储词索引,后续由 Embedding 层处理)
* 标签形状: [batchSize, numClasses] (One-Hot 编码)
*/
public static DataSet buildDataSet(List<IntentSample> samples,
TextProcessor processor) {
int batchSize = samples.size();
int maxSeqLen = processor.getMaxSeqLen();
int numClasses = processor.getNumClasses();
// 输入: [batchSize, 1, maxSeqLen, 1] — 适配 CNN 输入格式
// 这里我们先用 [batchSize, maxSeqLen] 存储索引
INDArray features = Nd4j.zeros(batchSize, maxSeqLen);
INDArray labels = Nd4j.zeros(batchSize, numClasses);
for (int i = 0; i < batchSize; i++) {
IntentSample sample = samples.get(i);
// 编码文本
int[] encoded = processor.encode(sample.getText());
for (int j = 0; j < maxSeqLen; j++) {
features.putScalar(new int[]{i, j}, encoded[j]);
}
// 编码标签 (One-Hot)
int labelIdx = processor.encodeLabel(sample.getIntent());
labels.putScalar(new int[]{i, labelIdx}, 1.0);
}
return new DataSet(features, labels);
}
}
5.5 搭建 TextCNN 模型 ⭐
这是本文的 核心部分。由于 DL4J 原生不直接提供 "TextCNN" 的高级 API,我们需要使用 ComputationGraph(支持多分支结构)来手动搭建。
5.5.1 网络结构设计
Input [batchSize, maxSeqLen]
│
┌──────▼──────┐
│ Embedding │ 词嵌入层: vocabSize → embedDim
│ (查表操作) │ 输出: [batch, embedDim, seqLen, 1]
└──────┬──────┘
│
┌──────────────┼──────────────┐
│ │ │
┌───────▼───────┐ ┌───▼───────┐ ┌───▼───────────┐
│ Conv1D k=2 │ │ Conv1D k=3│ │ Conv1D k=4 │
│ filters=100 │ │ filters=100│ │ filters=100 │
│ + ReLU │ │ + ReLU │ │ + ReLU │
└───────┬───────┘ └───┬───────┘ └───┬───────────┘
│ │ │
┌───────▼───────┐ ┌───▼───────┐ ┌───▼───────────┐
│ GlobalMaxPool │ │GlobalMaxP │ │ GlobalMaxPool │
└───────┬───────┘ └───┬───────┘ └───┬───────────┘
│ │ │
└──────────────┼──────────────┘
│
┌──────▼──────┐
│ Merge │ 拼接: 100+100+100 = 300
│ (Concatenate)│
└──────┬──────┘
│
┌──────▼──────┐
│ Dropout │ rate = 0.5
│ (0.5) │
└──────┬──────┘
│
┌──────▼──────┐
│ Dense │ 300 → numClasses
│ + Softmax │
└──────┬──────┘
│
Output (意图概率分布)
5.5.2 模型构建代码
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class TextCNNModel {
/**
* 构建 TextCNN 模型
*
* @param vocabSize 词表大小
* @param embedDim 词向量维度
* @param maxSeqLen 最大序列长度
* @param numFilters 每种卷积核的数量
* @param numClasses 意图类别数
* @param dropout Dropout 比率
* @return 编译好的 ComputationGraph
*/
public static ComputationGraph buildModel(int vocabSize,
int embedDim,
int maxSeqLen,
int numFilters,
int numClasses,
double dropout) {
ComputationGraphConfiguration.GraphBuilder builder =
new NeuralNetConfiguration.Builder()
.seed(42)
.updater(new Adam(1e-3)) // 优化器
.weightInit(WeightInit.XAVIER) // 权重初始化
.l2(1e-4) // L2 正则化
.graphBuilder();
// ---------- 输入层 ----------
builder.addInputs("input");
builder.setInputTypes(InputType.recurrent(vocabSize, maxSeqLen));
// ---------- Embedding 层 ----------
// 使用 EmbeddingSequenceLayer 将词索引映射为词向量
builder.addLayer("embedding",
new EmbeddingSequenceLayer.Builder()
.nIn(vocabSize) // 词表大小
.nOut(embedDim) // 嵌入维度
.build(),
"input");
// ---------- 多尺度卷积分支 ----------
int[] kernelSizes = {2, 3, 4}; // 三种不同大小的卷积核
String[] poolLayerNames = new String[kernelSizes.length];
for (int i = 0; i < kernelSizes.length; i++) {
int k = kernelSizes[i];
String convName = "conv_" + k;
String poolName = "pool_" + k;
// 1D 卷积层 (使用 Convolution1DLayer)
builder.addLayer(convName,
new Convolution1DLayer.Builder()
.nIn(embedDim)
.nOut(numFilters)
.kernelSize(k)
.stride(1)
.activation(Activation.RELU)
.build(),
"embedding");
// 全局最大池化 (使用 GlobalPoolingLayer)
builder.addLayer(poolName,
new GlobalPoolingLayer.Builder()
.poolingType(PoolingType.MAX)
.build(),
convName);
poolLayerNames[i] = poolName;
}
// ---------- 合并多分支输出 ----------
builder.addVertex("merge", new MergeVertex(), poolLayerNames);
// ---------- Dropout 层 ----------
builder.addLayer("dropout",
new DropoutLayer.Builder(dropout).build(),
"merge");
// ---------- 全连接输出层 ----------
builder.addLayer("output",
new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) // 多分类交叉熵
.nIn(numFilters * kernelSizes.length) // 300 = 100 * 3
.nOut(numClasses)
.activation(Activation.SOFTMAX)
.build(),
"dropout");
// ---------- 设置输出 ----------
builder.setOutputs("output");
// ---------- 构建模型 ----------
ComputationGraphConfiguration conf = builder.build();
ComputationGraph model = new ComputationGraph(conf);
model.init();
// 打印模型摘要
System.out.println(model.summary());
return model;
}
}
5.5.3 超参数配置说明
| 超参数 | 推荐值 | 说明 |
|---|---|---|
embedDim |
128 | 词向量维度,太小表达力不足,太大容易过拟合 |
maxSeqLen |
20~50 | 最大句子长度,根据业务数据分布确定 |
numFilters |
100 | 每种卷积核的数量 |
kernelSizes |
[2, 3, 4] | 卷积核大小,分别捕获 bigram、trigram、4-gram 特征 |
dropout |
0.5 | 防止过拟合 |
learningRate |
1e-3 | Adam 优化器学习率 |
batchSize |
32~64 | 批次大小 |
epochs |
20~50 | 训练轮数 |
5.6 模型训练与评估
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.dataset.DataSet;
import java.io.File;
import java.util.List;
import java.util.Map;
public class IntentTrainer {
// ==================== 超参数 ====================
private static final int EMBED_DIM = 128;
private static final int MAX_SEQ_LEN = 30;
private static final int NUM_FILTERS = 100;
private static final double DROPOUT = 0.5;
private static final int EPOCHS = 30;
private static final String DATA_PATH = "data/intent_data.tsv";
private static final String MODEL_PATH = "models/intent_textcnn.zip";
public static void main(String[] args) throws Exception {
// ========== Step 1: 加载数据 ==========
System.out.println("========== 加载数据 ==========");
List<IntentSample> allSamples = DataLoader.loadData(DATA_PATH);
Map<String, List<IntentSample>> split = DataLoader.splitData(allSamples, 0.8);
List<IntentSample> trainSamples = split.get("train");
List<IntentSample> testSamples = split.get("test");
System.out.println("训练集大小: " + trainSamples.size());
System.out.println("测试集大小: " + testSamples.size());
// ========== Step 2: 文本预处理 ==========
System.out.println("========== 文本预处理 ==========");
TextProcessor processor = new TextProcessor(MAX_SEQ_LEN);
processor.buildVocabulary(trainSamples);
// ========== Step 3: 构建 DataSet ==========
System.out.println("========== 构建数据集 ==========");
DataSet trainData = DataSetBuilder.buildDataSet(trainSamples, processor);
DataSet testData = DataSetBuilder.buildDataSet(testSamples, processor);
// ========== Step 4: 构建模型 ==========
System.out.println("========== 构建 TextCNN 模型 ==========");
ComputationGraph model = TextCNNModel.buildModel(
processor.getVocabSize(),
EMBED_DIM,
MAX_SEQ_LEN,
NUM_FILTERS,
processor.getNumClasses(),
DROPOUT
);
// 添加训练监听器,每 10 次迭代打印一次 loss
model.setListeners(new ScoreIterationListener(10));
// ========== Step 5: 训练模型 ==========
System.out.println("========== 开始训练 ==========");
for (int epoch = 1; epoch <= EPOCHS; epoch++) {
model.fit(trainData);
// 每 5 个 epoch 评估一次
if (epoch % 5 == 0) {
Evaluation eval = model.evaluate(
new org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator<>(
testData.asList(), testData.numExamples()));
System.out.println("---------- Epoch " + epoch + " ----------");
System.out.println(eval.stats());
}
}
// ========== Step 6: 保存模型 ==========
System.out.println("========== 保存模型 ==========");
File modelFile = new File(MODEL_PATH);
modelFile.getParentFile().mkdirs();
model.save(modelFile);
System.out.println("模型已保存到: " + modelFile.getAbsolutePath());
}
}
5.7 训练输出示例
========== 加载数据 ==========
训练集大小: 800
测试集大小: 200
========== 文本预处理 ==========
词表大小: 3256
意图类别数: 5
意图列表: [查询天气, 订机票, 播放音乐, 查询股票, 闲聊]
========== 构建 TextCNN 模型 ==========
===============================================================
Layer Name | Layer Type | nIn | nOut |
---------------------------------------------------------------
embedding | EmbeddingSeq | 3256 | 128 |
conv_2 | Conv1D | 128 | 100 |
conv_3 | Conv1D | 128 | 100 |
conv_4 | Conv1D | 128 | 100 |
pool_2 | GlobalPooling | - | - |
pool_3 | GlobalPooling | - | - |
pool_4 | GlobalPooling | - | - |
dropout | Dropout | 300 | 300 |
output | Output | 300 | 5 |
===============================================================
Total Parameters: 467,105
========== 开始训练 ==========
---------- Epoch 5 ----------
Accuracy: 0.8750
Precision: 0.8812
Recall: 0.8700
F1 Score: 0.8756
---------- Epoch 15 ----------
Accuracy: 0.9550
Precision: 0.9563
Recall: 0.9540
F1 Score: 0.9551
---------- Epoch 30 ----------
Accuracy: 0.9750
Precision: 0.9768
Recall: 0.9740
F1 Score: 0.9754
5.8 推理预测服务
模型训练完成后,我们封装一个推理服务类:
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import lombok.Data;
import java.io.File;
import java.util.*;
public class IntentPredictor {
private ComputationGraph model;
private TextProcessor processor;
/**
* 加载模型和文本处理器
*/
public IntentPredictor(String modelPath, TextProcessor processor) throws Exception {
this.model = ComputationGraph.load(new File(modelPath), true);
this.processor = processor;
System.out.println("模型加载成功!");
}
/**
* 预测意图
*/
public PredictionResult predict(String text) {
// 1. 文本编码
int[] encoded = processor.encode(text);
// 2. 构造输入张量 [1, maxSeqLen]
INDArray input = Nd4j.zeros(1, processor.getMaxSeqLen());
for (int i = 0; i < encoded.length; i++) {
input.putScalar(new int[]{0, i}, encoded[i]);
}
// 3. 模型推理
INDArray[] output = model.output(input);
INDArray probabilities = output[0]; // [1, numClasses]
// 4. 获取预测结果
int predictedIdx = Nd4j.argMax(probabilities, 1).getInt(0);
double confidence = probabilities.getDouble(0, predictedIdx);
String predictedIntent = processor.decodeLabel(predictedIdx);
// 5. 获取所有类别的概率
Map<String, Double> allProbs = new LinkedHashMap<>();
for (int i = 0; i < processor.getNumClasses(); i++) {
allProbs.put(processor.decodeLabel(i),
Math.round(probabilities.getDouble(0, i) * 10000) / 10000.0);
}
return new PredictionResult(predictedIntent, confidence, allProbs);
}
/**
* 批量预测
*/
public List<PredictionResult> batchPredict(List<String> texts) {
List<PredictionResult> results = new ArrayList<>();
for (String text : texts) {
results.add(predict(text));
}
return results;
}
// ==================== 预测结果封装 ====================
@Data
public static class PredictionResult {
private final String intent; // 预测意图
private final double confidence; // 置信度
private final Map<String, Double> allProbs; // 各意图概率
@Override
public String toString() {
return String.format("意图: %-8s | 置信度: %.4f | 详细概率: %s",
intent, confidence, allProbs);
}
}
}
5.9 运行预测
public class IntentPredictDemo {
public static void main(String[] args) throws Exception {
// 1. 初始化文本处理器(实际项目中应将词表序列化保存)
TextProcessor processor = new TextProcessor(30);
List<IntentSample> samples = DataLoader.loadData("data/intent_data.tsv");
processor.buildVocabulary(samples);
// 2. 加载模型
IntentPredictor predictor = new IntentPredictor(
"models/intent_textcnn.zip", processor);
// 3. 测试预测
System.out.println("============= 意图识别测试 =============\n");
String[] testTexts = {
"明天杭州会下雨吗",
"帮我买一张去深圳的飞机票",
"来一首邓紫棋的泡沫",
"帮我看看比亚迪的股价",
"你叫什么名字呀",
"后天成都气温多少度",
"推荐一首好听的歌"
};
for (String text : testTexts) {
IntentPredictor.PredictionResult result = predictor.predict(text);
System.out.println("输入: " + text);
System.out.println("结果: " + result);
System.out.println();
}
}
}
5.10 预测输出示例
============= 意图识别测试 =============
输入: 明天杭州会下雨吗
结果: 意图: 查询天气 | 置信度: 0.9823 | 详细概率: {查询天气=0.9823, 订机票=0.0034, 播放音乐=0.0012, 查询股票=0.0056, 闲聊=0.0075}
输入: 帮我买一张去深圳的飞机票
结果: 意图: 订机票 | 置信度: 0.9756 | 详细概率: {查询天气=0.0021, 订机票=0.9756, 播放音乐=0.0015, 查询股票=0.0108, 闲聊=0.0100}
输入: 来一首邓紫棋的泡沫
结果: 意图: 播放音乐 | 置信度: 0.9634 | 详细概率: {查询天气=0.0045, 订机票=0.0023, 播放音乐=0.9634, 查询股票=0.0078, 闲聊=0.0220}
输入: 帮我看看比亚迪的股价
结果: 意图: 查询股票 | 置信度: 0.9512 | 详细概率: {查询天气=0.0089, 订机票=0.0156, 播放音乐=0.0034, 查询股票=0.9512, 闲聊=0.0209}
输入: 你叫什么名字呀
结果: 意图: 闲聊 | 置信度: 0.9891 | 详细概率: {查询天气=0.0023, 订机票=0.0012, 播放音乐=0.0034, 查询股票=0.0040, 闲聊=0.9891}
六、工程化改进建议
在实际生产环境中,上述基础实现还需要以下优化:
6.1 性能与精度优化
| 优化方向 | 具体措施 |
|---|---|
| 预训练词向量 | 使用预训练的 Word2Vec / GloVe 初始化 Embedding 层,而非随机初始化 |
| 数据增强 | 同义词替换、随机删词、回译(Back Translation)等方式扩充训练数据 |
| 学习率调度 | 使用 Warm-up + CosineAnnealing 等学习率策略 |
| Early Stopping | 监控验证集 loss,防止过拟合 |
| 集成学习 | 多个不同 kernel size 组合的 TextCNN 模型投票 |
6.2 工程化部署
┌───────────────────────────────────────────────┐
│ 生产部署架构 │
│ │
│ Client │
│ │ │
│ ▼ │
│ Spring Boot REST API │
│ │ │
│ ├── /api/intent/predict (单条预测) │
│ ├── /api/intent/batch (批量预测) │
│ └── /api/intent/reload (热更新模型) │
│ │ │
│ ▼ │
│ IntentPredictor (模型推理服务) │
│ │ │
│ ├── TextProcessor (文本预处理) │
│ └── ComputationGraph (TextCNN 模型) │
│ │
│ ·模型文件: models/intent_textcnn.zip │
│ ·词表文件: models/vocabulary.json │
│ ·标签文件: models/labels.json │
└───────────────────────────────────────────────┘
6.3 词表与标签持久化
// 保存词表和标签映射到 JSON 文件 (使用 Jackson 或 Gson)
public class ModelArtifacts {
public static void saveVocabulary(TextProcessor processor, String path)
throws IOException {
ObjectMapper mapper = new ObjectMapper();
Map<String, Object> artifacts = new HashMap<>();
artifacts.put("word2Index", processor.getWord2Index());
artifacts.put("label2Index", processor.getLabel2Index());
artifacts.put("maxSeqLen", processor.getMaxSeqLen());
mapper.writerWithDefaultPrettyPrinter()
.writeValue(new File(path), artifacts);
}
public static TextProcessor loadVocabulary(String path) throws IOException {
ObjectMapper mapper = new ObjectMapper();
Map<String, Object> artifacts = mapper.readValue(
new File(path), new TypeReference<>() {});
// 还原 TextProcessor ...
}
}
七、TextCNN 优缺点分析
| 说明 | |
|---|---|
| ✅ 优点 | |
| 结构简单 | 模型参数少,训练速度快 |
| 推理高效 | 相比 RNN/Transformer,推理延迟更低 |
| 短文本效果好 | 对于意图识别这类短文本分类任务表现优异 |
| 易于部署 | 模型文件小,Java 原生部署方便 |
| ❌ 缺点 | |
| 长距离依赖弱 | CNN 只能捕获局部特征,对长文本效果一般 |
| 语序信息丢失 | Max Pooling 丢弃了位置信息 |
| 需要标注数据 | 监督学习方式,依赖标注数据质量和数量 |
八、总结
本文完整介绍了使用 Java + DL4J + TextCNN 实现意图识别的全流程:
📌 核心流程回顾:
1. 数据准备 → 收集意图标注数据,TSV 格式存储
2. 文本预处理 → HanLP 分词 → 构建词表 → 序列编码 → Padding
3. 模型构建 → Embedding + 多尺度 Conv1D + MaxPooling + Merge + Softmax
4. 模型训练 → Adam 优化器 + 交叉熵损失 + Dropout 正则化
5. 模型评估 → Accuracy / Precision / Recall / F1
6. 推理预测 → 加载模型 → 文本编码 → 前向推理 → 输出意图+置信度
TextCNN 虽然是 2014 年提出的"经典"模型,但在 短文本分类 / 意图识别 这类场景下依然具有极强的实用性:训练快、推理快、效果好、易部署。对于 Java 技术栈的团队来说,DL4J + TextCNN 是一个非常务实的选择。
参考文献:
- Yoon Kim. Convolutional Neural Networks for Sentence Classification. EMNLP, 2014.
- DeepLearning4J 官方文档: https://deeplearning4j.konduit.ai/
- HanLP 中文分词: https://github.com/hankcs/HanLP
如果觉得这篇文章对你有帮助,欢迎点赞收藏,也欢迎在评论区交流讨论! 🚀