GBDT源码分析之一 :总览

0x00 前言

这个系列将会对python的scikit-learn算法包中GBDT算法的源码实现做一个深入梳理和解读。本文会首先对GBDT算法做一个简单的介绍,并对其源码的结构做一个整体上的梳理。因为这里偏重的是源码分析,所以如果想对GBDT算法本身的原理进行深入了解,可以阅读参考文献中推荐的几位大牛的文章。

文章结构

本文将分为下面几个部分:

  1. 简要介绍一下GBDT算法的基本概念。
  2. scikit-learn中GBDT算法的运行例子。
  3. 对GBDT源码结构的一个整体梳理。这里我们会通过思维导图的方式展现GBDT算法实现涉及的主要源码构成。

0x01 GBDT简介

GBDT(Gradient Boosting Decision Tree) 又称 MART(Multiple Additive Regression Tree)或GBRT(Gradient Boosting Regression Tree),是一种基于回归决策树的Boosting集成算法。

GBDT的核心从算法命名来看一目了然,即决策树(DT)和梯度提升(GB)。

决策树

决策树是一种十分常用和基础的监督学习算法,可适用于分类和回归问题;它将决策过程表述为树状结构,树中的不同路径代表不同的决策分支。决策树的构建过程由根节点出发,根据样本的属性(特征)不断将样本集分裂生成子节点,直至满足停止条件;树结构的每个叶子节点都代表一个最终的预测结果,一般取落入该叶子节点的样本的众数/概率分布/平均值等。由于通过决策树算法生成的模型可以由一系列if-then规则表述,因此非常易于理解和实现,也是最简单的非线性算法之一。

决策树的关键技术包括分裂点的选择、分裂停止的条件以及避免过拟合的方法(如剪枝;合适的分裂停止条件也可以防止过拟合)。经典的决策树算法包括ID3、C4.5、CART等。

回归树

回归树即用来解决回归问题的决策树。在分类树中,样本标签是离散的或非有序的,我们取叶子节点样本标签的众数或概率分布作为预测结果;而在回归树中,样本标签一般是连续性的有序数据,我们取叶子结点中所有样本标签的平均值作为预测结果。

集成方法(Ensemble Method)

集成学习方法是将多个弱模型通过一定的组合方式组成一个新的强模型的方法,一般情况下集成的模型具有更强的预测和泛化能力。在机器学习问题中,这是一种非常强大的思路,也是"集体智慧"的典型例子。集成算法中的弱模型又称元算法;在GBDT中,回归树是GBDT的元算法。

我们在理解集成方法时,可以更多将其看作一些学习框架,重点在于理解这些框架的思路。各种集成算法(如GBDT、随机森林)的核心也可理解为将基本算法(如决策树)带入集成框架(如Boosting、Bagging)的产物。

Boosting与Gradient Boosting

Boosting的意思是"提升",它关注被预测错误的样本,基于预测错误的部分构建新的弱模型并集成,是一种常用的迭代集成方法。原始的Boosting方法可以说是基于"样本"的,它会在一开始给所有样本附上相等的权重值,在每轮迭代(生成一个弱模型)后增加预测错误的样本的权重,减少预测正确的样本的权重,并在此基础上训练新的弱模型;最终通过加权或投票的形式对所有弱模型进行组合,生成强模型。

而Gradient Boosting和原始Boosting方法不同的地方在于,它在残差减少的梯度方向建立新的弱模型。直观上看,它用来训练第K轮弱模型的数据,来自于之前所有弱模型集成后的预测值和样本真实值的"差"(准确来说损失函数梯度减少的方向)。


基于上面描述的一系列概念,我们可以较为容易的理解:一个GBDT模型由多颗回归决策树组成;理论上在训练过程中的一轮迭代中,算法基于残差减少的梯度方向生成一颗决策树(scikit-learn在用GBDT解决多标签问题时,实际上在每一轮迭代中用了多棵回归树,本文中我们不对这种情况做深入说明)。在预测阶段,累加模型中所有决策树的预测值(乘上步长/学习率),即可计算整个模型的预测结果。

GBDT算法在实际生产中运用非常广泛,表达能力也很强,通常不需要复杂的特征工程就能得到较好的预测效果,还能输出特征重要性得分;同时通过设定合理的样本和特征抽样比例,可以在训练过程中实现交叉检验(cross validation),有效地减少模型过拟合的出现。缺点则是基于Boosting集成方法的算法较难实现并行化,且基于GBDT的模型会较为复杂,深入分析和调优会有一定困难性。

0x02 运行示例

scikit-learn中ensemble包下关于GBDT的算法有两个,分别用来解决回归问题GradientBoostingRegressor和分类问题GradientBoostingClassifier,调用起来十分简单。

回归示例(波士顿房价数据集)

from sklearn.ensemble import GradientBoostingRegressor
from sklearn.datasets import load_boston
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

# 导入数据
X_train, X_test, y_train, y_test = train_test_split(load_boston().data, load_boston().target, test_size=0.2)


"""初始化算法,设置参数

一些主要参数
loss: 损失函数,GBDT回归器可选'ls', 'lad', 'huber', 'quantile'。
learning_rate: 学习率/步长。
n_estimators: 迭代次数,和learning_rate存在trade-off关系。
criterion: 衡量分裂质量的公式,一般默认即可。
subsample: 样本采样比例。
max_features: 最大特征数或比例。

决策树相关参数包括max_depth, min_samples_split, min_samples_leaf, min_weight_fraction_leaf, max_leaf_nodes, min_impurity_split, 多数用来设定决策树分裂停止条件。

verbose: 日志level。
具体说明和其它参数请参考官网API。
"""
reg_model = GradientBoostingRegressor(
    loss='ls',
    learning_rate=0.02,
    n_estimators=200,
    subsample=0.8,
    max_features=0.8,
    max_depth=3,
    verbose=2
)

# 训练模型
reg_model.fit(X_train, y_train)

# 评估模型
prediction_train = reg_model.predict(X_train)
rmse_train = mean_squared_error(y_train, prediction_train)
prediction_test = reg_model.predict(X_test)
rmse_test = mean_squared_error(y_test, prediction_test)
print "RMSE for training dataset is %f, for testing dataset is %f." % (rmse_train, rmse_test)
"""Output:
RMSE for training dataset is 4.239157, for testing dataset is 10.749044.
"""

分类示例(鸢尾花分类数据集)

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split

# 导入数据
X_train, X_test, y_train, y_test = train_test_split(load_iris().data, load_iris().target, test_size=0.2)


"""初始化算法,设置参数

一些主要参数
loss: 损失函数,GBDT分类器可选'deviance', 'exponential'。
learning_rate: 学习率/步长。
n_estimators: 迭代次数,和learning_rate存在trade-off关系。
criterion: 衡量分裂质量的公式,一般默认即可。
subsample: 样本采样比例。
max_features: 最大特征数或比例。

决策树相关参数包括max_depth, min_samples_split, min_samples_leaf, min_weight_fraction_leaf, max_leaf_nodes, min_impurity_split, 多数用来设定决策树分裂停止条件。

verbose: 日志level。
具体说明和其它参数请参考官网API。
"""
clf_model = GradientBoostingClassifier(
    loss='deviance',
    learning_rate=0.01,
    n_estimators=50,
    subsample=0.8,
    max_features=1,
    max_depth=3,
    verbose=2
)

# 训练模型
clf_model.fit(X_train, y_train)

# 评估模型
prediction_train = clf_model.predict(X_train)
cm_train = confusion_matrix(y_train, prediction_train)
prediction_test = clf_model.predict(X_test)
cm_test = confusion_matrix(y_test, prediction_test)
print "Confusion matrix for training dataset is \n%s\n for testing dataset is \n%s." % (cm_train, cm_test)
"""Output:
Confusion matrix for training dataset is 
[[40  0  0]
 [ 0 40  1]
 [ 0  1 38]]
 for testing dataset is 
[[10  0  0]
 [ 0  8  1]
 [ 0  0 11]].
"""

0x03 源码总览

整体介绍

Python的scikit-learn包包含了我们常用的大部分的机器学习算法和数据处理方法,我们主要分析其中实现GBDT的源码。GBDT的实现源码依然可以被分为GB和DT两部分。其中DT为决策树部分,其源码在一个名为Tree的package下;GB为gradient boosting方法,其相关源码在一个名为Ensemble的package下。总体结构见下面的思维导图。

Tree包的源码结构截图如下。里面实现了决策树算法、决策树的基本数据结构Tree、决策树构建策略以及树的可视化等内容。

Ensemble包的源码结构截图如下。Ensemble包里还包含了如bagging、随机森林等其它主题,但我们主要关注其中的base.py和grandient_boosting.py文件。

在本系列后续的两篇文章里,我们将分别介绍Tree包和Ensemble包中和GBDT相关的内容。

0xFF 参考:


作者:cathyxlyl | 简书 | GITHUB

个人主页:http://cathyxlyl.github.io/
文章可以转载, 但必须以超链接形式标明文章原始出处和作者信息

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,718评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,683评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,207评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,755评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,862评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,050评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,136评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,882评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,330评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,651评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,789评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,477评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,135评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,864评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,099评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,598评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,697评论 2 351

推荐阅读更多精彩内容