不使用递归的决策树生成算法
利用队列 queue ,实现层次遍历(广度优先遍历),逐步处理每个节点来建立子树结构。再构建一个辅助队列,将每个节点存储到 nodes_to_process 列表中,以便在树生成完成后可以反向遍历计算每个节点的 leaf_num(叶子节点数量)。对于每个节点,根据特征选择和树的条件构建子节点;如果达到叶节点条件,直接将其标记为叶节点。最后,逆序处理计算每个结点的叶节点数量:通过逆序遍历 nodes_to_process 列表(即从叶节点到根节点),每次更新父节点的 leaf_num 为其所有子节点 leaf_num 的总和。
在构建决策树的过程中,每个节点都会根据特征选择和树的构建条件来决定是否进一步分裂。以下是这个步骤的详细说明:
1
、当前节点的特征选择
对于每个节点
current_node
,需要从剩余的特征集合
A
中选择一个“最优特征”
a
∗
,用于将数据集 D
划分成不同的子集。这个“最优特征”由基尼指数、信息增益或信息增率等来确定,使得划分后的子集在类别上更加纯净。
2
、判断是否满足叶节点条件
在进一步构建子节点之前,检查当前节点是否满足叶节点条件。如果满足以下任一条件,则将 current_node
标记为叶节点,而不再继续分裂:
(
1
)
单一类别
:如果数据集
D
中的所有样本都属于同一类
C
,则不再需要进一步划分。此时可以将 current_node
标记为叶节点,类别为
C
。
(
2
)
属性集为空或样本在剩余特征上取值相同
:如果
A
为空(即没有剩余特征可以选择),或数据集 D
中样本在剩余特征上的取值都相同,那么即使进一步分裂也不能提供更多信息。在这种情况下,current_node
也被标记为叶节点,并根据
D
中的样本数最多的类别作为
current_node 的类别。
(
3
)
达到最大深度
:如果当前节点的深度已经达到了预设的最大深度
MaxDepth
,则停止继续分裂,将 current_node
直接标记为叶节点,并将类别设为当前数据集中样本数最多的类别。
3
、构建子节点
如果不满足叶节点条件,则
current_node
将根据选择的特征
a
∗
来生成子节点。分情况处理:
(
1
)当前特征为离散值:如果
a
∗
是一个离散特征,节点会针对
a
∗
的每个可能的取值创建一个子节点 child_node
,表示
a
∗
取该值的样本子集。将数据集中所有在
a
∗
上取值为 a
∗v的样本(记作Da
∗
=a
∗v)分配到 child_node
,并继续构建树。 如果D
a
∗
=a
∗v为空,即该子集没有样本,说明该特征值在当前分支下没有样本。此时,将 child_node 标记为叶节点,并将其类别设为当前数据中出现次数最多的类别。如果D
a
∗
=a
∗v不为空,则将 child_node
和该子集继续加入到构建队列中。
(
2
)当前特征为连续值:如果
a
∗
是一个连续特征,则会根据分割点(采用二分法选取)将数据集划分为两个子集。构建两个子节点:一个子节点代表 a
∗
≥
split_valuea
的样本子集;另一个子节点代表 a
∗
<split_valuea
的样本子集。将两个子节点及其对应的数据集加入到构建队列中,继续后续的树构建。
4
、将子节点添加到树中
每个
child_node
会作为
current_node
的子节点,存储在
current_node.subtree
中。通过这种方式,不断将子节点加入树中,直到所有节点都满足叶节点条件,不再继续分裂为止。
5
、完成子节点分裂后的后续处理
当队列中所有节点都处理完后,逆序遍历已处理的节点列表,计算每个节点的叶节点数。
不使用递归的建树算法的实现思路
创建两个队列,分别为
queue
与
nodes_to_process
。
queue = deque([(root, X, y)])
用来存储节点和数据,queue
的结构为三元组,分别为根节点、当前节点的
X
值,即去除
a
∗
属性后剩下的 X
值,以及
y
标签。
nodes_to_process = []
记录所有节点以便后续计算
leaf_num
。遍历 queue
队列,创建根节点并将其放入队列,并将当前节点存入
nodes_to_process以记录节点。使用 queue
按层次处理每个节点。
每次处理时,首先检查是否达到叶节点条件(如最大深度或单一类别),如果是则标记为叶节点。如果不是叶节点,则选择最佳分割特征,并根据特征类型(离散或连续)生成对应的子节点。
queue
队列处理完毕后,通过
nodes_to_process
逆序遍历,每个节点的
leaf_num
设为其子节点的 leaf_num
总和。
代码实现
def generate_tree(self, X, y):
root = Node()
root.high = 0 # 根节点的高度为0
queue = deque([(root, X, y)]) # 使用队列来存储节点和数据
nodes_to_process = [] # 记录所有节点以便后续计算 leaf_num
while queue:
current_node, current_X, current_y = queue.popleft()
nodes_to_process.append(current_node)
# 叶节点条件:达到最大深度或只有单一类别或没有特征
if current_node.high >= self.MaxDepth or current_y.nunique() == 1 or current_X.empty:
current_node.is_leaf = True
current_node.leaf_class = current_y.mode()[0]
current_node.leaf_num = 1 # 是叶子节点,叶子数量为 1
continue
# 选择最佳划分特征
best_feature_name, best_impurity = self.choose_best_feature_to_split(current_X, current_y)
current_node.feature_name = best_feature_name
current_node.impurity = best_impurity[0]
current_node.feature_index = self.columns.index(best_feature_name)
feature_values = current_X[best_feature_name]
if len(best_impurity) == 1: # 离散值特征
current_node.is_continuous = False
unique_vals = feature_values.unique()
sub_X = current_X.drop(best_feature_name, axis=1)
for value in unique_vals:
child_node = Node()
child_node.high = current_node.high + 1
queue.append((child_node, sub_X[feature_values == value], current_y[feature_values == value]))
current_node.subtree[value] = child_node
elif len(best_impurity) == 2: # 连续值特征
current_node.is_continuous = True
current_node.split_value = best_impurity[1]
up_part = '>= {:.3f}'.format(current_node.split_value)
down_part = '< {:.3f}'.format(current_node.split_value)
child_node_up = Node()
child_node_down = Node()
child_node_up.high = current_node.high + 1
child_node_down.high = current_node.high + 1
queue.append((child_node_up, current_X[feature_values >= current_node.split_value],
current_y[feature_values >= current_node.split_value]))
queue.append((child_node_down, current_X[feature_values < current_node.split_value],
current_y[feature_values < current_node.split_value]))
current_node.subtree[up_part] = child_node_up
current_node.subtree[down_part] = child_node_down
# 逆序遍历 nodes_to_process,计算每个节点的 leaf_num
while nodes_to_process:
node = nodes_to_process.pop()
if node.is_leaf:
node.leaf_num = 1
else:
node.leaf_num = sum(child.leaf_num for child in node.subtree.values())
return root