特立独行的蜗牛 2023-04-02 10:41 采纳率: 53.3%
浏览 16

二维输入显示维度问题显示错误

stack expects each tensor to be equal size, but got [1, 330, 34410] at entry 0 and [1, 330, 4010] at entry 1
使用赵志斌的stft的代码,用作二维输出,报错显示张量不匹配,我该如何修改这段代码


```python
source1_dataset = stft.CWRUSTFT(path,s1_name, normlizetype).data_preprare()
source1_loader = torch.utils.data.DataLoader(source1_dataset, batch_size=64, shuffle=True,
                                                 num_workers=0,
                                                 pin_memory=(True))
source1_iter = iter(source1_loader)
source_data, source_label = next(iter(source1_loader))
print(source_label,source_data.shape())

```

  • 写回答

2条回答 默认 最新

  • Leodong. 2023-04-02 10:51
    关注

    该回答通过自己思路及引用到各个渠道搜索综合及思考,得到内容具体如下:
    根据错误提示,可以看出两个张量的第三个维度长度不同,分别是34410和4010。这说明在数据处理中出现了维度不匹配的问题。

    根据你提供的代码,可以看出这是在数据加载时出现的问题。可能是由于数据集中的某些样本的长度不同,导致在数据加载时无法将它们作为一个batch处理,从而出现维度不匹配的错误。要解决这个问题,可以尝试以下几种方法:

    1、手动调整数据集中的所有样本的长度,使它们的第三个维度长度相同。这可能需要一些数据预处理的工作。
    
    2、调整batch_size的大小,使其与数据集中最短的样本长度相同。这样可以保证每个batch中所有样本的维度都相同。
    
    3、在数据加载时动态调整每个batch中样本的长度,使其均匀分布。这样可以保证每个batch中所有样本的维度都相同,但需要对数据加载代码进行一定的修改。
    

    具体代码及思路如下:

    在数据加载时动态调整每个batch中样本的长度,使其均匀分布的实现方式如下:
    
    import torch.nn.functional as F
    
    # 获取数据集中所有样本的长度
    lengths = [len(data) for data, label in source1_dataset]
    
    # 计算所有样本长度的中位数
    median_length = int(np.median(lengths))
    
    # 定义数据加载器
    source1_loader = torch.utils.data.DataLoader(source1_dataset, batch_size=64, shuffle=True,
                                                  num_workers=0, pin_memory=True,
                                                  collate_fn=lambda batch: collate_fn(batch, median_length))
    
    def collate_fn(batch, target_length):
        # 将batch中的数据按长度从长到短排序
        sorted_batch = sorted(batch, key=lambda x: len(x[0]), reverse=True)
        
        # 获取batch中所有样本的长度
        lengths = [len(data) for data, label in sorted_batch]
        
        # 将所有样本的长度调整为中位数
        padded_data = [F.pad(torch.Tensor(data), (0, 0, 0, target_length - len(data)), 'constant', 0) for data, label in sorted_batch]
        
        # 将数据和标签打包成元组
        padded_batch = [(padded_data[i], label) for i, (data, label) in enumerate(sorted_batch)]
        
        return padded_batch
    

    这里使用了collate_fn参数来定义数据加载器的数据合并方式。在collate_fn函数中,首先将batch中的数据按长度从长到短排序,然后将所有样本的长度调整为中位数,最后将数据和标签打包成元组返回。这样可以保证每个batch中所有样本的维度都相同,且长度均匀分布。
    将修改后的数据加载器应用到代码中,可以按以下方式进行:

    source1_dataset = stft.CWRUSTFT(path,s1_name, normlizetype).data_preprare()
    
    # 定义数据加载器
    source1_loader = torch.utils.data.DataLoader(source1_dataset, batch_size=64, shuffle=True,
                                                  num_workers=0, pin_memory=True,
                                                  collate_fn=lambda batch: collate_fn(batch, median_length))
    
    source1_iter = iter(source1_loader)
    source_data, source_label = next(iter(source1_loader))
    print(source_label, source_data.shape)
    

    需要注意的是,需要先定义collate_fn函数,再将其应用到数据加载器中。另外,median_length变量需要在定义source1_loader之前计算出来。


    如果以上回答对您有所帮助,点击一下采纳该答案~谢谢

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 4月2日

悬赏问题

  • ¥15 黄永刚的晶体塑性子程序中输入的材料参数里的晶体取向参数是什么形式的?
  • ¥20 数学建模来解决我这个问题
  • ¥15 计算机网络ip分片偏移量计算头部是-20还是-40呀
  • ¥15 stc15f2k60s2单片机关于流水灯,时钟,定时器,矩阵键盘等方面的综合问题
  • ¥15 YOLOv8已有一个初步的检测模型,想利用这个模型对新的图片进行自动标注,生成labellmg可以识别的数据,再手动修改。如何操作?
  • ¥30 NIRfast软件使用指导
  • ¥20 matlab仿真问题,求功率谱密度
  • ¥15 求micropython modbus-RTU 从机的代码或库?
  • ¥15 django5安装失败
  • ¥15 Java与Hbase相关问题