# 数据集类
class MyDataset(paddle.io.Dataset):
"""
步骤一:继承paddle.io.Dataset类
"""
def __init__(self, data, num_features=10, num_labels=1):
"""
步骤二:实现构造函数,定义数据集大小
data: numpy.Array 1维数组
"""
super(MyDataset, self).__init__()
self.data = data
self.num_features = num_features
self.num_labels = num_labels
x = []
y = []
for i in range(0, len(data) - num_features - num_labels + 1):
x.append(data[i:i+num_features])
y.append(data[i+num_features:i+num_features+num_labels])
print('x',x)
print('y',y)
self.x = np.vstack(x).reshape(-1, self.num_features, 1)
self.y = np.vstack(y)
self.x = np.array(self.x, dtype="float32")
self.y = np.array(self.y, dtype="float32")
self.num_samples = len(x)
def __getitem__(self, index):
"""
步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
"""
data = self.x[index]
label = self.y[index]
return data, label
def __len__(self):
"""
步骤四:实现__len__方法,返回数据集总数目
"""
return self.num_samples
train_dataset = MyDataset(train_data, num_features=10)
test_dataset = MyDataset(test_data, num_features=10)
print(train_dataset.__getitem__(0)[0].shape)
print(train_dataset.__getitem__(0)[1].shape)
请问这段代码最后两行啥意思?getitem那两行
- 写回答
- 好问题 0 提建议
- 追加酬金
- 关注问题
- 邀请回答
-
2条回答 默认 最新
- Jackyin0720 2022-10-22 14:57关注
item就是一个索引,在__getitem__()内使用的时候是随机索引,python的底层会随机分配索引,在函数外面,我们可以指定索引
就像是最下面的代码块A、B部分
A代码def __init__(self, dataset, idxs): self.dataset = dataset self.idxs = [int(i) for i in idxs] def __len__(self): return len(self.idxs) def __getitem__(self, item): image, label = self.dataset[self.idxs[item]] print('\r\n') print(item) print(self.idxs) print(self.idxs[item]) print(len(self.dataset)) for k,v in self.dataset: print(self.dataset[35524]) print(self.dataset[self.idxs[item]]) a = input('请输入a=0 or a=1: \t') a = input('请输入a=0 or a=1: \t') if a ==1: print('执行') return torch.tensor(image), torch.tensor(label)
B代码
class Student(object): def __init__(self, user_dic): self.value = user_dic def __getitem__(self, item): print('__getitem__', item) return self.value[item] def __setitem__(self, key, value): print('__setitem__', key, value) self.value[key] = value def __delitem__(self, key): print('__delitem__', key) del self.value[key] def __len__(self): return len(self.value)
这两句概括起来就是利用自己数据集来训练神经网络pytorch,重写Dataset类
【自我理解,仅供参考】
另外,提供一个参考学习链接,期望对你的有所帮助:https://blog.csdn.net/weixin_44911037/article/details/123202869
【里面以实例清晰讲解说明,利于理解】解决 无用评论 打赏 举报 编辑记录
悬赏问题
- ¥15 onlyoffice编辑完后立即下载,下载的不是最新编辑的文档
- ¥15 求caverdock使用教程
- ¥15 Coze智能助手搭建过程中的问题请教
- ¥15 12864只亮屏 不显示汉字
- ¥20 三极管1000倍放大电路
- ¥15 vscode报错如何解决
- ¥15 前端vue CryptoJS Aes CBC加密后端java解密
- ¥15 python随机森林对两个excel表格读取,shap报错
- ¥15 基于STM32心率血氧监测(OLED显示)相关代码运行成功后烧录成功OLED显示屏不显示的原因是什么
- ¥100 X轴为分离变量(因子变量),如何控制X轴每个分类变量的长度。