RuntimeError: shape '[1, 1]' is invalid for input of size 4
报错部分
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
B为4
view_shape为[1, 1]
请问应该如何解决呢?
RuntimeError: shape '[1, 1]' is invalid for input of size 4
报错部分
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
B为4
view_shape为[1, 1]
请问应该如何解决呢?
这样试试呢?
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(B, 1).repeat(repeat_shape)