import pandas as pd
from sklearn.tree import DecisionTreeRegressor
from sklearn import preprocessing
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import GridSearchCV
from sklearn.tree import export_graphviz
import graphviz
# 读取含有热误差数据的CSV文件
train_file = open('data8.csv', encoding='utf-8')
train_df = pd.read_csv(train_file)
# 读取测试集的含有热误差数据的CSV文件
test_file = open('data.csv', encoding='utf-8')
test_df = pd.read_csv(test_file)
# 对训练集进行数据预处理
X_train = train_df.iloc[:, :-1]
Y_train = train_df.iloc[:, -1]
X_train_scaled = preprocessing.scale(X_train)
# 对测试集进行数据预处理
X_test = test_df.iloc[:, :-1]
Y_test = test_df.iloc[:, -1]
X_test_scaled = preprocessing.scale(X_test, with_mean=X_train_scaled.mean(axis=0)[0], with_std=X_train_scaled.mean(axis=0)[0])
X = pd.concat([test_df.iloc[:, :-1], train_df.iloc[:, :-1]], axis=1)
Y = pd.concat([test_df.iloc[:, -1], train_df.iloc[:, -1]], axis=1)
X_test.columns = X_train.columns
# 创建决策树模型
dt = DecisionTreeRegressor()
# 训练决策树模型
# 定义网格搜索参数
param_grid = {
'max_depth': [1,2,3,4,5,6,7,8,9],
'min_samples_split': [2, 4, 6],
'min_samples_leaf': [1, 2, 3]
}
# 进行网格搜索优化
grid = GridSearchCV(dt, param_grid, cv=5, error_score='raise')
grid.fit(X, Y)
grid_search = GridSearchCV(DecisionTreeRegressor(), param_grid, cv=5)
grid_search.fit(X_train, Y_train)
best_model = grid_search.best_estimator_
best_dt = DecisionTreeRegressor(max_depth=grid.best_params_['max_depth'], min_samples_leaf=grid.best_params_['min_samples_leaf'], min_samples_split=grid.best_params_['min_samples_split'])
best_dt.fit(X_train, Y_train)
# 输出最优参数和模型得分
print('Best Parameters:', grid.best_params_)
# 定义新的温度数据
# 输出预测结果
Y_pred = best_model.predict(X_test)
mse = mean_squared_error(Y_test, Y_pred)
print(f"MSE: {mse:.4f}")
print(Y_pred)
X_test = pd.concat([X_train, X_test], axis=0, ignore_index=True)
dot_data = export_graphviz(best_dt, out_file=None,
feature_names=X_train.columns,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("decision_tree")
下面的问题怎么解决
ExecutableNotFound: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH