0x00 前言
本文是《GBDT源码分析》系列的第二篇,主要关注GBDT的元算法"决策树"在scikit-learn中的实现。
0x01 整体说明
scikit-learn的主要源码都在项目的sklearn文件夹下,其中sklearn/tree
里是基本树模型的实现代码,如图,该文件夹下有以下几个文件。
- _init_.py: 包初始化
- setup.py: 包的安装代码
- tree.py: DecisionTree模型的主要逻辑代码
- export.py: tree的输出
- *.pxd和*.pyx文件: 分别有_tree, _splitter, _criterion, _utils四个内容共八个文件,是tree具体实现步骤的CPython代码。注意具体算法实现的核心代码都是用CPython,即在pyx后缀的文件里,其它py文件实际上可以看做模型的壳。
- 还有一个tests文件夹,从tests里的内容来看,里面对tree做回归和分类的功能,以及tree的输出和保存(export)做了测试。
__init__.py
文件
查看_init_.py文件,可发现tree提供的主要类/方法有5个:
- DecisionTreeClassifier: 位于tree.py
- DecisionTreeRegressor: 位于tree.py
- ExtraTreeClassifier: 位于tree.py
- ExtraTreeRegressor: 位于tree.py
- export_graphviz: 位于export_graphviz
我们最终主要关注其中的DecisionTreeRegressor和export_graphviz。
关注前者的原因是因为GDBT是基于回归树实现的;后者关系到树的可视化和输出,对我们理解能够获得的模型参数有一定帮助。
0x02 数据结构:Tree
决策树算法围绕的中心实际上就是一颗树,这棵树结构在训练阶段中被构建,在预测阶段则根据样本特征寻找一条从根节点到某个叶子结点的路径。这一章我们将介绍这个关键的数据结构;在源码中,这个树结构是用一个名为Tree的类实现的
Node节点
在定义树以前,我们需要先定义节点Node。Node是一个struct,在_tree.pxd中定义。一个Node由以下变量组成:
- left_child: 左子节点的ID
- right_child: 右子节点的ID
- feature: 该节点分裂所选的特征
- threshold: 分裂点的特征值
- impurity: 该节点的不纯度
- n_node_samples: 训练集在该节点的样本数(置信度)
- weighted_n_node_samples: 训练集在该节点的加权样本数
Tree类的变量
类Tree是一个二叉树,可以由TreeBuilder构建(我们将在下一章介绍TreeBuilder类)。Tree的定义和s具体实现代码分别在_tree.pxd和_tree.pyx文件中。
Tree类有以下一些成员变量,其中“可从外部获得的attribute”代表这些值可以通过Tree.xxx_attribute的方式读取:
- n_features: 特征数
- n_classes: 类别数,对于回归问题是1。
- n_outputs: 多标签问题中的标签数;对于一般的单标签问题和回归问题,该值取1。
- max_n_classes: max(n_classes),仅对多标签问题时有实际意义。
- nodes: 树中所有的节点组成的数组。
- value_stride: n_outputs * max_n_classes
- max_depth: 可从外部获得的attribute,树的最大深度。
- node_count: 可从外部获得的attribute,树的节点数。
- capacity: 可从外部获得的attribute,树的容量,大于等于node_count。
- children_left: 可从外部获得的attribute,所有节点的左子节点对应的nodeID组成的数组,如果节点本身是叶子节点,其左子节点为常量TREE_LEAF,即-1。
- children_right: 可从外部获得的attribute,所有节点的右子节点对应的nodeID组成的数组,如果节点本身是叶子节点,其右子节点为常量TREE_LEAF,即-1。
- features: 可从外部获得的attribute,所有节点分裂所选的特征(编号)组成的数组。
- threshold: 可从外部获得的attribute,所有节点分裂点的特征值组成的数组。
- impurity: 可从外部获得的attribute,所有节点的不纯度组成的数组。
- n_node_sample: 可从外部获得的attribute,所有节点上的样本数组成的数组。
- weighted_n_node_samples: 可从外部获得的attribute,所有节点上的加权样本数组成的数组。
- value: 可从外部获得的attribute,维度是(capacity, n_outputs, max_n_classes),存储了各个节点、标签位、类别位中样本的个数;对于回归问题,则是各个节点的预测值。
Tree类的方法
类Tree有以下主要的方法
- _resize: 根据Tree的capacity更改树内所有数组变量的长度;如果capacity是-1,将各个数组长度加倍。实际调用_resize_c方法。
- _add_node: 加入一个新节点。
- predict: 对一个样本输出预测值,DecisionTree算法里面predict方法实际调用的方法。
- apply: 对一个样本输出它预测所在叶子节点的标号,DecisionTree算法里面apply方法实际调用的方法。根据特征是否稀疏会调用_apply_sparse_csr方法或_apply_dense方法。
- decision_path: 对一个样本输出它的预测路径(经过的节点标号),DecisionTree算法里面decision_path方法实际调用的方法。根据特征是否稀疏会调用_decision_path_sparse_csr方法或_decision_path_dense方法。
- compute_feature_importances: 输出所有特征的重要性importance,DecisionTree里面features_importences_属性实际调用的方法。计算思路:遍历所有node,如果该节点不是叶子节点,该节点分裂后不纯度的下降值加到该节点分裂特征的重要性上。这样,每个特征的重要性就是该特征作为分裂特征时不纯度的下降的和。最后再做一个归一化的处理。
0x03 决策树的构建
决策树的构建过程涉及三个关键类:
- Criterion和Splitter:定义和实现了树结构的分裂策略。
- TreeBuilder:通过递归的方式从训练样本中构建树。
Criterion(分裂点好坏的评判标准)
_criterion.pyx
里主要定义和实现了各种关于不纯度计算的类,包括:
- Criterion: 不纯度评评判标准的基类/接口
- ClassificationCriterion: 分类问题的不纯度评判标准的基类
- Entropy: 交叉熵,一种分类criteria
- Gini: 基尼系数,一种分类criteria
- RegressionCriterion: 回归问题的不纯度评判标准的基类
- MSE: 平均方差,一种回归criteria
- MAE: 平均绝对差,一种回归criteria
这里提一下Criterion表示分裂点和左右节点的方法。在Criterion类内置的变量里,有start、pos、end这样三个值,假设样本标签是samples(已排序的),则samples[start: pos]可代表分裂后的左子节点,samples[pos: end]可代表分裂后的右子节点。Criterion类在初始化时其实是空的,仅当构建树的时候才会被使用。
Splitter(最佳分裂点的寻找方法)
_splitter.pyx
里主要定义和实现了分裂点的具体计算方法,包括考虑一些工程上的性能问题。
- Splitter: 所有splitter的基类/接口
- BaseDenseSplitter: 非稀疏特征矩阵的splitter基类
- BestSplitter: 对非稀疏情况寻找最佳分裂的splitter
- RandomSplitter: 对非稀疏情况寻找最佳随机分裂的splitter
- BaseSparseSplitter: 稀疏特征矩阵的splitter基类
- BestSparseSplitter: 对稀疏情况寻找最佳分裂的splitter
- RandomSparseSplitter: 对稀疏情况寻找最佳随机分裂的splitter
同Criterion一样,Splitter在初始化时是空的,只有TreeBuilder才会调用它,用来寻找最佳的分裂点。
TreeBuilder
TreeBuilder类通过递归的方式从训练样本中构建树。它通过Splitter来分裂内部节点以及给叶子节点赋值。TreeBuilder类主要用来控制构建树的过程中的各种停止条件,以及节点分裂时的优先度策略,包括depth-first或best-first策略。
在决策树的参数中,TreeBuilder相关的几个关键参数包括
- min_samples_split: 每个节点上最少的样本数
- min_samples_leaf: 每个叶子节点上最少的样本数
- min_weight_leaf: 每个叶子节点最小的权重(加权样本数占总加权样本数的比例)
- max_depth: 树的最大深度
- min_impurity_split: 当节点的不纯度小于该值时,节点将不再分裂
TreeBuilder有两种,DepthFirstTreeBuilder和BestFirstTreeBuilder。TreeBuilder内构建树的主要过程都是在build方法内实现。
0x04 决策树算法
决策树算法这里包含三部分:BaseDecisionTree、DecisionTreeClassifier和DecisionTreeRegressor。其中,BaseDecisionTree类是后两者的基类。
BaseDecisionTree
查看tree.py,除了之前提到的各种具体的树模型,有一个BaseDecisionTree类,是其它所有树模型的基类,我们先看看这个类。
如同scikit-learn中所有estimator一样,BaseDecisionTree的基类也是BaseEstimator。
它有几个基本方法/属性:
- _init_: 初始化构造方法
- fit: 训练方法
- predict: 预测方法
- apply: 返回样本被预测所在的叶子结点的index,在0.17版本时加入
- decisionpath: 预测结果的预测路径,在0.18版本时加入
- feature_importances: 特征重要性
此外,BaseDecisionTree应还包含基类BaseEstimator中的get_params和set_params方法。
init方法
构造方法里有以下参数,即我们在调用DecisionTree模型时输入的那些参数:
- criterion: 衡量分裂的准则,对于分类树,可选"gini"和"entropy",分别对应基尼纯度和信息增益。对于回归树,可选"mse"和"mae"(0.18版本时加入),分别对应平均平方误差和平均绝对值误差。
- splitter: 每一个node上分裂点选择策略,可选"best"和"random"。(不太懂这两个的区别)
- max_depth: 树的最大深度。
- min_samples_split: 分裂时要求的最小样本数。如果类型是int,为个数;如果类型是float,为总样本数的比例(0.18版本时加入)。
- min_samples_leaf: 叶子结点上最小样本数。如果类型是int,为个数;如果类型是float,为总样本数的比例(0.18版本时加入)。
- min_weight_fraction_leaf: 相较于总样本的权重和,叶子结点上样本的权重和所占的最小比例。(举例来说,如果sample_weight没有设定,所有样本有同样的权重)
- max_features: 分裂时考虑的特征数。如果类型是int,为个数;如果类型是float,为所占比例;还可以取sqrt、log2、auto(sqrt);如果是None,取总特征数。(注意:the search for a split does not stop until at least one valid partition of the node samples is found, even if it requires to effectively inspect more than max_features features.)
- max_leaf_nodes: 使用“最好优先”的方式构造树,使得树最多有max_leaf_nodes个叶子。“最好优先”中的“最好”的定义是impurity的relative reduction最大。
- random_state: 确定初始随机状态的参数。
- min_impurity_split: 若一个节点的不纯度小于该参数,停止分裂,该节点成为叶子节点。
- class_weight: 只有分类树才有的参数,类别label的权重,可以是"balanced"、None或者一个dict。
- presort: fitting时是否分拣数据。对于大数据集可能会是的training过程变慢,对于小数据集或限制深度的情况,可能能加快training速度。
fit方法
该方法的主输入参数有X(特征)和y(标签),可选参数有sample_weight、check_input、X_idx_sorted。下面提取部分主要代码做说明,省略可选参数、数据格式判断、多标签分类等内容。省略号代表这部分代码略过,仅描述一下主要干了些什么。
# 从X的形状获取样本数和特征数,n_features_是树模型的一个attribute
n_samples, self.n_features_ = X.shape
# 判断模型是分类还是回归
is_classification = isinstance(self, ClassifierMixin)
# ......对y(label)做处理
# 对于多标签分类问题,n_outputs_是多标签的个数;对于普通的分类问题,该值为1
self.n_outputs_ = y.shape[1]
# 当模型是分类树时
if is_classification:
# 初始化classes,这是分类树模型的一个attibute;有哪些label
self.classes_ = []
# 初始化n_classes,这是分类树模型的一个attibute;label的个数
self.n_classes_ = []
# 初始化编码后的label
y_encoded = np.zeros(y.shape, dtype=np.int)
# ......对label进行编码,同时计算self.classes_和self.n_classes_
# ......根据sample_weight对class_weight进行修正
#当模型是回归树时
else:
# 初始化self.classes_和self.n_classes_
self.classes_ = [None] * self.n_outputs_
self.n_classes_ = [1] * self.n_outputs_
# ......对参数进行检查和计算,非常长的一段代码
# 构建树的各种定义和过程
# 插入一段别处的代码:在tree.py中有一段关于type和constant的定义
DTYPE = _tree.DTYPE # np.float32
DOUBLE = _tree.DOUBLE # np.float64
# 分类树中的分裂评判标准,在_criterion.pyx中实现
CRITERIA_CLF = {"gini": _criterion.Gini, "entropy": _criterion.Entropy}
# 回归树中的分裂评判标准,即各种不纯度函数的计算方法,在_criterion.pyx中实现
CRITERIA_REG = {"mse": _criterion.MSE, "friedman_mse": _criterion.FriedmanMSE, "mae": _criterion.MAE}
# 非稀疏情况下分裂点选择策略,在_splitter.pyx中实现
DENSE_SPLITTERS = {"best": _splitter.BestSplitter, "random": _splitter.RandomSplitter}
# 稀疏情况下分裂点选择策略,在_splitter.pyx中实现
SPARSE_SPLITTERS = {"best": _splitter.BestSparseSplitter, "random": _splitter.RandomSparseSplitter}
# 读入模型的分裂评判标准
criterion = self.criterion
# 如果模型的criterion不是Criterion类的实例,根据前面定义的字典对分类和回归情况分别初始化criterion
if not isinstance(criterion, Criterion):
if is_classification:
criterion = CRITERIA_CLF[self.criterion](self.n_outputs_, self.n_classes_)
else:
criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples)
# 判断输入是否稀疏,读入模型的分裂评判标准
SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
splitter = self.splitter
# 如果模型的splitter不是Splitter类的实例,根据前面定义的字典初始化Splitter
if not isinstance(self.splitter, Splitter):
splitter = SPLITTERS[self.splitter](criterion, self.max_features_, min_samples_leaf, min_weight_leaf, random_state, self.presort)
# 初始化树,这里的Tree实在_tree.pyx中实现的
self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_)
# 如果max_leaf_nodes有定义,则使用BestFirstTreeBuilder这个TreeBuilder;其它情况使用DepthFirstTreeBuilder
if max_leaf_nodes < 0:
builder = DepthFirstTreeBuilder(splitter, min_samples_split, min_samples_leaf, min_weight_leaf, max_depth, self.min_impurity_split)
else:
builder = BestFirstTreeBuilder(splitter, min_samples_split, min_samples_leaf, min_weight_leaf, max_depth, max_leaf_nodes, self.min_impurity_split)
# 构建树(训练)
builder.build(self.tree_, X, y, sample_weight, X_idx_sorted)
predict方法
# 使用训练得到的tree做预测
proba = self.tree_.predict(X)
# 对于分类树
if isinstance(self, ClassifierMixin):
# 对于只有一个标签的分类问题,取预测概率最大的类别作为输出
if self.n_outputs_ == 1:
return self.classes_.take(np.argmax(proba, axis=1), axis=0)
# 对于多标签的分类问题,对每一列标签分别取预测概率最大的类别作为输出
else:
predictions = np.zeros((n_samples, self.n_outputs_))
for k in range(self.n_outputs_):
predictions[:, k] = self.classes_[k].take(np.argmax(proba[:, k], axis=1), axis=0)
return predictions
# 对于回归树
else:
if self.n_outputs_ == 1:
return proba[:, 0]
else:
return proba[:, :, 0]
apply方法
调用self.tree_.apply(X)
decision_path方法
调用self.tree_.decision_path(X),读取的时候可以通过toarray()方法。
feature_importances_
调用self.tree_.compute_feature_importances()
DecisionTreeClassifier和DecisionTreeRegressor
DecisionTreeClassifier
DecisionTreeClassifier的init和fit方法直接继承了基类BaseDecisionTree,额外增加了predict_proba和predict_log_proba方法。
predict_proba返回输入样本被预测到每个类的概率,即样本被预测到的叶子节点中label的分布;predict_log_proba是对上述概率取log后的结果。
DecisionTreeRegressor
DecisionTreeRegressorr的init和fit方法直接继承了基类BaseDecisionTree,无其它新增方法。
Regression Example
# Very simple toy example for Regression
X = [[0, 0], [3, 2], [-2, 2]]
y = [0.5, 2.5, 1]
clf = tree.DecisionTreeRegressor()
clf = clf.fit(X, y)
# feature importance
print "feature_importances_"
print clf.feature_importances_
# clf.tree_ is the generated tree, its value represents the predicted value on each node
print "tree_.value"
print clf.tree_.value
# the decision_path
print "decision_path"
print clf.decision_path([[-2, 2]]).toarray()
"""Output
feature_importances_
[ 0.94230769 0.05769231]
tree_.value
[[[ 1.33333333]]
[[ 0.75 ]]
[[ 0.5 ]]
[[ 1. ]]
[[ 2.5 ]]]
decision_path
[[1 1 0 1 0]]
"""
树结构可视化
# visualization
import pydotplus
from IPython.display import Image
dot_data = tree.export_graphviz(clf, out_file=None)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())
Classification Example
# Very simple toy example for Classifier
X = [[0, 0], [3, 2], [-2, 2], [5, -1]]
y = ['a', 'b', 'a', 'a']
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)
# feature importance
print "feature_importances_"
print clf.feature_importances_
# clf.tree_ is the generated tree, its value represents the predicted value on each node
print "tree_.value"
print clf.tree_.value
# the number of classes
print "n_classes_"
print clf.n_classes_
# the classes
print "classes_"
print clf.classes_
# the decision_path
print "decision_path"
print clf.decision_path([[-2, 2]]).toarray()
"""Output
feature_importances_
[ 0.33333333 0.66666667]
tree_.value
[[[ 3. 1.]]
[[ 2. 0.]]
[[ 1. 1.]]
[[ 1. 0.]]
[[ 0. 1.]]]
n_classes_
2
classes_
['a' 'b']
decision_path
[[1 1 0 0 0]]
"""
树结构可视化
# visualization
import pydotplus
from IPython.display import Image
dot_data = tree.export_graphviz(clf, out_file=None)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())
0x05 可视化
可以通过调用export_graphviz来输出一颗决策树(见上一章的例子),输出是dot格式的结果,这是一种有向图格式,该结果可以很容易转化成其它文件格式如png、jpeg、pdf,这样我们就可以完成对决策树的可视化过程。
export_graphviz方法在export.py文件里实现,有如下参数:
- decision_tree: 需要可视化的决策树。
- out_file: 输出的文件,可为file object或string。
- max_depth: 输出的树最大深度。
- feature_names: 一个list,为特征的名称;如果不指定,输出结果的特征名称就以X[0], X[1]......来表示。
- class_names: 类别的名称。
- label: 是否在node上显示每个展示值的名称,可选"all", "root"。
- filled: 是否给node上色,来表示分类中的majority class,或回归中的extremity value。
- leaves_parallel: 是否将所有叶子节点都画在底部。
- impurity: 是否在每个node上显示impurity值。
- node_ids: 是否在每个node上显示node ID。
- proportion: 显示values/sample的比例还是绝对值。
- rotate: 横着展示树还是竖着展示树。
- rounded: node的边框是否是圆角的,字体是Helvetica还是Times-Roman。
- special_characters: 是否忽略特殊字符。
export_graphviz中有一些子方法:
- get_color: 当filled是True时,计算node的颜色。
- node_to_str: 生成一个节点的字符串表示,包括nodeID、criteria、impurity、sample count、class distribution/regression value。
- recurse: 递归的遍历每个节点,生成整棵树的字符串表示。
个人主页:http://cathyxlyl.github.io/
文章可以转载, 但必须以超链接形式标明文章原始出处和作者信息