首页人工智能Scikit-learn7.深入可视化决策树

7.深入可视化决策树

1.决策树模型可视化:

1.1 准备步骤

安装graphviz程序,https://graphviz.gitlab.io/download/
安装graphviz模块,pip install graphviz
安装pydotplus模块,pip install pydotplus

1.2 可视化步骤

使用sklearn中export_graphviz()方法导出树的结构,
使用pydotplusgraph_from_dot_file()、graph_from_dot_data()方法根据树的结构生成可视化结果。

2.代码:

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_graphviz
import pydotplus
import matplotlib.pyplot as plt

data_file = './data/Iris.csv'

CATEGRORY_LABEL_DICT = {
	'Iris-setosa': 0,  # 山鸢尾
	'Iris-versicolor': 1,  # 变色鸢尾
	'Iris-virginica': 2  # 维吉尼亚鸢尾
}

# 使用的特征列
FEAT_COLS = ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']


def plot_decision_tree(dt_model):
	"""
		可视化决策树的结构
	"""
	dot_file = 'decision_tree.dot'
	categrory_list = list(CATEGRORY_LABEL_DICT.keys())
	# 导出树的结构
	export_graphviz(dt_model, out_file=dot_file, feature_names=FEAT_COLS, class_names=categrory_list, filled=True,
					impurity=False)
	with open(dot_file) as f:
		# 读取生成的结构文件,生成可视化结果
		graph = pydotplus.graph_from_dot_file(dot_file)
		graph.write_png('decision_tree.png')


def plot_features_importance(dt_model):
	"""
		可视化特征重要性
	"""
	print('特征名称', FEAT_COLS)
	# 通过feature_importances_获取特征重要性
	print('特征重要性:', dt_model.feature_importances_)

	# 生成图表
	plt.figure()
	plt.barh(range(len(FEAT_COLS)), dt_model.feature_importances_)
	plt.xlabel('features importance')
	plt.ylabel('features name')
	plt.yticks(np.arange(len(FEAT_COLS)),FEAT_COLS)
	plt.tight_layout()
	plt.show()


def main():
	"""
		主函数
	"""
	iris_data = pd.read_csv(data_file, index_col='Id')

	# 添加label一列作为预测标签
	iris_data['Label'] = iris_data['Species'].apply(lambda category_name: CATEGRORY_LABEL_DICT[category_name])

	# 4列花的属性作为样本特征
	X = iris_data[FEAT_COLS].values
	# label列为样本标签
	y = iris_data['Label'].values

	# 将原始数据集拆分成训练集和测试集,测试集占总样本数的1/3
	X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1 / 3, random_state=10)

	# 构建模型
	dt_model = DecisionTreeClassifier(max_depth=4)
	dt_model.fit(X_train, y_train)

	# 可视化决策树的结构
	plot_decision_tree(dt_model)

	# 可视化特征重要性
	plot_features_importance(dt_model)


if __name__ == '__main__':
	main()

3.输出结果:

特征名称 ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
特征重要性: [0.         0.         0.40141195 0.59858805]
RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments