train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = torch.utils.data.random_split(train_dataset, [train_len, len(train_dataset) - train_len])#0.95给训练 0.05给验证
print('type(data):', type(train_dataset))
print('type(data):', type(sub_valid_))
print(list(sub_valid_))
for content in sub_valid_:
print('data:', content)
train_dataset是一个一共三列的数据集
train_dataset: type(data): <class 'torchtext.data.datasets_utils.RawTextIterableDataset'>
sub_train:type(data): <class 'torch.utils.data.dataset.Subset'>
报错如上
我不知道原因出在哪