Spark MLlib的协同过滤

Spark的MLlib实现了协同过滤(Collaborative Filtering)这个功能。官网文档链接

熟悉推荐算法的同学可能也有这个认识:协同过滤主要分为3大类——1、基于User的协同过滤;2、基于Item的协同过滤;3、基于Model的协同过滤。前面两个比较简单不多描述了,主要讲下基于Model的协同过滤。在网上找到一个对基于Model的协同过滤的算法总结包括:Aspect Model,pLSA,LDA,聚类,SVD,Matrix Factorization等。不管这句话说的是否严谨(比如还有二分图模型),总之我认为Spark MLlib目前(2.2.0版本)并不能算是完整的协同过滤。只是做了基于Model的协同过滤中的矩阵分解内容。当然做好了矩阵分解,接下来再做别的也就轻松了。

关于基于Model的矩阵分解,可以参考矩阵分解在协同过滤推荐算法中的应用。Spark的MLlib中使用的是ALS(Alternating Least Squares (ALS) matrix factorization)算法。这个可以看成是对FunkSVD的一种求解实现。不过考虑到有时候我们输入的User-Item的rating可能不是某种评判的数值打分,而是User对于Item的某种偏好,此时使用ALS-WR(alternating-least-squares with weighted-λ-regularization)通过置信度权重来重新定义目标函数,从而得到新的结果。关于ALS和ALS-WR可以参考协同过滤之ALS-WR算法机器学习(十四)——协同过滤的ALS算法(2)、主成分分析以及协同过滤 CF & ALS 及在Spark上的实现

上面主要是理论基础部分,熟悉了理论基础后,我们看下通过Spark的MLlib的落地实现,我们需要做哪些工作。同时依然建议参考另2篇文章ALS-WR(协同过滤推荐算法) in ML深入理解Spark ML:基于ALS矩阵分解的协同过滤算法与源码分析

Collaborative filtering

正如前面所讲的,我们的工作是要把评分矩阵用User和Item的latent factors表达出来。MLlib通过ALS算法来学习得到User以及Item的latent factors,在具体的实现中需要以下参数:

  • numBlocks is the number of blocks the users and items will be partitioned into in order to parallelize computation (defaults to 10). 用于并行计算,同时设置User和Item的block数目,还可以使用numUserBlocksnumItemBlocks分别设置User和Item的block数目。
  • rank is the number of latent factors in the model (defaults to 10). 表示latent factors的长度。对于这个值的设置参见What is recommended number of latent factors for the implicit collaborative filtering using ALS
  • maxIter is the maximum number of iterations to run (defaults to 10). 交替计算User和Item的latent factors的迭代次数。
  • regParam specifies the regularization parameter in ALS (defaults to 1.0). L2正则的系数lambda
  • implicitPrefs specifies whether to use the explicit feedback ALS variant or one adapted for implicit feedback data (defaults to false which means using explicit feedback). 表示原始User和Item的rating矩阵的值是否是评判的打分值,False表示是打分值,True表示是矩阵的值是某种偏好。
  • alpha is a parameter applicable to the implicit feedback variant of ALS that governs the baseline confidence in preference observations (defaults to 1.0). 当implicitPrefs为true时,表示对原始rating的一个置信度系数,用于和rate相乘,是一个常值。可以根据对于原始数据的观察,统计先设置一个值,然后再进行后续的tuning。
  • nonnegative specifies whether or not to use nonnegative constraints for least squares (defaults to false). 对应于选择求解最小二乘的方法:if (nonnegative) new NNLSSolver else new CholeskySolver。如果True就是用非负正则化最小二乘(NNLS),False就是用乔里斯基分解(Cholesky)

Note: 基于DataFrame的MLlib API目前只支持integer类型的user和Item的id。其他numeric类型的user和item id列也支持,不过ids必须在integer的取值范围内。这里的numeric类型指的是java.lang.Number,看了下源码感觉负值也应该是可以的。

读取ID,如果是Int直接使用,Number进行Cast并检查

除了上面文档中的参数,还有一些别的参数设置也有必要列出来(下面的Dataset<Row>即为DataFrame):

  • userCol:用户列的名字,String类型。对应于后续调用fit()操作时输入的Dataset<Row>入参时用户id所在schema中的name
  • itemCol:item列的名字,String类型。对应于后续调用fit()操作时输入的Dataset<Row>入参时item id所在schema中的name
  • ratingCol:rating列的名字,String类型。对应于后续调用fit()操作时输入的Dataset<Row>入参时rating值所在schema中的name
  • predictionCol:String类型。做transform()操作时输出的预测值在Dataset<Row>结果的schema中的name,默认是“prediction”
  • coldStartStrategy:String类型。有两个取值"nan" or "drop"。这个参数指示用在prediction阶段时遇到未知或者新加入的user或item时的处理策略。尤其是在交叉验证或者生产场景中,遇到没有在训练集中出现的user/item id时。"nan"表示对于未知id的prediction结果为NaN。"drop"表示对于transform()的入参DataFrame中出现未知ids的行,将会在包含prediction的返回DataFrame中被drop。默认值是"nan"

Explicit和implicit feedback

标准的协同过滤中的矩阵分解(matrix factorization)都是对user-item的打分矩阵做因子分解,比如用户对电影的打分,也称为显式反馈(explicit feedback)。

不过在现实情况中,很多user-item都不是某种特定意义的评分,而是一些比如用户的购买记录、搜索关键字,甚至是鼠标的移动。我们将这些间接用户行为称之为隐式反馈(implicit feedback)。

在Spark中处理隐式反馈的算法是ALS-WR。可以重点看下前面给出的参考链接中的算法结果,观察损失函数,就可以知道大致过程。

正则化系数

这里指的是在ALS算法中L2正则项的系数,用来防止过拟合,也能使矩阵的因子分解后的U和V矩阵的值不会太震荡,方便接下来对U和V矩阵再做进一步的利用。

而且Spark通过ALS-WR算法使得 regParam 较少的被数据集的规模所影响。这样可以使得在样本子集中学习得到的最佳参数可以应用在数据全集上而且获得相似的性能。

冷启动策略

我们使用训练后的 ALSModel 对test数据进行预测,不过可能会遇到没有出现在训练模型中的user或者item id,这是由以下两种情况产生引起的:

  • 在生成中:本来就会有新的user或者item上线,是之前训练时不曾有的(这也称之为“cold start problem”)
  • 在交叉验证阶段:不管是用Spark的 CrossValidator 或者 TrainValidationSplit 都有可能出现验证集中的id是训练集中没有出现过的。

默认Spark使用NaN来表示对于未知id的rate的预测结果,这样在生产中可以提示系统有新的id加入,作为接下来是否采取措施的依据。
不过在交叉验证阶段,NaN会妨碍接下来的评分度量 (比如使用 RegressionEvaluator ),此时可以选择"drop"来使得出现NaN的行都丢掉。方便调参时做模型选择。

举个栗子

下面这个栗子也是官网文档中的栗子。首先看下数据的模样:

sample_movielens_ratings.txt

然后是代码:

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

// $example on$
import java.io.Serializable;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
// $example off$

public class JavaALSExample {

  // $example on$
  public static class Rating implements Serializable {
    private int userId;
    private int movieId;
    private float rating;
    private long timestamp;

    public Rating() {}

    public Rating(int userId, int movieId, float rating, long timestamp) {
      this.userId = userId;
      this.movieId = movieId;
      this.rating = rating;
      this.timestamp = timestamp;
    }

    public int getUserId() {
      return userId;
    }

    public int getMovieId() {
      return movieId;
    }

    public float getRating() {
      return rating;
    }

    public long getTimestamp() {
      return timestamp;
    }

    public static Rating parseRating(String str) {
      String[] fields = str.split("::");
      if (fields.length != 4) {
        throw new IllegalArgumentException("Each line must contain 4 fields");
      }
      int userId = Integer.parseInt(fields[0]);
      int movieId = Integer.parseInt(fields[1]);
      float rating = Float.parseFloat(fields[2]);
      long timestamp = Long.parseLong(fields[3]);
      return new Rating(userId, movieId, rating, timestamp);
    }
  }
  // $example off$

  public static void main(String[] args) {
    SparkSession spark = SparkSession
      .builder()
      .appName("JavaALSExample")
      .getOrCreate();

    // $example on$
    JavaRDD<Rating> ratingsRDD = spark
      .read().textFile("data/mllib/als/sample_movielens_ratings.txt").javaRDD()
      .map(Rating::parseRating);
    Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
    Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
    Dataset<Row> training = splits[0];
    Dataset<Row> test = splits[1];

    // Build the recommendation model using ALS on the training data
    ALS als = new ALS()
      .setMaxIter(5)
      .setRegParam(0.01)
      .setUserCol("userId")
      .setItemCol("movieId")
      .setRatingCol("rating");
    ALSModel model = als.fit(training);
    model.userFactors();
    model.itemFactors();

    // Evaluate the model by computing the RMSE on the test data
    // Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics
    model.setColdStartStrategy("drop");
    Dataset<Row> predictions = model.transform(test);

    RegressionEvaluator evaluator = new RegressionEvaluator()
      .setMetricName("rmse")
      .setLabelCol("rating")
      .setPredictionCol("prediction");
    Double rmse = evaluator.evaluate(predictions);
    System.out.println("Root-mean-square error = " + rmse);

    // Generate top 10 movie recommendations for each user
    Dataset<Row> userRecs = model.recommendForAllUsers(10);
    // Generate top 10 user recommendations for each movie
    Dataset<Row> movieRecs = model.recommendForAllItems(10);

    // Generate top 10 movie recommendations for a specified set of users
    //todo: Those API @Since("2.3.0")
//    Dataset<Row> users = ratings.select(als.getUserCol()).distinct().limit(3);
//    Dataset<Row> userSubsetRecs = model.recommendForUserSubset(users, 10);
//    // Generate top 10 user recommendations for a specified set of movies
//    Dataset<Row> movies = ratings.select(als.getItemCol()).distinct().limit(3);
//    Dataset<Row> movieSubSetRecs = model.recommendForItemSubset(movies, 10);
    // $example off$
    userRecs.show();
    movieRecs.show();
//    userSubsetRecs.show();
//    movieSubSetRecs.show();

    spark.stop();
  }
}

代码还是不难的,建议在IDEA中阅读看下。实际使用时还需要加上tuning环节来对rankmaxIterregParamalpha 甚至numBlocks进行调参。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 211,884评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,347评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,435评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,509评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,611评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,837评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,987评论 3 408
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,730评论 0 267
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,194评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,525评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,664评论 1 340
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,334评论 4 330
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,944评论 3 313
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,764评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,997评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,389评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,554评论 2 349

推荐阅读更多精彩内容