- 在运行mixmatch程序的时候,用torchvision.datasets载入CIFAT10的时候出现AttributeError: 'CIFAR10' object has no attribute 'targets',错误
还有一个问题就是:由于用torchvision下载太慢,我先把数据集下下来了,然后放在了data目录下面,这个对结果会有影响嘛?
希望大家可以给点建议和意见,谢谢。
加载数据集的代码如下:
def get_cifar10(root, n_labeled,
transform_train=None, transform_val=None,
download=True):
base_dataset = torchvision.datasets.CIFAR10(root, train=True, target_transform=True, download=download,)
train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, int(n_labeled/10))
train_labeled_dataset = CIFAR10_labeled(root, train_labeled_idxs, train=True, transform=transform_train)
train_unlabeled_dataset = CIFAR10_unlabeled(root, train_unlabeled_idxs, train=True, transform=TransformTwice(transform_train))
val_dataset = CIFAR10_labeled(root, val_idxs, train=True, transform=transform_val, download=True)
test_dataset = CIFAR10_labeled(root, train=False, transform=transform_val, download=True)
print (f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}")
return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset
def train_val_split(labels, n_labeled_per_class):
labels = np.array(labels)
train_labeled_idxs = []
train_unlabeled_idxs = []
val_idxs = []
for i in range(10):
idxs = np.where(labels == i)[0]
np.random.shuffle(idxs)
train_labeled_idxs.extend(idxs[:n_labeled_per_class])
train_unlabeled_idxs.extend(idxs[n_labeled_per_class:-500])
val_idxs.extend(idxs[-500:])
np.random.shuffle(train_labeled_idxs)
np.random.shuffle(train_unlabeled_idxs)
np.random.shuffle(val_idxs)
return train_labeled_idxs, train_unlabeled_idxs, val_idxs
错误信息如下
(base) D:\CSStudy\PycharmProject\MixMatch-pytorch-master>python train.py --gpu 0 --n-labeled 250 --out cifar10@250
==> Preparing cifar10
Using downloaded and verified file: ./data\cifar-10-python.tar.gz
Traceback (most recent call last):
File "train.py", line 431, in
main()
File "train.py", line 88, in main
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, transform_train=transform_train, transf
orm_val=transform_val)
File "D:\CSStudy\PycharmProject\MixMatch-pytorch-master\dataset\cifar10.py", line 21, in get_cifar10
train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, int(n_labeled/10))
AttributeError: 'CIFAR10' object has no attribute 'targets'