KD-Tree 算法的 C++ 实现

阅读本文前,建议查阅相关资料,了解 KNN 算法与 KD 树。

基础知识

如图所示,假设一个点 a 目前的最近邻点为 b,如果存在相对于 ba 更近的点,那么这个点一定在以 a 为圆心,ab 为半径的圆内。
现右侧的区域是未知的,如果 a 到分界线的距离 l 大于目前的最近距离 L(圆半径),则没有必要在右侧的未知区域继续寻找最近邻点(如图一),反之,则要继续寻找(如图二)。
相应的,投射到多维空间,假如切分边界为第 i 维,切分点的值为 v(标量),当前最近邻点为 y(向量),如果目标点 x(向量) 到切分边界的距离 |x[i] - v| 满足以下关系


时,需要在另一侧继续搜索。

图1:不需要在右侧未知区域继续搜索的情况

图2:需要在右侧未知区域继续搜索的情况

通常地,一个机器学习算法分为 fitpredict 两个阶段,基于线性搜索的 KNN 是一种惰性算法,它将全部的计算任务放到了 predict 阶段,predict 的时间复杂度为 O(n),KD 树之所以比线性搜索快,就是因为它将一部分任务放到了 fit(建立 KD 树) 阶段,从而在搜索时可以略去大量不必搜索的结点(最优情况下时间复杂度为 O(1))。
上面说的比较简单,关于 KNN 算法和 KD 树的详细内容,请参考李航博士的《统计学习方法》。

代码

我们给出部分关键性的代码。

基本数据结构

  • 训练集用一个一维数组 double *data 表示,它的长度为 n_samples * n_features,标签集也用一个一维数组 double *labels 表示,它的长度为 n_samples
  • 树的结点用以下数据结构表示
     struct tree_node
     {
         size_t id;               // 表示训练集中的第 i 个数据
         size_t split;            // 切分的维度
         tree_node *left, *right; // 左、右子树
     };
    
  • 一个 KD 树的模型可用以下结构表示
     struct tree_model
     {
         tree_node *root;        // 根结点
         const double *datas;    // X
         const double *labels;   // y
         size_t n_samples;       // 样例数
         size_t n_features;      // 每个样例的特征数
         double p;               // 距离度量
     };
    
  • 求 K-近邻时需要用到大顶堆,我们直接用 C++ 的优先队列来表示,堆内现有的 n(n <= k) 个近邻点中,距离测试点最远的在堆顶
    struct neighbor_heap_cmp {
        bool operator()(const std::tuple<size_t, double> &i, 
                        const std::tuple<size_t, double> &j) {
              return std::get<1>(i) < std::get<1>(j);
          }
      };
    
    typedef std::tuple<size_t, double> neighbor;
    typedef std::priority_queue<neighbor,
            std::vector<neighbor>, neighbor_heap_cmp> neighbor_heap_;
    
    neighbor_heap k_neighbor_heap_;
    

KD-Tree 类

我们用类 KDTree 表示一个 KD 树类,它应该具有的功能有建树搜索

//(简化的代码,完整的代码详见最后)
class KDTree {
public:
    // 建树
    KDTree(const double *datas, const double *labels, size_t rows, size_t cols, double p)
    // 返回树
    tree_node *GetRoot() { return root; }
    // 求一个测试点的 k 邻
    std::vector<std::tuple<size_t, double>> FindKNearests(const double *coor, size_t k);
private:
    tree_node *root_;
}

寻找切分维和切分点

在建树之前,我们还要考虑如何选择切分维度和切分点。切分维度的选择有许多,一般的,可以取 dim = floor % n_features,即当前树的层数对特征数取余,我们在这里使用 dim = argmax(nmax - nmin),即选取当前结点集合中极差最大的维度。

(这里是不完整的代码,有些工具函数的定义请详见完整源代码)
size_t KDTree::FindSplitDim(const std::vector<size_t> &points) {
    if (points.size() == 1)
        return 0;
    size_t cur_best_dim = 0;
    double cur_largest_spread = -1;
    double cur_min_val;
    double cur_max_val;
    for (size_t dim = 0; dim < n_features; ++dim) {
        cur_min_val = GetDimVal(points[0], dim);
        cur_max_val = GetDimVal(points[0], dim);
        for (const auto &id : points) {
            if (GetDimVal(id, dim) > cur_max_val)
                cur_max_val = GetDimVal(id, dim);
            else if (GetDimVal(id, dim) < cur_min_val)
                cur_min_val = GetDimVal(id, dim);
        }

        if (cur_max_val - cur_min_val > cur_largest_spread) {
            cur_largest_spread = cur_max_val - cur_min_val;
            cur_best_dim = dim;
        }
    }
    return cur_best_dim;
}

选择完切分维 k 之后,我们需选取当前结点集合中的结点在第 k 维的值的中位数 x 作为切分点的值,除去该点之外的点,第 k 维的值小于等于 x 的,放入左子树,反之放入右子树。
在求中位数时,不要全排序,然后取中间的点,可以采用类似快排的方法,找到中位数时就停止排序,这里我们就不写算法了,直接用 C++ 的函数。

std::tuple<size_t, double> KDTree::MidElement(const std::vector<size_t> &points, size_t dim) {
    size_t len = points.size();
    for (size_t i = 0; i < points.size(); ++i)
        get_mid_buf_[i] = std::make_tuple(points[i], GetDimVal(points[i], dim));
    std::nth_element(get_mid_buf_,
                     get_mid_buf_ + len / 2,
                     get_mid_buf_ + len,
                     [](const std::tuple<size_t, double> &i, const std::tuple<size_t, double> &j) {
                         return std::get<1>(i) < std::get<1>(j);
                     });
    return get_mid_buf_[len / 2];
}

建树

建树直接按照建立二叉树的方法即可

tree_node *KDTree::BuildTree(const std::vector<size_t> &points) {
    size_t dim = FindSplitDim(points);
    std::tuple<size_t, double> t = MidElement(points, dim);
    size_t arg_mid_val = std::get<0>(t);
    double mid_val = std::get<1>(t);

    tree_node *node = Malloc(tree_node, 1);
    node->left = nullptr;
    node->right = nullptr;
    node->id = arg_mid_val;
    node->split = dim;
    std::vector<size_t> left, right;
    for (auto &i : points) {
        if (i == arg_mid_val)
            continue;
        if (GetDimVal(i, dim) <= mid_val)
            left.emplace_back(i);
        else
            right.emplace_back(i);
    }
    if (!left.empty())
        node->left = BuildTree(left);
    if (!right.empty())
        node->right = BuildTree(right);
    return node;
}

搜索 K-近邻的规则

一般书上所讲的都是搜索最近邻,但是我们这里是搜索 K-近邻,需要对书上的算法做少许的扩充。
搜索最近邻时,我们一般设置两个变量 cur_min_idcur_min_dist,如果当前搜索到的点到测试点的距离 l < cur_min_dist 时,我们将上述两个变量更新为新点的 iddist
相应的,在搜索 K-近邻时,我们可以设置一个最多有 k 个元素的大顶堆,这样,在搜索时,当堆满时,只需比较当前搜索点的 dist 是否小于堆顶点的 dist,如果小于,堆顶出堆,并将当前搜索点压入,反之,则不变;当堆未满时,直接将该搜索点压入。

搜索 K-近邻的算法

我们直接使用二叉树深度优先遍历的非递归算法(具体的描述详见《统计学习方法》第 43 页算法 3.3)。

std::vector<std::tuple<size_t, double>> KDTree::FindKNearests(const double *coor, size_t k) {
    std::memset(visited_buf_, 0, sizeof(bool) * n_samples);
    std::stack<tree_node *> paths;
    tree_node *p = root;

    while (p) {
        HeapStackPush(paths, p, coor, k);
        p = coor[p->split] <= GetDimVal(p->id, p->split) ? p = p->left : p = p->right;
    }
    while (!paths.empty()) {
        p = paths.top();
        paths.pop();

        if (!p->left && !p->right)
            continue;

        if (k_neighbor_heap_.size() < k) {
            if (p->left)
                HeapStackPush(paths, p->left, coor, k);
            if (p->right)
                HeapStackPush(paths, p->right, coor, k);
        } else {
            double node_split_val = GetDimVal(p->id, p->split);
            double coor_split_val = coor[p->split];
            double heap_top_val = std::get<1>(k_neighbor_heap_.top());
            if (coor_split_val > node_split_val) {
                if (p->right)
                    HeapStackPush(paths, p->right, coor, k);
                if ((coor_split_val - node_split_val) < heap_top_val && p->left)
                    HeapStackPush(paths, p->left, coor, k);
            } else {
                if (p->left)
                    HeapStackPush(paths, p->left, coor, k);
                if ((node_split_val - coor_split_val) < heap_top_val && p->right)
                    HeapStackPush(paths, p->right, coor, k);
            }
        }
    }
    std::vector<std::tuple<size_t, double>> res;

    while (!k_neighbor_heap_.empty()) {
        res.emplace_back(k_neighbor_heap_.top());
        k_neighbor_heap_.pop();
    }
    return res;
}

完整代码

详见 https://github.com/WiseDoge/libkdtree
完整代码中除了 KD-Tree 的代码外,还给出了测试代码和 Python 接口代码,以及一些调用第三方库来加速的手段。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,864评论 6 494
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,175评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,401评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,170评论 1 286
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,276评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,364评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,401评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,179评论 0 269
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,604评论 1 306
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,902评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,070评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,751评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,380评论 3 319
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,077评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,312评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,924评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,957评论 2 351

推荐阅读更多精彩内容