因为自己的好奇心,所以做了这一篇关于KNN 算法的笔记。
一、简介
K-近邻算法是一种常用的监督学习的方法,其与K-Means算法有点类似,其原理是:在给定的样本数据中,基于某种距离(欧式距离或马氏距离等等)找出与当前样本数据距离最近的K个样本数据,然后再基于这K个“邻居”的信息来进行预测。
这个算法在生活中应用的其实也很多,比如电影、新闻等信息的归类,它可以很简单的就做到将电影、新闻进行分门别类。
除此之外呢,就是K近邻算法有着不同于其他学习方法的独特之处,它几乎没有显式训练的过程,它的训练时间为零,也就是说它是对测试样本直接进行处理,它这种方式被称为“懒惰学习”;相应的,那些在训练阶段就对样本进行学习处理的方法,被称为“急切学习”。
二、KNN算法实现
2.1实现步骤
在真正实现之前,首先理一下思路。
(1)分别计算已知类别数据的点到当前点(未知类别)之间的距离,通常选用欧式距离,当然也可以使用其他距离来进行计算。
(2)按照距离将已知类别的点进行从小到大排序。
(3)选取距离当前点最近的前k个点,并统计前k个点中各个类别出现的频率。
(4)返回前k个点中出现频率最高的类别,以此作为当前点的预测类别。
2.2代码实现
#-*- coding=utf-8 -*-
#@Time : 2020/9/21 13:49
#@File : KNN1.0.py
#@Software : PyCharm
from numpy import *
import operator
import matplotlib.pyplot as plt
# 通用的一些函数
"""
函数说明:创建数据集
"""
def createDataSet():
group=array([[1,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels=['A','A','B','B']
return group, labels
"""
函数说明:KNN分类器,找到距离inX最近的前k个样本
"""
def classify(inX,dataSet,labels,k):
dataSetSize=dataSet.shape[0]
diffMat=tile(inX,(dataSetSize,1))-dataSet #tile,瓦片之意,也就是按照某个东西来拼接出新的东西,这里是用于计算出已知点到当前点的距离
sqDiffMat=diffMat**2
sqDistances=sqDiffMat.sum(axis=1) #横向求和
distances=sqDistances**0.5
sortedDistIndices=distances.argsort() #从小到大进行排序
classCount={}
for i in range(k): #统计前k个点出现的频率
voteIlabel=labels[sortedDistIndices[i]]
classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) #按照键值对中的值进行排序
return sortedClassCount[0][0]
if __name__ == '__main__':
group,labels=createDataSet()
print("group:\n",group,'\n',"labels:\n",labels)
#绘制图像
fig=plt.figure() #创建画布
ax=fig.add_subplot(111) #添加子图,并获取坐标轴
ax.scatter(group[:,0],group[:,1],c='b',marker='o') #通过坐标轴添加散点图
ax.scatter(0,0,c="r",marker='*')
plt.show() #显示图像
result = classify([0,0],group,labels,3)
print("预测结果为:",result)
实现效果:
三、相关测试
关于一个算法,有两点的问题是我们每一个人都会关心的:效率和正确率。因为处于一种学习算法的阶段,所以我对Python的效率问题就没有什么高的要求,那么下面主要会对KNN算法的正确率方面进行相关的测试。
测试的例子就采用了一本书中的例子“海伦约会”,代码如下所示:
#-*- coding=utf-8 -*-
#@Time : 2020/9/21 16:54
#@Author : wangjy
#@File : KNN2.0.py
#@Software : PyCharm
from numpy import *
import operator
from tkinter import filedialog
import matplotlib.pyplot as plt
"""
函数说明:KNN分类器,找到距离inX最近的前k个样本
"""
def classify(inX,dataSet,labels,k):
dataSetSize=dataSet.shape[0]
diffMat=tile(inX,(dataSetSize,1))-dataSet #tile,瓦片之意,也就是按照某个东西来拼接出新的东西
sqDiffMat=diffMat**2
sqDistances=sqDiffMat.sum(axis=1)
distances=sqDistances**0.5
sortedDistIndices=distances.argsort()
classCount={}
for i in range(k):
voteIlabel=labels[sortedDistIndices[i]]
classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
"""
函数说明:读取文件中的数据,转为矩阵
"""
def file2matrix(filename):
if filename==None:
return ;
fr=open(filename)
arrayOfLines=fr.readlines()
numberOfLines=len(arrayOfLines)
returnMat=zeros((numberOfLines,3))
classLabelVector=[]
types={}
index=0
for line in arrayOfLines:
line=line.strip()
listFromLine=line.split('\t')
returnMat[index,:]=listFromLine[0:3]
typeName=listFromLine[-1]
if types.get(typeName.strip(),-1)==-1:
types[typeName]=len(types)
classLabelVector.append(types[typeName])
index+=1
return returnMat,classLabelVector
"""
函数说明:归一化特征值
"""
def autoNorm(dataSet):
minVals=dataSet.min(0)
maxVals=dataSet.max(0)
ranges=maxVals-minVals
normDataSet=zeros(shape(dataSet))
m=dataSet.shape[0]
normDataSet=dataSet-tile(minVals,(m,1))
normDataSet=normDataSet/tile(ranges,(m,1))
return normDataSet,ranges,minVals
"""
函数说明:测试KNN算法对数据预测的错误率
"""
def datingClassTest():
#打开的文件名
dlg = filedialog.askopenfile(title='打开文件', filetypes=[("文本文件", "*.txt"), ('Python源文件', '*.py')])
fileName = dlg.name # 获取数据文件路径
datingDataMat, datingLabels = file2matrix(fileName)
hoRatio = 0.10 #取所有数据的百分之十作为测试样本
normMat, ranges, minVals = autoNorm(datingDataMat) #因为各类数据的范围有大有小,所以需要进行数据归一化(等权重)
m = normMat.shape[0]
numTestVecs = int(m * hoRatio)
errorCount = 0.0 #记录分类错误个数
for i in range(numTestVecs):
classifierResult = classify(normMat[i,:], normMat[numTestVecs:m,:],
datingLabels[numTestVecs:m], 3)
print("classifierResult:%s\tdatingLabels:%d" % (classifierResult, datingLabels[i]))
if classifierResult != datingLabels[i]:
errorCount += 1.0
print("the total error rate is:%%%.3f" %(errorCount/float(numTestVecs)*100))
print("the total error count is :%d" % errorCount)
if __name__ == '__main__':
datingClassTest()
测试结果:
四、小结
总的来说,KNN算法虽然简单,但在处理数据进行分类时仍是一个优秀的方法,限制其应用的主要是其在处理大量数据时所带来的效率问题,像Python或MATLAB中虽然有着很优秀的矩阵处理的模块,但是仍不能避免大量的数值运算这种情况。所以如果对数据处理效率有着更高的要求,可以考虑借助k-d树、四叉树或八叉树等数据结构来进一步提高获取最邻近点的速度。
参考资料:《机器学习》《机器学习实战》