implicit 库 ALS 算法分析
ALS 算法
推荐模型基于“隐式反馈数据集的协同过滤”一文中描述的算法,其中包含“隐式反馈协同过滤的共轭梯度法的应用”中描述的性能优化。
构造函数
- 模型对象的构造函数,作用是初始化训练模型所需要的各个数值,借此可以实现各种功能
- 输入的参数如下:
- factors(int ,optional) - 要计算的潜在因子的数量,默认设置100
- regularization(float ,optional) - 要使用的正则化因子,默认设置0.01
- dtype(data-type ,optional) - 指定是生成64位还是32位浮点因子,默认设置32位
- use_native(bool ,optional) - 使用原生扩展来加速模型拟合,默认设置 True
- use_cg(bool ,optional) - 使用更快的共轭渐变求解器来计算因子,默认设置 True
- use_gpu(bool ,optional) - 如果可用,安装在GPU上,默认情况下仅在GPU上运行(如果可用)
- iterations(int ,optional) - 拟合数据时使用的ALS迭代次数,默认设置15次
- calculate_training_loss(bool ,optional) - 是否在每次迭代时输出训练损失,默认设置 False
- num_threads(int ,optional) - 用于拟合模型的线程数。这仅适用于本机扩展。指定0表示默认为计算机上的核心数。
fit 函数
-
fit 函数是 model 对象的函数,作用是将 item_users 矩阵归纳,具体作用如下:
- 调用此方法后,将使用输入数据的潜在因子模型初始化成员的 “user_factors” 和 “item_factors” 。
- item_users 矩阵在这里执行双重任务。 它定义了哪些用户喜欢哪些项目(原始论文中的 P_iu),以及我们对用户喜欢该项目(C_iu)的信心程度。
- 隐含地定义了负项:此代码假定 item_users 矩阵中的非零项表示用户喜欢该项。 在这个稀疏矩阵中未设置负数:库将假定所有这些项的 Piu = 0 和 Ciu = 1 。
接下来具体分析
-
首先是输入的参数:
- item_users:
- 类型:csr_matrix,(scipy.sparse.csr.csr_matrix)
- 意义:喜欢的项目的信心矩阵。 此矩阵应为csr_matrix,其中矩阵的行是项目,列是喜欢该项目的用户,值是用户喜欢该项目的置信度。
- show_progress:bool,可选
- 意义:是否在装配期间显示进度条
- item_users:
然后将 item_users 转换为 user_item 矩阵。
-
然后初始化 user_factors 和 item_factors
- 这两个变量的类型为:numpy.ndarray
- 数值范围是:0.001 - 0.000
- 这个数组里都有 3953 组数据,每一组里有 10 个 0.001 - 0.000 的数值
- 作用是 “交替最小二乘算法” 优化时的参数
-
接着初始化函数 solver
- 该函数使用了 functools 库,可以使函数变得灵活
- 该函数用于初始化 训练 函数
最后是 计算优化 user_factors 和 item_factors
fit 函数中的参数示例
- 首先是 item_users 矩阵
ratings = (1, 1) 2.1740600524989357
(1, 6) 2.1236897235732894
(1, 8) 1.8686159930982247
(1, 9) 1.9691045566482313
(1, 10) 1.249025946299109
(1, 18) 1.4689834016015424
(1, 19) 1.6169865477035692
(1, 23) 1.6669149036259538
(1, 34) 1.7310277745931328
(1, 36) 1.252227001367775
(1, 38) 2.069393912218018
(1, 44) 1.7019373471502897
(1, 45) 1.8356939148740512
(1, 48) 1.574793729260897
(1, 49) 1.9978280850518142
(1, 51) 2.434153458553812
(1, 56) 2.1236897235732894
(1, 60) 2.3074706942731167
(1, 65) 1.7529219887936436
(1, 75) 1.6484916172440682
(1, 76) 1.9419776254119916
(1, 78) 1.8859489079105
(1, 92) 1.620388367384877
(1, 96) 2.1333565572573283
(1, 109) 1.9553526853284542
: :
(3952, 4312) 6.895796737954929
(3952, 4448) 4.575881363212303
(3952, 4449) 9.186996293704246
(3952, 4593) 9.764640044171973
(3952, 4607) 6.943696747668272
(3952, 4682) 6.122127042319515
(3952, 4751) 14.914646596002854
(3952, 4802) 6.003344651685048
(3952, 4816) 7.443208670267048
(3952, 4831) 9.001062383106769
(3952, 4834) 9.017481036707101
(3952, 4858) 10.96101783958135
(3952, 5049) 9.222122597824775
(3952, 5074) 6.667498844581781
(3952, 5087) 9.152290674001744
(3952, 5100) 5.156611838083416
(3952, 5205) 8.750278126694095
(3952, 5304) 15.966304840498406
(3952, 5333) 6.3490338715645045
(3952, 5359) 7.414657852046655
(3952, 5405) 11.77330059012851
(3952, 5475) 7.630871525852557
(3952, 5812) 6.009169547331631
(3952, 5837) 6.689536096093871
(3952, 5998) 11.730994953983654
Ciu = (3953, 6041)
- 接着是 user_item 矩阵
Cui[:5] = (1, 1) 2.17406
(1, 48) 25.469677
(1, 150) 3.6122231
(1, 260) 1.377664
(1, 527) 1.7410772
(1, 531) 15.555779
(1, 588) 3.9625092
(1, 594) 6.864531
(1, 595) 4.896073
(1, 608) 1.7385802
(1, 783) 20.22804
(1, 919) 2.59251
(1, 938) 27.18305
(1, 1022) 9.451304
(1, 1028) 5.130934
(1, 1029) 10.0521555
(1, 1035) 5.8114653
(1, 1097) 2.1897688
(1, 1193) 2.3664563
(1, 1207) 4.3305497
(1, 1246) 5.5710964
(1, 1270) 1.886484
(1, 1287) 6.416413
(1, 1545) 51.37579
(1, 1566) 16.32384
: :
(3, 2735) 26.195389
(3, 2858) 1.3211542
(3, 2871) 6.898585
(3, 3168) 11.295114
(3, 3421) 4.07478
(3, 3552) 5.5711293
(3, 3671) 4.4112697
(4, 260) 1.6511813
(4, 480) 2.4938776
(4, 1036) 3.2292256
(4, 1097) 2.6245189
(4, 1198) 1.9135972
(4, 1201) 6.478673
(4, 1214) 2.6565099
(4, 1240) 2.5627942
(4, 1387) 3.3952415
(4, 1954) 5.3040104
(4, 2028) 1.9135972
(4, 2366) 9.647836
(4, 2692) 4.7647886
(4, 2947) 6.2555227
(4, 2951) 10.383282
(4, 3418) 4.93192
(4, 3468) 10.95754
(4, 3702) 6.556637
Cui = (6041, 3953)
- 然后是 user_factors 数组
self.user_factors[:1] = [[5.1226740e-04 2.3522158e-03 4.5723325e-04 6.6150324e-03 3.9459649e-03
6.1708693e-03 2.0300677e-04 8.5077155e-03 1.8972992e-03 2.0382046e-03
9.5186438e-03 5.4341452e-03 7.2569805e-03 8.3699776e-03 9.6579799e-03
6.4331447e-03 6.0637039e-03 9.9095488e-03 7.4764276e-03 5.0485092e-03
1.6186638e-03 4.2158086e-03 9.1464035e-03 1.3624025e-03 2.8236140e-03
7.7461987e-03 5.2876738e-03 4.5593702e-03 2.0289577e-03 5.0300560e-03
3.4632062e-04 9.6785016e-03 5.7958192e-03 5.0499458e-03 4.5029027e-03
7.4713272e-03 7.5995307e-03 3.5817956e-03 1.5400714e-03 5.5456865e-03
9.5722824e-03 7.3524272e-05 7.0764977e-03 8.8191610e-03 3.5753862e-03
7.7350269e-05 2.3987989e-03 1.1091189e-03 8.3196014e-03 3.8097801e-03
1.0482759e-03 1.0113537e-03 1.7453432e-03 4.8129768e-03 4.7269161e-03
6.7377803e-03 5.3967237e-03 6.9357478e-03 4.3213055e-03 8.8212257e-03
8.9827348e-03 7.2493451e-03 9.9861454e-03 6.2127886e-03 5.2154157e-03
1.4604862e-03 2.2060559e-03 7.5461916e-03 4.5615411e-03 3.3119759e-03
8.6249942e-03 8.9962762e-03 2.0078689e-04 8.9933379e-03 6.0593761e-03
1.9120758e-03 6.2310225e-03 2.0863705e-03 9.8393485e-03 8.5211238e-03
2.2520742e-03 5.3290394e-03 2.1112463e-03 5.8315467e-04 7.7766585e-03
2.0383617e-03 3.7987002e-03 6.2283650e-03 6.9619704e-04 9.1746580e-03
2.4119455e-03 2.8836483e-03 5.8108927e-03 7.5439108e-04 2.2920519e-03
7.6556685e-03 8.5246023e-03 6.8681776e-03 8.5503897e-03 7.1971295e-03
3.0294131e-03 1.4615995e-03 8.0523659e-03 1.4213074e-03 1.0314758e-03
9.9696340e-03 7.6575801e-03 1.7297893e-03 1.1224474e-03 3.6195740e-03
3.4666925e-03 3.0255079e-04 8.0479765e-03 4.1560605e-03 9.9898130e-03
7.4630273e-03 7.2351955e-03 9.4736805e-03 1.8989003e-03 5.3447112e-03
7.5252936e-03 6.4640073e-03 1.9048633e-03 4.2088274e-03 9.3266906e-05
1.5658136e-03 3.4540305e-03 1.3873345e-03]]
self.user_factors[:1] = (6041, 128)
- 最后是 item_factors 数组
self.item_factors[:1] = [[4.2947675e-03 6.1421338e-03 2.2593006e-03 9.6851232e-04 8.0828415e-03
8.1364866e-03 3.3353108e-03 9.8315496e-03 3.3877280e-03 6.0345028e-03
2.1222895e-03 8.8947183e-03 9.5906522e-04 8.7277377e-03 6.3836616e-03
6.5069442e-05 3.9558033e-03 8.2298331e-03 5.7909191e-03 2.8879766e-04
1.7520637e-03 7.7279550e-03 7.6245693e-03 1.5885481e-03 6.7957896e-03
8.2463212e-03 7.3174033e-03 6.0746092e-03 2.0345300e-03 9.5943911e-03
8.3695222e-03 4.1773785e-03 5.1691209e-04 7.0946836e-03 3.0858058e-03
6.8540201e-03 3.0306168e-03 8.0440026e-03 3.2406261e-03 2.2302095e-03
9.1938255e-04 7.4884794e-03 5.9711579e-03 4.1350457e-03 1.2515859e-03
2.6157775e-03 5.0506424e-03 2.1029548e-03 2.2212563e-03 1.7169695e-03
3.8359638e-03 4.5402725e-03 5.7884911e-04 7.6607009e-03 6.1318455e-03
6.3904864e-03 9.6114064e-03 3.8021808e-03 7.6144017e-05 3.5354136e-03
5.3026886e-03 5.6798127e-03 8.4759976e-04 1.5025025e-03 6.8341750e-03
8.8900710e-03 8.3640553e-03 2.5437644e-03 2.4673997e-03 1.0014516e-03
9.4416691e-03 5.0879637e-04 7.9685627e-03 8.4503777e-03 8.3975308e-03
6.6174692e-03 7.4964440e-03 5.3146235e-03 1.4990562e-03 9.4615193e-03
2.4433812e-04 9.8566143e-03 7.0711724e-03 3.0148539e-03 5.5398257e-03
7.3834471e-03 8.3977645e-03 4.4447011e-03 3.7807515e-03 4.3067979e-03
3.5194417e-03 6.4200228e-03 1.0479660e-03 9.1948211e-03 8.4213223e-03
6.0157869e-03 5.2273017e-03 7.9428963e-03 2.3604098e-03 1.5468789e-03
1.6676424e-03 3.4407319e-03 1.9700862e-03 9.0078106e-03 3.7092036e-03
2.6234990e-04 7.0038107e-03 1.0468360e-03 7.1541648e-03 2.0921687e-03
8.3596595e-03 9.7496398e-03 4.9834684e-03 7.7716759e-03 3.1272604e-03
3.3663216e-03 5.6318738e-03 6.4339940e-03 4.1630040e-03 4.8983088e-03
1.8529991e-03 8.9196122e-04 6.5733227e-03 3.6214802e-03 8.0824737e-03
7.7728992e-03 1.8295562e-03 6.9596451e-03]]
self.item_factors[:1] = (3953, 128)
solver 函数
由于该函数涉及的知识点较多,尚未清楚怎么工作
用到了 functools 库
RMSE
就是 MSE 开个根号。其实实质跟 MSE 是一样的。只不过用于数据更好的描述。
例如:要做房价预测,每平方是万元(真贵),我们预测结果也是万元。那么差值的平方单位应该是 千万级别的。那我们不太好描述自己做的模型效果。怎么说呢?我们的模型误差是 多少千万?。。。。。。于是干脆就开个根号就好了。我们误差的结果就跟我们数据是一个级别的可,在描述模型的时候就说,我们模型的误差是多少万元。
ALS 算法的 RMSE 结果如下:
ALS 算法的 RMSE 结果就是图中的 LOSS 的数值,数值为 0.0137
SA 算法原准确率计算代码:
SA 算法原准确率计算代码:
if Y_ >= 3.5 and use_a >= 3.5:
data_ture = data_ture + 1
if Y_ < 3.5 and use_a < 3.5:
data_ture = data_ture + 1
if Y_ >= 3.5 and use_a < 3.5:
data_false = data_false + 1
if Y_ < 3.5 and use_a >= 3.5:
data_false = data_false + 1
print('data_ture:', data_ture) # 预测正确的个数
print('data_false:', data_false) # 预测错误的个数
accuracy = data_ture / (data_ture + data_false)
print(accuracy)
endtime = datetime.datetime.now()
print(endtime - starttime)
这里的准确率就是正确的数量除以总数得到了。
转换成 RMSE 标准的话就是需要计算 Y_ 和 use_a 之间的误差了。