本小白在看机器学习实战时,绘制精度、召回率相对阈值的函数图时报了错。
代码如下:
from sklearn.datasets import fetch_mldata
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score,recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
#导入部分
mnist = fetch_mldata('MNIST original')
X,y = mnist["data"],mnist["target"]
#显现部分
some_digit = X[36000]
some_digit_image = some_digit.reshape(28,28)
plt.imshow(some_digit_image,cmap=matplotlib.cm.binary,interpolation="nearest")
plt.axis("off")
#plt.show()
#训练集和测试集
X_train,X_test,y_train,y_test=X[:60000],X[60000:],y[:60000],y[60000:]
shuffle_index = np.random.permutation(60000)
X_train,y_train = X_train[shuffle_index],y_train[shuffle_index]
#二分分类器
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train,y_train_5)
predict1 = sgd_clf.predict([some_digit])
print(predict1)
#实施交叉验证
skfolds = StratifiedKFold(n_splits=3,random_state=42)
for train_index,test_index in skfolds.split(X_train,y_train_5):
clone_clf = clone(sgd_clf)
X_train_folds = X_train[train_index]
y_train_folds = (y_train_5[train_index])
X_test_fold = X_train[test_index]
y_test_fold = (y_train_5[test_index])
clone_clf.fit(X_train_folds,y_train_folds)
y_pred = clone_clf.predict(X_test_fold)
n_correct = sum(y_pred == y_test_fold)
print(n_correct/len(y_pred))
#kfold方法
print(cross_val_score(sgd_clf,X_train,y_train_5,cv=3,scoring="accuracy"))
y_train_pred = cross_val_predict(sgd_clf,X_train,y_train_5,cv=3)
#print(confusion_matrix(y_train_5,y_train_pred))
#print(precision_score(y_train_5,y_pred)) #精度
#print(recall_score(y_train_5,y_train_pred)) #召回率
#print(f1_score(y_train_5,y_pred)) #fi分数
y_scores = sgd_clf.decision_function([some_digit])
print(y_scores)
#threshold = 0
#y_some_digit_pred = (y_scores>threshold)
#print(y_some_digit_pred)
#提高阈值
threshold = 200000
y_some_digit_pred = (y_scores>threshold)
print(y_some_digit_pred)
#绘制阈值函数图
y_scores = cross_val_predict(sgd_clf,X_train,y_train_5,cv=3,method="decision_function")
precisions, recalls, thresholds = precision_recall_curve(y_train_5,y_scores)
def plot_precison_recall_vs_threshold(precisions,recalls,thresholds):
plt.plot(thresholds,precisions[:-1],"b--",label="Precision")
plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
plt.xlabel("Threshold")
plt.legend(loc="upper left")
plt.ylim([0,1])
plot_precison_recall_vs_threshold(precisions,recalls,thresholds)
plt.show()
报错信息如下:
Traceback (most recent call last):
File "F:/python项目/mnist.py", line 77, in
precisions, recalls, thresholds = precision_recall_curve(y_train_5,y_scores)
File "C:\Users\15701\Anaconda3\lib\site-packages\sklearn\metrics\ranking.py", line 417, in precision_recall_curve
sample_weight=sample_weight)
File "C:\Users\15701\Anaconda3\lib\site-packages\sklearn\metrics\ranking.py", line 304, in _binary_clf_curve
y_score = column_or_1d(y_score)
File "C:\Users\15701\Anaconda3\lib\site-packages\sklearn\utils\validation.py", line 583, in column_or_1d
raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (60000, 2)
不胜感激