本节课我们讲解一下文本分类,在文本分类中我们需要将文本预处理。
分词
public static void data(String source,String save) throws IOException {
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader( new FileInputStream(new File(source)), "UTF-8"));
File saveFile = new File(save);
if(!saveFile.exists()){
saveFile.createNewFile();
}
OutputStreamWriter writerStream = new OutputStreamWriter( new FileOutputStream(saveFile), "UTF-8");
BufferedWriter writer = new BufferedWriter(writerStream);
String line = null;
long startTime = System.currentTimeMillis();
while ((line = bufferedReader.readLine()) != null) {
String[] array = line.split("_!_");
StringBuilder stringBuilder = new StringBuilder();
for (Term term : HanLP.segment(array[3])) {
if (stringBuilder.length() > 0) {
stringBuilder.append(" ");
}
stringBuilder.append(term.word.trim());
}
writer.write(Integer.parseInt(array[1].trim()) + "_!_" + stringBuilder.toString() + "\n");
}
writer.flush();
writer.close();
System.out.println(System.currentTimeMillis() - startTime);
bufferedReader.close();
}
使用分词工具将文本进行分词处理
分本向量处理工具
public static void dataSet(String filePath,String savePath) throws FileNotFoundException {
SentenceIterator iter = new BasicLineIterator(filePath);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
VocabCache<VocabWord> cache = new AbstractCache<>();
WeightLookupTable<VocabWord> table = new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100)
.useAdaGrad(false).cache(cache).build();
Word2Vec vec = new Word2Vec.Builder()
.elementsLearningAlgorithm("org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram")
.minWordFrequency(0).iterations(1).epochs(20).layerSize(100).seed(42).windowSize(8).iterate(iter)
.tokenizerFactory(t).lookupTable(table).vocabCache(cache).build();
vec.fit();
WordVectorSerializer.writeWord2VecModel(vec, savePath);
}
构建训练数据
private static HashMap<String,DataSetIterator> dataSet() throws IOException {
List<String> trainLabelList = new ArrayList<>();// 训练集label
List<String> trainSentences = new ArrayList<>();// 训练集文本集合
List<String> testLabelList = new ArrayList<>();// 测试集label
List<String> testSentences = new ArrayList<>();//// 测试集文本集合
Map<String, List<String>> map = new HashMap<>();
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(
new FileInputStream(new File(basePath+"toutiao_data_type_word.txt")), "UTF-8"));
String line = null;
int truncateReviewsToLength = 0;
Random random = new Random(123);
while ((line = bufferedReader.readLine()) != null) {
String[] array = line.split("_!_");
if (map.get(array[0]) == null) {
map.put(array[0], new ArrayList<String>());
}
map.get(array[0]).add(array[1]);// 将样本中所有数据,按照类别归类
int length = array[1].split(" ").length;
if (length > truncateReviewsToLength) {
truncateReviewsToLength = length;// 求样本中,句子的最大长度
}
}
bufferedReader.close();
for (Map.Entry<String, List<String>> entry : map.entrySet()) {
for (String sentence : entry.getValue()) {
if (random.nextInt() % 5 == 0) {// 每个类别抽取20%作为test集
testLabelList.add(entry.getKey());
testSentences.add(sentence);
} else {
trainLabelList.add(entry.getKey());
trainSentences.add(sentence);
}
}
}
int batchSize = 64;
Random rng = new Random(12345);
Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel(basePath+"toutiao_cat_data_dataset.txt");
System.out.println("Loading word vectors and creating DataSetIterators");
DataSetIterator trainIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, trainLabelList, trainSentences, rng);
DataSetIterator testIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, testLabelList, testSentences, rng);
HashMap<String,DataSetIterator> data = new HashMap<>();
data.put("trainIter",trainIter);
data.put("testIter",testIter);
return data;
}
private static DataSetIterator getDataSetIterator(WordVectors wordVectors, int minibatchSize, int maxSentenceLength,
List<String> lableList, List<String> sentences, Random rng) {
LabeledSentenceProvider sentenceProvider = new CollectionLabeledSentenceProvider(sentences, lableList, rng);
return new CnnSentenceDataSetIterator.Builder().sentenceProvider(sentenceProvider).wordVectors(wordVectors)
.minibatchSize(minibatchSize).maxSentenceLength(maxSentenceLength).useNormalizedWordVectors(false)
.build();
}
模型搭建
public static ComputationGraph model(int truncateReviewsToLength){
int vectorSize = 100;
int cnnLayerFeatureMaps = 50;
PoolingType globalPoolingType = PoolingType.MAX;
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().weightInit(WeightInit.RELU)
.activation(Activation.LEAKYRELU).updater(new Nesterovs(0.01, 0.9))
.convolutionMode(ConvolutionMode.Same).l2(0.0001).graphBuilder().addInputs("input")
.addLayer("cnn3",
new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn4",
new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn5",
new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn6",
new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn3-stride2",
new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn4-stride2",
new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn5-stride2",
new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn6-stride2",
new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addVertex("merge1", new MergeVertex(), "cnn3", "cnn4", "cnn5", "cnn6")
.addLayer("globalPool1", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
"merge1")
.addVertex("merge2", new MergeVertex(), "cnn3-stride2", "cnn4-stride2", "cnn5-stride2", "cnn6-stride2")
.addLayer("globalPool2", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
"merge2")
.addLayer("fc",
new DenseLayer.Builder().nOut(200).dropOut(0.5).activation(Activation.LEAKYRELU).build(),
"globalPool1", "globalPool2")
.addLayer("out",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(15).build(),
"fc")
.setOutputs("out").setInputTypes(InputType.convolutional(truncateReviewsToLength, vectorSize, 1))
.build();
ComputationGraph net = new ComputationGraph(config);
return net;
}
训练模型
private static void train(ComputationGraph model,DataSetIterator trainIter,DataSetIterator testIter) throws IOException {
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);
model.setListeners(new ScoreIterationListener(100), new StatsListener(statsStorage, 20),
new EvaluativeListener(testIter, 1, InvocationType.EPOCH_END));
model.fit(trainIter, 10);
ModelSerializer.writeModel(model, new File(basePath + "mlp.mod"), true);
}
public static void main(String arg[]) throws IOException {
//data(basePath+"toutiao_cat_data.txt",basePath+"toutiao_data_type_word.txt");
//dataSet(basePath+"toutiao_cat_data.txt",basePath+"toutiao_cat_data_dataset.txt");
ComputationGraph model = model(100);
model.init();
HashMap<String,DataSetIterator> data = dataSet();
train(model,data.get("trainIter"),data.get("testIter"));
}
下一节课讲解图片目标检测
本人诚接各类商业AI模型训练工作,如果您是一家公司,想借助AI解决当前服务问题,可以联系我。微信号:CompanyAiHelper