概述
kNN算法是最常用的分类算法之一,属于监督学习的一种。
监督学习:简单来说就是训练数据集拥有“答案”,监督学习需要有明确的目标,很清楚自己想要什么结果。比如:按照“既定规则”来分类、预测某个具体的值。
- 银行积累了一定的用户信息以及他们的信用情况,预测新的用户的信用情况。
- 积累了一定图片并且清楚图片上的物品属于什么,预测新的图片上的内容。
- ...
上面这些都属于监督学习的范畴。
kNN算法介绍
kNN的全称是K Nearest Neighbors,意思是K个最近的邻居。
目标:判断未知物体P属于什么分类。
方法:找到最接近该物体的k个物体,k个物体中占比最大的分类即为P的分类。
我们以下图为例:
- 当k取3时,由于最近的三个物体中,红色三角占比较多,判断未知物体P也就是绿色圆点的分类是红色三角。
- 当k取5时,由于最近的三个物体中,蓝色矩形占比较多,判断未知物体P也就是绿色圆点的分类是蓝色矩形。
KnnClassification.png
从上面的例子可以看出,k的取值在预测的过程中扮演了较为重要的角色,通过调整k的值可以得到更好的或者说正确率更高的预测结果。
在预测的过程中,除了k还有一些别的参数可以开发人员自定义,也就是超参数:
距离
明可夫斯基距离
度量空间中点的距离,方式有很多,比如二维平面中使用的欧式距离:
拓展到多维平面,就变成了这样:
将欧拉距离做一定的变换就可以得出明可夫斯基距离
可以看出,当p取值为2的时候,就得到了欧拉距离,所以在预测的过程中,同样可以通过调整p的取值来优化预测结果,但是超参数p仅当使用明可夫斯基距离时有效。
别的距离定义
除了明可夫斯基距离,我们同样可以使用别的方式来定义距离,比如皮尔森相关系数,比如向量空间余弦相似度等等。
在使用kNN算法的过程中,还有更多的超参数可以去自定义,可以通过scikit-learn官方文档来了解他们。
通过sklearn完成预测
我们以sklearn自带的digits数据集为例展示其使用方式:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
digits = datasets.load_digits()
X = digits.data
y = digits.target
## 将数据集随机切分为两部分,一部分用于训练,一部分用于测试准确度
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=321)
knn_clf = KNeighborsClassifier(n_neighbors=3)
knn_clf.fit(X_train, y_train) ## 拟合
knn_clf.predict(X_test) ## 进行预测
knn_clf.score(X_test, y_test) ## 得到预测的准确度
调参
对于如何寻找出更优的超参数取值,sklearn也提供了一套方法:
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
## 加载数据集
digits = datasets.load_digits()
X = digits.data
y = digits.target
## 将数据集随机切分为两部分,一部分用于训练,一部分用于测试准确度
X_digit_train, X_digit_test, y_digit_train, y_digit_test = train_test_split(X, y, test_size=0.2, random_state=333)
## 定义需要测试的参数 详见文档
param_grid = [
{
'weights': ['uniform'], ## 距离不设置权重
'n_neighbors': [i for i in range(1, 11)] ## k的取值
},
{
'weights': ['distance'], ## 根据距离长短占相应的权重
'n_neighbors': [i for i in range(1, 11)],
'p': [i for i in range(1, 8)] ## p的取值
}
]
knn_clf = KNeighborsClassifier()
grid_search = GridSearchCV(knn_clf, param_grid, n_jobs=-1, verbose=2)
grid_search.fit(X_digit_train, y_digit_train)
通过上面的代码,就可以对列出的超参数进行网格搜索,找到更优解。
我们可以通过best_estimator_来查看范围内最优解的各项参数:
print(grid_search.best_estimator_)
## KNeighborsClassifier(n_neighbors=4, p=3, weights='distance')
可以看出打印出的就是一个KNeighborsClassifier的实例,我们同样可以通过best_score_来得到准确度的值。
总结
优点
- 训练时间复杂度为O(n);
- 对数据没有假设,准确度高,对outlier不敏感;
- kNN是一种在线技术,新数据可以直接加入数据集而不必进行重新训练;
- kNN理论简单,容易实现;
缺点
- 样本不平衡问题(即有些类别的样本数量很多,而其它样本的数量很少)效果差;
- 需要大量内存;
- 对于样本容量大的数据集计算量比较大(体现在距离计算上);
- 样本不平衡时,预测偏差比较大。如:某一类的样本比较少,而其它类样本比较多;
- KNN每一次分类都会重新进行一次全局运算;
参考资料:https://github.com/liuyubobobo/Play-with-Machine-Learning-Algorithms