Deeplearning4j文本分类——今日头条【原创】

本节课我们讲解一下文本分类,在文本分类中我们需要将文本预处理。

分词

public static void data(String source,String save) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(  new FileInputStream(new File(source)), "UTF-8"));
        File saveFile = new File(save);
        if(!saveFile.exists()){
            saveFile.createNewFile();
        }
        OutputStreamWriter writerStream = new OutputStreamWriter( new FileOutputStream(saveFile), "UTF-8");
        BufferedWriter writer = new BufferedWriter(writerStream);
        String line = null;
        long startTime = System.currentTimeMillis();
        while ((line = bufferedReader.readLine()) != null) {
            String[] array = line.split("_!_");
            StringBuilder stringBuilder = new StringBuilder();
            for (Term term : HanLP.segment(array[3])) {
                if (stringBuilder.length() > 0) {
                    stringBuilder.append(" ");
                }
                stringBuilder.append(term.word.trim());
            }
            writer.write(Integer.parseInt(array[1].trim()) + "_!_" + stringBuilder.toString() + "\n");
        }
        writer.flush();
        writer.close();
        System.out.println(System.currentTimeMillis() - startTime);
        bufferedReader.close();
    }

使用分词工具将文本进行分词处理

分本向量处理工具

public static void dataSet(String filePath,String savePath) throws FileNotFoundException {
        SentenceIterator iter = new BasicLineIterator(filePath);
        TokenizerFactory t = new DefaultTokenizerFactory();
        t.setTokenPreProcessor(new CommonPreprocessor());
        VocabCache<VocabWord> cache = new AbstractCache<>();
        WeightLookupTable<VocabWord> table = new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100)
                .useAdaGrad(false).cache(cache).build();

        Word2Vec vec = new Word2Vec.Builder()
                .elementsLearningAlgorithm("org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram")
                .minWordFrequency(0).iterations(1).epochs(20).layerSize(100).seed(42).windowSize(8).iterate(iter)
                .tokenizerFactory(t).lookupTable(table).vocabCache(cache).build();

        vec.fit();
        WordVectorSerializer.writeWord2VecModel(vec, savePath);
    }

构建训练数据

private static HashMap<String,DataSetIterator> dataSet() throws IOException {
        List<String> trainLabelList = new ArrayList<>();// 训练集label
        List<String> trainSentences = new ArrayList<>();// 训练集文本集合
        List<String> testLabelList = new ArrayList<>();// 测试集label
        List<String> testSentences = new ArrayList<>();//// 测试集文本集合
        Map<String, List<String>> map = new HashMap<>();


        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(
                new FileInputStream(new File(basePath+"toutiao_data_type_word.txt")), "UTF-8"));
        String line = null;
        int truncateReviewsToLength = 0;
        Random random = new Random(123);
        while ((line = bufferedReader.readLine()) != null) {
            String[] array = line.split("_!_");
            if (map.get(array[0]) == null) {
                map.put(array[0], new ArrayList<String>());
            }
            map.get(array[0]).add(array[1]);// 将样本中所有数据,按照类别归类
            int length = array[1].split(" ").length;
            if (length > truncateReviewsToLength) {
                truncateReviewsToLength = length;// 求样本中,句子的最大长度
            }
        }
        bufferedReader.close();
        for (Map.Entry<String, List<String>> entry : map.entrySet()) {
            for (String sentence : entry.getValue()) {
                if (random.nextInt() % 5 == 0) {// 每个类别抽取20%作为test集
                    testLabelList.add(entry.getKey());
                    testSentences.add(sentence);
                } else {
                    trainLabelList.add(entry.getKey());
                    trainSentences.add(sentence);
                }
            }

        }
        int batchSize = 64;
        Random rng = new Random(12345);
        Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel(basePath+"toutiao_cat_data_dataset.txt");
        System.out.println("Loading word vectors and creating DataSetIterators");
        DataSetIterator trainIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, trainLabelList, trainSentences, rng);
        DataSetIterator testIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, testLabelList, testSentences, rng);
        HashMap<String,DataSetIterator> data = new HashMap<>();
        data.put("trainIter",trainIter);
        data.put("testIter",testIter);
        return data;

    }
    private static DataSetIterator getDataSetIterator(WordVectors wordVectors, int minibatchSize, int maxSentenceLength,
                                                      List<String> lableList, List<String> sentences, Random rng) {

        LabeledSentenceProvider sentenceProvider = new CollectionLabeledSentenceProvider(sentences, lableList, rng);

        return new CnnSentenceDataSetIterator.Builder().sentenceProvider(sentenceProvider).wordVectors(wordVectors)
                .minibatchSize(minibatchSize).maxSentenceLength(maxSentenceLength).useNormalizedWordVectors(false)
                .build();
    }

模型搭建

public static ComputationGraph model(int truncateReviewsToLength){
        int vectorSize = 100;
        int cnnLayerFeatureMaps = 50;
        PoolingType globalPoolingType = PoolingType.MAX;
        ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().weightInit(WeightInit.RELU)
                .activation(Activation.LEAKYRELU).updater(new Nesterovs(0.01, 0.9))
                .convolutionMode(ConvolutionMode.Same).l2(0.0001).graphBuilder().addInputs("input")
                .addLayer("cnn3",
                        new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(1, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn4",
                        new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(1, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn5",
                        new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(1, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn6",
                        new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(1, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn3-stride2",
                        new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(2, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn4-stride2",
                        new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(2, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn5-stride2",
                        new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(2, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn6-stride2",
                        new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(2, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addVertex("merge1", new MergeVertex(), "cnn3", "cnn4", "cnn5", "cnn6")
                .addLayer("globalPool1", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
                        "merge1")
                .addVertex("merge2", new MergeVertex(), "cnn3-stride2", "cnn4-stride2", "cnn5-stride2", "cnn6-stride2")
                .addLayer("globalPool2", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
                        "merge2")
                .addLayer("fc",
                        new DenseLayer.Builder().nOut(200).dropOut(0.5).activation(Activation.LEAKYRELU).build(),
                        "globalPool1", "globalPool2")
                .addLayer("out",
                        new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
                                .activation(Activation.SOFTMAX).nOut(15).build(),
                        "fc")
                .setOutputs("out").setInputTypes(InputType.convolutional(truncateReviewsToLength, vectorSize, 1))
                .build();

        ComputationGraph net = new ComputationGraph(config);
        return net;
    }

训练模型

 private static void train(ComputationGraph model,DataSetIterator trainIter,DataSetIterator testIter) throws IOException {
        UIServer uiServer = UIServer.getInstance();
        StatsStorage statsStorage = new InMemoryStatsStorage();
        uiServer.attach(statsStorage);
        model.setListeners(new ScoreIterationListener(100), new StatsListener(statsStorage, 20),
                new EvaluativeListener(testIter, 1, InvocationType.EPOCH_END));
        model.fit(trainIter, 10);
        ModelSerializer.writeModel(model, new File(basePath + "mlp.mod"), true);
    }
public static void main(String arg[]) throws IOException {
        //data(basePath+"toutiao_cat_data.txt",basePath+"toutiao_data_type_word.txt");
        //dataSet(basePath+"toutiao_cat_data.txt",basePath+"toutiao_cat_data_dataset.txt");

        ComputationGraph model = model(100);
        model.init();
        HashMap<String,DataSetIterator> data = dataSet();
        train(model,data.get("trainIter"),data.get("testIter"));
    }

下一节课讲解图片目标检测

本人诚接各类商业AI模型训练工作,如果您是一家公司,想借助AI解决当前服务问题,可以联系我。微信号:CompanyAiHelper

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容