树剪枝
一棵树如果节点过多,说明该模型存在过拟合问题。
通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning)。树回归(一)中的chooseBestSplit
函数中的提前终止条件,实际上是一种预剪枝(prepruning)操作。另一种形式的剪枝需要使用测试集和训练集,称作后剪枝(postpruning)。
预剪枝的不足
树回归(一)中的树构建算法对参数tolS
和tolN
非常敏感,下面用树回归(一)中的第一个数据集,采用不同的ops
参数,来观察结果。
dataSet = loadDataSet('ex00.txt')
myMat = np.mat(dataSet)
createTree(myMat, ops=(0,1))
结果如下:
{'spInd': 0,
'spVal': 0.48813,
'left': {'spInd': 0,
'spVal': 0.620599,
'left': {'spInd': 0,
......
'right': {'spInd': 0,
'spVal': 0.325412,
'left': {'spInd': 0, 'spVal': 0.3371, 'left': 0.1910235, 'right': 0.118208},
'right': -0.028594120689655174}}}
由于输出过长,这里省略部分内容。
与上文中只包含两个节点的树相比,这里构建的树过于臃肿。
下面用一个与ex00.txt
数据集分布类似,但y轴数量级是其100倍的ex2.txt
数据集来构建树。
dataSet = loadDataSet('ex2.txt')
myMat2 = np.mat(dataSet)
createTree(myMat2)
结果如下
{'spInd': 0,
'spVal': 0.499171,
'left': {'spInd': 0,
'spVal': 0.729397,
'left': {'spInd': 0,
'spVal': 0.952833,
'left': 108.838789625,
......
'right': {'spInd': 0,
'spVal': 0.457563,
'left': 7.969946125,
'right': -3.6244789069767447}}
用默认参数构建的树显得比较臃肿。下面是其分布。
ex00.txt
与ex2.txt
两个数据集分布类似,但在都采用默认参数的情况下,ex00.txt
构建的树只有两个叶节点,而ex2.txt
却有很多。产生这种现象的原因在于,停止条件tolS
对误差的数量级十分敏感。如果在选项上花费时间并对上述误差容忍度取平方值,也能得到两个叶节点的树:
createTree(myMat2, ops=(10000, 4))
output:
{'spInd': 0,
'spVal': 0.499171,
'left': 101.35815937735848,
'right': -2.637719329787234}
然而,通过不断修改参数来得到合理结果并不是很好的办法。
下面将介绍后剪枝,利用测试集来对树进行剪枝,并不需要指定参数,是一种更理想化的剪枝方法。
后剪枝
剪枝函数prune()
的伪代码如下:
基于已有的树切分测试数据:
如果存在任意子集是一棵树,则在该子集递归剪枝
计算将当前两个叶节点合并后的误差
计算不合并的误差
如果合并会降低误差的话,就将叶节点合并
def isTree(obj):
return (type(obj).__name__ == 'dict')
def getMean(tree):
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
return (tree['left'] + tree['right'])/2
def prune(tree, testData):
# 没有测试数据则对树进行塌陷处理
if testData.shape[0] == 0:
return getMean(tree)
if (isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = splitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']):
tree['right'] = prune(tree['right'], rSet)
# 如果都是叶节点,看能不能合并
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = splitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(np.power(lSet[:,-1] - tree['left'],2)) +\
sum(np.power(rSet[:,-1] - tree['right'],2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = sum(np.power(testData[:,-1] - treeMean,2))
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else: return tree
else: return tree
isTree()
判断是否为树。
getMean()
从上往下遍历树直到叶节点为止。该函数对树进行塌陷处理,即返回树平均值。
接下来看看实际效果。
# 构建一个过拟合的树
myTree = createTree(myMat2, ops=(0,1))
# 加载测试集
testData = loadDataSet('ex2test.txt')
testMat = np.mat(testData)
# 剪枝
prune(myTree, testMat)
运行后观察两棵树,可以发现大量节点被剪枝掉,但没有预期那样剪枝成两部分。
一般地,为了寻求最佳模型可以同时使用两种剪枝技术。