Deeplearning4j图片分类——手写数字识别【原创】

本节课我们讲解手写图片分类识别

训练数据


public DataSetIteratorimageDateSet(String dataLocalPath,int seed,int width,int height,int channels)throws IOException {

String [] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS;

    Random randNumGen =new Random(seed);

    File parentDir=new File(dataLocalPath);

    FileSplit filesInDir =new FileSplit(parentDir, allowedExtensions, randNumGen);

    ParentPathLabelGenerator labelMaker =new ParentPathLabelGenerator();

    BalancedPathFilter pathFilter =new BalancedPathFilter(randNumGen, allowedExtensions, labelMaker);

    InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter, 100);

    System.out.println("---------"+filesInDirSplit.length);

    InputSplit trainData = filesInDirSplit[0];

    ImageRecordReader recordReader =new ImageRecordReader(height,width,channels,labelMaker);

    ImageTransform transform =new MultiImageTransform(randNumGen,new ShowImageTransform("Display - before "));

    recordReader.initialize(trainData,transform);

    int outputNum = recordReader.numLabels();

    int batchSize =10;

    int labelIndex =1;

    DataSetIterator dataIter =new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, outputNum);

    return dataIter;

}

模型搭建


public MultiLayerNetworkmodel(){

try{

MultiLayerConfiguration.Builder builder =new NeuralNetConfiguration.Builder()

.seed(12345)

.weightInit(WeightInit.XAVIER)

.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)

.updater(Updater.ADAM)

.list()

.layer(0, new ConvolutionLayer.Builder(5, 5)

.nIn(1)

.stride(1, 1)

.nOut(32)

.activation(Activation.LEAKYRELU)

.build())

.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)

.kernelSize(2,2)

.stride(2,2)

.build())

.layer(2, new ConvolutionLayer.Builder(5, 5)

.stride(1, 1)

.nOut(64)

.activation(Activation.LEAKYRELU)

.build())

.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)

.kernelSize(2,2)

.stride(2,2)

.build())

.layer(4, new DenseLayer.Builder().activation(Activation.LEAKYRELU)

.nOut(500).build())

.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)

.nOut(10)

.activation(Activation.SOFTMAX)

.build())

.setInputType(InputType.convolutionalFlat(28, 28, 1));

        MultiLayerConfiguration conf = builder.build();

        MultiLayerNetwork model =new MultiLayerNetwork(conf);

        return model;

    }catch (Exception e){

e.printStackTrace();

    }

return null;

}

开始训练


public static Booleantrain(MultiLayerNetwork mlp,DataSetIterator trainIter,DataSetIterator testIter){

for(int i =0; i <1; ++i ){

mlp.fit(trainIter);    //训练模型

        Evaluation trainEval = mlp.evaluate(trainIter);    //在验证集上进行准确性测试

        Evaluation testEval = mlp.evaluate(testIter);

        trainIter.reset();

        testIter.reset();

    }

return Boolean.TRUE;

}

监督学习并保存


public static void main(String arg[])throws IOException {

DataSetIterator trainData =lesson4.imageDateSet(basePath+"training/",12345,28,28,1);

    DataSetIterator testData =lesson4.imageDateSet(basePath+"testing/",12345,28,28,1);

    MultiLayerNetwork model =lesson4.model();

    File stateFile =new File(basePath+"state");

    stateFile.createNewFile();

    UIServer uiServer = UIServer.getInstance();

    StatsStorage statsStorage =new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "ui-stats.dl4j"));

    int listenerFrequency =1;

    model.setListeners(new StatsListener(statsStorage, listenerFrequency));

    uiServer.attach(statsStorage);

    train(model,trainData,testData);

    ModelSerializer.writeModel(model, new File(basePath +"mlp.mod"), true);

}

关于训练数据的下载可见第二节课所讲内容。

下一节课讲解文本分类

本人诚接各类商业AI模型训练工作,如果您是一家公司,想借助AI解决当前服务问题,可以联系我。微信号:CompanyAiHelper

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。