提问::plot_tree模块绘图无法显示分类名,但是已经加上了class_names参数
import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV # 通过网格方式来搜索参数
from sklearn.tree import DecisionTreeClassifier as dtc
import matplotlib.pyplot as plt # 可视化
from matplotlib import rcParams # 图大小
from sklearn.tree import plot_tree # 树图
from termcolor import colored as cl # 文本自定义
data = pd.read_csv("class_weathering_chemical.csv")#导入数据集
features = data.columns[1:15]
X = data[features]#设置待估Xy
y = data[data.columns[15:17]]
# 设置需要搜索的参数值,在这里寻找最优的决策树深度
parameters = {'max_depth':[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]}
model = dtc() # 注意:在这里不用指定参数
# GridSearchCV
clf = GridSearchCV(model, parameters, cv=5)
clf.fit(X, y)
# 输出最好的参数以及对应的准确率
print ("best score is: %.4f"%clf.best_score_, " best param: ",clf.best_params_,)
model = dtc(max_depth=3,criterion = 'entropy')#一层无法可视化,于是构建两层
model = model.fit(X, y)
rcParams['figure.figsize'] = (5, 5)
target_names = ['rich_k','PbBa']
from sklearn.tree import plot_tree # 树图
plot_tree(
model,
feature_names = features,
class_names= target_names,
filled=True,
rounded = True
)
plt.savefig('tree_visualization.png')
运行出来的图像并没有显示分类名
想要达成
这样的下面有显示类名的效果