bad input shape (60000, 2)

本小白在看机器学习实战时,绘制精度、召回率相对阈值的函数图时报了错。

代码如下:

 from sklearn.datasets import fetch_mldata
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score,recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score

#导入部分
mnist = fetch_mldata('MNIST original')
X,y = mnist["data"],mnist["target"]

#显现部分
some_digit = X[36000]
some_digit_image = some_digit.reshape(28,28)
plt.imshow(some_digit_image,cmap=matplotlib.cm.binary,interpolation="nearest")
plt.axis("off")
#plt.show()

#训练集和测试集
X_train,X_test,y_train,y_test=X[:60000],X[60000:],y[:60000],y[60000:]
shuffle_index = np.random.permutation(60000)
X_train,y_train = X_train[shuffle_index],y_train[shuffle_index]

#二分分类器
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train,y_train_5)
predict1 = sgd_clf.predict([some_digit])
print(predict1)

#实施交叉验证
skfolds = StratifiedKFold(n_splits=3,random_state=42)
for train_index,test_index in skfolds.split(X_train,y_train_5):
    clone_clf = clone(sgd_clf)
    X_train_folds = X_train[train_index]
    y_train_folds = (y_train_5[train_index])
    X_test_fold = X_train[test_index]
    y_test_fold = (y_train_5[test_index])

    clone_clf.fit(X_train_folds,y_train_folds)
    y_pred = clone_clf.predict(X_test_fold)
    n_correct = sum(y_pred == y_test_fold)
    print(n_correct/len(y_pred))

#kfold方法
print(cross_val_score(sgd_clf,X_train,y_train_5,cv=3,scoring="accuracy"))
y_train_pred = cross_val_predict(sgd_clf,X_train,y_train_5,cv=3)
#print(confusion_matrix(y_train_5,y_train_pred))
#print(precision_score(y_train_5,y_pred))           #精度
#print(recall_score(y_train_5,y_train_pred))        #召回率
#print(f1_score(y_train_5,y_pred))                 #fi分数
y_scores = sgd_clf.decision_function([some_digit])
print(y_scores)
#threshold = 0
#y_some_digit_pred = (y_scores>threshold)
#print(y_some_digit_pred)
#提高阈值
threshold = 200000
y_some_digit_pred = (y_scores>threshold)
print(y_some_digit_pred)
#绘制阈值函数图



y_scores = cross_val_predict(sgd_clf,X_train,y_train_5,cv=3,method="decision_function")
precisions, recalls, thresholds = precision_recall_curve(y_train_5,y_scores)

def plot_precison_recall_vs_threshold(precisions,recalls,thresholds):
    plt.plot(thresholds,precisions[:-1],"b--",label="Precision")
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
    plt.xlabel("Threshold")
    plt.legend(loc="upper left")
    plt.ylim([0,1])
plot_precison_recall_vs_threshold(precisions,recalls,thresholds)
plt.show()

报错信息如下:
Traceback (most recent call last):
File "F:/python项目/mnist.py", line 77, in
precisions, recalls, thresholds = precision_recall_curve(y_train_5,y_scores)
File "C:\Users\15701\Anaconda3\lib\site-packages\sklearn\metrics\ranking.py", line 417, in precision_recall_curve
sample_weight=sample_weight)
File "C:\Users\15701\Anaconda3\lib\site-packages\sklearn\metrics\ranking.py", line 304, in _binary_clf_curve
y_score = column_or_1d(y_score)
File "C:\Users\15701\Anaconda3\lib\site-packages\sklearn\utils\validation.py", line 583, in column_or_1d
raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (60000, 2)

不胜感激

查看全部
weixin_43436824
小小白要努力成长啊
2018/11/24 01:21
  • 机器学习
  • sklearn
  • 点赞
  • 收藏
  • 回答
    私信
满意答案
查看全部

2个回复