目录
[toc]
1. K近邻法基础
1.1 模型与算法
K近邻法(K-nearest neighbor,KNN)是最基础的机器学习模型之一,它的类别为:
- 分类(√)、回归(√)、标注
- 概率软分类、非概率硬分类(√)
- 监督(√)、无监督、强化
- 线性、非线性(√)
- 判别(√)、生成
KNN既可以用于分类,也可用于回归。分类模型和回归模型本质一样,分类模型是将回归模型的输出离散化。一般来讲:回归问题是对真实值的定量逼近预测,通常结果为连续值;分类问题是为对象定性打标签,通常结果为离散值。
分类模型:
输入:
1.训练数据集:,其中,为训练样本,为样本的类别。
2.测试数据
输出:
测试数据所属的类别
算法:
1.根据给定的距离度量,在训练集中寻找与x最临近的k个点,涵盖这k个点的区域记作;
2.根据多数表决规则,确定x的类别y:
式中,为指示函数,即当时,否则为0.
回归模型:
输入:
1.训练数据集:,其中,为训练样本,为样本对应的值。
2.待回归数据
输出 :
对应的值
算法:
1.根据给定的距离度量,在训练集中寻找与x最临近的k个点,涵盖这k个点的区域记作;
2.根据这k个近邻点的对应的值,确定x的类别y:
1.2 距离度量
在上节的算法中提到了距离度量,最常用的距离度量方法是欧式距离,即二范数距离:
也可以是1范数距离,又叫曼哈顿距离:
曼哈顿距离可用于这样的场景:在一个由垂直和水平街道分割的城市里,从一个交叉路口到另一个交叉路口之间的路程即为曼哈顿距离。下图中,绿色连线的长度为欧式距离,其他三种颜色的连线长度都为曼哈顿距离。
还可以时∞范数距离,等价于各维坐标距离的最大值:
负无穷范数刚好相反,等价于各维坐标距离的最小值。
范数距离的关系如下图所示:
1.3 K值选择
K值选择会影响算法结果。
若选择较小的K值,相当于用较小的邻域中的训练样本来预测,可以获得较小的经验误差,但容易过拟合,泛化误差将会很大,泛化能力弱。
若选择较大的K值,能起到平滑的效果,随着K的增大,泛化误差先减小,再增大。而经验误差随着K增大而不断增大。
如果K=N,无论输入实例是什么,都简单地预测为训练实例中的最多数(分类),或训练实例的均值(回归)。
在实际应用中,K一般取一个较小的值,且通常采用交叉验证的方法来选取最优的K。
下图测试了回归问题中,K的不同取值对于回归性能的影响,具体代码见附录:
1.4 邻近点的搜索算法
KNN算法需要在中搜索与x最临近的k个点,最直接的方法是逐个计算x与中所有点的距离,并排序选择最小的k个点,即线性扫描。当训练数据集很大时,计算非常耗时,以至于不可行。
实际应用中常用的是kd-tree(k-dimension tree)和ball-tree这两种方法。ball-tree是对kd-tree的改进,在数据维度大于20时,kd-tree性能急剧下降,而ball-tree在高维数据情况下具有更好的性能。
关于kd-tree和ball-tree将在本文第2和第3章介绍。
2. kd-tree算法
KNN算法的核心是寻找待测样本在训练样本集中的k个近邻,如果训练样本集过大,则传统的遍历全样本寻找k近邻的方式将导致性能的急剧下降。
kd-tree以空间换时间,利用训练样本集中的样本点,沿各维度依次对k维空间进行划分,建立二叉树,利用分治思想大大提高算法搜索效率。我们知道,二分查找的算法复杂度是,kd-tree的搜索效率与之接近(取决于所构造kd-tree是否接近平衡树)。如下图所示,为训练样本对空间的划分以及对应的kd树。绿色实心五角星为测试样本,通过kd-tree的搜索算法,快速找到与其最近邻的3个训练样本点(空心五角星标注的点)。
2.1 kd-tree构建方法
构造kd-tree的方法如下:构造根节点,使根节点对应包含所有训练样本点的k维超矩形区域;递归构建左右子节点,对当前节点所包含的样本点进行划分,划分是根据第i维的中位点来确定的,中位点赋值给当前节点作为第i维的划分点,第i维小于该点的,划给左儿子节点,大于该点的,划给右儿子节点。根节点对应的划分维度为0,后继子节点按照深度依次加1,即。
这种通过对各维依次进行划分所构建的kd-tree搜索效率并非最高,若在选择划分维度时,选择剩余维度中方差最大的维度来进行划分,这样的划分分辨率最大,搜索效率也更高。但在通常的算法实现中,通过逐维度进行划分,已经足够满足性能要求。
构建kd-tree的算法伪代码如下,具体代码见附录4.2:
function fit_kd_tree is
input:
x,y: 训练样本集和对应标签
dim: 当前节点的分割维度(子节点的分割维度=(dim+1)%样本的维度)
output:
node: 构造好的kd tree的根节点
if 只有一个数据点 then
创建一个叶子结点node包含这一单一的点:
node.point := x[0]
node.label := y[0]
node.son1 := None,
node.son2 := None
return node
else:
让p为dim维度的中位点(对x中的数据按dim维排序,取中位点,偶数个则取较小的那个)
让xl为左集合(dim维小于p点的所有点)
让xr为右集合(dim维大于p点的所有点)
对应的标签也划分为yl,yr
创建带有两个孩子的node:
node.point := p
node.label := p的标签
node.son1 := fit_kd_tree(xl,yl),
node.son2 := fit_kd_tree(xr,yr)
return node
end if
end function
2.2 kd-tree K近邻搜索方法
搜索算法伪代码如下,具体代码见附录4.2:
function kd_tree_search is
global:
Q, 缓存k个最近邻点(初始时包含一个无穷远点)
q, 与Q对应,保存Q中各点与测试点的距离
input:
k, 寻找k个最近邻
t, 测试点
node, 当前节点
dim, 当前节点的分割维度(子节点的分割维度=(dim+1)%数据点的维度)
output:
无
if distance(t, node.point) < max(q) then
将node.point添加到Q,并同步更新q
若Q内超过k个近邻点,则移出与测试点距离最远的那个点,并同步更新q
end if
测试点到Q中最远点的距离为max(q),
判断测试点沿dim方向-+max(q)区间是否与当前节点分割的两个子区间重合,
若-重合,则递归搜索左儿子
若+重合,则递归搜索右儿子
if t[dim]-max(q) < node.point[dim]:
kd_tree_search(k,t,node.son1)
end if
if t[dim]+max(q) > node.point[dim]:
kd_tree_search(k,t,node.son2)
end if
end function
3. ball-tree算法
在kd-tree 中,我们看到一个导致性能下降的最核心因素是因为kd-tree中被分割的子空间是一个个的超方体,求最近邻时使用的是欧式距离(超球)。超方体与超球体相交的可能性是极高的,如下图所示,凡是相交的子空间,都需要进行检查,大大的降低运行效率。
如果划分区域也是超球体,则相交的概率大大降低。如下图所示,为ball-tree通过超球体划分空间,去掉棱角,划分超球体和搜索超球体相交的概率大大降低,特别实在数据维度很高时,算法效率得到大大提升。
3.1 ball-tree构建方法
构建ball-tree的算法伪代码如下,具体代码见附录4.3:
function fit_ball_tree is
input: x,y, 数据点的数组和对应标签
output: node,构造好的ball tree的根节点
if 只有一个数据点 then
创建一个叶子结点node包含这一单一的点:
node.pivot := x[0]
node.label := y[0]
node.son1 := None,
node.son2 := None,
node.radius := 0
return node
else:
让c为最宽的维度
让p1,p2为该维度最两端的点
让p为这个维度的中心点 := (p1+p2)/2
让radius为p到x上最远点的距离
让xl为左集合(距离p1更近的所有点)
让xr为右集合(距离p2更近的所有点)
对应的标签也划分为yl,yr
创建带有两个孩子的node:
node.pivot := p
node.label := None
node.son1 := fit_balltree(xl,yl),
node.son2 := fit_balltree(xr,yr),
node.radius := radius
return node
end if
end function
3.2 ball-tree K近邻搜索方法
搜索算法伪代码如下,具体代码见附录4.3:
function ball_tree_search is
global:
Q, 缓存k个最近邻点(初始时包含一个无穷远点)
q, 与Q对应,保存Q中各点与测试点的距离
input:
k, 寻找k个最近邻
t, 测试点
node, 当前节点
output:
无
三角不等式:若测试点到当前球的最近距离大于到Q中最远点的距离,则当前球中不可能包含待搜索的近邻点
if distance(t, node.pivot) - node.radius ≥ max(q) then
return
if node为叶节点 then
将node.pivot添加到Q,并同步更新q
若Q内超过k个近邻点,则移出与测试点距离最远的那个点,并同步更新q
else:
递归搜索当前节点的左儿子和右儿子
ball_tree_search(k,t,node.son1)
ball_tree_search(k,t,node.son2)
end if
end function
4. 附录
4.1 K值选择对回归性能的影响
import numpy as np
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
X,Y=make_moons(200,noise=0.05,random_state=1)
x1=np.arange(-1,2,0.1)
fig=plt.figure(figsize=(9,6))
K=[1,5,10,50,100,200]
for j in range(6):
ax=fig.add_subplot(2,3,j+1)
ax.scatter(X[:,0],X[:,1],s=5)
x2=np.array([])
k=K[j]
for i in x1:
x2=np.append(x2,np.mean(X[np.argsort(np.abs(X[:,0]-i))[0:k],1]))
ax.plot(x1,x2,c='r')
ax.title.set_text('k=%d'%k)
4.2 kd-tree构建和搜索
- 注:kd-tree和ball-tree构建后,借助于networkx工具包绘制树形图。networkx工具包主要用于构建图模型和绘制图,绘制树图需要对节点位置进行调整,这里使用了hierarchy_pos_ugly和hierarchy_pos_beautiful两个函数来对图中节点按树形布局。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import networkx as nx
import random
def hierarchy_pos_ugly(G, root, levels=None, width=1., height=1.):
"""If there is a cycle that is reachable from root, then this will see infinite recursion.
G: the graph
root: the root node
levels: a dictionary
key: level number (starting from 0)
value: number of nodes in this level
width: horizontal space allocated for drawing
height: vertical space allocated for drawing"""
TOTAL = "total"
CURRENT = "current"
def make_levels(levels, node=root, currentLevel=0, parent=None):
"""Compute the number of nodes for each level
"""
if not currentLevel in levels:
levels[currentLevel] = {TOTAL: 0, CURRENT: 0}
levels[currentLevel][TOTAL] += 1
neighbors = G.neighbors(node)
for neighbor in neighbors:
if not neighbor == parent:
levels = make_levels(levels, neighbor, currentLevel + 1, node)
return levels
def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0):
dx = 1 / levels[currentLevel][TOTAL]
left = dx / 2
pos[node] = ((left + dx * levels[currentLevel][CURRENT]) * width, vert_loc)
levels[currentLevel][CURRENT] += 1
neighbors = G.neighbors(node)
for neighbor in neighbors:
if not neighbor == parent:
pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc - vert_gap)
return pos
if levels is None:
levels = make_levels({})
else:
levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels}
vert_gap = height / (max([l for l in levels]) + 1)
return make_pos({})
def hierarchy_pos_beautiful(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5):
'''
From Joel's answer at https://stackoverflow.com/a/29597209/2966723.
Licensed under Creative Commons Attribution-Share Alike
If the graph is a tree this will return the positions to plot this in a
hierarchical layout.
G: the graph (must be a tree)
root: the root node of current branch
- if the tree is directed and this is not given,
the root will be found and used
- if the tree is directed and this is given, then
the positions will be just for the descendants of this node.
- if the tree is undirected and not given,
then a random choice will be used.
width: horizontal space allocated for this branch - avoids overlap with other branches
vert_gap: gap between levels of hierarchy
vert_loc: vertical location of root
xcenter: horizontal location of root
'''
if not nx.is_tree(G):
raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')
if root is None:
if isinstance(G, nx.DiGraph):
root = next(iter(nx.topological_sort(G))) # allows back compatibility with nx version 1.11
else:
root = random.choice(list(G.nodes))
def _hierarchy_pos(G, root, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None):
'''
see hierarchy_pos docstring for most arguments
pos: a dict saying where all nodes go if they have been assigned
parent: parent of this branch. - only affects it if non-directed
'''
if pos is None:
pos = {root: (xcenter, vert_loc)}
else:
pos[root] = (xcenter, vert_loc)
children = list(G.neighbors(root))
if not isinstance(G, nx.DiGraph) and parent is not None:
children.remove(parent)
if len(children) != 0:
dx = width / len(children)
nextx = xcenter - width / 2 - dx / 2
for child in children:
nextx += dx
pos = _hierarchy_pos(G, child, width=dx, vert_gap=vert_gap,
vert_loc=vert_loc - vert_gap, xcenter=nextx,
pos=pos, parent=root)
return pos
return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)
'''
调用绘制树形图:
pos = hierarchy_pos_beautiful(G, "Root") # 生成树的节点位置信息,第二个参数为根节点名
node_labels = nx.get_node_attributes(G, 'attr') # 提取树的属性标签,第二个参数为属性标签名
nx.draw(G, pos, with_labels=True, labels=node_labels) # 绘制树
plt.show() # 显示
'''
X,Y=make_blobs(n_samples=6,
n_features=2,
centers=2,
cluster_std=4,
random_state=0)
fig=plt.figure(figsize=(5,5))
ax=fig.add_subplot(111)
plt.show()
ax.scatter(X[:,0],X[:,1],c=Y, s=60, cmap='rainbow')
# function fit_kd_tree is
# input:
# x,y: 数据点的数组和对应标签
# dim: 当前节点的分割维度(子节点的分割维度=(dim+1)%数据点的维度)
# output:
# node: 构造好的kd tree的根节点
# if 只有一个数据点 then
# 创建一个叶子结点node包含这一单一的点:
# node.point := x[0]
# node.label := y[0]
# node.son1 := None,
# node.son2 := None
# return node
# else:
# 让p为dim维度的中位点(对x中的数据按dim维排序,取中位点,偶数个则取较小的那个)
# 让xl为左集合(dim维小于p点的所有点)
# 让xr为右集合(dim维大于p点的所有点)
# 对应的标签也划分为yl,yr
# 创建带有两个孩子的node:
# node.point := p
# node.label := p的标签
# node.son1 := fit_kd_tree(xl,yl),
# node.son2 := fit_kd_tree(xr,yr)
# return node
# end if
# end function
G=nx.Graph()
def fit_kd_tree(x,y,dim=0):
if x.size==0:
return None
# if x.shape[0]==1:
# node=dict({'point':x[0],
# 'label':y[0],
# 'son1':None,
# 'son2':None
# })
# return node
idxs=np.argsort(x[:,dim])
middle_idx=idxs[int(idxs.size/2)]
p=x[middle_idx] #p为dim维度的中位点
label=y[middle_idx]
x1,y1,x2,y2=[],[],[],[]
for i in idxs[0:int(idxs.size/2)]:
x1.append(x[i])
y1.append(y[i])
for i in idxs[int(idxs.size/2)+1:]:
x2.append(x[i])
y2.append(y[i])
x1=np.array(x1)
y1=np.array(y1)
x2=np.array(x2)
y2=np.array(y2)
# 递归构建左子树和右子树
son1=fit_kd_tree(x1,y1,(dim+1)%x.shape[1])
son2=fit_kd_tree(x2,y2,(dim+1)%x.shape[1])
node=dict({'point':p,
'label':label,
'son1':son1,
'son2':son2
})
if son1!=None:
G.add_edge('(%.1f,%.1f)'%tuple(node['point']),
'(%.1f,%.1f)'%tuple(node['son1']['point']))
if son2!=None:
G.add_edge('(%.1f,%.1f)'%tuple(node['point']),
'(%.1f,%.1f)'%tuple(node['son2']['point']))
return node
root=fit_kd_tree(X,Y)
# 遍历kd tree,将划分区域绘制出来
def plot_partition(node,dim=0,bound=ax.axis()): #bound为绘制划分线的边界
# if node['son1']==None and node['son2']==None: #叶结点,返回
# return
line_d=np.arange(bound[(dim+1)%2*2],bound[(dim+1)%2*2+1],0.01)
line=np.ones((line_d.size,2))
line[:,(dim+1)%2]=line_d
line[:,dim]=node['point'][dim]
plt.plot(line[:,0],line[:,1])
if node['son1']!=None:
bound1=list(bound)
bound1[dim*2+1]=node['point'][dim]
plot_partition(node['son1'],(dim+1)%2,bound1)
if node['son2']!=None:
bound2=list(bound)
bound2[dim*2]=node['point'][dim]
plot_partition(node['son2'],(dim+1)%2,bound2)
orign_bound=ax.axis()
plot_partition(root)
ax.axis(orign_bound)
fig2=plt.figure(figsize=(5,5))
pos=hierarchy_pos_ugly(G,root='(%.1f,%.1f)'%tuple(root['point']))
nx.draw(G,pos,with_labels=True,font_size=8,node_size=1500,node_shape='o',node_color='xkcd:light blue')
# function kd_tree_search is
# global:
# Q, 缓存k个最近邻点(初始时包含一个无穷远点)
# q, 与Q对应,保存Q中各点与测试点的距离
# input:
# k, 寻找k个最近邻
# t, 测试点
# node, 当前节点
# dim, 当前节点的分割维度(子节点的分割维度=(dim+1)%数据点的维度)
# output:
# 无
# if distance(t, node.point) < max(q) then
# 将node.point添加到Q,并同步更新q
# 若Q内超过k个近邻点,则移出与测试点距离最远的那个点,并同步更新q
# end if
# 测试点到Q中最远点的距离为max(q),
# 判断测试点沿dim方向-+max(q)区间是否与当前节点分割的两个子区间重合,
# 若-重合,则递归搜索左儿子
# 若+重合,则递归搜索右儿子
# if t[dim]-max(q) < node.point[dim]:
# kd_tree_search(k,t,node.son1)
# end if
# if t[dim]+max(q) > node.point[dim]:
# kd_tree_search(k,t,node.son2)
# end if
# end function
Q=np.array([[np.inf,np.inf]])
q=np.array([np.inf])
def kd_tree_search(k,t,node,dim=0):
global Q,q
if np.linalg.norm(t-node['point'])<np.max(q):
if Q.shape[0]==k:
Q=np.delete(Q,np.argmax(q),axis=0)
q=np.delete(q,np.argmax(q))
Q=np.append(Q,[node['point']],axis=0)
q=np.append(q,np.linalg.norm(t-node['point']))
if t[dim]-np.max(q)<node['point'][dim] and node['son1']!=None:
kd_tree_search(k,t,node['son1'],(dim+1)%t.size)
if t[dim]+np.max(q)>node['point'][dim] and node['son2']!=None:
kd_tree_search(k,t,node['son2'],(dim+1)%t.size)
k=3
t=np.array([6,3])
kd_tree_search(k,t,root)
print(Q)
fig.axes[0].scatter(t[0],t[1],marker='*',s=500,color='green')
fig.axes[0].scatter(Q[:,0],Q[:,1],marker='*',s=500,facecolors='none',edgecolors='green')
4.3 ball-tree构建和搜索
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import networkx as nx
import random
def hierarchy_pos_ugly(G, root, levels=None, width=1., height=1.):
"""If there is a cycle that is reachable from root, then this will see infinite recursion.
G: the graph
root: the root node
levels: a dictionary
key: level number (starting from 0)
value: number of nodes in this level
width: horizontal space allocated for drawing
height: vertical space allocated for drawing"""
TOTAL = "total"
CURRENT = "current"
def make_levels(levels, node=root, currentLevel=0, parent=None):
"""Compute the number of nodes for each level
"""
if not currentLevel in levels:
levels[currentLevel] = {TOTAL: 0, CURRENT: 0}
levels[currentLevel][TOTAL] += 1
neighbors = G.neighbors(node)
for neighbor in neighbors:
if not neighbor == parent:
levels = make_levels(levels, neighbor, currentLevel + 1, node)
return levels
def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0):
dx = 1 / levels[currentLevel][TOTAL]
left = dx / 2
pos[node] = ((left + dx * levels[currentLevel][CURRENT]) * width, vert_loc)
levels[currentLevel][CURRENT] += 1
neighbors = G.neighbors(node)
for neighbor in neighbors:
if not neighbor == parent:
pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc - vert_gap)
return pos
if levels is None:
levels = make_levels({})
else:
levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels}
vert_gap = height / (max([l for l in levels]) + 1)
return make_pos({})
def hierarchy_pos_beautiful(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5):
'''
From Joel's answer at https://stackoverflow.com/a/29597209/2966723.
Licensed under Creative Commons Attribution-Share Alike
If the graph is a tree this will return the positions to plot this in a
hierarchical layout.
G: the graph (must be a tree)
root: the root node of current branch
- if the tree is directed and this is not given,
the root will be found and used
- if the tree is directed and this is given, then
the positions will be just for the descendants of this node.
- if the tree is undirected and not given,
then a random choice will be used.
width: horizontal space allocated for this branch - avoids overlap with other branches
vert_gap: gap between levels of hierarchy
vert_loc: vertical location of root
xcenter: horizontal location of root
'''
if not nx.is_tree(G):
raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')
if root is None:
if isinstance(G, nx.DiGraph):
root = next(iter(nx.topological_sort(G))) # allows back compatibility with nx version 1.11
else:
root = random.choice(list(G.nodes))
def _hierarchy_pos(G, root, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None):
'''
see hierarchy_pos docstring for most arguments
pos: a dict saying where all nodes go if they have been assigned
parent: parent of this branch. - only affects it if non-directed
'''
if pos is None:
pos = {root: (xcenter, vert_loc)}
else:
pos[root] = (xcenter, vert_loc)
children = list(G.neighbors(root))
if not isinstance(G, nx.DiGraph) and parent is not None:
children.remove(parent)
if len(children) != 0:
dx = width / len(children)
nextx = xcenter - width / 2 - dx / 2
for child in children:
nextx += dx
pos = _hierarchy_pos(G, child, width=dx, vert_gap=vert_gap,
vert_loc=vert_loc - vert_gap, xcenter=nextx,
pos=pos, parent=root)
return pos
return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)
'''
调用绘制树形图:
pos = hierarchy_pos_beautiful(G, "Root") # 生成树的节点位置信息,第二个参数为根节点名
node_labels = nx.get_node_attributes(G, 'attr') # 提取树的属性标签,第二个参数为属性标签名
nx.draw(G, pos, with_labels=True, labels=node_labels) # 绘制树
plt.show() # 显示
'''
X,Y=make_blobs(n_samples=6,
n_features=2,
centers=2,
cluster_std=4,
random_state=0)
fig=plt.figure(figsize=(5,5))
ax=fig.add_subplot(111)
plt.show()
ax.scatter(X[:,0],X[:,1],c=Y, s=60, cmap='rainbow')
# function fit_ball_tree is
# input: x,y, 数据点的数组和对应标签
# output: node,构造好的ball tree的根节点
# if 只有一个数据点 then
# 创建一个叶子结点node包含这一单一的点:
# node.pivot := x[0]
# node.label := y[0]
# node.son1 := None,
# node.son2 := None,
# node.radius := 0
# return node
# else:
# 让c为最宽的维度
# 让p1,p2为该维度最两端的点
# 让p为这个维度的中心点 := (p1+p2)/2
# 让radius为p到x上最远点的距离
# 让xl为左集合(距离p1更近的所有点)
# 让xr为右集合(距离p2更近的所有点)
# 对应的标签也划分为yl,yr
# 创建带有两个孩子的node:
# node.pivot := p
# node.label := None
# node.son1 := fit_balltree(xl,yl),
# node.son2 := fit_balltree(xr,yr),
# node.radius := radius
# return node
# end if
# end function
G=nx.Graph()
def fit_ball_tree(x,y):
if x.shape[0]==1:
node=dict({'pivot':x[0],
'label':y[0],
'son1':None,
'son2':None,
'radius':0
})
return node
c=np.argmax(np.std(x,axis=0)) #c为最宽的维度
p1=x[np.argmin(x[:,c])]
p2=x[np.argmax(x[:,c])]
p=(p1+p2)/2 #p为c维度的中心点
radius=max(np.linalg.norm(x-p,axis=1)) #p到各点的最大距离(球半径)
x1,y1,x2,y2=[],[],[],[]
# 根据x中各点到p1和p2的距离,将x分为两个子集
for i in range(x.shape[0]):
if np.linalg.norm(x[i]-p1)<np.linalg.norm(x[i]-p2):
x1.append(x[i])
y1.append(y[i])
else:
x2.append(x[i])
y2.append(y[i])
x1=np.array(x1)
y1=np.array(y1)
x2=np.array(x2)
y2=np.array(y2)
# 递归构建左子树和右子树
son1=fit_ball_tree(x1,y1)
son2=fit_ball_tree(x2,y2)
node=dict({'pivot':p,
'label':None,
'son1':son1,
'son2':son2,
'radius':radius
})
G.add_edge('(%.1f,%.1f)'%tuple(node['pivot']),
'(%.1f,%.1f)'%tuple(node['son1']['pivot']))
G.add_edge('(%.1f,%.1f)'%tuple(node['pivot']),
'(%.1f,%.1f)'%tuple(node['son2']['pivot']))
return node
root=fit_ball_tree(X, Y)
# 遍历ball tree,将划分区域绘制出来,使用参数方程画圆
def plot_partition(node):
if node['radius']==0: #叶结点,返回
return
theta = np.linspace(0,2*np.pi,200)
x0 = node['radius']*np.cos(theta)+node['pivot'][0]
x1 = node['radius']*np.sin(theta)+node['pivot'][1]
plt.plot(x0,x1,color='black')
if node['son1']!=None:
plot_partition(node['son1'])
if node['son2']!=None:
plot_partition(node['son2'])
plot_partition(root)
fig2=plt.figure(figsize=(5,5))
pos=hierarchy_pos_ugly(G,root='(%.1f,%.1f)'%tuple(root['pivot']))
nx.draw(G,pos,with_labels=True,font_size=8,node_size=1500,node_shape='o',node_color='xkcd:light blue')
# function ball_tree_search is
# global:
# Q, 缓存k个最近邻点(初始时包含一个无穷远点)
# q, 与Q对应,保存Q中各点与测试点的距离
# input:
# k, 寻找k个最近邻
# t, 测试点
# node, 当前节点
# output:
# 无
# 三角不等式:若测试点到当前球的最近距离大于到Q中最远点的距离,则当前球中不可能包含待搜索的近邻点
# if distance(t, node.pivot) - node.radius ≥ max(q) then
# return
# if node为叶节点 then
# 将node.pivot添加到Q,并同步更新q
# 若Q内超过k个近邻点,则移出与测试点距离最远的那个点,并同步更新q
# else:
# 递归搜索当前节点的左儿子和右儿子
# ball_tree_search(k,t,node.son1)
# ball_tree_search(k,t,node.son2)
# end if
# end function
Q=np.array([[np.inf,np.inf]])
q=np.array([np.inf])
def ball_tree_search(k,t,node):
global Q,q
if np.linalg.norm(t-node['pivot'])-node['radius']>=np.max(q):
return
if node['son1']==None and node['son2']==None:
if Q.shape[0]==k:
Q=np.delete(Q,np.argmax(q),axis=0)
q=np.delete(q,np.argmax(q))
Q=np.append(Q,[node['pivot']],axis=0)
q=np.append(q,np.linalg.norm(t-node['pivot']))
else:
ball_tree_search(k,t,node['son1'])
ball_tree_search(k,t,node['son2'])
k=3
t=np.array([6,3])
ball_tree_search(k,t,root)
print(Q)
fig.axes[0].scatter(t[0],t[1],marker='*',s=500,color='green')
fig.axes[0].scatter(Q[:,0],Q[:,1],marker='*',s=500,facecolors='none',edgecolors='green')