一、决策树简介
决策树的特点:
1)既可以处理分类问题,也可以处理回归问题
2)对于缺失值数据也能比较好的处理
3)高度可解释
决策树的思想很简单: 根据特征及判断阈,按照一定的路径来决定最终的判断结果。所以这里有几个问题:
1)节点如何确定?也就是说该如何选择用于分裂的特征?路径怎么决定?也就是说,我要把哪个特征放前面进行判断,哪个特征放后面进行判断?
2)阈值怎么确定?
3)何时停止?并不是把所有的数据都分割才算是好的,否则有可能算法对训练集能非常好的拟合,但是对测试集就拟合的不够好(过拟合)
4)预测的值是什么?对于分类问题比较好说,我们可以通过样本类型来进行分类判断。对于回归问题,我们如何得到预测值呢? 我们可以有很多种方法,例如取多次样本的平均值,或者,直接用叶节点的数据再做一个线形回归等等。
因此,决策树的生成也分为如下几步:
1)节点的分裂:一般当一个节点所代表的属性无法给出判断时,将这一个节点分裂成若干个子节点;
2)阈值的确定:选择适当的阈值使得分类错误率最小。
3)深度的确定
4 )最终预测值的算法确定
二、熵
在讲后面的具体算法之前,我们先要讲一下“熵”。那么什么是“熵”呢?按照一般的说法。熵就是混沌状态的一种度量。熵越大,表示越混沌。而所谓的增熵原理,就是说宇宙中的事物都有自发变得更混乱的倾向,也就是说熵会不断增加。 怎么理解? 例如:一个整洁的房间,如果不加收拾,会变得越来越凌乱。一个建筑物,如果不加修缮,会变得越来越破旧。我们本来将整理得好好的耳机线放在口袋里,可是,不出意外的是,每次从口袋里掏出它,它又变得乱糟糟的。但是反过来,我们从来没有见过一个乱糟糟的房间不加收拾,会自然变得很整洁;一团乱麻的手机线,不加真理,会自然变得很整齐。任何物体,一定是从有序变成无序,而不是相反,因为无序总是更有可能发生。既然是度量,我们就有公式。熵的计算公式如下:
关于熵的介绍,我们这里就点到为止。还有不懂或者感兴趣的同学,可以参考下面的文章。
三、信息熵
上面讲了熵,现在再来说信息熵。
首先,我们要知道什么是信息量。
信息量是对信息的度量,就跟时间的度量是秒一样,当我们考虑一个离散的随机变量x的时候,当我们观察到的这个变量的一个具体值的时候,我们接收到了多少信息呢?
多少信息用信息量来衡量,我们接受到的信息量跟具体发生的事件有关。
信息的大小跟随机事件的概率有关。越小概率的事情发生了产生的信息量越大,如湖南产生的地震了;越大概率的事情发生了产生的信息量越小,如太阳从东边升起来了(肯定发生嘛,没什么信息量)。
下面我们正式引出信息熵。
信息量度量的是一个具体事件发生了所带来的信息,而熵则是在结果出来之前对可能产生的信息量的期望——考虑该随机变量的所有可能取值,即所有可能发生事件所带来的信息量的期望。即
四、ID3 决策树算法
有了上面的理论基础,为了能让我们的决策树分类越正确,越稳定,信息量越大,我们就需要信息熵能达到最小。这就是我们进行选择节点和阈值的依据。另外,构建树的另外一个基本想法是随着树的深度的增加,节点的熵迅速的降低。熵降低的速度越快越好,这样就能得到一颗较矮的决策树(能 3 步判断出来就不用 5 步判断)。现在我们来使用ID3算法构建一个决策树。先看如下的一个样本例子:
这是一个分类问题。根据天气情况、温度情况、湿度情况、风力情况来决定要不要出去玩。 去或者不去,因此是个二分类问题。我们总共有14条样本数据。在这14条样本数据中,有9条选择了去,有5条选择了不去。因此,我们原始的信息熵为:。
1)现在我们来选择根节点:
我们有四个特征,这四个特征是都是分类变量,其中第一个特征有3个值(sunny、overcast、rainy);第二个特征也有三个值(hot、mild、cool);第三个特征有两个值(high、normal)、第四个特征有两个值(True、False)
我们分别使用这四个特征作为根节点,来计算分类后的熵。
使用outlook时,当特征值为sunny时,总共有5条数据,其中,2条为yes,3条为no,因此
当特征值为overcast时,总共有4条数据,这4条数据都是yes,熵为0。()
当特征值为rainy时,总共有5条数据,3条yes,2条为no,
那么,我们采用第一个特征作为根节点,得到的总体熵为:
同理,我们可以计算出,采用temperature作为根节点后的熵为:0.91;采用humidty作为根节点后的熵为:0.79;采用windy后的熵为:0.89
我们可以看到,使用outlook能使得我们的熵最小(下降得最快),因此,根节点我们选用outlook。
2)选择内节点
根节点确定后,我们在用同样的方法来选择内节点。
当根节点的值为overcast时,不用再分了。熵已经最小。
当根节点的值为sunny时,选temperature作为内节点,得到熵为:,选humidity作为内节点,得到的熵为0,选wind得到的熵为0.95。显然我们要选择humidity作为sunny下面的内节点。
当根节点值为rainy时,原理同上。
以上就是ID3构建决策树的基本思想。
当然,这个算法有个致命弱点,就是特征的不同特征值越多,越容易被采用为内(根)节点。试想一下,如果有一个特征,对于14个样本数据,有14个不同的值,那么他的信息熵就会为0,然后决策树会用这个特征来划分原数据集,其实这种划分毫无意义。所以,引申出了另外一个决策算法——C4.5
五、C4.5 决策树算法
为了解决上面ID3提到的问题,我们把ID3的评估算法稍微改进一下:
ID3 评估一个特征能否成为节点,是用原始熵减去该节点的熵,得到的一个信息增益。信息增益越大,代表该节点相对于原始节点来说,信息熵越低,降低得越快。
C4.5 在信息增益的基础上,需要除以该节点“纯度”。例如如果该特征有14个不同的特征值,其纯度为14,某特征有3个不同的特征值,其纯度为3。显然,纯度越大,也即分母越大,那么最终得到的值越小。 相当于加了一个惩罚项(惩罚因子),来进行平衡。
另外,ID3只能处理特征值是分类变量的情况。C4.5改进了算法,支持特征为连续值的情况。主要思路是将连续型变量离散化。具体做法为:
1)将连续变量的值从小到大排列;
2)取两个相邻变量的平均数,得到m-1个划分点;
3)将这m-1个划分点作为离散值,分别计算其信息熵的增益,选择信息增益最大的点作为该连续特征的分类阈值。
六、CART 决策树算法
不同于ID3、C4.5,CART没有使用信息熵增益来进行决策,而是用了另外一个衡量标准——基尼系数。基尼系数代表了模型的不纯度,基尼系数越小,不纯度越低,特征越好。这和信息增益(比)相反。
假设K个类别,第k个类别的概率为,概率分布的基尼系数表达式为:。
如果是二分类问题,第一个样本输出概率为p,概率分布的基尼系数表达式为:。
对于样本D,个数为|D|,根据特征A的某个值a,把D分成|D1|和|D2|,则在特征A的条件下,样本D的基尼系数表达式为:。
直接上例子。还是上面那个例子。假设我们选择outlook作为根节点,则:
当阈值为sunny时,其基尼系数为:
当阈值为overcast时,其基尼系数为:
当阈值为rainy时,其基尼系数为:
同理,我们计算出其他几个特征对应的基尼系数:
因为humidity和windy的特征都只有两个,所以他们都只有一个切分点
显然,Gini(outlook,overcast) 最小,因此我们会选择outlook作为根节点,overcast作为切分点,划分出一个叶节点,一个内节点,然后对内节点再按照上述步骤计算基尼系数,进一步进行划分。
所以,CART作为分类算法时,其思想跟ID3是一样的,只不过是决策的方法不同而已。另外,在对离散变量进行处理时,是采用的不停的二分策略。在ID3、C4.5,特征A被选取建立决策树节点,如果它有3个类别A1,A2,A3,我们会在决策树上建立一个三叉点,这样决策树是多叉树。CART采用的是不停的二分。会考虑把特征A分成{A1}和{A2,A3}、{A2}和{A1,A3}、{A3}和{A1,A2}三种情况,找到基尼系数最小的组合。
CART除了可以作为分类算法,也可以作为回归算法。分类树与回归树的区别在样本的输出,如果样本输出是离散值,这是分类树;样本输出是连续值,这是回归树。当作为分类树时,我们用基尼系数来进行度量划分节点的优劣,且预测结果采用叶子节点里概率最大的类别作为当前节点的预测类别。而作为回归树时,我们用SSE来进行度量划分节点的优劣,采用叶子节点的均值或者中位数来预测输出结果。具体来说,度量目标是对于划分特征A,对应划分点s两边的数据集D1和D2,求出使D1和D2各自集合的均方差最小,同时D1和D2的均方差之和最小。
#给定一个分类点,将数据分为两类,分别计算这两类的均方差
compute_SSE_split <- function(v, y, split_point) {
index<-v<split_point #这里的index全是false或者ture,也可以认为是0和1
y1<-y[index] # 取出来所有为true对应的y值
y2<-y[!index] # 取出所有为fasle对应的y值
SSE<-sum((y1-mean(y1))^2) + sum((y2-mean(y2))^2) #计算两堆数的方差
return(SSE)
}
#给定一个向量v和一个y值,把向量v按照里面的每个元素进行分类,将数据分为两类,分别计算这两类数据的均方差
compute_all_SSE_splits <- function(v, y) {
sapply(unique(v), function(sp) compute_SSE_split(v,y,sp))
}
set.seed(99)
x1<-rbinom(20,1,0.5) #生成20个0,1。1次试验,概率0.5
set.seed(100)
x2<-round(10+rnorm(20,5,5),2) #产生20个均值为5,标准差也为5的随机数,加上10后,进行四舍五入,保留两位小数。
set.seed(101)
y<-round((1+(x2*2/5)+x1-rnorm(20,0,3)),2) #生成y
rcart_df<-data.frame(x1,x2,y)
x1splits<-compute_all_SSE_splits(x1,y) #对特征1计算所有的方差。因为x1只有0和1两个数,因此只会有2个结果。其中小于0并不能进行区分,其实是有一个结果有效。
x2splits<-compute_all_SSE_splits(x2,y) #对特征2计算所有的方差。特征2是离散值,有多少个不同的特征,就会有多少个不同的结果。当然,对于最小的特征,其最终得到的sse也是无用的。
决策树很容易对训练集过拟合,导致泛化能力差,所以要对CART树进行剪枝,即类似线性回归的正则化。剪枝分为预减枝和后剪枝两种方法。预剪枝就是在做模型前,先想好每个叶节点上要保留多少数据量。事先设定一个阈值,来确定每个叶节点上的数据量。如果小于了该值,则不再分裂,确保每个叶节点上的数据量不至于太少。后剪枝则是先生成决策树,然后再根据剪枝损失函数和交叉检验来决定要剪掉哪条,保留泛化能力最强的分支。剪枝损失函数表达式:。其中,为正则化参数(和线性回归的正则化一样),为训练数据的预测误差(回归树为均方差,分类树为基尼系数),为树T叶子节点数量。CART采用后剪枝法。
七、其他决策树算法
** M5 回归模型树**
M5 回归模型树与CART的区别在于:
1)CART在做回归时,是取的叶节点的均值或中位数,而M5是用叶节点的数据再做了一个线性模型,用线形模型的结果做为输出;
2)M5 在做分裂判断时,用的是标准差的加权(MSE),CART用的是均方差;
# 在R中使用M5算法
library(RWeka)
m5tree<-M5P(LeagueIndex~., data=skillcraft_train)
八、总结及R的实现
算法 | 区分要点 | R包 |
---|---|---|
ID3 | 使用信息增益 | rpart包中rpart函数 |
C4.5 | 使用信息增益率 | RWeka包中J48() |
CART | 使用基尼系数 | rpart包中rpart函数,tree包中的tree函数 |
C5.0 | C4.5的改进,比较适合于大规模数据 | c50包 |
rpart包常用参数如下:
参数 | 意义 |
---|---|
formula | y ~ x1+x2+x3+...+xn,或则 y~. 指明了因变量和自变量 |
data | 使用的数据集 |
na.action | 缺失值处理,默认na.rpart,表示删掉因变量y yy缺失的观测,但是保留自变量缺失的观测 |
method | 决策树的类型,“exp”用于生存分析,“poisson”用于二分类变量,“class”用于分类变量(使用居多),”anova”对应回归树 |
parms | 只用于分类树,parms=list(split,prior,loss),其中split的选项默认"gini"(对应CART),和"information"(对应ID3算法) |
control | 控制决策树形状的参数 |
C50使用示例:
#预测纸币的真实性,四个特征分别是:小波变换后图像的方差、偏斜度、峰度以及图像的熵
bnote <- read.csv("data_banknote_authentication.txt", header=FALSE)
names(bnote) <- c("waveletVar", "waveletSkew", "waveletCurt", "entropy", "class") #添加列名
bnote$class <- factor(bnote$class)
#划分测试集和训练集
library(caret)
set.seed(266)
bnote_sampling_vector <- createDataPartition(bnote$class, p = 0.80, list = FALSE)
bnote_train <- bnote[bnote_sampling_vector,]
bnote_test <- bnote[-bnote_sampling_vector,]
#使用C50算法构建决策树
library(C50)
bnote_tree <- C5.0(class~.,data=bnote_train)
#在测试集上进行预测
bnote_predictions <- predict(bnote_tree,bnote_test)
mean(bnote_test$class == bnote_predictions) #评判测试集上的预测效果
rpart 使用示例:
#预测星际争霸的比赛
skillcraft <- read.csv("SkillCraft1_Dataset.csv")
#去掉GameId,因为不是特征量,对分类无用
skillcraft<-skillcraft[-1]
skillcraft$TotalHours=as.numeric(skillcraft$TotalHours)
skillcraft$HoursPerWeek=as.numeric(skillcraft$HoursPerWeek)
skillcraft$Age=as.numeric(skillcraft$Age)
View(skillcraft)
#划分测试集和训练集
library(caret)
set.seed(133)
skillcraft_sampling_vector <- createDataPartition(skillcraft$LeagueIndex, p = 0.80, list = FALSE)
skillcraft_train <- skillcraft[skillcraft_sampling_vector,]
skillcraft_test <- skillcraft[-skillcraft_sampling_vector,]
#使用rpart进行决策树分类
library(rpart)
regtree <- rpart(LeagueIndex~., data=skillcraft_train)
#计算SSE
compute_SSE <- function(correct,predictions) {
return(sum((correct-predictions)^2))
}
#在测试集上进行预测
regtree_predictions = predict(regtree, skillcraft_test)
(regtree_SSE <- compute_SSE(regtree_predictions, skillcraft_test$LeagueIndex))#评判测试集上的预测效果
#minsplit指定父节点和子节点中所包含的最少样本量,是最小分支节点数,这里指大于等于20,那么该节点会继续分划下去,否则停止
#minbucket:叶子节点最小样本数
#cp 复杂程度.用于控制树的复杂度,指某个点的复杂度,对每一步拆分,模型的拟合优度必须提高的程度。值越低,复杂度越高,默认为0.01;
#maxdepth 最大10层
#xval指定交叉验证的次数;
regtree.random <- rpart(LeagueIndex~., data=skillcraft_train,
control=rpart.control(minsplit=20, cp=0.001, maxdepth=10))
regtree.random_predictions = predict(regtree.random, skillcraft_test)
(regtree.random_SSE <- compute_SSE(regtree.random_predictions,
skillcraft_test$LeagueIndex))
#调参
library(e1071)
rpart.ranges<-list(minsplit=seq(5,50,by=5),
cp=c(0,0.001,0.002,0.005,0.01,0.02,0.05,0.1,0.2,0.5), maxdepth=1:10)
(regtree.tune<-tune(rpart,LeagueIndex~., data=skillcraft_train, ranges=rpart.ranges))
#用调参后的值进行预测
regtree.tuned <- rpart(LeagueIndex~., data=skillcraft_train,control=rpart.control(minsplit=50, cp=0.002, maxdepth=7))
regtree.tuned_predictions = predict(regtree.tuned, skillcraft_test)
(regtree.tuned_SSE <- compute_SSE(regtree.tuned_predictions, skillcraft_test$LeagueIndex))
【参考文章】
能否尽量通俗地解释什么叫做熵?
信息熵是什么?
通俗理解信息熵
机器学习实战(三)——决策树
通俗易懂的讲解决策树
决策树算法原理
R语言:决策树ID3/C4.5/CART/C5.0算法的实现