SparkML 实现 ALS 算法

引入依赖

<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-mllib_2.12</artifactId>
    <version>2.4.4</version>
    <exclusions>
        <exclusion>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
        </exclusion>
    </exclusions>
</dependency>
<dependency>
    <groupId>com.google.guava</groupId>
    <artifactId>guava</artifactId>
    <version>14.0.1</version>
</dependency>

数据准备

门店数据
  • 通过 dml.sql 导入了 400 条数据;
行为数据
  • 保存在文件 behavior.csv 中,总共 3 列,第一列 userId,第二列 shopId,第三列用户对这个门店的钟爱度打分;
  • behavior.csv 中大概有 2 万多条数据;

离线 ALS 召回模型的训练

离线 ALS 召回模型的训练 | 过程
  • 读行为数据 behavior.csv 到内存中;
  • 转换数据结构:JavaRDD<String> -> JavaRDD<Rating> -> Dataset<Row>;
  • 按 8-2 分,将行为数据集分成 2 份,一份训练用,一份测试用;
  • 设置 ALS 模型的参数:.setMaxIter(10).setRank(5).setRegParam(0.01)
  • 生成模型;
  • 生成模型测评器;
  • 用测试行为数据,测试生成的模型,得到 rmse 得分;
  • 生成的模型可以保存在磁盘;
模型生成的结果
  • alsmodel
    • itemFactor - 存储门店训练出来的特征值;
    • metadata
    • userFactors - 存储用户训练出来的特征值,二进制的;
离线 ALS 召回模型的训练 | 代码
package tech.lixinlei.dianping.recommand;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import java.io.IOException;
import java.io.Serializable;


/**
 * ALS 召回算法的训练
 * 实现 Serializable 是因为,Spark 的程序可以运行在不同的机器上;
 */
public class AlsRecallTrain implements Serializable {

    public static void main(String[] args) throws IOException {

        //初始化spark运行环境
        SparkSession spark = SparkSession.builder().master("local").appName("DianpingApp").getOrCreate();

        JavaRDD<String> csvFile = spark.read().textFile("file:///home/lixinlei/project/gitee/dianping/src/main/resources/behavior.csv").toJavaRDD();

        JavaRDD<Rating> ratingJavaRDD = csvFile.map(new Function<String, Rating>() {
            /**
             * 将 behavior.csv 中的一行,从 String 转成 Rating;
             * @param v1 behavior.csv 中数据的一行
             * @return
             * @throws Exception
             */
            @Override
            public Rating call(String v1) throws Exception {
                return Rating.parseRating(v1);
            }
        });

        // Dataset 可以理解为 MySQL 中的一张表,row 中 column 的定义遵从 Rating 的定义;
        Dataset<Row> rating = spark.createDataFrame(ratingJavaRDD, Rating.class);

        // 将所有的 rating 数据分成 8-2 分,80% 的数据用来做训练,20% 的训练用来做测试
        Dataset<Row>[] splits = rating.randomSplit(new double[]{0.8, 0.2});
        Dataset<Row> trainingData = splits[0];
        Dataset<Row> testingData = splits[1];

        // .setMaxIter(10) 设置最大拟合次数
        // .setRank(5) 分解矩阵后 feature 的数量
        // .setRegParam(0.01) 正则化系数,增大正则化的值,可以防止过拟合的情况
        // 过拟合:指得是模型训练出来的内容,过分的逼近真实数据,导致一旦真实数据出现一些误差,预测的结果反而不尽如人意;
        // 欠拟合:模型训练出来的内容,没有达到收敛于真是数据,使得预测结果的偏差距离真实结果太大;
        // 过拟合的解决方案:1)增大数据规模 2)减少 RANK,即特征的数量,使得模型预测的能力更加松散 3)增大正则化的系数
        // 欠拟合的解决方案:1)增加 RANK 2)减少正则化系数
        ALS als = new ALS().setMaxIter(10).setRank(5).setRegParam(0.01).
                setUserCol("userId").setItemCol("shopId").setRatingCol("rating");

        // 模型训练
        ALSModel alsModel = als.fit(trainingData);

        // 模型评测:测评的时候,用到了 testingData 中的 userId 和 shopId 字段的值,没有用 rating 字段的值,而且计算出了一个新字段,叫 prediction
        Dataset<Row> predictions = alsModel.transform(testingData);

        // rmse 均方根误差,预测值与真实值的偏差的平方除以观测次数(testingData 的条数),开个根号
        // rmse 的值越小,标识模型在测试数据集上的表现越好;
        RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse")
                .setLabelCol("rating").setPredictionCol("prediction");
        double rmse = evaluator.evaluate(predictions);
        System.out.println("rmse = " + rmse);

        alsModel.save("file:///home/lixinlei/project/gitee/dianping/src/main/resources/alsmodel");
    }

    /**
     * 自定义数据结构,用来承接 behavior.csv 中的一行数据;
     */
    public static class Rating implements Serializable{

        private int userId;
        private int shopId;
        private int rating;

        /**
         * 将 hebavior.csv 中的一行数据,组装成 Rating 对象返回;
         * @param str behavior.csv 文件的一行输入
         * @return
         */
        public static Rating parseRating(String str){
            str = str.replace("\"","");
            String[] strArr = str.split(",");
            int userId = Integer.parseInt(strArr[0]);
            int shopId = Integer.parseInt(strArr[1]);
            int rating = Integer.parseInt(strArr[2]);
            return new Rating(userId,shopId,rating);
        }

        public Rating(int userId, int shopId, int rating) {
            this.userId = userId;
            this.shopId = shopId;
            this.rating = rating;
        }

        public int getUserId() {
            return userId;
        }

        public int getShopId() {
            return shopId;
        }

        public int getRating() {
            return rating;
        }
    }

}

使用离线 ALS 召回模型为活跃的 5 个用户召回(粗排)门店信息

召回 | 步骤
  • 先加载训练出的离线模型 ALSModel;
  • 再加载行为数据 behavior.csv;
  • 再选 5 个用户做预测;
  • 解析预测结果存入数据库;
召回 | 代码实现
package tech.lixinlei.dianping.recommand;


import org.apache.commons.lang3.StringUtils;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.ForeachPartitionFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;

import java.io.Serializable;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.*;

/**
 * 加载生成的模型,预测比较活跃的用户,并且生成离线数据的候选集;
 */
public class AlsRecallPredict {

    public static void main(String[] args) {

        // 初始化spark运行环境
        SparkSession spark = SparkSession.builder().master("local").appName("DianpingApp").getOrCreate();

        // 加载模型进内存
        ALSModel alsModel = ALSModel.load("file:///home/lixinlei/project/gitee/dianping/src/main/resources/alsmodel");

        JavaRDD<String> csvFile = spark.read().textFile("file:///home/lixinlei/project/gitee/dianping/src/main/resources/behavior.csv").toJavaRDD();
        JavaRDD<Rating> ratingJavaRDD = csvFile.map(new Function<String, Rating>() {
            @Override
            public Rating call(String v1) throws Exception {
                return Rating.parseRating(v1);
            }
        });
        Dataset<Row> rating = spark.createDataFrame(ratingJavaRDD, Rating.class);

        // 给 5 个用户做离线的召回结果预测
        Dataset<Row> users = rating.select(alsModel.getUserCol()).distinct().limit(5);
        // userRecs 就是预测的结果
        Dataset<Row> userRecs = alsModel.recommendForUserSubset(users,20);

        userRecs.foreachPartition(new ForeachPartitionFunction<Row>() {
            @Override
            public void call(Iterator<Row> t) throws Exception {
                Connection connection = DriverManager.
                        getConnection("jdbc:mysql://127.0.0.1:3306/dianping?" +
                                "user=root&password=Jiangdi_2018&useUnicode=true&characterEncoding=UTF-8");
                PreparedStatement preparedStatement = connection.
                        prepareStatement("insert into recommend(id, recommend) values (?, ?)");

                List<Map<String,Object>> data =  new ArrayList<Map<String, Object>>();
                t.forEachRemaining(action -> {
                    int userId = action.getInt(0);
                    List<GenericRowWithSchema> recommendationList = action.getList(1);
                    List<Integer> shopIdList = new ArrayList<Integer>();
                    recommendationList.forEach(row->{
                        Integer shopId = row.getInt(0);
                        shopIdList.add(shopId);
                    });
                    String recommendData = StringUtils.join(shopIdList,",");
                    Map<String,Object> map = new HashMap<String, Object>();
                    map.put("userId",userId);
                    map.put("recommend",recommendData);
                    data.add(map);
                });

                data.forEach(stringObjectMap -> {
                    try {
                        preparedStatement.setInt(1, (Integer) stringObjectMap.get("userId"));
                        preparedStatement.setString(2, (String) stringObjectMap.get("recommend"));

                        preparedStatement.addBatch();
                    } catch (SQLException e) {
                        e.printStackTrace();
                    }

                });
                preparedStatement.executeBatch();
                connection.close();
            }
        });

    }

    public static class Rating implements Serializable {

        private int userId;
        private int shopId;
        private int rating;

        public static Rating parseRating(String str){
            str = str.replace("\"","");
            String[] strArr = str.split(",");
            int userId = Integer.parseInt(strArr[0]);
            int shopId = Integer.parseInt(strArr[1]);
            int rating = Integer.parseInt(strArr[2]);

            return new Rating(userId,shopId,rating);
        }

        public Rating(int userId, int shopId, int rating) {
            this.userId = userId;
            this.shopId = shopId;
            this.rating = rating;
        }

        public int getUserId() {
            return userId;
        }

        public int getShopId() {
            return shopId;
        }

        public int getRating() {
            return rating;
        }

    }

}
召回的结果

SELECT * FROM dianping.recommend;

# id, recommend
'148', '400,216,145,131,421,464,128,257,332,479,283,248,447,138,494,292,228,186,231,378'
'463', '202,323,479,420,255,154,484,318,405,135,206,345,324,382,262,199,123,494,201,388'
'471', '216,479,127,191,464,172,202,125,389,494,411,303,455,226,249,369,291,105,211,434'
'1088', '324,465,402,135,294,199,163,203,255,185,147,323,130,430,388,313,112,145,219,481'
'1238', '268,438,130,383,313,324,465,203,180,148,222,353,252,402,481,368,142,428,448,198'
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 217,734评论 6 505
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,931评论 3 394
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,133评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,532评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,585评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,462评论 1 302
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,262评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,153评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,587评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,792评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,919评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,635评论 5 345
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,237评论 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,855评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,983评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,048评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,864评论 2 354