本文来自之前在Udacity上自学机器学习的系列笔记。这是第6篇,介绍了监督学习中的决策树模型。
决策树
决策树是监督学习中的分类模型的一种。关于分类模型,我们先了解下面的的概念:
- Instance:实例,表示输入input
- Concept:概念,表示输出函数Function
- Target concept:目标概念,即根据输入,希望得到的输出是什么
- Sample:训练集Training set
- Candidate:候选的输出函数Function
- Testing set:测试集Testing set
什么是决策树?它就像一颗倒过来的树形结构一样。比如,我们决定是否去餐厅A吃饭,我们会考虑多个因素,比如说肚子是否饿了、餐厅是否满座、中餐还是西餐等。最后,我们可以依次对每一个因素的判断,来得到结论。
所以,决策树就是:
- 选择数据集的一个特征,数据集作为初始的判断节点(Nodes),根据对这个特征的判断(Attributes),将数据集分为几类子节点;
- 沿着不同的子数据集,再选择数据集的一个特征,进行判断并细分数据集;
- 重复步骤2,直到获得满足条件的数据集,即叶子节点的数据集中的数据大多都属于同一类别了。
AND、OR、XOR
我们可以用“AND”,“OR”和“XOR”来加深理解。
- "AND"表示A和B同时为真时,C为真;否者为假;
- “OR”表示A或B为真时,C为真;否者为假;
- “XOR”表示A和B的结果不相同时,C为真,否则为假。
从而可以表示为:
上面是两个节点的情况,这可以推广到n个节点的情况。对于“AND”和“OR”,n个节点判断的计算量是线性的n阶运算;而“XOR”的计算量是指数级的阶运算。如果我们用真值表来表示n个节点的“XOR”决策树的话,每个节点是二元的,那么一共有种情况,而输出则有种,这个运算量是很大的。
模型
由上面的介绍,我们现在知道了,决策树是一种树状结构模型,可以解决分类的问题。
解决这类问题,可以通过特征选择的方法。
算法
ID3:
- 选择最优的特征A;
- 赋予特征A到决策树的节点;
- 对于特征A的每个值,创建节点的分支;
- 划分训练集到决策树的子节点;
- 如果子结点的数据得到完美地划分,那么模型训练结束;否则继续选择其他特征进行迭代。
ID3算法的特征选取按照“信息增益”最大的特征进行。所谓“信息增益”,字面上意思就是,根据该特征来划分数据可以降低数据集中的不确定性,即对每个划分出来的子数据集,里面的样本基本属于同一类。
这个方法背后思想也是很朴素的。因为当我们对一个不确定的问题进行思考时,会优先根据问题的特征,以及特征可能存在的情况进行细分,并且这种细分可以将问题最大程度地确定清楚。
熵(Entropy)
在信息学中,用“熵”这个概念来表示信息的不确定性,或者说数据集的不纯度。
假设一个数据集的样本中,包含了个样本,每个样本有个特征;一共有个分类标签。用表示标签集合的概率分布,那么整个数据集的熵定义为:
如果Entropy越大,说明数据集的不纯度越高,不确定性越大。
信息增益(Information Gain)
“信息增益”是可以降低这个数据集不纯度的定义。假设父节点经过某特征的划分,根据这个特征的取值,划分为多个子节点。那么可以得到:
计算例子
Grade | Bumpiness | Speed limit | Speed |
---|---|---|---|
steep | bumpy | yes | slow |
steep | smooth | yes | slow |
flat | bumpy | no | fast |
steep | smooth | no | fast |
上述表格的特征分别是公路的Grade, Bumpiness和Speed limit特征,然后驾驶速度Speed是相应的输出。
对于父节点(ssff),对应有,且
如果选择Grade特征进行划分数据集,可以划分为(ssf)和(f)。
对于子节点(f),有,且
对于子节点(ssf),有,且
从而
综上,可以得到特征Grade下的信息增益
同理,可以求得
所以,Speed limit的信息增益最大,可以使用Speed limit作为最优特征进行划分数据集。
其他注意点
如果特征是连续值,例如年龄、体重、距离等,可以进行离散化处理,比如说年龄段在20到30的。
对于存在连续值的特征的决策树,可能有必要重复某个特征的判断,比如说年龄在20到30岁,如果判断为F,那就还需要对年龄继续判断,到底是大于30岁,还是小于20岁。
如果所有数据集都得到正确地划分,或者没有更多的特征时,就可以停止决策树。另外,没有出现过拟合情况也是。
决策树的优点是它相对来说比较直观和容易理解。但缺点就是容易过拟合,这是因为树的分叉过细导致。特别是当数据集包含大量特征的数据时。所以需要仔细地调整参数,避免过拟合。
具体scikit-learn可以参考:
https://scikit-learn.org/stable/modules/tree.html#tree