根据统计学习方法写的KdTree实现,###
参考了这个博客的主要思路,但是在关于如何搜索最近邻上有些不同。
1.我采取在发现可能的路径后,采取扩展路径到叶子节点,生成一个新路径后重新计算最近路径。而这个博客中只检查了路径上与超球体相交的点。没有递归搜索
2.他的博客用利用方差确定分割的方向。我则选用了简单的依次更换策略。
#include <iostream>
#include <vector>
#include <algorithm>
#include <stack>
using namespace std;
struct Node
{
double x;
double y;
};
struct KdTree
{
Node val;
int split; /描述根据X或Y进行划分/
KdTree* left;
KdTree* right;
};
KdTree myKdTree{};
const int N = 6;
const int dim = 2;
Node dataSet[N] = {
{ 2,3 },
{ 5,4 },
{ 9,6 },
{ 4,7 },
{ 8,1 },
{ 7,2 }
};
int time = 0;/记录寻找分割次数/
stack<KdTree> search_path;/记录搜索过程的路经*/
/*结果结构*/
struct result {
Node resNode;
double dist;
};
/*X,Y维比较函数*/
bool compareX(Node a,Node b) {
return a.x > b.x;
}
bool compareY(Node a, Node b) {
return a.y > b.y;
}
void chooseSplit(Node unsortSet[],Node& splitData,int size) {
if (time % 2 == 0) {
/*根据x维分割*/
sort(unsortSet, unsortSet + size, compareX);
}
else {
/*根据y维分割*/
sort(unsortSet, unsortSet + size, compareY);
}
int mid;
if (size % 2 == 0) {
mid = size / 2 - 1;
}
else {
mid = size / 2;
}
splitData.x = unsortSet[mid].x;
splitData.y = unsortSet[mid].y;
time++;
}
/*构造kdTree*/
KdTree* build(int size,Node unsortSet[], KdTree* tree) {
if (size == 0) {
return 0;
}
else {
int split;
Node splitData;
chooseSplit(unsortSet,splitData, size);
Node leftset[100]{};
Node rightset[100]{};
int leftnum = 0;
int rightnum = 0;
if (time % 2 == 1) {
/*根据x维分割,time加一后*/
split = 0;
for (int i = 0; i < size; i++) {
if (splitData.x > unsortSet[i].x) {
leftset[leftnum] = unsortSet[i];
leftnum++;
}
else if(splitData.x < unsortSet[i].x) {
rightset[rightnum] = unsortSet[i];
rightnum++;
}
}
}
else {
split = 1;
for (int i = 0; i < size; i++) {
if (splitData.y > unsortSet[i].y) {
leftset[leftnum] = unsortSet[i];
leftnum++;
}
else if (splitData.y < unsortSet[i].y) {
rightset[rightnum] = unsortSet[i];
rightnum++;
}
}
}
tree = new KdTree;
tree->val = splitData;
tree->split = split;
tree->left = build(leftnum, leftset, tree->left);
tree->right = build(rightnum, rightset, tree->right);
return tree;
}
}
/*计算距离 p=2*/
double distance(Node a, Node b) {
return (a.x - b.x)*(a.x - b.x) + (a.y - b.y)*(a.y - b.y);
}
/*建立搜索路径*/
void buildpath(Node target, KdTree* tree) {
KdTree* pSearch = tree;
while (pSearch != NULL) {
search_path.push(pSearch);
if (pSearch->split == 0) {
if (target.x < pSearch->val.x) {
pSearch = pSearch->left;
}
else {
pSearch = pSearch->right;
}
}
else {
if (target.y < tree->val.y) {
pSearch = pSearch->left;
}
else {
pSearch = pSearch->right;
}
}
}
}
/*根据搜索路径查找最近邻*/
result findnearest (Node target,KdTree* tree){
/*初始化搜索路径*/
buildpath(target, tree);
Node nearest = search_path.top()->val;
double dist = distance(nearest, target);
search_path.pop();
//搜索潜在的路径上最近点。
KdTree* pBack;
while (search_path.size() != 0) {
pBack = search_path.top();
search_path.pop();
if (pBack->left == NULL && pBack->right == NULL) {
if (distance(pBack->val, target) < dist) {
dist = distance(pBack->val, target);
nearest = pBack->val;
}
}
else {
if (pBack->split == 0) {
if (abs(target.x - pBack->val.x) < dist) {//X方向相交。
KdTree* newTree{};
if ((target.x > pBack->val.x)&&(pBack->left !=NULL)) {//点在右侧,向左搜索。
search_path.push(pBack->left);
newTree = pBack->left;
}
if ((target.x < pBack->val.x) && (pBack->right != NULL)) {
search_path.push(pBack->right);
newTree = pBack->right;
};
//搜索新发现的路径
buildpath(target, newTree);
}
}
else {
if (abs(target.y - pBack->val.y) < dist) {//Y方向相交。
KdTree* newTree{};
if ((target.y > pBack->val.y) && (pBack->left != NULL)) {//点在右侧,向左搜索。
search_path.push(pBack->left);
newTree = pBack->left;
}
if ((target.y < pBack->val.y) && (pBack->right != NULL)) {
search_path.push(pBack->right);
newTree = pBack->right;
};
//搜索新发现的路径
buildpath(target, newTree);
}
}
}
}
return result{ nearest ,dist };
}
//打印树结构
void printNode(Node node) {
cout << "("<<node.x<<","<<node.y<<")"<<endl;
}
void printTree_rootfirst(KdTree* root) {
printNode(root->val);
if (root->left != NULL) {
printTree_rootfirst(root->left);
}
if (root->right != NULL) {
printTree_rootfirst(root->right);
}
}
void printTree_leftfirst(KdTree* root) {
if (root->left != NULL) {
printTree_leftfirst(root->left);
}
printNode(root->val);
if (root->right != NULL) {
printTree_leftfirst(root->right);
}
}
int main() {
KdTree * root = NULL;
root = build(N, dataSet, root);
Node target = {2,4.5};
result res = findnearest(target,root);
cout <<"最近距离:"<< res.dist << endl;
cout <<"X方向:"<< res.resNode.x << endl;
cout << "Y方向:" << res.resNode.y << endl;
system("pause");
}