CCK1124 2022-04-28 16:18 采纳率: 0%
浏览 68

卷积神经网络数据预处理

问题遇到的现象和发生背景

labels = df.category.tolist() 这行代码报错

问题相关代码,请勿粘贴截图

import os
import numpy as np
#import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import time
import torchvision
from torch.autograd import Variable
from torchvision import datasets, transforms
import os # os包集成了一些对文件路径和目录进行操作的类
from PIL import Image
from torch import optim
#from models.Category import category

class TrafficData(Dataset):

def __init__(self, path, train=True):
    super(TrafficData, self).__init__()
    df = pd.read_csv(os.path.join(path, 'annotation.csv'))
    labels = df.category.tolist()
    image_files = df.file_name.tolist()
    self.path = path
    del df
    self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((128, 128)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307), (0.3081))
        ])
    if train:
        self.image_files = image_files[:int(len(image_files)*0.9)]
        self.labels = labels[:int(len(image_files)*0.9)]
    else:
        self.image_files = image_files[int(len(image_files)*0.9):]
        self.labels = labels[int(len(image_files)*0.9):]
        
def __getitem__(self, index):
    image = Image.open(os.path.join(self.path + '/images/', self.image_files[index]))
    return self.transform(image), self.labels[index]

def __len__(self):
    return len(self.image_files)

TrafficData1=TrafficData('C:/Users/17204/Desktop/traffic/train', train=True)

TrafficData2==TrafficData('C:/Users/17204/Desktop/traffic/test', train=False)

train_loader = DataLoader(dataset=TrafficData1,
batch_size=32, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=TrafficData2,
batch_size=32, shuffle=True, drop_last=True)

运行结果及报错内容

AttributeError Traceback (most recent call last)
in
46
47

48 TrafficData1=TrafficData('C:/Users/17204/Desktop/traffic/train', train=True)
49
50 TrafficData2==TrafficData('C:/Users/17204/Desktop/traffic/test', train=False)

in init(self, path, train)
22 super(TrafficData, self).init()
23 df = pd.read_csv(os.path.join(path, 'annotation.csv'))

24 labels = df.category.tolist()
25 image_files = df.file_name.tolist()
26 self.path = path

~\anaconda3\envs\pytorch\lib\site-packages\pandas\core\generic.py in getattr(self, name)
5139 if self._info_axis._can_hold_identifiers_and_holds_name(name):
5140 return self[name]
-> 5141 return object.getattribute(self, name)
5142
5143 def setattr(self, name: str, value) -> None:

AttributeError: 'DataFrame' object has no attribute 'category'

我的解答思路和尝试过的方法
我想要达到的结果

怎么解决

  • 写回答

1条回答 默认 最新

  • 不会长胖的斜杠 后端领域新星创作者 2022-04-28 16:27
    关注

    r'C:/Users/17204/Desktop/traffic/train'
    加个r试试

    评论

报告相同问题?

问题事件

  • 创建了问题 4月28日

悬赏问题

  • ¥15 react-diff-viewer组件,如何解决数据量过大卡顿问题
  • ¥20 遥感植被物候指数空间分布图制作
  • ¥15 安装了xlrd库但是import不了…
  • ¥20 Github上传代码没有contribution和activity记录
  • ¥20 SNETCracker
  • ¥15 数学建模大赛交通流量控制
  • ¥15 为什么我安装了open3d但是在调用的时候没有报错但是什么都没有发生呢
  • ¥50 paddleocr最下面一行似乎无法识别
  • ¥15 求某类社交网络数据集
  • ¥15 靶向捕获探针方法/参考文献