小小白要努力成长啊 2018-11-24 01:21 采纳率: 80%
浏览 3674
已采纳

bad input shape (60000, 2)

本小白在看机器学习实战时,绘制精度、召回率相对阈值的函数图时报了错。

代码如下:

 from sklearn.datasets import fetch_mldata
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score,recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score

#导入部分
mnist = fetch_mldata('MNIST original')
X,y = mnist["data"],mnist["target"]

#显现部分
some_digit = X[36000]
some_digit_image = some_digit.reshape(28,28)
plt.imshow(some_digit_image,cmap=matplotlib.cm.binary,interpolation="nearest")
plt.axis("off")
#plt.show()

#训练集和测试集
X_train,X_test,y_train,y_test=X[:60000],X[60000:],y[:60000],y[60000:]
shuffle_index = np.random.permutation(60000)
X_train,y_train = X_train[shuffle_index],y_train[shuffle_index]

#二分分类器
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train,y_train_5)
predict1 = sgd_clf.predict([some_digit])
print(predict1)

#实施交叉验证
skfolds = StratifiedKFold(n_splits=3,random_state=42)
for train_index,test_index in skfolds.split(X_train,y_train_5):
    clone_clf = clone(sgd_clf)
    X_train_folds = X_train[train_index]
    y_train_folds = (y_train_5[train_index])
    X_test_fold = X_train[test_index]
    y_test_fold = (y_train_5[test_index])

    clone_clf.fit(X_train_folds,y_train_folds)
    y_pred = clone_clf.predict(X_test_fold)
    n_correct = sum(y_pred == y_test_fold)
    print(n_correct/len(y_pred))

#kfold方法
print(cross_val_score(sgd_clf,X_train,y_train_5,cv=3,scoring="accuracy"))
y_train_pred = cross_val_predict(sgd_clf,X_train,y_train_5,cv=3)
#print(confusion_matrix(y_train_5,y_train_pred))
#print(precision_score(y_train_5,y_pred))           #精度
#print(recall_score(y_train_5,y_train_pred))        #召回率
#print(f1_score(y_train_5,y_pred))                 #fi分数
y_scores = sgd_clf.decision_function([some_digit])
print(y_scores)
#threshold = 0
#y_some_digit_pred = (y_scores>threshold)
#print(y_some_digit_pred)
#提高阈值
threshold = 200000
y_some_digit_pred = (y_scores>threshold)
print(y_some_digit_pred)
#绘制阈值函数图



y_scores = cross_val_predict(sgd_clf,X_train,y_train_5,cv=3,method="decision_function")
precisions, recalls, thresholds = precision_recall_curve(y_train_5,y_scores)

def plot_precison_recall_vs_threshold(precisions,recalls,thresholds):
    plt.plot(thresholds,precisions[:-1],"b--",label="Precision")
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
    plt.xlabel("Threshold")
    plt.legend(loc="upper left")
    plt.ylim([0,1])
plot_precison_recall_vs_threshold(precisions,recalls,thresholds)
plt.show()

报错信息如下:
Traceback (most recent call last):
File "F:/python项目/mnist.py", line 77, in
precisions, recalls, thresholds = precision_recall_curve(y_train_5,y_scores)
File "C:\Users\15701\Anaconda3\lib\site-packages\sklearn\metrics\ranking.py", line 417, in precision_recall_curve
sample_weight=sample_weight)
File "C:\Users\15701\Anaconda3\lib\site-packages\sklearn\metrics\ranking.py", line 304, in _binary_clf_curve
y_score = column_or_1d(y_score)
File "C:\Users\15701\Anaconda3\lib\site-packages\sklearn\utils\validation.py", line 583, in column_or_1d
raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (60000, 2)

不胜感激

  • 写回答

2条回答

  • 一叶扁舟。。。 2019-04-15 18:51
    关注

    print(y_train_5.shape) 结果为(60000,) print(y_scores.shape)结果为(60000, 2),
    print(y_scores)结果为[[ 0. -229600.48544944]
    [ 0. -792845.57622101]
    [ 0. -529311.13077603]
    ...,
    [ 0. -806955.80116218]
    [ 0. -199716.61091746]
    [ 0. -499524.22190059]]
    解决方案为: y_scores=y_score[:,1]

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

悬赏问题

  • ¥20 ML307A在使用AT命令连接EMQX平台的MQTT时被拒绝
  • ¥20 腾讯企业邮箱邮件可以恢复么
  • ¥15 有人知道怎么将自己的迁移策略布到edgecloudsim上使用吗?
  • ¥15 错误 LNK2001 无法解析的外部符号
  • ¥50 安装pyaudiokits失败
  • ¥15 计组这些题应该咋做呀
  • ¥60 更换迈创SOL6M4AE卡的时候,驱动要重新装才能使用,怎么解决?
  • ¥15 让node服务器有自动加载文件的功能
  • ¥15 jmeter脚本回放有的是对的有的是错的
  • ¥15 r语言蛋白组学相关问题