data_folds_file="data_folds.pkl"
data_folds=load_pickle(data_folds_file)
train=data_folds['train']
dev=data_folds['dev']
test=data_folds['test']
len(train)+len(dev)+len(test)
train_set = HumorDataset(train)
dev_set = HumorDataset(dev)
test_set = HumorDataset(test)
batch=10
train_dataloader = DataLoader(train_set, batch_size=batch, shuffle=True)
dev_dataloader = DataLoader(dev_set, batch_size=batch, shuffle=True)
test_dataloader = DataLoader(test_set, batch_size=1, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for batch_idx, batch in enumerate(train_dataloader, 0): #枚举类型
x_p,x_c,y=map(lambda x: x.to(device), batch)
print("*********")
print("punchline shape: ",x_p.shape)
print("context shape: ",x_c.shape)
print("humor labels: ",y)
if batch_idx==5:
break
for循环代码报错
IndexError: only integers, slices (:
), ellipsis (...
), numpy.newaxis (None
) and integer or boolean arrays are valid indices
batch和batch_idx都是int类型,不知道是不是enumerate的问题
问一下这种情况怎么修改?