帅气的阳阳子 2019-11-21 16:10 采纳率: 25%
浏览 4062
已采纳

运行mixmatch源码CIFAR10数据集时报错AttributeError: 'CIFAR10' object has no attribute 'targets',是怎么回事?

  1. 在运行mixmatch程序的时候,用torchvision.datasets载入CIFAT10的时候出现AttributeError: 'CIFAR10' object has no attribute 'targets',错误

还有一个问题就是:由于用torchvision下载太慢,我先把数据集下下来了,然后放在了data目录下面,这个对结果会有影响嘛?
希望大家可以给点建议和意见,谢谢。

加载数据集的代码如下:

def get_cifar10(root, n_labeled,
                 transform_train=None, transform_val=None,
                 download=True):

    base_dataset = torchvision.datasets.CIFAR10(root, train=True, target_transform=True, download=download,)
    train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, int(n_labeled/10))

    train_labeled_dataset = CIFAR10_labeled(root, train_labeled_idxs, train=True, transform=transform_train)
    train_unlabeled_dataset = CIFAR10_unlabeled(root, train_unlabeled_idxs, train=True, transform=TransformTwice(transform_train))
    val_dataset = CIFAR10_labeled(root, val_idxs, train=True, transform=transform_val, download=True)
    test_dataset = CIFAR10_labeled(root, train=False, transform=transform_val, download=True)

    print (f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}")
    return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset
def train_val_split(labels, n_labeled_per_class):
    labels = np.array(labels)
    train_labeled_idxs = []
    train_unlabeled_idxs = []
    val_idxs = []

    for i in range(10):
        idxs = np.where(labels == i)[0]
        np.random.shuffle(idxs)
        train_labeled_idxs.extend(idxs[:n_labeled_per_class])
        train_unlabeled_idxs.extend(idxs[n_labeled_per_class:-500])
        val_idxs.extend(idxs[-500:])
    np.random.shuffle(train_labeled_idxs)
    np.random.shuffle(train_unlabeled_idxs)
    np.random.shuffle(val_idxs)

    return train_labeled_idxs, train_unlabeled_idxs, val_idxs

错误信息如下
(base) D:\CSStudy\PycharmProject\MixMatch-pytorch-master>python train.py --gpu 0 --n-labeled 250 --out cifar10@250
==> Preparing cifar10
Using downloaded and verified file: ./data\cifar-10-python.tar.gz
Traceback (most recent call last):
File "train.py", line 431, in
main()
File "train.py", line 88, in main
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, transform_train=transform_train, transf
orm_val=transform_val)
File "D:\CSStudy\PycharmProject\MixMatch-pytorch-master\dataset\cifar10.py", line 21, in get_cifar10
train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, int(n_labeled/10))
AttributeError: 'CIFAR10' object has no attribute 'targets'

  • 写回答

1条回答 默认 最新

  • 南梦倾寒 2019-11-22 15:32
    关注

    应该是torch版本的问题,不同torch对应的后缀不同,我正在尝试修改这个问题,推荐查一下torch英文手册

    我之前的问题是torch版本问题过低,我需要将data变成train_data
    把targets变成train_labels,就好了,希望对你有帮助~

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

悬赏问题

  • ¥15 如何在scanpy上做差异基因和通路富集?
  • ¥20 关于#硬件工程#的问题,请各位专家解答!
  • ¥15 关于#matlab#的问题:期望的系统闭环传递函数为G(s)=wn^2/s^2+2¢wn+wn^2阻尼系数¢=0.707,使系统具有较小的超调量
  • ¥15 FLUENT如何实现在堆积颗粒的上表面加载高斯热源
  • ¥30 截图中的mathematics程序转换成matlab
  • ¥15 动力学代码报错,维度不匹配
  • ¥15 Power query添加列问题
  • ¥50 Kubernetes&Fission&Eleasticsearch
  • ¥15 報錯:Person is not mapped,如何解決?
  • ¥15 c++头文件不能识别CDialog