问题相关代码,请勿粘贴截图
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import Dataset
import pandas, numpy, random
import matplotlib.pyplot as plt
class MnistDataset(Dataset):
def init(self, csv_file):
self.data_df = pandas.read_csv(csv_file, header=None, low_memory=False)
pass
def __len__(self):
return len(self.data_df)
def __getitem__(self, index):
label = self.data_df.iloc[index, 0]
target = torch.zeros((10))
target[label] = 1.0
image_values = torch.FloatTensor(self.data_df.iloc[index, 1:].values) / 255.0
return label, image_values, target
def plot_image(self, index):
img = self.data_df.iloc[index, 1:].values.reshape(28, 28)
plt.title("label=" + str(self.data_df.iloc[index, 0]))
plt.imshow(img, interpolation='none', cmap='Blues')
pass
pass
mnist_dataset = MnistDataset('./mnist_data/mnist_train.csv')
mnist_dataset.plot_image(7)
TypeError: Image data of dtype object cannot be converted to float
运行结果及报错内容
应该在最底下的地方打印第七张图片,但是没有正常显示,只显示了网格,没显示图片。报错了
我的解答思路和尝试过的方法
尝试了自己解决,但是无法实现,搜索也无法搜索到解决方法,故来询问。