问题遇到的现象和发生背景
我的毕设是CV相关的课题标签分类识别,在网上找了套相关的代码,有bug不知如何修改了
问题相关代码,请勿粘贴截图
def main(op):
if op == 'train':
train_df = pd.read_csv('../data/train.csv')
print(train_df['Sports'].value_counts())
train_df['filename'] = train_df['filename'].apply(lambda x: '../data/train/{0}'.format(x))
if mode == 1:
n_splits = 5
x = train_df['filename'].values
y = train_df['label'].values
skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)
for fold_idx, (train_idx, val_idx) in enumerate(skf.split(x, y)):
train(train_df.iloc[train_idx], train_df.iloc[val_idx], fold_idx)
运行结果及报错内容
数据集有三个属性如图
而代码中读y值处仅有一个label,运行无法读取
我的解答思路和尝试过的方法
尝试将label改为Sports单属性,报错输入中有NaN
我想要达到的结果
代码成功运行