一、假设数据集
1、简单介绍
昨天上午刚自学了机器学习的k-近邻算法,下午一个亲戚就来说买了一个手机保护壳,然后我就思考了下我自己寿命的问题,想着想着突然问自己为啥有人要手机保护壳?不还是为了手机使用寿命能更长点。。所以写下了这个代码,本代码是在Anaconda最新版下的Jupyter Notebook下写出来的,使用的电脑系统为Windows 10专业版。
2、数据准备及自定义各个参考指标
由于没有找到手机寿命方面的数据集,所以我就假设一个数据集出来,假设一部手机最高寿命为10年,数字“1”代表寿命长,数字“0”代表寿命短,并定义小于5年为短寿命,除外都是长寿命。
内容如下:
每天玩手机的时间(小时) | 点击频率(大于0且小于1) | 寿命(年) |
---|---|---|
3 | 0.87 | 1 |
0.5 | 0.9 | 1 |
5 | 0.258 | 1 |
6.5 | 0.52 | 1 |
7 | 0.36 | 0 |
10 | 0.45 | 1 |
14 | 0.13 | 0 |
18 | 0.67 | 0 |
20 | 0.89 | 0 |
24 | 0.77 | 0 |
利用 k-近邻算法 预测手机寿命(数据纯属瞎编,请勿轻信)
二、算法实现及其他
创建好训练数据集和测试数据及作图
# 导入第三方库
import numpy as np
from math import sqrt
import matplotlib.pyplot as plt
from collections import Counter
# 原始数据
X_original_train = [[3, 0.87],
[0.5, 0.9],
[5, 0.258],
[6.5, 0.52],
[7, 0.36],
[10, 0.45],
[14, 0.13],
[18, 0.67],
[20, 0.89],
[24, 0.77]]
Y_original_train = [1, 1, 1, 1, 0, 1, 0, 0, 0, 0]
# 转换成可以操作的科学计算中的数组
X_train = np.array(X_original_train)
Y_train = np.array(Y_original_train)
X_train
输出结果:
array([[ 3. , 0.87 ],
[ 0.5 , 0.9 ],
[ 5. , 0.258],
[ 6.5 , 0.52 ],
[ 7. , 0.36 ],
[10. , 0.45 ],
[14. , 0.13 ],
[18. , 0.67 ],
[20. , 0.89 ],
[24. , 0.77 ]])
Y_train
输出结果:
array([1, 1, 1, 1, 0, 1, 0, 0, 0, 0])
# 画原始数据的散点图
plt.scatter(X_train[Y_train == 0, 0], X_train[Y_train == 0, 1], color = 'g')
plt.scatter(X_train[Y_train == 1, 0], X_train[Y_train == 1, 1], color = 'r')
plt.show()
输出结果:
假设一部手机每天被蹂躏11个小时且在被蹂躏的期间点击率为0.5,预测此手机使用寿命的长短
# 创建测试数据且作包含测试数据点的散点图
X_test = np.array([11, 0.5])
plt.scatter(X_train[Y_train == 0, 0], X_train[Y_train == 0, 1], color = 'g')
plt.scatter(X_train[Y_train == 1, 0], X_train[Y_train == 1, 1], color = 'r')
plt.scatter(X_test[0], X_test[1], color = 'b') # 测试数据的坐标点颜色为蓝色
plt.show()
输出结果:
# 计算每个点与测试坐标点的距离
distances = [] # 用来存储每个点与测试点的距离
for x_train in X_train:
distance = sqrt(sum(((x_train - X_test) ** 2)))
distances.append(distance)
distances
输出结果
[8.008551679298822,
10.507616285342742,
6.004878350141658,
4.500044444224968,
4.002449250146715,
1.0012492197250393,
3.022730553655089,
7.002063981427191,
9.008446036914469,
13.002803543851611]
# 对距离进行排序,找出最近的五个点(这里不一定非要找五个点,也可以找三个点或者其他个数的点,但不能超过数据集的大小)
point_sort = np.argsort(distances)
k = 5 # 这里设置为最近的五个点
near_point = [Y_train[i] for i in point_sort[:k]]
near_point
输出结果:
[1, 0, 0, 1, 1]
# 找到出现次数最多的那个点
More_point = Counter(near_point)
More_point.most_common(1)[0][0]
输出结果:
1
三、总结
至此,k-近邻算法给出了我们答案,对于一天玩11个小时,玩期间屏幕点击率为0.5的玩法,手机的使用寿命很长。目前的这个算法所使用的数据由于是自己胡编乱造的,所以并不合理,其实第一列玩手机的时间这个数据的值用个均值方差归一化会更好,后面再改算了。
作者:无聊的SEVEN