一、概述
- 决策树是机器学习中最基础、应用最广泛的算法模型,常用于解决分类和回归问题。
- 决策树的构建是一种自上而下的递归分裂学习方法,其学习的关键在于如何选择最优划分属性。一般情况下,随着划分过程不断进行,决策树的分支结点所包含的样本会尽可能属于同一类别,即结点的纯度(purity)越来越高。
- 度量样本集合纯度有三种常用的评估准则:ID3、C4.5和CART。本文将从构造、应用和实现三个角度,对比这三种模型的异同点。
二、决策树学习算法
2.1 相亲数据集
编号 | 年龄 | 长相 | 工资 | 编程 | 类别 |
---|---|---|---|---|---|
1 | 老 | 帅 | 高 | 不会 | 不见 |
2 | 年轻 | 一般 | 中等 | 会 | 见 |
3 | 年轻 | 丑 | 高 | 不会 | 不见 |
4 | 年轻 | 一般 | 高 | 会 | 见 |
5 | 年轻 | 一般 | 低 | 不会 | 不见 |
2.2 ID3(Iterative Dichotomiser 3)--信息增益
信息熵(information entropy)是度量样本集合纯度最常用的一种指标。对于样本集合D,类别数为K,信息熵定义为:
其中,是样本集合D中属于第k(k=1,2,3...K)类的样本子集,表示该子集的元素个数,表示样本集合的元素个数。特征属性A的条件熵H(D|A)定义为:
其中,表示D中特征A取第i个值的样本子集,表示中属于第k类的样本子集。因此,特征A的信息增益等于两者之差:
以表2.1中相亲数据为例,该数据包含5个样本集,正样本占比,负样本占比。于是,根据公式计算出根节点的信息熵为:
然后,计算当前属性集合{年龄,长相,工资,编程}中每个属性的条件熵。以属性“年龄”为例,它有2个可能的取值:{老,年轻}。若使用该属性对D进行划分,则可得到2个子集,分别为:(年龄=老),(年龄=年轻)。
- 子集包含编号{1}的1个样例,其中正样本占比,负样本占比;
- 子集包含编号{2,3,4,5}的4个样例,其中正样本占比,负样本占比;
根据公式计算出属性=年龄的条件熵为:
依此类推:
每个属性对应的信息增益为:
g(D,年龄)=0.171,g(D,长相)=0.42,g(D,工资)=0.42,g(D,编程)=0.971
由此可得,特征“编程”的信息增益最大,使用特征“编程”来进行划分所得的纯度提升越大。图2.1给出了基于“编程”对根节点进行划分的结果:
然后,决策树学习算法对每个分支节点做进一步划分。以图2.1中第一个分支节点(编程=会)为例,该节点包含的样本集合有{2,4}两个样本,可用的属性集合有{年龄,长相,工资},基于计算出各属性的信息增益,继续划分树节点,直到满足停止分裂的条件。
最后,ID3对取值较多的特征有所偏好,特征取值越多意味着确定性越高,也就是条件熵越小,信息增益越大。如果将前面的编号也作为一个划分属性,其信息增益为0.971,它将产生5个分支,每个分支结点仅包含一个样本,显然这样划分出来的决策树不具备泛化能力,无法对新样本进行有效预测。因此,C4.5对ID3进行优化,通过引入信息增益比,对取值较多的特征进行惩罚,避免出现过拟合的特性,提升决策树的泛化能力。
信息熵和条件熵计算的scala-spark实现:
/**
* 计算信息熵
*
* @param df
* @param column
* @return
*/
def calculate(df: DataFrame, column: String = "label"): Double = {
val counts = df.select(column).groupBy(column).agg(count(column)).collect().map(row => row.getLong(1))
val totalCount = counts.sum.toDouble
if (totalCount == 0) {
return 0
}
counts.map {
count =>
var impurity = 0.0
if (count != 0) {
val freq = count / totalCount
impurity -= freq * log2(freq)
}
impurity
}.reduce((v1, v2) => v1 + v2)
}
/**
* 计算每个特征的条件熵
*
* @param df
* @param column
* @return
*/
def calculateFeature(df: DataFrame, column: String): Double = {
val counts = df.select(column).groupBy(column).agg(count(column)).collect()
val totalCount = counts.map(row => row.getLong(1)).sum.toDouble
if (totalCount == 0) {
return 0
}
val impurity = counts.map {
row =>
val featureValue = row.get(0).toString.toDouble
val featureCount = row.getLong(1)
val freq = featureCount / totalCount
val tmp = df.filter(col(column) === featureValue)
freq * calculate(tmp, "label")
}.reduce((v1, v2) => v1 + v2)
impurity
}
2.3 C4.5--信息增益比
特征A对于数据集D的信息增益比定义为:
其中
称为数据集D关于A的取值熵。因此,可以根据上面的公式求出每个特征的取值熵:
最终可计算出各个特征的信息增益比:
通过信息增益比,特征年龄对应的指标上升了,而特征长相和工资有所下降。
2.4 CART--基尼指数(Gini)
CART决策树使用基尼指数来选择划分属性,Gini描述的是数据的纯度,其定义为:
其中,是样本集合D中属于第k(k=1,2,3...K)类的样本子集,表示该子集的元素个数,表示样本集合的元素个数。如果所有样本都属于同一个类别,则=,=0,此时impurity最小。CART利用基尼指数构造二叉决策树。如果特征是离散型变量,将样本按特征A的取值切分成两份;如果特征是连续型变量,CART的处理方式和C4.5相同,先将特征值进行升序排序,然后把左边第一个值(index=1)作为一个分类,右边其他值作为另一个分类,计算其Gini指数,然后移动index的位置,直到计算完所有的分类结果,然后选取Gini最小的位置对应的index作为切分点。特征A的Gini指数定义为:
当n=2时,该公式可以简化为:
使用CART分类准则,选取年龄维度,把老作为特征标签,那么年轻就被划分到另外一类
老(总数=1) | 年轻(总数=4) | |
---|---|---|
类别 | 不见 | 见、不见、见、不见 |
帅(总数=1) | 一般、丑(总数=4) | |
---|---|---|
类别 | 不见 | 见、不见、见、不见 |
一般(总数=3) | 帅、丑(总数=2) | |
---|---|---|
类别 | 见、不见 | 不见 |
高(总数=3) | 中等、低(总数=2) | |
---|---|---|
类别 | 见、不见 | 见、不见 |
会(总数=2) | 不会(总数=3) | |
---|---|---|
类别 | 见 | 不见 |
因此,特征编程的Gini指数最小,选择该特征作为最优的切分点。
三、小结
通过比较ID3、C4.5和CART三种决策树的构造准则,在同一个样本集上,表现出不同的划分行为。
- ID3和C4.5在每个结点上可以产生多个分支,而CART每个结点只会产生两个分支
- C4.5通过引入信息增益比,弥补了ID3在特征取值比较多时,由于过拟合造成泛化能力变弱的缺陷
- ID3只能处理离散型变量,而C4.5和CART可以处理连续型变量
- ID3和C4.5只能用于分类任务,而CART可以用于分类和回归任务
参考文献:
[1]诸葛越,葫芦娃.百面机器学习
[2]周志华.机器学习