import torch
length_of_sequences = map(len, [[1,2,3,3,3,3,3,3,3,3,3]])
batch_target_tensor = torch.ones(12,
max(length_of_sequences), dtype=torch.long) * (-100)
batch_target_tensor1 = torch.zeros(1,
max(length_of_sequences), dtype=torch.long) * (-100)
print(batch_target_tensor,batch_target_tensor1)
打印batch_target_tensor1时出现报错max() arg is an empty sequence。
切换成如下模式就不会报错,想知道是为什么,有没有帮忙解决一下。
import torch
length_of_sequences = map(len, [[1,2,3,3,3,3,3,3,3,3,3]])
length_of_sequences_max = max(length_of_sequences)
batch_target_tensor = torch.ones(12,
length_of_sequences_max, dtype=torch.long) * (-100)
batch_target_tensor1 = torch.zeros(1,
length_of_sequences_max, dtype=torch.long) * (-100)
print(batch_target_tensor,batch_target_tensor1)