决策树底层实现及绘制(python)

# coding:utf-8
import matplotlib
import matplotlib.pyplot as plt
from collections import defaultdict

from math import log
import matplotlib.path as mpath
import matplotlib.patches as mpatches
import numpy as np
from matplotlib import font_manager as fm, rcParams


class DecTree:
    def __init__(self):
        pass

    # 计算香浓熵
    def cacuChannoEnt(self, data_set):
        cls_dict = defaultdict(int)
        for data in data_set:
            cls_dict[data[-1]] += 1
        total = len(data_set)
        channo_ent = 0.0
        for num in cls_dict.values():
            p = float(num) / total
            channo_ent -= p * log(p, 2)
        return channo_ent

    def createDataSet(self):
        dataSet = [[0, 0, 0, 0, 'maybe'], [0, 1, 1, 1, 'error'], [1, 1, 1, 0, 'yes'], [1, 0, 0, 1, 'no'],
                   [1, 0, 1, 1, 'no'],
                   [1, 0, 0, 0, 'yes'], [1, 0, 0, 0, 'maybe'],
                   [1, 1, 1, 0, 'yes']]
        labels = ['no surfaceing', 'flippers', 'flippers']
        return dataSet, labels

    def splitDataSet(self, data_set, feat, value):
        '''
        抽取某个feat值为value的数据子集
        :param data_set: 
        :param feat: 
        :param value: 
        :return: 
        '''
        sub_data_set = []
        for data in data_set:
            if data[feat] == value:
                sub_feat = data[:feat]
                sub_feat.extend(data[feat + 1:])
                sub_data_set.append(sub_feat)
        return sub_data_set

    def chooseBestFeature(self, data_set):
        '''
        获取能产生最大信息增益的feat
        :param data_set: 
        :return: 
        '''
        base_ent = self.cacuChannoEnt(data_set)
        feat_num = len(data_set[0]) - 1
        best_feat = 0
        total = len(data_set)
        ent_gain = 0.0
        for i in range(feat_num):
            uni_vals = set([data[i] for data in data_set])
            ent_tmp = 0
            for v in uni_vals:
                sub_data_set = self.splitDataSet(data_set, i, v)
                p = float(len(sub_data_set)) / total
                ent_tmp += p * self.cacuChannoEnt(sub_data_set)
            cur_gain = base_ent - ent_tmp
            if cur_gain > ent_gain:
                best_feat = i
                ent_gain = cur_gain
        return best_feat

    def allCls(self, data_set):
        '''
        计算数据集中全部类别
        :param data_set: 
        :return: 
        '''
        v_list = [data[-1] for data in data_set]
        s = set(v_list)
        return s

    def createTree(self, data_set, feat_list):
        '''
        创建决策树
        :param data_set:数据集 
        :param feat_list: feat集合
        :return: 
        '''
        node = {}
        if len(feat_list) == 0:
            data_cls = [data[-1] for data in data_set]
            return {'cls': self.majorCnt(data_cls)}
        all_cls = self.allCls(data_set)
        if len(all_cls) == 1:
            return {'cls': all_cls.pop()}
        feat = self.chooseBestFeature(data_set)
        uni_vals = set([v[feat] for v in data_set])
        node['feat'] = feat_list[feat]
        node['label'] = {}
        for v in uni_vals:
            sub_dat_set = self.splitDataSet(data_set, feat, v)
            sub_feat_list = feat_list[:feat]
            sub_feat_list.extend(feat_list[feat + 1:])
            child_nd = self.createTree(sub_dat_set, sub_feat_list)
            child_nd['lb'] = v
            node['label'][v] = child_nd
        return node

    def majorCnt(self, clsList):
        '''
        数据占比最多的分类
        :param clsList: 
        :return: 
        '''
        num_dict = defaultdict(int)
        for cls in clsList:
            num_dict[cls] += 1
        data = zip(num_dict.values(), num_dict.keys())
        sorted_data = sorted(data, reverse=True)
        return sorted_data[-1][1]

    def classfiy(self, vec, tree_root):
        '''
        对给定数据分类
        :param vec: 
        :param tree_root: 
        :return: 
        '''
        feat = tree_root['feat']
        labels = tree_root['label']
        cls = None
        while len(labels) > 0:
            v = vec[feat]
            node = labels[v]
            if 'cls' in node:
                cls = node['cls']
                break
            labels = node['label']
            feat = node['feat']
        return cls


class DecTreePlotter(object):
    '''
    绘制决策树类
    '''
    decNode = dict(boxstyle='square', fc='0.8')
    leafNode = dict(boxstyle='round4', fc='0.4')

    def __init__(self):
        super(DecTreePlotter, self).__init__()

    def draw(self, tree_root):
        width, height = self._getSize(tree_root)
        fig, ax = plt.subplots()
        ax.grid()
        pt = (0.5, 0.9)
        tree_root['loc'] = pt
        # 绘制根结点
        plt.text(pt[0], pt[1], 'feat:{}'.format(tree_root['feat']), horizontalalignment='center', size=10,
                 bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5))
        self.draw_retrieve(ax, [tree_root], width, height, height-1)

    def draw_retrieve(self, ax, p_nodes, width, height, level):
        child_nodes = []
        index = 0
        cell_width = 1.0 / width - 0.1/width
        for pn in p_nodes:
            cur_nodes = pn['label'].values()
            p_pt = pn['loc']
            for i in range(len(cur_nodes)):
                nd = cur_nodes[i]
                if 'cls' not in nd:
                    child_nodes.append(nd)
                    txt = 'feat:{}'.format(str(nd['feat']))
                else:
                    txt = '{}'.format(nd['cls'])
                txt_pt = ((index + 1) * cell_width, float(level) * (1.0 / height))
                nd['loc'] = txt_pt
                node_type = DecTreePlotter.leafNode if 'cls' in nd else DecTreePlotter.decNode
                self.plotNode(ax, txt, txt_pt, p_pt, node_type)
                mid_pt = (txt_pt[0] / 2 + p_pt[0] / 2, txt_pt[1] / 2 + p_pt[1] / 2)
                plt.text(mid_pt[0], mid_pt[1], str(nd['lb']), color='red', size=20)
                index += 1

        if len(child_nodes) > 0:
            self.draw_retrieve(ax, child_nodes, width, height, level - 1)

    def _getSize(self, tree_root):
        cur_nodes = [tree_root]
        width = len(cur_nodes)
        height = 0
        while len(cur_nodes) > 0:
            tmp_nodes = []
            height += 1
            cur_width = 0
            for node in cur_nodes:
                label_nodes = node['label'].values()
                cur_width += len(label_nodes)
                tmp_nodes.extend([vo for vo in label_nodes if 'cls' not in vo])

            width = cur_width if cur_width > width else width
            cur_nodes = tmp_nodes
        return width, height + 1

    def plotNode(self, ax, nodeText, centerPt, parentPt, nodeType):
        print '{}-{}'.format(centerPt, parentPt)
        ax.annotate(nodeText, xy=parentPt, xycoords='axes fraction', \
                    xytext=centerPt, textcoords='axes fraction', \
                    va='center', ha='center', bbox=nodeType, arrowprops=dict(arrowstyle='<-',connectionstyle="arc,angleA=60,angleB=20,rad=0.0"))


tree = DecTree()
data_set, labels = tree.createDataSet()
root = tree.createTree(data_set, [0, 1, 2,3])
# 测试分类器
dataSet = [[0, 0, 0, 'maybe'], [1, 1, 1, 'yes'], [1, 0, 1, 'no'], [0, 1, 1, 'no'], [0, 0, 0, 'yes'],
           [1, 1, 1, 'yes']]
for vec in data_set:
    cls = tree.classfiy(vec, root)
    print 'vec:{},cls is {},real is {}'.format(vec, cls, vec[-1])

# 绘制决策树
tree_plotter = DecTreePlotter()
tree_plotter.draw(root)
plt.show()

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,294评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,780评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,001评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,593评论 1 289
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,687评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,679评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,667评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,426评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,872评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,180评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,346评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,019评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,658评论 3 323
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,268评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,495评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,275评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,207评论 2 352

推荐阅读更多精彩内容

  • 如果要用一个词来形容2018年的房地产,那就是“用时间换空间”。直白一点说,就是尽量拖。这也是所有出台政策的最终指...
    RGLR阅读 196评论 0 0
  • 我是隔壁老张阅读 181评论 0 0
  • 杭州红舞鞋少儿舞蹈培训——跳舞的孩子最美丽 少儿时期是整个人生的开端。在这一时期,少年儿童无论是身体还是心理都在迅...
    杭州红舞鞋阅读 531评论 0 0
  • 六。 我买了套餐,找到一个位置,拿出手机开了机,一下子蹦出来七条留言。 法国电话公司都提供留言系统。关机或者响铃五...
    卢璐说阅读 256评论 0 2
  • 上午,我们攻克了走美杯下五道题目,那五道题目,乍看之下,哇,好简单,可仔细一看,这题目怎么这么难,做的来的...
    ED艾迪阅读 165评论 0 1