以下是我的代码,该导入的库没有任何问题,代码检查了多遍依然正确
def plot_learning_curve(estimator,title, X, y,
x=None, #选择子图
ylim=None, #设置纵坐标的取值范围
cv=None, #交叉验证
n_jobs=None #设定索要使用的线程
):
train_sizes,train_scores,test_scores = learning_curve(estimator,
X, y,
shuffle=True
,cv=cv
,random_state=420
,n_jobs=n_jobs)
ax.set_title(title)#设置标题
if ax == None:
ax = plt.gca()
else:
ax = plt.figure()
if ylim is not None:
ax.set_ylim(*ylim)
ax.set_xlabel("Training examples")
ax.set_ylabel("Score")
ax.grid() #绘制网格,不是必须
ax.plot(train_sizes,np.mean(train_scores, axis=1),'o-',color="r",label="Training score")
ax.plot(train_sizes,np.mean(test_scores, axis=1),'o-',color="g",label="Test score")
ax.legend(loc="best")
return ax
cv = KFold(n_splits=5, shuffle = True, random_state=42)#交叉验证模式
plot_learning_curve(XGBR(n_estimators=100,random_state=420),"XGB",Xtrain,Ytrain,
ax=None,cv=cv)
plt.show()
报错如下:
TypeError Traceback (most recent call last)
in
----> 1 plot_learning_curve(XGBR(n_estimators=100,random_state=420),"XGB",Xtrain,Ytrain,
2 ax=None,cv=cv)
3 plt.show()
TypeError:
```python
``` plot_learning_curve() got an unexpected keyword argument 'ax'