问题遇到的现象和发生背景
pytorch重写Dataset类,读取的csv数据类型为str,怎样转换为(1,48,48)的矩阵?
问题相关代码,请勿粘贴截图
class MyDataset(data.Dataset):
def __init__(self, root,transforms=None):
super(MyDataset, self).__init__()
self.root = root
self.transforms = transforms
df_label = pd.read_csv(root, header=None, usecols=[0])
df_path = pd.read_csv(root, header=None, usecols=[1])
self.label = np.array(df_label)[1:, 0]
self.path = np.array(df_path)[1:, 0]
def __getitem__(self, item):
img=self.path[item]
target=self.label[item]
print(type(img), type(img[0]),img)
#img=img.reshape(48,48)
#img = Image.fromarray(img.numpy(), mode='L')
if self.transforms is not None:
img = self.transforms(img)
return img, target
def __len__(self):
return self.path.shape[0]
运行结果及报错内容
<class 'str'> <class 'str'> 251 251 251 253 246 217 186 172 162 139 144 113 92 164 209 225 232 234 237 239 237 234 231 233 233 230 228 225 212 203 182 164 148 136 119 108 110 116 129 151 149 129 103 109 99 88 93 87 251 251 251 253 223 193 166 161 136 141 123 80 150 200 219 228 231 236 238 236 237
我的解答思路和尝试过的方法
我想要达到的结果