上一节课我们讲了训练数据的准备,本节课我们讲AI学习的基本结构。鸢尾花分类是一个典型的数据分类模型,训练数据的下载详见第二节课程,PS:我们需要对原始数据中的英文替换成数字。
训练数据的准备
public HashMap<String, DataSetIterator> trainData(){
try{
int batchSize = 50;
long seed = 12345L;
String tainFilePath = basePath + "bezdekIris.csv";
RecordReader reader1 = new CSVRecordReader();
FileSplit split = new FileSplit(new File(tainFilePath));
reader1.initialize(split);
DataSetIterator trainData = new RecordReaderDataSetIterator(reader1,batchSize,4,labelNum);
DataSet data = trainData.next();
data.shuffle(seed);
SplitTestAndTrain testAndTrain = data.splitTestAndTrain(0.75);
DataSetIterator trainIter = new ListDataSetIterator(testAndTrain.getTrain().asList() , batchSize);
DataSetIterator testIter = new ListDataSetIterator(testAndTrain.getTest().asList() , batchSize);
System.out.println(trainData.next().asList().size());
System.out.println(testIter.next().asList().size());
HashMap<String, DataSetIterator> dataMap = new HashMap<>();
dataMap.put("trainData",trainIter);
dataMap.put("testData",testIter);
return dataMap;
}catch (Exception e){
e.printStackTrace();
}
return null;
}
说明:我们将加载的CSV数据分为训练集与校验集两部分。
模型搭建
public MultiLayerNetwork model(){
double learningRate = 1e-3;
double lrMomentum = 0.9;
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Nesterovs.Builder().learningRate(learningRate).momentum(lrMomentum).build())
.list()
.layer(0, new DenseLayer.Builder().activation(Activation.LEAKYRELU)
.nIn(4).nOut(2).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(2).nOut(3).build());
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
return model;
}
说明:MultiLayerConfiguration多层配置器来设置模型超参数,以上逻辑定义了一个全连接神经网络。此外,我们还定义了很多超参数,如学习率、参数的初始分布、优化算法及优化器、激活函数。这些超参数对最后网络参数的收敛有直接影响,具体在后续的课程中会有详细讨论,这里不再赘述。
模型训练,保存与监控
public static void main(String arg[]) throws IOException, InterruptedException {
Lesson3 lesson3 = new Lesson3();
HashMap<String, DataSetIterator> tranData = lesson3.trainData();
MultiLayerNetwork model = lesson3.model();
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "ui-stats.dl4j"));
model.setListeners(new ScoreIterationListener());
uiServer.attach(statsStorage);
for( int i = 0; i < 20; ++i ){
model.fit(tranData.get("trainData")); //训练模型
tranData.get("trainData").reset();
Evaluation eval = model.evaluate(tranData.get("testData")); //在验证集上进行准确性测试
System.out.println(eval.stats());
tranData.get("testData").reset();
}
ModelSerializer.writeModel(model, new File(basePath + "mlp.mod"), true);
}
说明:使用UIServer监控模型的学习进度,for循环来训练模型20次,使用testData验证模型准确率。最后借助ModelSerializer将训练信息保存到本地硬盘。
训练
打开:http://localhost:9000进入学习监督界面,如下图所示:
最后会在目录下生成一个mlp.mod就是我们训练出来的模型。 下一节我们讲解Word2Vec,向量计算工具。
本人诚接各类商业AI模型训练工作,如果您是一家公司,想借助AI解决当前服务问题,可以联系我。微信号:CompanyAiHelper