前言
最近刚开始看XGBoost,发现和回归树有关,这一块确实不太熟悉,于是在网上找了一些资料了解了一下。
作用
首先区分回归树与决策树。决策树的作用说白了是一个分类器,通过对特征的选择,划分,对数据进行分类。具体的算法这里也不再多说了,李航老师的《统计学习方法》里面讲的已经很清楚了。
与决策树不同,回归树做的是回归,是对值的回归预测。比如可以通过回归树预测房价,或者预测人的年龄,等等。输出的是连续值,而不是离散的分类类别。
算法
通俗的讲一回归树的思路,找到一个最好的特征的最优的划分点,把整个数据集根据这个划分点分成大于和小于的两个子集。然后对于这两个划分后的子集再分别寻找最优的划分点。直到满足终止条件,那么回归树就构建完成了。
首先,如何选择回归树的划分点。
遍历数据空间的特征,和每个特征所对应的所有取值。假设将j特征的s取值处选为取值点,那么由这个切分点将得到两个区域。
对于最优切分点的寻找是通过最小化目标函数。
其中和的计算是计算区间内的平均值。
使用均值的原因如下:
假设我们用L来表示区间上的损失,那么对于真实值和区间的表示值而言,
为了最小化这个损失,求导,将梯度设为0之后,可以求得结果。
所以,对于每个划分出来的区间,我们用均值来表示这个区间的值。
接下来就是不停的重复以上的步骤,寻找特征,再寻找特征里的最优划分点,划分区域,把均值作为这个区域的输出。直到最后构建好回归树。
下面具体看一下回归树算法的流程(图片来自《统计学习方法》),
关于终止条件一直没有找到一个很确切的定义,个人理解可以人为的设定树的深度,比如当树的深度达到5层时就停止继续划分。另一种思路可以设置一个关于准确度的阈值,当整个回归树的预测准确度(误差)低于阈值时就停止进一步的划分。如果有其他的方法希望可以在下面留言回复。
关于回归树的复杂度,假设当前的数据存在F个特征,每个特征里面有N个取值。如果生成的回归树最终有S个内部结点,那么整个的复杂度为
代码分析
这里使用sklearn分别构建了3棵回归树,对应的对应的深度分别为1,2,5。并将结果与线性回归做了简单的对比。
从结果分析中可以看出,当树的深度为5时,很好的拟合了数据点。表现要比普通的线性回归好很多。
对于树的深度的选择就涉及到了过拟合问题,包括了树的剪枝。后续如果遇到这些情况会再针对剪枝写篇文章总结一下。