问题遇到的现象和发生背景
labels = df.category.tolist() 这行代码报错
问题相关代码,请勿粘贴截图
import os
import numpy as np
#import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import time
import torchvision
from torch.autograd import Variable
from torchvision import datasets, transforms
import os # os包集成了一些对文件路径和目录进行操作的类
from PIL import Image
from torch import optim
#from models.Category import category
class TrafficData(Dataset):
def __init__(self, path, train=True):
super(TrafficData, self).__init__()
df = pd.read_csv(os.path.join(path, 'annotation.csv'))
labels = df.category.tolist()
image_files = df.file_name.tolist()
self.path = path
del df
self.transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((128, 128)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307), (0.3081))
])
if train:
self.image_files = image_files[:int(len(image_files)*0.9)]
self.labels = labels[:int(len(image_files)*0.9)]
else:
self.image_files = image_files[int(len(image_files)*0.9):]
self.labels = labels[int(len(image_files)*0.9):]
def __getitem__(self, index):
image = Image.open(os.path.join(self.path + '/images/', self.image_files[index]))
return self.transform(image), self.labels[index]
def __len__(self):
return len(self.image_files)
TrafficData1=TrafficData('C:/Users/17204/Desktop/traffic/train', train=True)
TrafficData2==TrafficData('C:/Users/17204/Desktop/traffic/test', train=False)
train_loader = DataLoader(dataset=TrafficData1,
batch_size=32, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=TrafficData2,
batch_size=32, shuffle=True, drop_last=True)
运行结果及报错内容
AttributeError Traceback (most recent call last)
in
46
47
48 TrafficData1=TrafficData('C:/Users/17204/Desktop/traffic/train', train=True)
49
50 TrafficData2==TrafficData('C:/Users/17204/Desktop/traffic/test', train=False)
in init(self, path, train)
22 super(TrafficData, self).init()
23 df = pd.read_csv(os.path.join(path, 'annotation.csv'))
24 labels = df.category.tolist()
25 image_files = df.file_name.tolist()
26 self.path = path
~\anaconda3\envs\pytorch\lib\site-packages\pandas\core\generic.py in getattr(self, name)
5139 if self._info_axis._can_hold_identifiers_and_holds_name(name):
5140 return self[name]
-> 5141 return object.getattribute(self, name)
5142
5143 def setattr(self, name: str, value) -> None:
AttributeError: 'DataFrame' object has no attribute 'category'
我的解答思路和尝试过的方法
我想要达到的结果
怎么解决