帅气的阳阳子 2019-11-21 16:10 采纳率: 25%
浏览 4074
已采纳

运行mixmatch源码CIFAR10数据集时报错AttributeError: 'CIFAR10' object has no attribute 'targets',是怎么回事?

  1. 在运行mixmatch程序的时候,用torchvision.datasets载入CIFAT10的时候出现AttributeError: 'CIFAR10' object has no attribute 'targets',错误

还有一个问题就是:由于用torchvision下载太慢,我先把数据集下下来了,然后放在了data目录下面,这个对结果会有影响嘛?
希望大家可以给点建议和意见,谢谢。

加载数据集的代码如下:

def get_cifar10(root, n_labeled,
                 transform_train=None, transform_val=None,
                 download=True):

    base_dataset = torchvision.datasets.CIFAR10(root, train=True, target_transform=True, download=download,)
    train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, int(n_labeled/10))

    train_labeled_dataset = CIFAR10_labeled(root, train_labeled_idxs, train=True, transform=transform_train)
    train_unlabeled_dataset = CIFAR10_unlabeled(root, train_unlabeled_idxs, train=True, transform=TransformTwice(transform_train))
    val_dataset = CIFAR10_labeled(root, val_idxs, train=True, transform=transform_val, download=True)
    test_dataset = CIFAR10_labeled(root, train=False, transform=transform_val, download=True)

    print (f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}")
    return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset
def train_val_split(labels, n_labeled_per_class):
    labels = np.array(labels)
    train_labeled_idxs = []
    train_unlabeled_idxs = []
    val_idxs = []

    for i in range(10):
        idxs = np.where(labels == i)[0]
        np.random.shuffle(idxs)
        train_labeled_idxs.extend(idxs[:n_labeled_per_class])
        train_unlabeled_idxs.extend(idxs[n_labeled_per_class:-500])
        val_idxs.extend(idxs[-500:])
    np.random.shuffle(train_labeled_idxs)
    np.random.shuffle(train_unlabeled_idxs)
    np.random.shuffle(val_idxs)

    return train_labeled_idxs, train_unlabeled_idxs, val_idxs

错误信息如下
(base) D:\CSStudy\PycharmProject\MixMatch-pytorch-master>python train.py --gpu 0 --n-labeled 250 --out cifar10@250
==> Preparing cifar10
Using downloaded and verified file: ./data\cifar-10-python.tar.gz
Traceback (most recent call last):
File "train.py", line 431, in
main()
File "train.py", line 88, in main
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, transform_train=transform_train, transf
orm_val=transform_val)
File "D:\CSStudy\PycharmProject\MixMatch-pytorch-master\dataset\cifar10.py", line 21, in get_cifar10
train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, int(n_labeled/10))
AttributeError: 'CIFAR10' object has no attribute 'targets'

  • 写回答

1条回答 默认 最新

  • 南梦倾寒 2019-11-22 15:32
    关注

    应该是torch版本的问题,不同torch对应的后缀不同,我正在尝试修改这个问题,推荐查一下torch英文手册

    我之前的问题是torch版本问题过低,我需要将data变成train_data
    把targets变成train_labels,就好了,希望对你有帮助~

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

悬赏问题

  • ¥15 metadata提取的PDF元数据,如何转换为一个Excel
  • ¥15 关于arduino编程toCharArray()函数的使用
  • ¥100 vc++混合CEF采用CLR方式编译报错
  • ¥15 coze 的插件输入飞书多维表格 app_token 后一直显示错误,如何解决?
  • ¥15 vite+vue3+plyr播放本地public文件夹下视频无法加载
  • ¥15 c#逐行读取txt文本,但是每一行里面数据之间空格数量不同
  • ¥50 如何openEuler 22.03上安装配置drbd
  • ¥20 ING91680C BLE5.3 芯片怎么实现串口收发数据
  • ¥15 无线连接树莓派,无法执行update,如何解决?(相关搜索:软件下载)
  • ¥15 Windows11, backspace, enter, space键失灵