这个tensor的shape是(400),怎么以里面的2分割tensor成若干个tensor啊?
收起
需要借助numpy
import torch import numpy as np t=torch.tensor([1,2,3,4,5,6,8,2,56,5,2,10]) t_numpy=t.numpy() index=np.argwhere(t_numpy==2) print(index)#在tensor中的索引,可以根据该索引对tensor切片
报告相同问题?