y = torch.tensor([0,2]) y_hat = torch.tensor([[0.1,0.3,0.6], [0.3,0.2,0.5]]) y_hat[[0,1], y]
tensor([0.1000, 0.5000])
如题,谢谢各位。
收起
y_hat =
[[0.1,0.3,0.6],[0.3,0.2,0.5]]
然后y_hat[[0,1], y],也就是
y_hat[[0,1], [0,2]]
意思是从y_hat里面挑选出【0,0】元素和【1,2】元素得到tensor([0.1000, 0.5000])
望采纳, 谢谢!
报告相同问题?