SunnyEdward 2022-05-27 16:18 采纳率: 75%
浏览 406
已结题

torch.FloatTensor 处理不等长的列表报错

我有这样一个列表,其中的子列表长度是不一致的,在使用torch.FloatTensor的时候报错:expected sequence of length 5 at dim 1 (got 10).请问有解决的方案吗?

[[2643037.0, 6899734.0, 5104963.0, 2197590.0, 6796032.0], [274737.0, 6553502.0, 8387494.0, 8189333.0, 4175343.0, 8270815.0, 1691725.0, 221449.0, 4407937.0, 3193705.0], [3583583.0, 4270651.0, 2132056.0], [5780075.0], [7808558.0], [3884986.0], [6510503.0], [1860312.0, 988514.0, 3551827.0, 3452753.0], [1153600.0, 7232925.0, 8786106.0, 6854379.0, 2230286.0, 2753785.0, 1934809.0, 7726970.0, 5210459.0], [8548687.0], [3324262.0, 761825.0], [5885106.0], [8666430.0], [4933558.0, 3709514.0, 3021415.0, 5561149.0, 7826716.0, 2611716.0, 8622575.0, 8557317.0, 3717192.0, 4507507.0], [3892502.0, 5153929.0, 3309247.0, 1203831.0, 2624460.0, 8857676.0, 1253447.0, 7272396.0, 1785236.0, 4668947.0, 473317.0], [8284366.0], [1193636.0, 2860098.0], [3103978.0], [945809.0], [7706326.0, 4080953.0, 5530568.0, 6762956.0, 5198121.0, 8918377.0, 815908.0, 6486437.0, 3007308.0, 2296588.0, 3791218.0, 2923812.0, 5338646.0, 2711829.0, 3504941.0]]
  • 写回答

4条回答 默认 最新

  • weixin_41076199 2022-06-03 09:34
    关注

    简单来说,就是矩阵的每一行的长度应该是相同的!!!
    而你这里面的每一行(也就是每一句话)长度是不相同的!!!

    解决方案: 可以通过填充的方式使所有序列长度都相同!即通过补0的方式将error_matrix -> correct_matrix
    实现办法:torch.nn.utils.rnn.pad_sequence()方法

    a=[[2643037.0, 6899734.0, 5104963.0, 2197590.0, 6796032.0], [274737.0, 6553502.0, 8387494.0, 8189333.0, 4175343.0, 8270815.0, 1691725.0, 221449.0, 4407937.0, 3193705.0], [3583583.0, 4270651.0, 2132056.0], [5780075.0], [7808558.0], [3884986.0], [6510503.0], [1860312.0, 988514.0, 3551827.0, 3452753.0], [1153600.0, 7232925.0, 8786106.0, 6854379.0, 2230286.0, 2753785.0, 1934809.0, 7726970.0, 5210459.0], [8548687.0], [3324262.0, 761825.0], [5885106.0], [8666430.0], [4933558.0, 3709514.0, 3021415.0, 5561149.0, 7826716.0, 2611716.0, 8622575.0, 8557317.0, 3717192.0, 4507507.0], [3892502.0, 5153929.0, 3309247.0, 1203831.0, 2624460.0, 8857676.0, 1253447.0, 7272396.0, 1785236.0, 4668947.0, 473317.0], [8284366.0], [1193636.0, 2860098.0], [3103978.0], [945809.0], [7706326.0, 4080953.0, 5530568.0, 6762956.0, 5198121.0, 8918377.0, 815908.0, 6486437.0, 3007308.0, 2296588.0, 3791218.0, 2923812.0, 5338646.0, 2711829.0, 3504941.0]]
    
    import torch
    
    padded_sequence = torch.nn.utils.rnn.pad_sequence([torch.FloatTensor(i) for i in a], batch_first=True)
    print(padded_sequence)
    
    

    结果:
    tensor([[2643037., 6899734., 5104963., 2197590., 6796032., 0., 0.,
    0., 0., 0., 0., 0., 0., 0.,
    0.],
    [ 274737., 6553502., 8387494., 8189333., 4175343., 8270815., 1691725.,
    221449., 4407937., 3193705., 0., 0., 0., 0.,
    0.],
    [3583583., 4270651., 2132056., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0.,
    0.],
    [5780075., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0.,
    0.],
    [7808558., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0.,
    0.],
    [3884986., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0.,
    0.],
    [6510503., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0.,
    0.],..........................

    评论

报告相同问题?

问题事件

  • 系统已结题 6月4日
  • 修改了问题 5月27日
  • 创建了问题 5月27日