1.决策树模型可视化:
1.1 准备步骤
安装graphviz
程序,https://graphviz.gitlab.io/download/
安装graphviz
模块,pip install graphviz
安装pydotplus
模块,pip install pydotplus
1.2 可视化步骤
使用sklearn中export_graphviz()
方法导出树的结构,
使用pydotplusgraph_from_dot_file()、graph_from_dot_data()
方法根据树的结构生成可视化结果。
2.代码:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_graphviz
import pydotplus
import matplotlib.pyplot as plt
data_file = './data/Iris.csv'
CATEGRORY_LABEL_DICT = {
'Iris-setosa': 0, # 山鸢尾
'Iris-versicolor': 1, # 变色鸢尾
'Iris-virginica': 2 # 维吉尼亚鸢尾
}
# 使用的特征列
FEAT_COLS = ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
def plot_decision_tree(dt_model):
"""
可视化决策树的结构
"""
dot_file = 'decision_tree.dot'
categrory_list = list(CATEGRORY_LABEL_DICT.keys())
# 导出树的结构
export_graphviz(dt_model, out_file=dot_file, feature_names=FEAT_COLS, class_names=categrory_list, filled=True,
impurity=False)
with open(dot_file) as f:
# 读取生成的结构文件,生成可视化结果
graph = pydotplus.graph_from_dot_file(dot_file)
graph.write_png('decision_tree.png')
def plot_features_importance(dt_model):
"""
可视化特征重要性
"""
print('特征名称', FEAT_COLS)
# 通过feature_importances_获取特征重要性
print('特征重要性:', dt_model.feature_importances_)
# 生成图表
plt.figure()
plt.barh(range(len(FEAT_COLS)), dt_model.feature_importances_)
plt.xlabel('features importance')
plt.ylabel('features name')
plt.yticks(np.arange(len(FEAT_COLS)),FEAT_COLS)
plt.tight_layout()
plt.show()
def main():
"""
主函数
"""
iris_data = pd.read_csv(data_file, index_col='Id')
# 添加label一列作为预测标签
iris_data['Label'] = iris_data['Species'].apply(lambda category_name: CATEGRORY_LABEL_DICT[category_name])
# 4列花的属性作为样本特征
X = iris_data[FEAT_COLS].values
# label列为样本标签
y = iris_data['Label'].values
# 将原始数据集拆分成训练集和测试集,测试集占总样本数的1/3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1 / 3, random_state=10)
# 构建模型
dt_model = DecisionTreeClassifier(max_depth=4)
dt_model.fit(X_train, y_train)
# 可视化决策树的结构
plot_decision_tree(dt_model)
# 可视化特征重要性
plot_features_importance(dt_model)
if __name__ == '__main__':
main()
3.输出结果:
特征名称 ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
特征重要性: [0. 0. 0.40141195 0.59858805]