以下为k-d树的python实现:
from binary_tree import Node, BinaryTree # binary_tree的代码见文章:https://blog.csdn.net/moyao_miao/article/details/136787981
class KDNode(Node):
"""k-d树节点类"""
def __init__(self, value, dimension):
super().__init__(value)
self.parent_node = None # 记录父节点
self.dimension = dimension # 记录节点划分维度
self.dimension_value = self.value[self.dimension] # 节点在该维度的值
class KDTree(BinaryTree):
"""k-d树类"""
def __init__(self, data_list):
self.dimension = len(data_list[0]) # 设置树的维度为数据列表中点的维度
# 计算每一维的范围,作为归一化系数
self.normalization_list = [max(data_list, key=lambda x: x[dimension])[dimension] -
min(data_list, key=lambda x: x[dimension])[dimension]
for dimension in range(self.dimension)]
self.visited_node_set = set() # 初始化一个用于记录访问过的节点的集合,以避免重复访问
self.max_distance = float('inf') # 初始化最大距离为无穷大,用于最近邻搜索中的比较
self.root = self._list_to_binarytree(data_list)
def _list_to_binarytree(self, data_slice, layer=0):
"""
将数据列表转换为k-d树:通过递归地将数据列表分割,并在每一层交替考虑不同的维度来构建k-d树。中位数被用来决定如何分割数据,保证树是平衡的。创建的节点会按照当前考虑的维度(由层次控制)存储一个点,并递归地为左右子树分配剩余的点。
**改进建议**:
1. **性能优化**:在排序数据时,考虑到该步骤可能在大数据集上成为瓶颈,可以寻求更高效的选择中位数的算法。例如,可以使用类似快速选择算法的方法来找到中位数而不完全排序,以此提升性能。
:param data_slice:包含需要加入k-d树的点的列表。
:param layer:当前递归的层次,默认为0,用于计算当前的维度。
:return:构建好的KD树的根节点。
"""
if data_slice:
dimension = layer % self.dimension # 根据当前层次计算维度
data_slice.sort(key=lambda x: x[dimension]) # 根据当前维度对数据进行排序
median_index = len(data_slice) // 2 # 计算中位数索引,用于分割数据
node = KDNode(data_slice[median_index], dimension) # 创建当前节点
node.left_node = self._list_to_binarytree(data_slice[:median_index], layer + 1) # 为节点左子树递归构建KD树
if node.left_node: node.left_node.parent_node = node # 如果左子树非空,设置父节点
node.right_node = self._list_to_binarytree(data_slice[median_index + 1:], layer + 1) # 为节点右子树递归构建KD树
if node.right_node: node.right_node.parent_node = node # 如果右子树非空,设置父节点
return node
def distance(self, p1, p2):
"""计算两个点归一化的L2距离"""
return sum(((p1[dimension] - p2[dimension]) / self.normalization_list[dimension]) ** 2
for dimension in range(self.dimension)) ** 0.5
def search_nearest(self, data, node):
"""
递归搜索从指定节点开始离数据点最近的叶子节点。
:param data: 要找近邻的数据点
:param node: 搜索的起始节点
:return: 最近的叶子节点
"""
# 若节点离数据点的距离小于最大距离,且节点的左右子节点至少有一个存在且未被访问过,则递归向下搜索
if (self.distance(node.value, data) < self.max_distance and
((node.left_node and node.left_node not in self.visited_node_set) or
(node.right_node and node.right_node not in self.visited_node_set))):
# 若数据点该维度的值小于节点的,则递归搜索其左子节点,如左子节点已被访问过,则搜索其右子节点
if data[node.dimension] < node.dimension_value:
if node.left_node in self.visited_node_set:
node = self.search_nearest(data, node.right_node)
else: node = self.search_nearest(data, node.left_node)
# 若数据点该维度的值大于等于节点的,则递归搜索其右子节点,如右子节点已被访问过,则搜索其左子节点
else:
if node.right_node in self.visited_node_set:
node = self.search_nearest(data, node.left_node)
else: node = self.search_nearest(data, node.right_node)
return node
def KNN(self, data, k):
"""
KNN算法,搜索k个最近邻居。
:param data: 要找近邻的数据点
:param k: 搜索近邻的数量
:return: 由k个最近邻居按距离顺序排好的列表
"""
def distance(node):
return self.distance(node.value, data)
neighbor_list = [] # 初始化一个用于记录找到的邻居的列表
node = self.root # 初始化搜索起始位置为根节点
# 当搜索位置到顶、且数据点到节点的切平面距离超过最大距离之前循环搜索邻居:
while node and abs(data[node.dimension] - node.dimension_value) / self.normalization_list[node.dimension] < self.max_distance:
neighbor_node = self.search_nearest(data, node) # 本次搜索到的邻居
self.visited_node_set.add(neighbor_node) # 记录为已访问过
if len(neighbor_list) < k: neighbor_list.append(neighbor_node) # 列表未满之前直接加入
else:
# 列表已满则对其按离数据点的距离排序,记录最大距离
neighbor_list.sort(key=distance)
self.max_distance = distance(neighbor_list[-1])
# 若本次搜索到的邻居比列表里最远的邻居近,则将其替换
if distance(neighbor_node) < self.max_distance:
neighbor_list.pop()
neighbor_list.append(neighbor_node)
node = neighbor_node.parent_node # 返回上一级继续搜索
return neighbor_list
if __name__ == "__main__":
obj = KDTree([(3, 2), (7, 3), (4, 6), (5, 7), (8, 9), (11, 5), (12, 8), (13, 1), (14, 4), (14, 10)])
print('k-d树图示:');obj.plot()
data, k = (13, 6), 3
print(f'离{data}最近的{k}个近邻:')
[print(node) for node in obj.KNN(data, k)]
输出:
k-d树图示:
(11, 5)
(4, 6) (12, 8)
(7, 3) (8, 9) (14, 4) (14, 10)
(3, 2) N (5, 7) N (13, 1) N N N
离(13, 6)最近的3个近邻:
(14, 4)
(12, 8)
(11, 5)