KD-Tree 算法总结
KD-Tree 是什么
简而言之,KD-Tree是一种能维护高维数据空间的结构,主要支持几个操作:
1.插入点
2.进行距离查询(例如:查询距离某个点第k近的点)
KD-Tree 是一棵二叉搜索树。与普通的二叉搜索树一样,它具有左儿子比父亲小,右儿子比父亲大的特点。但是,比较点的大小是没有实际意义的,因此,KD-Tree并不是整体比较点的大小,而是比较某一维的大小。
上图中,加粗的数字表示当前选中的维数,也就是当前分割左儿子右儿子的关键字
为方便叙述,下文的KD-Tree均为2维KD-Tree
建树(build)
考虑两个问题:
1.如何选择划分的维度,使得KD-Tree的结构尽可能更优秀
2.如何选择当前的根节点,使得子树的深度尽量最小
显然,按照1,2,3,\dots k维的顺序来划分并不一定最好,我们考虑这样一种情况:假如在一个二维平面上,这k个点排成一条与x轴平行的直线,那么按y的大小来划分就会出现一些很尴尬的情况。
其实还有一种划分方法,我们不按1,2,\dots k这样顺序划分,而是按方差最大的那一维划分,将点最分散的那一维化成两部分,这也是我们希望看到的结果。
其实在实际应用中,顺序划分是一种最常见的方式,因为求方差的时间复杂度很高,而顺序划分对于随机数据来说表现也很出色
至于第二个问题,很明显的一件事就是我们可以选择中位数,左边一半,右边一半,这样很平均。
因此上面那个图的树应该建成这样:
关于怎么求中位数,algorithm
头文件中很贴心的为我们准备了一个函数nth_element
template<class _RanIt> inline
void nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last)
对于[First,Last)区间内的数重新排序,使得位置为Nth的值是第Nth小值
程序实现如下
l,r:当前区间,区间内的点等待插入
d:表示当前的维数,x为0,y为1
T:KD-Tree数组
ps:表示点的具体位置
ch:左儿子编号/右儿子编号
ncnt:当前点数
int build(int l,int r,int d) {
if (l>r) return 0;
del=d;int mid=(l+r)>>1,at=++ncnt;
nth_element(ps+l,ps+mid,ps+r+1,cmp);//查找中位数
T[at]=Tree(ps[mid],mid);//加点
T[at].ch[0]=build(l,mid-1,d^1),T[at].ch[1]=build(mid+1,r,d^1);//递归建子树
pushup(at);return at;
}
查询(query)
最近点(BZOJ 2648)
模拟插入的过程,找到最后待查询点在树上的位置。查找左/右子树的关键字是距离,即到待查点距离小的子树优先查。
注意“距离”指的不是到某个点的距离,而是到子节点所表示的矩形边界的最短距离。
但是这样会出现一些奇怪的情况,有可能子树的另一个儿子的某一个后继结点反而比当前所在的儿子节点更优秀,比如下面这样
带星号的点为待查点。
点(4,7)与待查点在同一个儿子,但是它并不是最优点,距离待查点距离更短。如果我们只考虑当前所在子树中的点,就不会考虑这个实际上更优秀的点,点(4,7)与待查点在点(5,4)的不同子树中。
解决方法其实很简单,判断一下另一个儿子节点是否更优秀即可。
dis():两点距离
成员函数dis():点到当前节点所表示的矩形边界的最短距离
ans:最终答案
p:待查点
now:当前所在节点
int ans;
void query(int now,Point p) {
if (!now) return;
ans=min(ans,dis(T[now].p,p);//更新答案
double dis[2]={T[T[now].ch[0]].dis(p),T[T[now].ch[1]].dis(p)};//左/右儿子所表示的矩形到待查点的最短距离
int next=dis[0]>dis[1];//选择左/右儿子(左儿子为0,右儿子为1)
query(T[now].ch[next],p);
if (dis[next^1]<ans) query(T[now].ch[next^1],p);//判断另一个儿子是否有可能对答案也有贡献
}
第k远点(洛谷 P2093)
算法与查询最近点类似,用一个优先队列维护距离,距离小的优先。每次判断是否能更新堆顶即可。
注意是k远点,并且当距离相等时,编号大的点优先。
id:点的编号
node:优先队列的元素,第一关键字是距离,第二关键字是编号
priority_queue<node> q;
void query(int now,Point p) {
if (!now) return;
node st=node(dis(T[now].p,p),T[now].p.id);
if (st<q.top()) q.pop(),q.push(st);
double dis[2]={T[T[now].ch[0]].dis(p),T[T[now].ch[1]].dis(p)};
int next=dis[0]<dis[1];query(T[now].ch[next],p);
if (node(dis[next^1],T[T[now].ch[next^1]].id)<q.top()) query(T[now].ch[next^1],p);
}
最终实现
这里仅给出k远点的实现
注:参考了k-d树学习笔记
二维点
struct Point {
int x,y,id;
Point (int x=0,int y=0):x(x),y(y) {}
}ps[N];
比较当前维数大小
del:当前维数,x为0,y为1
bool del;
bool cmp(Point p1,Point p2) {
if (!del) return (p1.x<p2.x||(p1.x==p2.x&&p1.y<p2.y));
return (p1.y<p2.y||(p1.y==p2.y&&p1.x<p2.x));
}
两点间距离
double dis(Point p1,Point p2) {
return (double)(p1.x-p2.x)*(p1.x-p2.x)+(double)(p1.y-p2.y)*(p1.y-p2.y);
}
KD-Tree结构体
r1:该节点表示矩形的左下角
r2:该节点表示矩形的右上角
注意dis函数的正确调用
struct Tree {
int ch[2],id;Point p,r1,r2;
Tree(Point p=Point(),int id=0):p(p),r1(p),r2(p),id(id) {}
double dis(Point p) {
if (!id) return -inf;
return max(max(::dis(p,r1),::dis(p,r2)),max(::dis(p,Point(r1.x,r2.y)),::dis(p,Point(r2.x,r1.y))));
}
}T[N];
维护r1,r2
void pushup(int rt) {
T[rt].r1.x=min(min(T[T[rt].ch[0]].r1.x,T[T[rt].ch[1]].r1.x),T[rt].r1.x);
T[rt].r1.y=min(min(T[T[rt].ch[0]].r1.y,T[T[rt].ch[1]].r1.y),T[rt].r1.y);
T[rt].r2.x=max(max(T[T[rt].ch[0]].r2.x,T[T[rt].ch[1]].r2.x),T[rt].r2.x);
T[rt].r2.y=max(max(T[T[rt].ch[0]].r2.y,T[T[rt].ch[1]].r2.y),T[rt].r2.y);
}
初始化
T[0]是一个不存在的点
void init() {
T[0].r1=Point(0x3f3f3f3f,0x3f3f3f3f),T[0].r2=Point(-0x3f3f3f3f,-0x3f3f3f3f);
}
建树
int build(int l,int r,int d) {
if (l>r) return 0;
del=d;int mid=(l+r)>>1,at=++ncnt;
nth_element(ps+l,ps+mid,ps+r+1,cmp);
T[at]=Tree(ps[mid],mid);
T[at].ch[0]=build(l,mid-1,d^1),T[at].ch[1]=build(mid+1,r,d^1);
pushup(at);return at;
}
优先队列中的点(注意优先队列是大根堆)
struct node {
double dis;int id;
node(double dis=0,int id=0):dis(dis),id(id) {}
bool operator < (node b) const {
return dis>b.dis||(dis==b.dis&&id<b.id);
}
};
查询
priority_queue<node> q;
void query(int now,Point p) {
if (!now) return;
node st=node(dis(T[now].p,p),T[now].p.id);
if (st<q.top()) q.pop(),q.push(st);
double dis[2]={T[T[now].ch[0]].dis(p),T[T[now].ch[1]].dis(p)};
int next=dis[0]<dis[1];query(T[now].ch[next],p);
if (node(dis[next^1],T[T[now].ch[next^1]].id)<q.top()) query(T[now].ch[next^1],p);
}
总代码实现如下
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#include <algorithm>
using namespace std;
#define N 100010
const double inf=1e300;
struct Point {
int x,y,id;
Point (int x=0,int y=0):x(x),y(y) {}
}ps[N];
bool del;
bool cmp(Point p1,Point p2) {
if (!del) return (p1.x<p2.x||(p1.x==p2.x&&p1.y<p2.y));
return (p1.y<p2.y||(p1.y==p2.y&&p1.x<p2.x));
}
double dis(Point p1,Point p2) {
return (double)(p1.x-p2.x)*(p1.x-p2.x)+(double)(p1.y-p2.y)*(p1.y-p2.y);
}
struct Tree {
int ch[2],id;Point p,r1,r2;
Tree(Point p=Point(),int id=0):p(p),r1(p),r2(p),id(id) {}
double dis(Point p) {
if (!id) return -inf;
return max(max(::dis(p,r1),::dis(p,r2)),max(::dis(p,Point(r1.x,r2.y)),::dis(p,Point(r2.x,r1.y))));
}
}T[N];
void pushup(int rt) {
T[rt].r1.x=min(min(T[T[rt].ch[0]].r1.x,T[T[rt].ch[1]].r1.x),T[rt].r1.x);
T[rt].r1.y=min(min(T[T[rt].ch[0]].r1.y,T[T[rt].ch[1]].r1.y),T[rt].r1.y);
T[rt].r2.x=max(max(T[T[rt].ch[0]].r2.x,T[T[rt].ch[1]].r2.x),T[rt].r2.x);
T[rt].r2.y=max(max(T[T[rt].ch[0]].r2.y,T[T[rt].ch[1]].r2.y),T[rt].r2.y);
}
int ncnt;
void init() {
T[0].r1=Point(0x3f3f3f3f,0x3f3f3f3f),T[0].r2=Point(-0x3f3f3f3f,-0x3f3f3f3f);
}
int build(int l,int r,int d) {
if (l>r) return 0;
del=d;int mid=(l+r)>>1,at=++ncnt;
nth_element(ps+l,ps+mid,ps+r+1,cmp);
T[at]=Tree(ps[mid],mid);
T[at].ch[0]=build(l,mid-1,d^1),T[at].ch[1]=build(mid+1,r,d^1);
pushup(at);return at;
}
struct node {
double dis;int id;
node(double dis=0,int id=0):dis(dis),id(id) {}
bool operator < (node b) const {
return dis>b.dis||(dis==b.dis&&id<b.id);
}
};
priority_queue<node> q;
void query(int now,Point p) {
if (!now) return;
node st=node(dis(T[now].p,p),T[now].p.id);
if (st<q.top()) q.pop(),q.push(st);
double dis[2]={T[T[now].ch[0]].dis(p),T[T[now].ch[1]].dis(p)};
int next=dis[0]<dis[1];query(T[now].ch[next],p);
if (node(dis[next^1],T[T[now].ch[next^1]].id)<q.top()) query(T[now].ch[next^1],p);
}
int main() {
init();int n,m;scanf("%d",&n);
for (int i=1;i<=n;i++) scanf("%d%d",&ps[i].x,&ps[i].y),ps[i].id=i;
build(1,n,0),scanf("%d",&m);
while (m--) {
int x,y,k;scanf("%d%d%d",&x,&y,&k);
while (!q.empty()) q.pop();
for (int i=1;i<=k;i++) q.push(node(-inf));
query(1,Point(x,y)),printf("%d\n",q.top().id);
}
return 0;
}