如何借用Mapperduce框架进行xgboost分布式预测(java版)

背景:能分布式地预测数据(当然spark-scala框架本身就可以做到)。本文主要目的是通过一个项目,弄明白MR的执行原理

实现步骤:

使用MapReduce进行预测,目前实现3分钟内完成112w个样本,16维特征的数据预测,具体实现思路如下:

1、mapreduce主入口类 main 函数中传入模型所在hdfs的路径及数据输入输出的hdfs路径

2、Mapper类中重写Mapper里的setup()、map()、cleanup()三个方法。

1)setup(Context context)方法获取context调用ml.dmlc.xgboost4j.java.XGBoost.loadModel将训练好的模型load完成

2)map()里解析数据,封装Dmatrix并进行数据的预测,具体如下:

  - 首先将读取的一行行数据封装成 Dmatrix,达到阈值(暂时定为6000)时执行预测,并将预测值写入hdfs路径

  - cleanup里实现最后的清理收尾预测工作

接下来分别讲解每个部分的实现。

mapreduce主入口类 main 函数

//設置conf

Configuration conf = new Configuration();

conf.set("hadoop.tmp.dir",args[0]);  //用于解决自动运行时目录权限问题,可以将此目录指定到一个有权限的目录 

例如 /tmp suffle过程会有数据落到本地磁盘,这里的路径必须有权限

很重要,因为默认的路径可能不具备访问的权限。

conf.set("mapreduce.framework.name","yarn");

设置运行的模式是yarn模式还是local模式

conf.set("mapreduce.map.cpu.vcores","8");//指定这个mapreduce任务运行时cpu的个数

根据数据量来设定合适的,因为集群上默认的map的cpu的核心数是1,未设置之前,任务一度出现map 0% reduce 0%保持不动。

conf.set("mapreduce.map.memory.mb","8296");//一个 Map Task 可使用的内存上限(单位:MB),默认为 1024

要想将模型的路径当成参数变量,传给每个map算子。则要考虑如何将模型的hdfs地址广播出去。

conf.set("xgboost.model", modelPath);

接下来就是设置Job提交的一些配置

        // 设置执行jar名

        Job = Job.getInstance(conf);

       job.setJarByClass(XGBoost.class);

      // 设置文件读取、输出的路径

        FileInputFormat.setInputPaths(job, new Path(inputFile));

        Path outputFile = new Path(ouptFile1);

        FileOutputFormat.setOutputPath(job, outputFile);

       // 设置mapper的类

        job.setMapperClass(XGBoostMapper.class);

       job.setMapOutputKeyClass(Text.class);//map的输出key值

        job.setMapOutputValueClass(IntWritable.class);//map的输出的value值

        //设置InputFormat类 设置 OutputFormat 类

        job.setInputFormatClass(TextInputFormat.class);

       job.setOutputFormatClass(TextOutputFormat.class);

       job.setNumReduceTasks(0);

       //因为对于本的的例子而言,没有reduce阶段,则将reduce的个数设置为0

        // 设置reduce输出的key value 类型

        job.setOutputKeyClass(Text.class);

        job.setOutputValueClass(IntWritable.class);

       // 提交job

        job.submit();

       // 等待执行完成

        boolean noErr = job.waitForCompletion(true);

       System.exit(noErr? 0 : 1);

      XGBoostMapper extends  Mapper<LongWritable, Text, Text, Text>每部分的实现

    public class XGBoostMapper extends Mapper<LongWritable, Text, Text, Text>{

    private  Booster;//模型变量

    private Text oKey = new Text();//write时的key值

    private Text oValue=new Text();//write时的value值

    private final static int ROUND_NUM = 6000;//每达到ROUND_NUM 开始进行预测

    private List<String> acct = new ArrayList<>(ROUND_NUM );//用来缓存待预测数据的pin值

    private List<String[]> preData = new ArrayList<>(ROUND_NUM );//用来缓存待预测数据的value值


@Override protected void setup(Mapper.Context context) throws IOException {

     Configuration conf = context.getConfiguration();//目的是拿到模型的存储地址和待预测数据的地址

      String modelPath = conf.get("xgboost.model");//从conf中拿到模型的地址

      FileSystem fs = FileSystem.get(conf);

      FSDataInputStream open = fs.open(new Path(modelPath));

     //如果直接.loadModel(modelPath )则会将该地址解析为本地路径,就会报错,not file://。因为java版的xgboost没有读取hdfs的API,所以需要借助inputStream

      try {

                  booster = ml.dmlc.xgboost4j.java.XGBoost.loadModel(open);

             } catch (XGBoostError xgBoostError) {

                    xgBoostError.printStackTrace();

               }

          }


@Override protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {

//读取每一行的数据,并将其转换成array

        String[] line = value.toString().split("\001");//输入数据的文件是以\001间隔的,注意检查输入数据的分隔符。即待预测数据的hive的数据存储时的间隔符

        try {

predictUntilRoundNum(context, line);

//当输入数据达到某个值的时候开始预测

} catch (XGBoostError xgBoostError) {

xgBoostError.printStackTrace();

System.out.println("执行失败");

}

    }

private void predictUntilRoundNum(Context context, String[] line) throws IOException, InterruptedException, XGBoostError {

acct.add(line[0]); //存放账号,hive表的第一列是用户的pin

preData.add(Arrays.copyOfRange(line,3,line.length));//所选用的特征列从第3列开始

//如果达到ROUND_NUM行数据

 if (acct.size() == ROUND_NUM){

predictNow(context, acct,preData);

//清空数据

            acct = new ArrayList<>();//清空存储账号的list

            preData = new ArrayList<>();//清空要组装成Dmatrix的list

        }

    }

private void predictNow(Context context, List<String> acct, List<String[]> preData) throws InterruptedException, XGBoostError, IOException {

if (!preData.isEmpty()) {

DMatrix dMatrix = buildDMatrix(preData);

//开始预测

 float[][] predict = booster.predict(dMatrix);

System.out.println("rowPredixt=" + predict.length + " colPredict=" + predict[0].length);

//write data to hdfs use context.write

for (int i = 0; i < acct.size(); i++) {

oKey.set(acct.get(i));//用户名作为输出key

oValue.set(String.valueOf(predict[i][0]));

context.write(oKey, oValue);

}

    }

}

/**多维数组转成dmatrix数据

     *@param data:输入的数据,是一个二维数据

     *@return

:返回的是一个Dmatrix的数据

     */

    private DMatrix buildDMatrix(List<String[]> data) throws XGBoostError {

//        System.out.println("数组的row:"+data.size()+"数组的列:"+data.get(0).length);

        int num = 0;

int col = data.get(0).length;

int row = data.size();//行数,也就是Dmatrix的行数

float[] resData = new float[row*col];

for (String[] str: data){

for (String aStr : str) {

resData[num] = Float.valueOf(aStr);

num++;

}

        }

return new DMatrix(resData, row, col);

//其中Dmatrix的构造方法的第二列表示样本数,col表示feature的个数    }

其中为什么设计了读到多少行后才开始进行Dmatrix的封装和预测。因为在封装Dmatrix和预测的过程中本身就很耗时。但是不全部进行预测的是因为当数据量太大的时候,内存可能不够。

@Override

protected void cleanup(Context context)throws IOException, InterruptedException {

if (!acct.isEmpty()) {

try {

predictNow(context, acct, preData);

        }catch (XGBoostError xgBoostError) {

xgBoostError.printStackTrace();

        }

System.out.println("执行失败");

    }

}

该步是为了计算最后的数据。因为不可能保证数据的行数是ROUND_NUM的倍数,在map执行完后,list中还会有一些数据未被预测,所以需要在最后进行最后数据的预测工作。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 217,907评论 6 506
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,987评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,298评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,586评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,633评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,488评论 1 302
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,275评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,176评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,619评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,819评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,932评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,655评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,265评论 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,871评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,994评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,095评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,884评论 2 354

推荐阅读更多精彩内容