在创建PyTorch的DataLoader时,传递给它的数据集是空的,具体是尝试加载CUB-200数据集时,我的get_cub200_dataloaders函数是:
def get_cub200_dataloaders(batch_size=128, num_workers=8, is_instance=False):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
test_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
if is_instance:
train_set = CUB2011Classification_Instance(root='./data',
download=True,
train=True,
transform=train_transform)
n_data = len(train_set)
else:
train_set = CUB2011Classification(root='./data', train=True, download=True, transform=train_transform)
train_loader = DataLoader(train_set,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True)
test_set = CUB2011Classification(root='./data', train=False, download=True, transform=test_transform)
test_loader = DataLoader(test_set,
batch_size=int(batch_size / 2),
shuffle=False,
num_workers=int(num_workers / 2))
if is_instance:
return train_loader, test_loader, n_data
else:
return train_loader, test_loader
下载的数据集结构是:
运行代码后报错:
Traceback (most recent call last):
File "C:\Users\59225\Desktop\我的文档\工作学习\实验\SDD-CVPR2024-main\train_origin.py", line 211, in <module>
main(cfg, args.resume, args.opts)
File "C:\Users\59225\Desktop\我的文档\工作学习\实验\SDD-CVPR2024-main\train_origin.py", line 118, in main
train_loader, val_loader, num_data, num_classes = get_dataset(cfg)
File "C:\Users\59225\Desktop\我的文档\工作学习\实验\SDD-CVPR2024-main\mdistiller\dataset\__init__.py", line 25, in get_dataset
train_loader, val_loader, num_data = get_cub200_dataloaders(
File "C:\Users\59225\Desktop\我的文档\工作学习\实验\SDD-CVPR2024-main\mdistiller\dataset\cub200.py", line 175, in get_cub200_dataloaders
train_loader = DataLoader(train_set,
File "D:\Anaconda\envs\pyTorchEnv\lib\site-packages\torch\utils\data\dataloader.py", line 277, in __init__
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
File "D:\Anaconda\envs\pyTorchEnv\lib\site-packages\torch\utils\data\sampler.py", line 97, in __init__
raise ValueError("num_samples should be a positive integer "
ValueError: num_samples should be a positive integer value, but got num_samples=0
请问有没有hxd可以帮帮忙?