Java 接入 ALS & LR 为用户推荐商户

召回(ALS)接入

  • 之前离线召回的数据保存在了 MySQL 中;
  • 直接取出来就行了;
package tech.lixinlei.dianping.recommand;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import tech.lixinlei.dianping.dal.RecommendModelMapper;
import tech.lixinlei.dianping.model.RecommendModel;

@Service
public class RecommendService{

    @Autowired
    private RecommendModelMapper recommendModelMapper;

    /**
     * 召回数据, 根据 userId 召回 shopIdList
     * @param userId
     * @return
     */
    public List<Integer> recall(Integer userId){
        RecommendModel recommendModel = recommendModelMapper.selectByPrimaryKey(userId);
        if(recommendModel == null){
            recommendModel = recommendModelMapper.selectByPrimaryKey(9999999);
        }
        String[] shopIdArr = recommendModel.getRecommend().split(",");
        List<Integer> shopIdList = new ArrayList<>();
        for(int i = 0; i < shopIdArr.length; i++) {
            shopIdList.add(Integer.valueOf(shopIdArr[i]));
        }
        return shopIdList;
    }

}

排序(LR)接入

package tech.lixinlei.dianping.recommand;

import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.SparkSession;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;

@Service
public class RecommendSortService {

    private SparkSession spark;

    private LogisticRegressionModel lrModel;


    @PostConstruct
    public void init(){
        //加载 LR 模型
        spark = SparkSession.builder().master("local").appName("DianpingApp").getOrCreate();
        lrModel = LogisticRegressionModel.load("file:///home/lixinlei/project/gitee/dianping/src/main/resources/lrmode");
    }

    public List<Integer> sort(List<Integer> shopIdList, Integer userId){

        // 需要根据 lrmode 所需要 11 维的 x,生成特征,然后调用其预测方法
        List<ShopSortModel> list = new ArrayList<>();
        for(Integer shopId : shopIdList){
            //造的假数据,可以从数据库或缓存中拿到对应的性别,年龄,评分,价格等做特征转化生成 feature 向量
            Vector v = Vectors.dense(1,0,0,0,0,1,0.6,0,0,1,0);
            Vector result = lrModel.predictProbability(v);
            // arr[1] 表示代表正样本的概率
            double[] arr = result.toArray();
            double score = arr[1];
            ShopSortModel shopSortModel = new ShopSortModel();
            shopSortModel.setShopId(shopId);
            shopSortModel.setScore(score);
            list.add(shopSortModel);
        }

        list.sort(new Comparator<ShopSortModel>() {
            @Override
            public int compare(ShopSortModel o1, ShopSortModel o2) {
                if(o1.getScore() < o2.getScore()){
                    return 1;
                }else if(o1.getScore() > o2.getScore()){
                    return -1;
                }else{
                    return 0;
                }
            }
        });

        return list.stream().map(shopSortModel -> shopSortModel.getShopId()).collect(Collectors.toList());
    }

}

修改原来的 recommand 方法的实现

  • 先召回,再排序;
package tech.lixinlei.dianping.service.impl;

@Service
public class ShopServiceImpl implements ShopService {  

    @Autowired
    RecommendService recommendService;

    @Autowired
    RecommendSortService recommendSortService;

    /**
     * 先召回,再排序
     * @param longitude
     * @param latitude
     * @return
     */
    @Override
    public List<ShopModel> recommend(BigDecimal longitude, BigDecimal latitude) {
        List<Integer> shopIdList = recommendService.recall(148);
        shopIdList = recommendSortService.sort(shopIdList,148);
        List<ShopModel> shopModelList = shopIdList.stream().map(id->{
            ShopModel shopModel = get(id);
            shopModel.setIconUrl("/static/image/shopcover/xchg.jpg");
            shopModel.setDistance(100);
            return shopModel;
        }).collect(Collectors.toList());
//        List<ShopModel> shopModelList = shopModelMapper.recommend(longitude, latitude);
//        shopModelList.forEach(shopModel -> {
//            shopModel.setSellerModel(sellerService.get(shopModel.getSellerId()));
//            shopModel.setCategoryModel(categoryService.get(shopModel.getCategoryId()));
//        });
        return shopModelList;
    }

}
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。