weixin_47872887 2022-10-22 14:50 采纳率: 52.5%
浏览 47
已结题

请问这段代码最后两行啥意思?getitem那两行


# 数据集类
class MyDataset(paddle.io.Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, data, num_features=10, num_labels=1):
        """
        步骤二:实现构造函数,定义数据集大小

        data: numpy.Array 1维数组
        """
        super(MyDataset, self).__init__()
        self.data = data
        self.num_features = num_features
        self.num_labels = num_labels

        x = [] 
        y = [] 
        for i in range(0, len(data) - num_features - num_labels + 1):
            x.append(data[i:i+num_features])
            y.append(data[i+num_features:i+num_features+num_labels])
        print('x',x)
        print('y',y)
        self.x = np.vstack(x).reshape(-1, self.num_features, 1)
        self.y = np.vstack(y)
        self.x = np.array(self.x, dtype="float32")
        self.y = np.array(self.y, dtype="float32")

        self.num_samples = len(x)

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = self.x[index]
        label = self.y[index]

        return data, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return self.num_samples

train_dataset = MyDataset(train_data, num_features=10)
test_dataset = MyDataset(test_data, num_features=10)
print(train_dataset.__getitem__(0)[0].shape)
print(train_dataset.__getitem__(0)[1].shape)
  • 写回答

2条回答 默认 最新

  • Jackyin0720 2022-10-22 14:57
    关注

    item就是一个索引,在__getitem__()内使用的时候是随机索引,python的底层会随机分配索引,在函数外面,我们可以指定索引
    就像是最下面的代码块A、B部分
    A代码

     def __init__(self, dataset, idxs):
            self.dataset = dataset
            self.idxs = [int(i) for i in idxs]
    
        def __len__(self):
            return len(self.idxs)
    
        def __getitem__(self, item):
            image, label = self.dataset[self.idxs[item]]
            print('\r\n')
            print(item)
            print(self.idxs)
            print(self.idxs[item])
            print(len(self.dataset))
            for k,v in self.dataset:
                print(self.dataset[35524])
                print(self.dataset[self.idxs[item]])
                a = input('请输入a=0 or a=1: \t')
            a = input('请输入a=0 or a=1: \t')
            if a ==1:
                print('执行')
            return torch.tensor(image), torch.tensor(label)
    
    
    

    B代码

    class Student(object):
    
        def __init__(self, user_dic):
            self.value = user_dic
    
        def __getitem__(self, item):
            print('__getitem__', item)
            return self.value[item]
    
        def __setitem__(self, key, value):
            print('__setitem__', key, value)
            self.value[key] = value
    
        def __delitem__(self, key):
            print('__delitem__', key)
            del self.value[key]
    
        def __len__(self):
            return len(self.value)
    
    
    

    这两句概括起来就是利用自己数据集来训练神经网络pytorch,重写Dataset类
    【自我理解,仅供参考】
    另外,提供一个参考学习链接,期望对你的有所帮助:https://blog.csdn.net/weixin_44911037/article/details/123202869
    【里面以实例清晰讲解说明,利于理解】

    评论 编辑记录

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 10月24日
  • 创建了问题 10月22日

悬赏问题

  • ¥15 onlyoffice编辑完后立即下载,下载的不是最新编辑的文档
  • ¥15 求caverdock使用教程
  • ¥15 Coze智能助手搭建过程中的问题请教
  • ¥15 12864只亮屏 不显示汉字
  • ¥20 三极管1000倍放大电路
  • ¥15 vscode报错如何解决
  • ¥15 前端vue CryptoJS Aes CBC加密后端java解密
  • ¥15 python随机森林对两个excel表格读取,shap报错
  • ¥15 基于STM32心率血氧监测(OLED显示)相关代码运行成功后烧录成功OLED显示屏不显示的原因是什么
  • ¥100 X轴为分离变量(因子变量),如何控制X轴每个分类变量的长度。