Nucleon_17th 2022-12-03 16:54 采纳率: 100%
浏览 36
已结题

plot_tree模块绘图无法显示分类名(但有输入分类名参数)

提问::plot_tree模块绘图无法显示分类名,但是已经加上了class_names参数


import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV # 通过网格方式来搜索参数
from sklearn.tree import DecisionTreeClassifier as dtc
import matplotlib.pyplot as plt # 可视化
from matplotlib import rcParams # 图大小
from sklearn.tree import plot_tree # 树图
from termcolor import colored as cl # 文本自定义

data = pd.read_csv("class_weathering_chemical.csv")#导入数据集
features = data.columns[1:15]
X = data[features]#设置待估Xy
y = data[data.columns[15:17]]

# 设置需要搜索的参数值,在这里寻找最优的决策树深度
parameters = {'max_depth':[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]}
model = dtc()  # 注意:在这里不用指定参数
# GridSearchCV
clf = GridSearchCV(model, parameters, cv=5)   
clf.fit(X, y)
# 输出最好的参数以及对应的准确率
print ("best score is: %.4f"%clf.best_score_, "  best param: ",clf.best_params_,)

model = dtc(max_depth=3,criterion = 'entropy')#一层无法可视化,于是构建两层
model = model.fit(X, y)

rcParams['figure.figsize'] = (5, 5)
target_names = ['rich_k','PbBa']
from sklearn.tree import plot_tree # 树图
plot_tree(
    model, 
    feature_names = features,  
    class_names= target_names,
    filled=True,
    rounded = True
)
plt.savefig('tree_visualization.png')

运行出来的图像并没有显示分类名

img

想要达成

img

这样的下面有显示类名的效果

  • 写回答

1条回答 默认 最新

  • ShowMeAI 2022-12-03 18:07
    关注

    y = data[data.columns[15:17]]
    似乎你的label的2个类别是用one-hot展开为2列的,用decision tree你放到一列里就可以。

    你可以试试用下面的函数转到一列中,再fit

    # conver one-hot to label
    def one_hot_to_label(y):
        y_label = np.zeros(y.shape[0])
        for i in range(y.shape[0]):
            y_label[i] = np.argmax(y[i])
        return y_label
    
    y = one_hot_to_label(y)
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 12月11日
  • 已采纳回答 12月3日
  • 创建了问题 12月3日

悬赏问题

  • ¥15 Opencv配置出错
  • ¥15 模电中二极管,三极管和电容的应用
  • ¥15 关于模型导入UNITY的.FBX: Check external application preferences.警告。
  • ¥15 气象网格数据与卫星轨道数据如何匹配
  • ¥100 java ee ssm项目 悬赏,感兴趣直接联系我
  • ¥15 微软账户问题不小心注销了好像
  • ¥15 x264库中预测模式字IPM、运动向量差MVD、量化后的DCT系数的位置
  • ¥15 curl 命令调用正常,程序调用报 java.net.ConnectException: connection refused
  • ¥20 关于web前端如何播放二次加密m3u8视频的问题
  • ¥15 使用百度地图api 位置函数报错?