#码可以跑起来,但是不出图,劳烦各位看看哪里有问题
def multi_models_roc(names, sampling_methods, colors, X_test, y_test, save=True, dpin=100):
plt.figure(figsize=(20, 20), dpi=dpin)
for (name, method, colorname) in zip(names=['Logistic Regression',
'Decision Tree',
'SVM',
'Random Forest',
'XGBoost'
],
sampling_methods =[clf_lr,
clf_tr,
clf_svm,
clf_forest,
clf_xgbc
],
colors=['crimson',
'orange',
'gold',
'mediumseagreen',
'steelblue'
]
train_roc_graph = multi_models_roc(names, sampling_methods, colors, X_train, y_train, save = True)
train_roc_graph.savefig('ROC_Train_all.png')
):
y_test_preds = method.predict(X_test)
y_test_predprob = method.predict_proba(X_test)[:,1]
fpr, tpr, thresholds = roc_curve(y_test, y_test_predprob, pos_label=1)
plt.plot(fpr, tpr, lw=5, label='{} (AUC={:.3f})'.format(name, auc(fpr, tpr)),color = colorname)
plt.plot([0, 1], [0, 1], '--', lw=5, color = 'grey')
plt.axis('square')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.xlabel('False Positive Rate',fontsize=20)
plt.ylabel('True Positive Rate',fontsize=20)
plt.title('ROC Curve',fontsize=25)
plt.legend(loc='lower right',fontsize=20)
plt.show()
if save:
plt.savefig('multi_models_roc.png')
return plt