SKIL/工作流程/SKIL上的分布式训练

SKIL中的分布式训练

SKIL提供了skil spark命令,用于在spark 集群上对DL4J模型进行分布式训练。它几乎类似于使用带有一些附加功能的spark-submit命令,以便能够查看DL4J UI上的训练并通过给定的模型历史服务器详细信息维护模型历史。

先决条件
你需要遵循以下步骤:

  1. SKIL
  2. Spark 集群 (或者你可以在本地使用spark,并将master指定为local)

使用“skil spark”参与分布式训练的组件

1. 实现DataSetProvider接口

为了给你的网络提供数据你需要从org.deeplearning4j.spark.data.DataSetProvider接口实现。接口定义如下:

import org.apache.spark.SparkContext;
import org.apache.spark.rdd.RDD;
import org.datavec.api.transform.TransformProcess;
import org.nd4j.linalg.dataset.DataSet;

public interface DataSetProvider {
    //此函数提供训练模型所需的数据。
    RDD<DataSet> data(SparkContext var1); 
    //此函数定义并返回一个转换过程,该转换过程将保存在模型的ETL JSON中。
    TransformProcess transformProcess(); 
}
image.gif

SKIL提供此接口(io.skymind.skil.train.spark.MnistProvider)的默认实现,用于提供MNIST数据。在Java中的实现如下:

import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.rdd.RDD;
import org.datavec.api.transform.TransformProcess;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.spark.data.DataSetProvider;
import org.nd4j.linalg.dataset.DataSet;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class MnistProvider implements DataSetProvider {
    @Override
    public RDD<DataSet> data(SparkContext sparkContext) {
        try {
            MnistDataSetIterator mnist = new MnistDataSetIterator(16, 60000);
            List<DataSet> data = new ArrayList<>();

            while (mnist.hasNext()) {
                data.add(mnist.next());
            }
            return new JavaSparkContext(sparkContext).parallelize(data).rdd();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public TransformProcess transformProcess() {
        return null;
    }
}
image.gif

2. Training master

DL4J中的TrainingMaster是一个抽象(接口),允许将多个不同的训练实现与SparkDl4jMultiLayerSparkComputationGraph一起使用。
目前,DL4J有一个实现,即ParameterAveragingTrainingMaster。基本的TrainingMaster如下:

import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster

//spark训练配置: see http://deeplearning4j.org/spark for
//解释这些配置选项
//指定每个DataSet对象中有多少个示例
val tm = new ParameterAveragingTrainingMaster.Builder(dataSetObjectSize) 
    //平均和重新分布参数的频率
    .averagingFrequency(5) 
    //如何处理异步预取多个小批量。0禁用预取,较大的值在预取时使用更多内存。
    .workerPrefetchNumBatches(2) 
    //每个工作机线程的最小批处理大小:每个工作机线程中用于每个参数更新的示例数
    .batchSizePerWorker(batchSize) 
    .build();
image.gif

可以在此处找到有关TrainingMaster Builder配置的更多信息。

3. 神经网络配置

最后,要训练的神经网络配置。示例配置(在scala中)如下(对于MultiLayerNetwork

import org.deeplearning4j.nn.api.Model
import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
import org.deeplearning4j.nn.conf.inputs.InputType
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer
import org.deeplearning4j.nn.conf.layers.DenseLayer
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster
import org.deeplearning4j.util.ModelSerializer
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.learning.config.Nesterovs
import org.nd4j.linalg.lossfunctions.LossFunctions

import java.io.File;
//如上所述的训练迭代
var builder = new NeuralNetConfiguration.Builder().seed(230) 
        .l2(0.0005)
        .weightInit(WeightInit.XAVIER)
        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
        .updater(Nesterovs.builder().learningRate(0.01).momentum(0.9).build())
        .list()
        .layer(0, new ConvolutionLayer.Builder(5, 5)

                //nIn和nOut指定深度。这里是nChannels,nOut是要应用的过滤器的数量。
                .nIn(1).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build())
        .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
                .stride(2, 2).build())
        .layer(2, new ConvolutionLayer.Builder(5, 5)

                //请注意,在后面的层中不需要指定nIn
                .stride(1, 1).nOut(50).activation(Activation.IDENTITY).build())
        .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
                .stride(2, 2).build())
        .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build())
        .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10)
                .activation(Activation.SOFTMAX).build())
         //见下文注释
        .setInputType(InputType.convolutionalFlat(28, 28, 1)) 
        .backprop(true).pretrain(false);
var network = new MultiLayerNetwork(builder.build());
network.init();

//将模型写入文件
var neuralNet = new File("/tmp", "neuralnet.bin")
neuralNet.createNewFile()
ModelSerializer.writeModel(network, neuralNet, true);
image.gif

使用“skil spark”命令
使用skil spark命令时,需要使用TrainingMaster#toJson函数将TrainingMaster配置保存在JSON文件中。此外,还需要使用ModelSerializer#writeModel函数将模型配置写入文件。你可以用Zeppelin中这个样本笔记本看看是怎么做到的。

样品使用
下面的shell脚本显示了如何使用skil spark命令:

$SKIL_HOME/sbin/skil login --userId admin --password admin # You might have a different userId and password, replace them accordingly.

$SKIL_HOME/sbin/skil spark --master local[*] --trainingMasterPath "/tmp/parameteraveraging.json" --modelPath "/tmp/neuralnet.bin" --dataSetProvider MnistProvider --numEpochs 5 --supervise false --uiUrl localhost:9002
image.gif

在上面的命令中,--dataSetProvider 的值为MnistProvider,而不是完整的类名(io.skymind.skil.train.spark.MnistProvider)。这种格式也是有效的,应用程序将扫描捆绑在“uberjar”中的类,并使用与给定名称匹配的第一个类。如果要防止其他可能共享相同名称的类之间的不匹配,建议使用该类的全名。

日志

运行skil spark命令后,可以在日志文件中查看skil的日志,该文件位于路径下:/var/log/skil/skil.log。

DL4J用户界面监控训练
你还可以在--uiUrl参数指定的相同地址上监视DL4J UI上的分布式训练。如果你没有在skil上启动DL4J UI服务器,可以使用$skil_home/sbin/skil ui --uiport 9002命令启动它。下图显示了它的外观:

image.gif

显示分布式训练统计信息的用户界面
“skil spark”命令的自定义参数
下表列出了skil spark命令的参数、说明和默认值:

| 变量 | 描述 | 默认值 |
|

--master

|

参数服务器要连接到的master URL。spark://host:port, mesos://host:port, yarn, 或 local

| "local[*]" |
| --deploy-mode |

是否在本地启动驱动程序(“client”)或
在集群中的一台工作机上(“cluster”)。

| "client" |
| --class | 应用程序的主类(对于Java/Scala应用程序) | "io.skymind.skil.train.spark.SKILSparkMain" |
| --jars | 驱动程序上要包含的本地jar的逗号分隔列表
和执行器类路径 |
|
| --name | 应用名称 |
|
| --packages |

驱动程序上要包含的本地jar的maven坐标逗号分隔列表和执行器类路径。

将搜索本地maven仓库,然后中央仓库和任何由repositories提供的额外的远程仓库

坐标格式为:groupId:artifactid:version

|
|
| --exclude-packages |

groupid:artifactid的逗号分隔列表,在解析--packages中提供的依赖项以避免依赖项冲突时排除

|
|
| --properties-file |

要从中加载额外属性的文件的路径。如果未指定,将查找conf/spark-defaults.conf

|
|
| --repositories | 以逗号分隔的其他远程存储库列表,用于搜索随包提供的maven坐标 |
|
| --files | 将放在每个执行器工作目录中的文件的逗号分隔列表 |
|
| --driver-memory | 驱动器内存(例如1000M、2G)(默认值:1024M) |
|
| --driver-java-options | 传递给驱动程序的额外Java选项 |
|
| --driver-library-path | 要传递给驱动程序的额外库路径条目 |
|
| --driver-class-path | 要传递给驱动程序的额外类路径条目。注意,随--jar添的jar会自动包含在类路径中。 |
|
| --executor-memory | 每个执行器的内存(例如1000M、2G)(默认值:1G) |
|
| --proxy-user | 提交应用程序时要模拟的用户 |
|
| --driver-cores | 用于驱动程序核心 | 1 |
| --yarn-queue | 用于提交到的YARN队列 | default |
| --num-executors | 要启动的执行器数 | 2 |
| --principal | 在安全hdfs上运行时用于登录kdc的主体 |
|
| --key-tab | 包含上述主体的keytab的文件的完整路径。此keytab将通过安全的分布式缓存复制到运行应用程序主服务器的节点,用于定期更新登录票证和委派令牌。 |
|
| --supervise | 如果给定,则在出现故障时重新启动驱动程序 |
|
| --kill | 如果给定,则杀死指定的驱动程序 |
|
| --status | 如果给定,则请求指定的驱动程序的状态 |
|
| --total-executor-cores | 所有执行者的核心总数 | 1 |
| --trainingMasterPath | TrainingMaster的路径 |
|
| --modelPath | 模型路径 |
|
| --uiUrl | 界面Url |
|
| --dataSetProvider | 数据集提供者 |
|
| --jarPath | jar文件路径 |
|
| --pArgs | spark作业程序参数 | new ArrayList<String>() |
| --modelHistoryUrl | 模型历史Url | null |
| --modelHistoryId | 模型历史Id | null |
| --evalType | 评估类型,可能的值有:evaluation, evaluationbinary, roc, rocbinary, rocmulticlass, regressionevaluation | null |
| --numEpochs | 训练的轮数 | 5 |
| --evalDataSetProviderClass | 评估数据集提供者 | MnistProvider |
| --multiDataSet | 是否为多数据集 | false |
| --modelInstanceId | 模型实例ID | null |
| --doInference | 用指定的模型URI进行推理 | false |
| --outputPath | 进行推理时,保存结果的路径 | null |
| --batchSize | 用于推理的批量大小 | 16 |
| --verbose | 打印调试信息 | false |

skil spark命令从零开始创建一个“uber jar”,它捆绑了所有skil jar,每次更新或添加新插件时,它都会重新生成该uber jar。创建这个jar需要很长时间。创建jar文件后,将启动训练过程,可以使用tail -f/var/log/skil/skil.log 命令查看其输出。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,651评论 6 501
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,468评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,931评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,218评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,234评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,198评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,084评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,926评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,341评论 1 311
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,563评论 2 333
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,731评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,430评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,036评论 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,676评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,829评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,743评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,629评论 2 354

推荐阅读更多精彩内容