树回归|理论与算法实现

上一篇文章中,我们比较全面地学习了线性回归的原理是实现,今天我们还是留在回归板块,针对树回归进行学习和实践。

01 树回归原理

相比于线性回归,树回归更适合对复杂、非线性的数据进行回归建模。

原理

回想一下决策树,树回归的原理就是决策树(人家都叫”树“回归了……),在决策树的学习中,有三种算法,ID3, C4.5, CART,前两种算法只能处理离散型数据,因此只能用于回归,而CART算法由于采用二分法构建树,可以处理连续性数据,因此也可以用于回归,树回归的基本原理,就是CART算法。

混乱度衡量

说完算法原理,我们再来说说对于连续性数据的混乱度度量。我们知道,对于离散型数据,可以使用信息增益、信息增益比、基尼指数这些指标来衡量数据的混乱程度,那么对于连续型数据,怎么衡量呢?

可以使用总方差来衡量连续性数据的混乱程度:(各数据-数据均值)**2,即方差*样本数,称为总方差。

两种树回归

树回归有两种方式:回归树和模型树,其中,

  • 回归树:叶节点是一个值:当前叶子所有样本标签均值
  • 模型树:叶节点是一个线性回归模型:当前叶子所有样本的线性回归模型

下面我们分别实现回归树和模型树。


02 回归树实现

  • 叶节点是一个值:当前叶子所有样本标签均值
  • 误差衡量:总方差,表示一组数据的混乱度,是本组所有数据与这组数据均值之差的平方和

回归树的构建逻辑:二分法,每次选择一个最佳特征,并找到最佳切分特征值(使数据混乱度减少最多的[特征,特征值])进行切分,得到左右子树,然后对左右子树递归调用createTree方法,直到没有最佳特征为止。(实践中,在选择最佳特征时,进行了预剪枝)

#数据读取
def loadDataSet(filename):
    dataMat=[]
    fr=open(filename,'r')
    for line in fr.readlines():
        curLine=line.strip().split('\t')
        fltLine=list(map(float,curLine)) #将curLine各元素转换为float类型
        dataMat.append(fltLine)
    return mat(dataMat)

#二分数据
def binSplitDataSet(dataset,feat,val):
    mat0=dataset[nonzero(dataset[:,feat]>val)[0],:] #数组过滤选择特征大于指定值的数据
    mat1=dataset[nonzero(dataset[:,feat]<=val)[0],:] #数组过滤选择特征小于指定值的数据
    return mat0,mat1

#定义回归树的叶子(该叶子上各样本标签的均值)
def regLeaf(dataset):
    return mean(dataset[:,-1])

#定义连续数据的混乱度(总方差,即连续数据的混乱度=(该组各数据-该组数据均值)**2,即方差*样本数))
def regErr(dataset):
    return var(dataset[:-1])*shape(dataset)[0]

"""最佳特征以及最佳特征值选择函数"""
#leafType为叶节点取值,默认为regleaf,即取样本标签均值,对于模型树,叶节点是一个线性模型
#errType为数据误差(混乱度)计算方式,默认为regErr,总方差
#ops[0]为以最佳特征及特征值切分数据前后,数据混乱度的变化阈值,若小于该阈值,不切分
#ops[1]为切分后两块数据的最少样本数,若少于该值,不切分
#可以预想,回归树形状对ops[0],ops[1]很敏感,若这两个值过小,回归树会很臃肿,过拟合
def chooseBestSplit(dataset,leafType=regLeaf,errType=regErr,ops=(1,4)):
    tolS=ops[0];tolN=ops[1];m,n=shape(dataset)
    S=errType(dataset);bestS=inf;beatIndex=0;bestVal=0
    if len(set(dataset[:,-1].T.tolist()[0]))==1: #若只有一个类别
        return None,leafType(dataset)
    for featIndex in range(n-1):
        for splitVal in set(dataset[:,featIndex].T.tolist()[0]):
            mat0,mat1=binSplitDataSet(dataset,featIndex,splitVal)
            #若切分后两块数据的最少样本数少于设定值,不切分
            if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN): 
                continue
            newS=errType(mat0)+errType(mat1)
            if newS<bestS:
                bestIndex=featIndex;bestVal=splitVal;bestS=newS
    #若以最佳特征及特征值切分后的数据混乱度与原数据混乱度差值小于阈值,不切分
    if (S-bestS)<tolS:
        return None,leafType(dataset)
    mat0,mat1=binSplitDataSet(dataset,bestIndex,bestVal)
    #若以最佳特征及特征值切分后两块数据的最少样本数少于设定值,不切分
    if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):
        return None,leafType(dataset)
    return bestIndex,bestVal

"""构建回归树"""
def createTree(dataset,leafType=regLeaf,errType=regErr,ops=(1,4)):
    feat,val=chooseBestSplit(dataset,leafType,errType,ops)
    if feat==None:
        return val
    regTree={}
    regTree['spFeat']=feat
    regTree['spVal']=val
    lSet,rSet=binSplitDataSet(dataset,feat,val)
    regTree['left']=createTree(lSet,leafType,errType,ops)
    regTree['right']=createTree(rSet,leafType,errType,ops)
    return regTree

好了好了,写了这么多代码,我们来测试一下,原始数据分布如下图,训练结果如下图。可以看到,回归树模型将这组数据分到了5个叶节点上,目前看起来还过得去。


03 回归树剪枝

当我们设置的最小分离叶节点样本数、最小混乱度减小值等参数过小,可能产生过拟合,直观的现象就是,训练出来非常多的叶子,其实是没有必要的,此时就需要剪枝了(很形象嘛)。

剪枝分为预剪枝和后剪枝,

  • 预剪枝:在chooseBestSplit函数中的几个提前终止条件(切分样本小于阈值、混乱度减弱小于阈值),都是预剪枝(参数敏感)。
  • 后剪枝:使用测试集对训练出的回归树进行剪枝(由于不需要用户指定,后剪枝是一种更为理想化的剪枝方法)

后剪枝逻辑:对训练好的回归树,自上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降低测试误差,若能,则合并。

#判断是否是一棵树(字典)
def isTree(obj):
    return (type(obj).__name__=='dict')

#得到树所有叶节点的均值
def getMean(tree):
    #若子树仍然是树,则递归调用getMeant直到叶节点
    if isTree(tree['left']):
        tree['left']=getMean(tree['left'])
    if isTree(tree['right']):
        tree['right']=getMean(tree['right'])
    return (tree['left']+tree['right'])/2.0

"""剪枝函数:对训练好的回归树,自上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降低测试误差,若能则合并"""
def prune(tree,testData):
    #若无测试数据,则直接返回树所有叶节点的均值(塌陷处理)
    if shape(testData)[0]==0:
        return getMean(tree)
    #若存在任意子集是树,则将测试集按当前树的最佳切分特征和特征值切分(子集剪枝用)
    if isTree(tree['left']) or isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
    #若存在任意子集是树,则该子集递归调用剪枝过程(利用刚才切分好的训练集)
    if isTree(tree['left']):
        tree['left']=prune(tree['left'],lSet)
    if isTree(tree['right']):
        tree['right']=prune(tree['right'],rSet)
    #若当前子集都是叶节点,则计算该二叶节点合并前后的误差,决定是否合并
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
        errNotMerge=sum(power(lSet[:,-1].T.tolist()[0]-tree['left'],2))+sum(power(rSet[:,-1].T.tolist()[0]-tree['right'],2))
        treeMean=(tree['left']+tree['right'])/2.0
        errMerge=sum(power(testData[:,-1].T.tolist()[0]-treeMean,2))
        if errMerge<errNotMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

测试一下

#构建回归树,可以看到,该回归树非常臃肿,过拟合
dataMat3=loadDataSet(r'D:\DM\python\data\MLiA_SourceCode\machinelearninginaction\Ch09\ex2.txt')
regTree3=createTree(dataMat3,ops=(1,2))
testData=loadDataSet(r'D:\DM\python\data\MLiA_SourceCode\machinelearninginaction\Ch09\ex2test.txt')
prune(regTree3,testData)

结果如下,

可以看到,虽然有6个叶节点被剪掉了,但仍然有很多叶节点保留->后剪枝可能不如预剪枝有效,因此一般为了寻求最佳模型,会同时使用两种剪枝技术。


04 模型树实现

  • 叶节点是一个线性回归模型:当前叶子所有样本的线性回归模型
  • 误差衡量:平方误差类比线性回归误差,用线性模型对数据拟合,计算真实值与拟合值之差,求差值的平方和
  • 比回归树有更好的可解释性、更高的预测准确度

模型树构建逻辑:通用函数,算法逻辑与createTree()一致,只需改变其中的叶节点计算方法leafType()和误差计算方法errType()

#叶节点计算方法:该叶节点所有样本的标准线性回归模型,算法与linearRegression()一致
def linearSolve(dataset):
    m,n=shape(dataset)
    X=mat(ones((m,n)));Y=mat(ones((m,1)))
    X[:,1:n]=dataset[:,0:n-1] #X第一列为常数项1
    Y=dataset[:,-1]
    xTx=X.T*X
    if linalg.det(xTx)==0.0:
        raise NameError("矩阵为奇异矩阵,不可逆,尝试增大ops的第二个参数")
    ws=xTx.I*(X.T*Y)
    return ws,X,Y

def modelLeaf(dataset):
    ws,X,Y=linearSolve(dataset)
    return ws

#误差计算方法:用线性模型对数据拟合,计算真实值与拟合值之差,求差值的平方和
def modelErr(dataset):
    ws,X,Y=linearSolve(dataset)
    yPred=X*ws
    return sum(power(Y-yPred,2))

测试一下,训练集如图蓝点所示,训练模型如图红线所示,可以看到,模型树对数据的预测更准确合理。

dataMat4=loadDataSet(r'D:\DM\python\data\MLiA_SourceCode\machinelearninginaction\Ch09\exp2.txt')
modelTree1=createTree(dataMat4,leafType=modelLeaf,errType=modelErr,ops=(1,10))

05 模型树剪枝

同样地,模型树也会出现过拟合,也需要剪枝,原理与回归树剪枝一样,只需要替换其中的误差计算方式,然后微调一下剪枝代码,让每次递归时,对训练数据也递归切分。

#判断是否是一棵树(字典)
def isTree(obj):
    return (type(obj).__name__=='dict')

#剪枝函数:对训练好的模型树,自上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降低测试误差,若能则合并
def modelPrune(tree,trainData,testData):
    m,n=shape(testData)
    #若无测试数据,则直接返回树所有叶节点的均值(塌陷处理)
    if m==0:
        return tree
    #若存在任意子集是树,则将测试集按当前树的最佳切分特征和特征值切分(子集剪枝用)
    #同时将训练集也按当前树的最佳切分特征和特征值切分(子集剪枝用)
    if isTree(tree['left']) or isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
        lTrain,rTrain=binSplitDataSet(trainData,tree['spFeat'],tree['spVal'])
    #若存在任意子集是树,则该子集递归调用剪枝过程(利用刚才切分好的训练集)
    if isTree(tree['left']):
        tree['left']=modelPrune(tree['left'],lTrain,lSet)
    if isTree(tree['right']):
        tree['right']=modelPrune(tree['right'],rTrain,rSet)
        
    #若当前子集都是叶节点,则计算该二叶节点合并前后的误差,决定是否合并
    """
    模型树,两个叶节点合并前的误差=((左叶子真实值-拟合值)的平方和+(右叶子真实值-拟合值)的平方和)
    模型树,两个叶节点合并后的误差=(左右真实值-左右拟合值)的平方和
    难点在于如何求左右拟合值,即求上层节点的回归系数wsMerge:用上层节点的traindata,通过linearSolve(traindata)求得
    上层节点的traindata在lTrain,rTrain的递归中已经求好了
    """
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
        lSetX=mat(ones((shape(lSet)[0],n)));rSetX=mat(ones((shape(rSet)[0],n)))
        lSetX[:,1:n]=lSet[:,0:n-1];rSetX[:,1:n]=rSet[:,0:n-1]
        errNotMerge=sum(power(array(lSet[:,-1].T.tolist()[0])-lSetX*tree['left'],2))+sum(power(array(rSet[:,-1].T.tolist()[0])-rSetX*tree['right'],2))
        #难点在于求上层节点的回归系数wsMerge:用上层节点的traindata,通过linearSolve(traindata)求得
        wsMerge=modelLeaf(trainData)  
        testDataX=mat(ones((m,n)));testDataX[:,1:n]=testData[:,0:n-1]
        errMerge=sum(power(array(testData[:,-1].T.tolist()[0])-testDataX*wsMerge,2))
        if errMerge<errNotMerge:
            print("merging")
            return wsMerge
        else:
            return tree
    else:
        return tree

测试一下,


06 模型预测效果对比

本次我们构建了回归树和模型树,顺便构建了一个线性回归函数,我们先来看看对于同一组数据,这三个模型的预测效果吧,这里使用R2值来评估预测效果。

结果如下,

  • 可以看到,这此数据集上,模型树表现比回归树好,线性回归表现最差
  • 说明树回归相比于线性回归,可以更好地处理复杂、非线性的数据集

07 总结

至此,我们基本上学习了回归任务中80%以上的算法模型,主要分为线性回归模型和树回归模型(再回忆一下,KNN也可以用于回归),它们各有优劣,针对具有不同特点的数据集,要选择合适的算法。


08 参考

  • 《机器学习实战》 Peter Harrington Chapter9
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,616评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,020评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,078评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,040评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,154评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,265评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,298评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,072评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,491评论 1 306
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,795评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,970评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,654评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,272评论 3 318
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,985评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,223评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,815评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,852评论 2 351

推荐阅读更多精彩内容