大师兄的数据分析学习笔记(十六):分类模型(二)
大师兄的数据分析学习笔记(十八):分类模型(四)
三、决策树
- 决策树就是模仿做决策的过程,根据特征一步步判断。
-
假设相亲的过程:
- 在决策树中,特征的顺序至关重要,所以首先需要将特征排序:
1. 信息增益
- 熵代表随机变量,或整个系统的不确定性。
- 信息增益指的是一个事件的熵减去另一个事件在条件下的熵,得到的是熵的变化。
- 根据公式算,信息增益值越大的特征,代表对事件的影响越大,应该排在前面。
2. 信息增益率
- 信息增益率对信息增益进行了改进,考虑分裂后每个子结点的样本数量的纯度。
- 信息增益率的定义:
- 在上面的公式中,分子为信息增益,而分母为事件Y的熵。
- 在相同的信息增益下,分裂的特征越纯越好,意味着信息增益率偏向选择纯度大的特征。
3. 基尼系数
- 基尼系数也叫不纯度:
- 如果一个切分的不纯度减小的比较大,就可以考虑把这个切分先进行决策。
4. 决策树的问题
4.1 连续值切分
- 如果遇到连续值,则对决策树进行从小到大排序,对每个间隔进行切分。
- 计算切分后的各个因子,取该因子最好的连续值切分作为他的切分。
4.2 规则用尽
- 如果所有的规则用尽,但数据还没有切分干净,则可以采用投票的方式。
- 也可以再次用规则进行切分,直到获得没有杂质的叶子节点。
4.2 过拟合
- 如发生过拟合的情况,则需要进行决策树剪枝。
- 剪枝分为前剪枝和后剪枝:
- 前剪枝表示在构造决策树之前,规定每个叶子节点的样本数量,或规定决策树的最大深度。
- 后剪枝表示构造决策树后,对样本值比较悬殊的枝叶进行修建。
5. 代码实现
import os
import pandas as pd
import numpy as np
import pydotplus
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB,BernoulliNB
from sklearn.metrics import accuracy_score,recall_score,f1_score
from sklearn.tree import DecisionTreeClassifier,export_graphviz
models = []
models.append(("DesicionTree",DecisionTreeClassifier())) # 基尼决策树
models.append(("DesicionTreeEntropy",DecisionTreeClassifier(criterion="entropy"))) # 信息增益决策树
df = pd.read_csv(os.path.join(".", "data", "WA_Fn-UseC_-HR-Employee-Attrition.csv"))
X_tt,X_validation,Y_tt,Y_validation = train_test_split(df.JobLevel,df.JobSatisfaction,test_size=0.2)
X_train,X_test,Y_train,Y_test = train_test_split(X_tt,Y_tt,test_size=0.25)
for clf_name,clf in models:
clf.fit(np.array(X_train).reshape(-1,1),np.array(Y_train).reshape(-1,1))
xy_lst = [(X_train,Y_train),(X_validation,Y_validation),(X_test,Y_test)]
for i in range(len(xy_lst)):
X_part = xy_lst[i][0]
Y_part = xy_lst[i][1]
Y_pred = clf.predict(np.array(X_part).reshape(-1,1))
print(i)
print(clf_name,"-ACC",accuracy_score(Y_part,Y_pred))
print(clf_name,"-REC",recall_score(Y_part,Y_pred,average='macro'))
print(clf_name,"-F1",f1_score(Y_part,Y_pred,average='macro'))
print("="*40)
dot_data = export_graphviz(clf,out_file=None,filled=True,rounded=True,special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf(f"graph_{clf_name}.pdf")