决策树图像化显示及剪枝
1.决策树可视化
在上一节(亲手实现决策树(一))中,我们利用print_tree进行了文字输出决策树:
3:21?
T->0:google?
T->{'Premium': 3}
F->{'Basic': 3}
F->2:yes?
T->0:slashdot?
T->{'None': 2}
F->{'Basic': 3}
F->{'None': 4}
下面将介绍如何图形化显示决策树
def draw_tree(tree, jpeg='tree.jpeg'):
w = get_width(tree) * 100
h = get_depth(tree) * 100 + 120
img = Image.new('RGB', (w,h), color=(255, 255, 255))
draw = ImageDraw.Draw(img)
draw_node(draw, tree, w/2, 20)
img.save(jpeg, 'JPEG')
利用到的辅助函数get_width如下:
def get_width(tree):
if tree.tb is None and tree.fb is None:
return 1
return get_width(tree.tb) + get_width(tree.fb)
利用到的辅助函数get_depth如下:
def get_depth(tree):
if tree.tb is None and tree.fb is None:
return 0
return max(get_depth(tree.tb), get_width(tree.fb)) + 1
利用到的辅助函数draw_node如下:
def draw_node(draw, tree, x, y):
if tree.results is None:
# 得到每个分支的宽度
w1 = get_width(tree.fb) * 100
w2 = get_width(tree.tb) * 100
# 确定此节点所要占据的总空间
left = x - (w1 + w2) / 2
right = x + (w1 + w2) / 2
# 绘制判断条件字符串
draw.text((x-20, y-10), str(tree.col) + ":" + str(tree.value), (0, 0, 0))
# 绘制到分支的连线
draw.line((x, y, left + w1/2, y + 100), fill=(255, 0, 0))
draw.line((x, y, right - w2/2, y + 100), fill=(255, 0, 0))
# 绘制分支的节点
draw_node(draw, tree.fb, left+w1/2, y+100)
draw_node(draw, tree.tb, right-w2/2, y+100)
else:
txt = ' \n'.join(['%s:%d' % v for v in tree.results.items()])
draw.text((x - 20, y), txt, (0, 0, 0))
画出来的结果为:
决策树
2.决策树的剪枝
为了避免过拟合,需要对决策树进行剪枝,如果对某个节点分类后的子节点信息增益小于给定阈值,则不进行细化。
def prune(tree, min_gain):
# 如果分支不是叶节点,则对其进行剪枝操作
if tree.tb.results is None:
prune(tree.tb, min_gain)
if tree.tb.results is None:
predict(tree.fb, min_gain)
# 如果两个子分支都是叶子节点,则判断它们是否需要合并
if tree.tb.results is not None and tree.fb.results is not None:
# 构造合并后的数据集
tb, fb = [], []
for v, c in tree.tb.results.items():
tb += [[v]] * c
for v, c in tree.tb.results.items():
fb += [[v]] * c
# 检查熵的减少情况
delta = entropy(tb + fb) - (entropy(tb) + entropy(fb))/2
if delta < min_gain:
# 合并分支
tree.tb, tree.fb = None, None
tree.results = unique_counts(tb + fb)
剪枝后的结果如下:
剪枝后的决策树