k近邻算法及其实现

1. KNN (k-Nearest Neighbor)

k近邻算法是一种基本分类与回归方法。k近邻法假设给定一个训练数据集,其中的实例类别一定。分类时,对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决等方法进行预测。因此k近邻算法不具有显式的学习过程。k近邻实际上是利用训练数据集对特征向量空间进行划分,并作为其分类的模型。
k近邻的三个基本要素是:k值的选择,距离的度量以及分类决策规则。

1.1 距离的度量

特征空间中两个实例点的距离是两个实例点相似程度的反映,常见的距离度量有:欧式距离,Lp距离等等(距离度量可以参考这篇博文: 从K近邻算法、距离度量谈到KD树、SIFT+BBF算法 - July_ - 博客园)。不同的距离度量得到的结果可能是不一样的。

1.2 k值的选择

如果选择较小的k,就相当于用较小的领域中的训练实例进行预测,只有与输入实例较近的训练实例才会对预测结果起作用,但是这样会导致预测结果对近邻点非常敏感。如果近邻的实例点恰巧是噪声,预测就会出错。也就是说,k值的减少就意味着整体模型变得复杂,容易过拟合。
如果选择较大的k值,与输入实例较远的(不相似的)训练实例也会对预测起作用,使得预测发生错误。k值的增加意味着整体模型变得简单。

1.3分类决策规则

可以选择多数表决规则,甚至加上距离的远近(即把距离当做权重),决定输入实例是哪个类别。

2.kd树

实现k近邻算法是,主要考虑的问题是如何对训练数据进行快速k近邻搜索。这点在特征空间的维数大及训练数据容量大时尤其必要。为了提高k近邻搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少距离计算次数。可以采用kd-tree。
k近邻搜索算法思路如下:
输入:已构造的kd树:目标点x;(辅助结构,数组)
输出:x的k近邻
公共操作P:在访问每个结点时,若数组容量不足k,则将该结点加入数组,若堆容量以达到k,则比较当前节点是否比数组尾元素与x的距离更近,若更近则以当前节点代替数组尾结点,并调整数组。
(1)从根节点出发,递归地向下访问kd树,若目标x当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点,知道结点为叶节点为止。执行公共操作P。
(2)递归的向上回退,在每个节点进行以下操作:
(a)执行公共操作P。
(b)检查该子结点的兄弟结点区域是否有比堆顶元素更近的点或堆容量未满。具体的,检查另一子结点对应的区域是否与以目标点为求心,以目标点与堆顶元素距离为半径的球体相交。
如果相交或容量未满,以另一子结点为根节点执行(1)。
(4)当回退到根节点时,搜索结束,堆中实例即为所求实例。

注:前几天刚做完机器学习的大作业,实现了KNN算法,是针对iris数据集的。特此总结

代码实现:

代码不友好!!!!
kd_tree.h
#include<stdlib.h>
#include<vector>
#include<math.h>
#include<algorithm>
#include<iostream>
using namespace std;

#define  K   4    ////输入数据的维度

class kd_tree_node{
//成员对象
public:
  vector<float> node_data;    //存储该节点样本数据
  string node_type;           //是叶节点还是树干(树枝)
  int numpoints;              //训练数据的个数,或者说这个二叉树有多少个节点
  int index;                  //节点数据在原数据中的索引位置
  int splitdim;               //该节点进行分裂是的,选择的分裂维度
  double splitval;            //该节点选择的分裂值
  kd_tree_node* left_node, *right_node,*parents;
};
vector<int> median_data(vector<vector<float>>data, vector<int> index, int splitdim_num);//排    序函数,返回排好序的索引序列

/ /递归实现创建kd_tree
kd_tree_node* create_kd_tree(vector<vector<float>>data,int split_dim_num,vector<int>index,kd_tree_node *parent){

//初始化,构造根节点。创建一个节点kd_tree_node;
kd_tree_node * root = new kd_tree_node;
root->numpoints = data.size();

//判断结束条件
if (index.size() == 1){
    //设置成员变量
    root->left_node = NULL;
    root->right_node = NULL;
    root->node_type = "leaf";
    root->splitdim = -1;
    root->splitval = 0;
    root->parents = parent;
    root->node_data = data[index[0]];
    root->index = index[0];
}
else{
    //排序,分裂
    index = median_data(data, index, split_dim_num);
    int length = index.size();
    vector<int>left, right;
    for (int i = 0; i < index.size(); i++){
        if (i < length  / 2)
            left.push_back(index[i]);
        else{
            if (i>length/2)
                right.push_back(index[i]);
        }
    }
    //设置类成员变量
    if (left.size() >= 1){
        root->left_node = create_kd_tree(data, split_dim_num  % K + 1, left, root);
    }
    else
        root->left_node = NULL;
    if (right.size() >= 1){
        root->right_node = create_kd_tree(data, split_dim_num  % K + 1, right, root);
    }
    else
        root->right_node = NULL;
    root->node_type = "body";
    root->splitdim = split_dim_num;
    root->splitval = data[index[length/2]][split_dim_num - 1];  //(<)
    root->parents = parent;
    root->node_data = data[index[length/2]];
    root->index = index[length / 2];
}
return root;

}

//排序函数,返回排好序的索引序列
vector<int> median_data(vector<vector<float>>data, vector<int> index, int splitdim_num){
vector<float>temp;
int length = index.size();
for (int i = 0; i < length; i++){
    temp.push_back(data[index[i]][splitdim_num - 1]);
}
//升序排序,冒泡法
int index_temp = 0;
float a = 0;

for (int i = 0; i < length - 1; i++){
    for (int j = 0; j < length -i- 1; j++){
        if (temp[j]>temp[j + 1]){
            a = temp[j + 1];
            temp[j + 1] = temp[j];
            temp[j] = a;

            index_temp = index[j + 1];
            index[j + 1] = index[j];
            index[j] = index_temp;
        }
    }
}
return index;

}

//k-近邻搜索算法
/*****公共操作P:在访问每个结点时,若最大堆容量不足k,则将该结点加入最大堆,若堆容量以达到k,则    比较当前节点是否比堆顶元素与x的距离更近,若更近则以当前节点代替堆顶结点,并调整堆。
(1)从根节点出发,递归地向下访问kd树,若目标x当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点,知道结点为叶节点为止。执行公共操作P。
(2)递归的向上回退,在每个节点进行以下操作:
(a)执行公共操作P。
(b)检查该子结点的兄弟结点区域是否有比堆顶元素更近的点或堆容量未满。具体的,检查另一子结点对应的区域是否与以目标点为求心,以目标点与堆顶元素距离为半径的球体相交。
如果相交或容量未满,以另一子结点为根节点执行(1)。
(4)当回退到根节点时,搜索结束,堆中实例即为所求实例。
  ****/
  /*************
function: knn_k_search()

input:
  test_data:测试数据
  near_num:需要寻找几个近邻元素,near_num
  root:kd树的根节点

output: 返回找到原数据中near_num个近邻点在原数据中的index(索引)数组。
*************/
vector<int> knn_k_search(vector<float>test_data, int near_num, kd_tree_node *root){
  vector<int> near_k_node_index(0);           //记录下k个近邻点的索引
  vector<double>near_k_nodedist(0);           //记录下k个紧邻点的距离
  vector<kd_tree_node*> near_k_nodepoint;     //记录下k个近邻点的kd_tree指针

if (root->numpoints < near_num){
    cout << "do not have enough points" << endl;
    return near_k_node_index;
}

//首先找到叶节点,并记录下搜索的路径
kd_tree_node * leaf_node = NULL;
int split_dim = 1;
leaf_node = root;
vector<kd_tree_node*>path;
path.push_back(leaf_node);
while (leaf_node->node_type != "leaf"){
    split_dim = leaf_node->splitdim;
    if (test_data[split_dim - 1] <= leaf_node->splitval)//分裂
        leaf_node = leaf_node->left_node;
    else{
        if (leaf_node->right_node == NULL)
            leaf_node = leaf_node->left_node;//如果只有左子树,那么叶节点就选是左子树
        else
            leaf_node = leaf_node->right_node;
    }
    path.push_back(leaf_node);
}
path.pop_back();

//copy一份路径
vector<kd_tree_node*>path_copy = path;

//k近邻搜索,回溯,,找到K个最接近给定测试数据的样本,统计出现频率
//计算两点之间的距离,从叶子节点开始;  

//test_data所在的叶节点指针一直存储在leaf_node中
double dist1 = 0, max_dist = 0;

//计算距离
for (int i = 0; i < test_data.size(); i++){
    dist1 += (leaf_node->node_data[i] - test_data[i])*(leaf_node->node_data[i] - test_data[i]);
}
dist1 = sqrt(dist1);
max_dist = dist1;

//压入数据
near_k_nodepoint.push_back(leaf_node);
near_k_nodedist.push_back(max_dist);

//定义一个指针,该值针,指向上一个分支。
kd_tree_node * rl_node = leaf_node;  //也就是表示该分支已经被访问过了
while (path.size() != 0){

    //回溯到父节点(不一定是父节点,是搜索队列中,栈顶元素)
    kd_tree_node *back_point = path[path.size() - 1];
    path.pop_back();
    int split_s = back_point->splitdim - 1; 
    double dist2 = 0;
    for (int i = 0; i < test_data.size(); i++){
        dist2 += (back_point->node_data[i] - test_data[i])*(back_point->node_data[i] - test_data[i]);
    }
    dist2 = sqrt(dist2);

    //判断是否加入队列,两个:队列是否已满?未满直接加入,更新最大距离,已满的话判断是否大于最大距离
    if (near_k_nodepoint.size() == near_num && dist2 < max_dist)//队列已满,且小于最大距离
    {
        near_k_nodepoint.pop_back();
        near_k_nodedist.pop_back();
        //此时队列是不满的
    }

    if (near_k_nodepoint.size() < near_num)//如果队列未满的话,压入队列
    {
        if (near_k_nodepoint.size() == 0){   // 当队列为空时
            near_k_nodepoint.push_back(back_point);
            near_k_nodedist.push_back(dist2);
            max_dist = dist2;
        }
        else{
            int i = 0;
            while (dist2>near_k_nodedist[i]){
                i++;
                if (i == near_k_nodepoint.size())
                    break;
            }
            //更新最大距离
            max_dist = near_k_nodedist[near_k_nodedist.size() - 1];
            if (i == near_k_nodepoint.size())
                max_dist = dist2;
            //插入对i之前,对near_k_nodepoint和near_k_nodepoint;
            near_k_nodepoint.insert(near_k_nodepoint.begin() + i, back_point);
            near_k_nodedist.insert(near_k_nodedist.begin() + i, dist2);
        }

    }

    if (back_point->node_type == "leaf"){
        continue;//到达叶节点就继续下一轮
    }
    double dist3 = abs(test_data[split_s] - back_point->node_data[split_s]);
    //判断是否需要进入另一个分支
    if (near_k_nodepoint.size() < near_num || (dist3<max_dist)){
        
        //判断back_point 是否是test_data搜索路径中某个节点
        bool flag = false;
        for (int i = path_copy.size()-1; i >=0; i--){
            if (back_point == path_copy[i]){
                flag = true;
            }
        }
        if (flag){
            double flag = test_data[split_s];
            double flag2 = back_point->node_data[split_s];
            if (flag <= flag2){
                if (back_point->right_node != NULL)
                    back_point = back_point->right_node;//可能只有左子树,//如果只有左子树,那么叶节点就选是左子树
                else
                    back_point = back_point->left_node;
            }
            else{
                back_point = back_point->left_node;
            }
            path.push_back(back_point);
        }
        else{
            if (back_point->right_node != NULL)  //右节点压入栈中
                path.push_back(back_point->right_node);
            if (back_point->left_node != NULL)   //左节点压入栈中
                path.push_back(back_point->left_node);
        }
    }
}
//返回索引向量
for (int i = 0; i < near_k_nodepoint.size(); i++){
    near_k_node_index.push_back(near_k_nodepoint[i]->index);
}
return near_k_node_index;

}

knn.cpp

#include"kd_tree.h"
#include<fstream>
#include<string>

#define  label_type  3   //有三种样本

using namespace std;
string iris_name[label_type] = {"Iris-setosa","Iris-versicolor","Iris-virginica"}; //三种iris花的名字

void main(){

//读取数据阶段
/**数据分为train.txt和test.txt
    每个数据有五个分量,最后一个分量是样本所属的类型
    读得数据分别存储在data和label里,分为train_data,train_label.
    分隔符是空格符
**/
//训练数据
string train_file = "train2.txt";
ifstream ist(train_file.c_str());
vector<vector<float>>train_data;
vector<int>train_label;
while (!ist.eof()){
    vector<float> single_data;
    for (int i = 0; i < K; i++){
        float temp = 0;
        ist >> temp;
        single_data.push_back(temp);
    }
    int label = 0;
    ist >> label;
    train_label.push_back(label);
    train_data.push_back(single_data);
    single_data.resize(0);
}
ist.close();

//测试数据
string test_file = "test2.txt";
ifstream ist2(test_file.c_str());
vector<vector<float>>test_data;
vector<int>test_label;
while (!ist2.eof()){
    vector<float> single_data;
    for (int i = 0; i < K; i++){
        float temp = 0;
        ist2 >> temp;
        single_data.push_back(temp);
    }
    int label = 0;
    ist2 >> label;
    test_label.push_back(label);
    test_data.push_back(single_data);
    single_data.resize(0);
}
ist2.close();

int NUM = 0;                 //NUM是K近邻的所选取的近邻点的数目
for (NUM = 1; NUM < 121; NUM++){
    //创建kd树
    kd_tree_node *iris_kd_tree = NULL;
    int numpoints = train_label.size();
    vector<int>index;
    for (int i = 0; i < numpoints; i++){
        index.push_back(i);
    }
    iris_kd_tree = create_kd_tree(train_data, 1, index, NULL);  //根据训练数据创建kd_tree

    //测试样本的准确率
    int sum_num[label_type];     //各类样本的总数
    int right_num[label_type];   //各类样本的正确判断数目
    int error_num[label_type];   //各类样本的错误识别率

    //初始化
    for (int i = 0; i < label_type; i++){
        sum_num[i] = 0;
        right_num[i] = 0;
        error_num[i] = 0;
    }

    //k近邻搜索,判断样本类型
    for (int i = 0; i < test_label.size(); i++){
        vector<int>k_index;
        vector<int>count_label;
        for (int j = 0; j < label_type; j++){
            count_label.push_back(0);
        }
        k_index = knn_k_search(test_data[i], NUM, iris_kd_tree);//k近邻搜索

        for (int j = 0; j < k_index.size(); j++){
            int flag = train_label[k_index[j]];
            count_label[flag]++;   //统计k近邻各类样本出现的次数
        }
        int max = count_label[0];
        int label_flag = 0;
        for (int j = 1; j < label_type; j++){
            if (max < count_label[j]){
                max = count_label[j];
                label_flag = j;
            }
        }
        
        if (label_flag == test_label[i]){
            right_num[test_label[i]]++;
        }
        else{
            error_num[label_flag]++;
        }
        sum_num[test_label[i]]++;
    }

    //统计结果,并打印出结果
    int sum = 0;
    int error = 0;
    for (int i = 0; i < label_type; i++){
        sum += sum_num[i];
        error += error_num[i];
    }
    cout << NUM << ":" << endl;
    for (int i = 0; i < label_type; i++){
        cout << iris_name[i] << "测试样本总数为:" << sum_num[i] << ",正确率为:" << right_num[i] / (sum_num[i] * 1.0) << ",错误识别为该样本的数目为:" << error_num[i] << endl;
    }
    cout << "总的正确率为:" << 1-error*1.0/sum<<endl;
    cout << endl;
}
//画出kd_树(选做)
system("pause");

}

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,293评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,604评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,958评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,729评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,719评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,630评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,000评论 3 397
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,665评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,909评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,646评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,726评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,400评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,986评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,959评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,197评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,996评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,481评论 2 342

推荐阅读更多精彩内容

  • 一.朴素贝叶斯 1.分类理论 朴素贝叶斯是一种基于贝叶斯定理和特征条件独立性假设的多分类的机器学习方法,所...
    wlj1107阅读 3,070评论 0 5
  • 保留初心,砥砺前行 k-nearest neighbor, k-NN是一种可以用于多分类和回归的方法。knn是一...
    加勒比海鲜王阅读 1,289评论 3 7
  • 原文章为scikit-learn中"用户指南"-->"监督学习的第六节:Nearest Neighbors"###...
    HabileBadger阅读 7,105评论 0 7
  • 机器学习是做NLP和计算机视觉这类应用算法的基础,虽然现在深度学习模型大行其道,但是懂一些传统算法的原理和它们之间...
    在河之简阅读 20,475评论 4 65
  • 这片天 蓝得深沉 灰得浅淡 高原的风 吹不淡你的忧郁愁肠 世俗的尘 遮不住你的斐然才情 凛冽的寒冬 冻不住你心头的...
    萧七_阅读 155评论 0 0