基于kd树的KNN算法的实现

记得大三初期,刚从大连理工大学回来,眼巴巴的望着同学各自都有着落了,就我一副“初出茅庐,不谙世事”的样子,于是不得不觍着脸厚着皮去找老师,恳求他让我去海洋所实习。他给我的第一份差事便是将几个G的图片里的数字输入到excel,我整整输了一个国庆节假日。当时就在到处询问,有没有那种算法可以让自动识别图片里的数字,存入到excel中去,想来,那时的自己也是够拼的。
如今这个自动识别数字的算法算是写出来了吧, 我至少可以这样自我安慰到。
KNN算法的理论算的上是最简单最直观的一种了,比起前几次的支持向量机、贝叶斯、逻辑斯特回归那是简单太多了,都不用推导半个公式。这周的核心都是在完成k-近邻中kd树的构建和搜索,几乎都是自己完成的,也没有经过周密的测试,只是调试调通了。
我想,它的用例定不止于此,但这个用例说出去可谓是最唬人的了。
识别下面的“图片”为数字2 0 8


原理就不多讲了,感兴趣的网上都有,理论很简单,只是构建和搜索kd树可能会有些麻烦,而kd树只是为了让它运行的更快,其实用最简单粗暴的方法计算目标点与每个训练集点的距离也未尝不可。

程序运行的效果还行,识别近千个数字只错了10个,错误率1%左右。效果如下,为了好看,我就仅截图出识别几个数字的效果:


下面为实现的程序,注释写的很明白,训练数据在网上也不难找到:

KnnHelper.py

import numpy as np
'''
Created on 2017年7月17日

@author: fujianfei
'''

class KDNode(object):
    '''
    定义KD节点:
    point:节点里面的样本点,指的就是一个样本点
    split:分割纬度(即用哪个纬度的数据进行切分,比如4维数据,split=3,则表示按照第4列的数据进行切分空间)
    left:节点的左子节点
    right:节点的右子节点
    '''


    def __init__(self, point=None, split=None, left=None, right=None):
        '''
        Constructor
        '''
        self.point = point
        self.split = split
        self.left = left
        self.right = right
        
        
class KDTree(object):
    '''
    定义:
    KDNode:kd-tree的节点
    dimensions:数据的纬度
    right:节点的右子节点
    left:节点的左子节点
    curr_axis:当前需要切分的纬度
    next_axis:下一次需要切分的纬度
    '''


    def __init__(self, data=None):
        '''
        Constructor
        '''        
        def createNode(split=None, data_set=None):
            '''
            创建KD节点
            输入值:split:分割纬度 data_set:需要分割的样本点集合
            返回值:KDNode:KD节点
            '''
            if  len(data_set) == 0:    # 数据集为空,作为递归的停止条件
                return None
            #找到split维的中位数median,先对数据进行排序,按照split维的数据大小排序
            data_set = list(data_set)
            data_set.sort(key=lambda x: x[split])#对data_set进行排序,lambda是隐函数,具体用法请百度。排序方式为按照split维的数据大小排序
            data_set = np.array(data_set)
            median = len(data_set) // 2#//为python的整数除法,找到中间点的位置median,按照这个位置进行空间切分
            #返回KD节点
            #输入的变量分别是:
            #data_set[median],中间点位置的样本点,传入KDNode即节点里面包含的数据
            #split,该节点的纬度分度位置
            #createNode(maxVar(data_set[:median]),data_set[:median]),该节点的左节点,maxVar(data_set[:median])为左节点的纬度分度位置,data_set[:median]为左节点包含的空间里的所有数据
            #同理,createNode(maxVar(data_set[median+1:]),data_set[median+1:]),为右节点。
            #用的是函数的递归创建树,因为要不断的调用函数,这个方法速度不快,用基本语句(判断、循环)去构建树的方法会更快
            return KDNode(data_set[median], split, createNode(maxVar(data_set[:median]),data_set[:median]), createNode(maxVar(data_set[median+1:]),data_set[median+1:]))
            
        def maxVar(data_set=None):
            '''
            按纬度计算样本集的最大方差纬度
            输入值:data_set:样本集
            输出值:split:最大方差的纬度,作为createNode的输入值
            '''
            if  len(data_set) == 0:    # 数据集为空,作为递归的停止条件
                return 0
            data_mean = np.mean(data_set,axis=0)#axis=0表示按列求均值
            mean_differ = data_set - data_mean#均值差
            data_var = np.sum(mean_differ ** 2,axis=0)/len(data_set)#按列求均值差平方之和,再除以样本数,便是方差
            re = np.where(data_var == np.max(data_var))#寻找方差最大的位置,也就是第几纬方差最大,返回它
            return re[0][0]
        
        self.root = createNode(maxVar(data),data)#定义根节点,分割纬度是使得样本点方差最大的纬度,需要分割的样本点为全数据
 
def computeDist(pt1, pt2):
    """
    计算两个数据点的距离
    return:pt1和pt2之间的距离
    """
    sum = 0.0
    for i in range(len(pt1)):
        sum = sum + (pt1[i] - pt2[i]) * (pt1[i] - pt2[i])
    return np.math.sqrt(sum)
        
def preOrder(root):
    '''
    KD树的前序遍历
    '''
    print(root.point)
    if root.left:
        preOrder(root.left)
    if root.right:
        preOrder(root.right)

def updateNN(min_dist_array=None, tmp_dist=0.0, NN=None, tmp_point=None, k=1):
    '''
    /更新近邻点和对应的最小距离集合
    min_dist_array为最小距离的集合
    NN为近邻点的集合
    tmp_dist和tmp_point分别是需要更新到min_dist_array,NN里的近邻点和距离
    '''
    
    if tmp_dist <= np.min(min_dist_array) : 
            for i in range(k-1,0,-1) :
                min_dist_array[i] = min_dist_array[i-1]
                NN[i] = NN[i-1]    
            min_dist_array[0] = tmp_dist
            NN[0] = tmp_point                
            return NN,min_dist_array
    for i in range(k) :
        if (min_dist_array[i] <= tmp_dist) and (min_dist_array[i+1] >= tmp_dist) :
            #tmp_dist在min_dist_array的第i位和第i+1位之间,则插入到i和i+1之间,并把最后一位给剔除掉
            for j in range(k-1,i,-1) : #range反向取值
                min_dist_array[j] = min_dist_array[j-1]
                NN[j] = NN[j-1]
            min_dist_array[i+1] = tmp_dist
            NN[i+1] = tmp_point
            break
    return NN,min_dist_array

def searchKDTree(KDTree=None, target_point=None, k=1):  
    '''
    /搜索kd树
    /输入值:KDTree,kd树;target_point,目标点;k,距离目标点最近的k个点的k值
    /输出值:k_arrayList,距离目标点最近的k个点的集合数组
    '''      
    if k == 0 : return None
    #从根节点出发,递归地向下访问kd树。若目标点当前维的坐标小于切分点的坐标,则移动到左子节点,否则移动到右子节点
    tempNode = KDTree.root#定义临时节点,先从根节点出发
    NN = [tempNode.point] * k#定义最邻近点集合,k个元素,按照距离远近,由近到远。初始化为k个根节点
    min_dist_array = [float("inf")] * k#定义近邻点与目标点距离的集合.初始化为无穷大
#     for i in range(k) :
#         NN[i] = tempNode.point#定义最邻近点集合,k个元素,按照距离远近,由近到远。初始化为k个根节点以下往左的集合
#         min_dist_array[i] = computeDist(NN[i],target_point)#定义近邻点与目标点距离的集合
#         tempNode = tempNode.left
    nodeList = []#我们是用二分查找建立路径,定义依次查找节点的list

    def buildSearchPath(tempNode=None, nodeList=None, min_dist_array=None, NN=None, target_point=None):
        '''
        P:此方法是用来建立以tempNode为根节点,以下所有节点的查找路径,并将它们存放到nodeList中
        nodeList为一系列节点的顺序组合,按此先后顺序搜索最邻近点
        tempNode为"根节点",即以它为根节点,查找它以下所有的节点(空间)
        '''
        while tempNode :
            nodeList.append(tempNode)
            split = tempNode.split#节点的分割纬度
            point = tempNode.point#节点包含的数据,当前实例点
            tmp_dist = computeDist(point,target_point)
            if tmp_dist < np.max(min_dist_array) : #小于min_dist_array中最大的距离
                NN,min_dist_array = updateNN(min_dist_array, tmp_dist, NN, point, k)#更新最小距离和最邻近点
            if  target_point[split] <= point[split] : #如果目标点当前维的值小于等于切分点的当前维坐标值,移动到左节点
                tempNode = tempNode.left
            else : #如果目标点当前维的值大于切分点的当前维坐标值,移动到右节点
                tempNode = tempNode.right
        return NN,min_dist_array
    #建立查找路径
    NN,min_dist_array = buildSearchPath(tempNode,nodeList,min_dist_array, NN, target_point)
    #回溯查找
    while nodeList :
        back_node = nodeList.pop()#将nodeList里的元素从后往前一个个推出来
        split = back_node.split
        point = back_node.point
        #判断是否需要进入父节点搜素
        #如果当前纬度,目标点减实例点大于最小距离,就没必要进入父节点搜素了
        #因为目标点到切割超平面的距离很大,那邻近点肯定不在那个切割的空间里,即没必要进入那个空间搜素了
        if not abs(target_point[split] - point[split]) >= np.max(min_dist_array) :
            #判断是搜索左子节点,还是搜索右子节点
            if (target_point[split] <= point[split]) :
                #如果目标点在左子节点的空间,则搜索右子节点,查看右节点是否有更邻近点
                tempNode = back_node.right
            else :
                #如果目标点在右子节点的空间,则搜索左子节点,查看左节点是否有更邻近点
                tempNode = back_node.left
            
            if tempNode :
                #把tempNode(此时它为另一个全新的未搜素的空间,需要将它放入nodeList,进行最近邻搜索)放入nodeList
                #nodeList.append(tempNode)
                #不能单纯地将tempNode存放到nodeList,这样下次只会搜索这一个节点
                #因为tempNode可做为一个全新的空间,故而需重新以它为根节点,构建查找路径,搜索它名下所有的节点
                NN,min_dist_array = buildSearchPath(tempNode,nodeList,min_dist_array, NN, target_point)
#                 curr_dist = computeDist(tempNode.point,target_point)
                #是否该节点为更邻近点,如果是,赋值给最邻近点
#                 if curr_dist < np.max(min_dist_array) :
#                     NN,min_dist_array = updateNN(min_dist_array, curr_dist, NN, tempNode.point, k)#更新最小距离和最邻近点
    return NN,min_dist_array 

def classify0(inX, dataSet, labels, k):
    '''
    k近邻算法的分类器
    \输入:
    inX:目标点
    dataSet:训练点集合
    labels:训练点对应的标签
    k:k值
    \这个方法的目的:已知训练点dataSet和对应的标签labels,确定目标点inX对应的labels
    ''' 
    kd = KDTree(dataSet)#构建dataSet的kd树
    NN,min_dist_array = searchKDTree(kd, inX, k)#搜索kd树,返回最近的k个点的集合NN,和对应的距离min_dist_array
    dataSet = dataSet.tolist()
    voteIlabels = []
    #多数投票法则确定inX的标签,为防止边界处分类不准的情况,以距离的倒数为权重,即距离越近,权重越大,越该认为inX是属于该类
    for i in range(k) :
        #找到每个近邻点对应的标签
        nni = list(NN[i])
        voteIlabels.append(labels[dataSet.index(nni)])
        
#     #开始记数,加权重的方法
#     uniques = np.unique(voteIlabels)
#     counts = [0.0] * len(uniques)
#     for i in range(len(voteIlabels)) :
#         for j in range(len(uniques)) :
#             if voteIlabels[i] == uniques[j] :
#                 counts[j] = counts[j] + uniques[j] / min_dist_array[i] #权重为距离的倒数
#                 break
    #开始记数,不加权重的方法
    uniques, counts = np.unique(voteIlabels, return_counts=True)
    return uniques[np.argmax(counts)]

**

HandWriting.py

**

import numpy as np
from os import listdir
from KNN import KnnHelper
'''
Created on 2017年7月23日

@author: fujianfei
'''
def img2vector(filename):
    '''
    \将32x32图像转化为1x1024的向量
    '''
    returnVect = np.zeros((1,1024))
    fr = open(filename)
    for i in range(32) :
        lineStr = fr.readline()
        for j in range(32) :
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

def handwritingClassTest():
    '''
    \识别手写数字
    '''
    hwLabels = []#定义训练数据对应的标签集合,即数字0-9
    trainingFileList = listdir('trainingDigits')#获取trainingDigits目录下所有的文件名,存在trainingFileList中
    m = len(trainingFileList)
    trainingMat = np.zeros((m,1024))
    for i in range(m) :
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)#将文件转化为矩阵
    testFileList = listdir('testDigits')#获取testDigits目录下所有的文件名,存在testFileList中
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest) :
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)#将文件转化为矩阵
        vectorUnderTest = list(vectorUnderTest[0])
        classifierResult = KnnHelper.classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print("算法识别的数字为  : %d , 真实的数字为 : %d " % (classifierResult, classNumStr))
        if (classifierResult != classNumStr) : errorCount += 1.0
    print("\n总共出错的次数为  : %d" % errorCount)
    print("\n出错率为  : %f" % (errorCount/float(mTest)))

init.py

import numpy as np
from KNN import KnnHelper,HandWriting
  
# data = [[4,1,3,5],[3,6,5,7],[5,2,6.5,5],[4.8,4.2,5,8],[1,1,8,6],[1,6,5,3],[4.1,3.7,2,5],[4.7,4.1,5,9],[2,4,6,8.7]]  # samples
# kd = KnnHelper.KDTree(data)
# # KnnHelper.preOrder(kd.root)
# ret = KnnHelper.searchKDTree(kd, [4.8,3.8,2,4], 9)
# print (ret)
HandWriting.handwritingClassTest()
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,047评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,807评论 3 386
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,501评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,839评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,951评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,117评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,188评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,929评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,372评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,679评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,837评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,536评论 4 335
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,168评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,886评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,129评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,665评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,739评论 2 351

推荐阅读更多精彩内容