〖原视频地址〗
分类器有很多种,比如神经网络、或支持向量机,决策树则是其中之一。决策树一大特点便是简单易读,便于理解。实际上,决策树是为数不多的可解释的分类器。你可以彻底理解为什么这个分类器做出了这样的选择。
鸢尾花数据
这节课使用一组真实的数据集-鸢尾花数据。“鸢尾花”是一个典型的机器学习问题。在这个问题中,我们使用不同的测量标准,例如花瓣的宽度、长度来辨别是哪种鸢尾花。在数据集里有三种不同的鸢尾花,山鸢尾、变色鸢尾、青龙鸢尾。
数据集中每种花有50个样本,所以共有150个样本。每个样本有四个特征值来描述,分别是萼片和花瓣的长度和宽度。每行数据的前面四列是特征值,最后一列是本行鸢尾花数据的种类,也就是标签。
本节课的目标就是通过决策器训练这些数据,然后可视化分类器的决策过程。
导入数据
scikit-learn 提供了一系列的样本数据集,我们可以很方便的导入到项目中。
from sklearn.datasets import load_iris
iris = load_iris()
拆分数据
我们需要 从样本数据中抽取部分数据作为验证的测试数据,剩余数据作为训练数据。
test_index = [0, 50, 100]
# traing data
train_target = np.delete(iris.target, test_index)
train_data = np.delete(iris.data, test_index, axis=0)
这里我们用到了 Numpy。Numpy - 是Python语言的一个扩充程序库,支持高级大量的维度数组与矩阵运算,此外也针对数组运算提供大量的数学函数库。
我们先导入 Numpy,这个库包含在 anaconda3-4.4.0 中。
import numpy as np
训练数据
这部分和第一课的代码是一致的。
# testing data
test_target = iris.target[test_index]
test_data = iris.data[test_index]
clf = tree.DecisionTreeClassifier()
clf.fit(train_data, train_target)
print(test_target)
print(clf.predict(test_data))
可视化决策树
本课最关键的部分就是如何将决策树做出判断的过程可视化,需要用到 graphviz 和 pydotplus。
可视化代码如下:
# viz code
dot_data = StringIO()
tree.export_graphviz(clf,
out_file=dot_data,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
impurity=False)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf('iris.pdf')
我们最终将生成 iris.pdf 的决策图,如下:
B 站视频网址
我顺便把 youtube 的视频嵌上字幕后上传到了 B 站,网址在 这里