代码是这样的,
y = torch.tensor([0,2])
y_hat = torch.tensor([[0.1,0.3,0.6], [0.3,0.2,0.5]])
y_hat[[0,1], y]
运行结果
tensor([0.1000, 0.5000])
如题,谢谢各位。
y = torch.tensor([0,2])
y_hat = torch.tensor([[0.1,0.3,0.6], [0.3,0.2,0.5]])
y_hat[[0,1], y]
tensor([0.1000, 0.5000])
如题,谢谢各位。