I got c10 error when I define a FP16 tensor and use torch::max() and torch::sqrt(). The code are shown as follows:
当我定义一个FP16的tensor并使用torch::max() 和 torch::sqrt()函数时会出现C10 error错误,代码如下:
float test_float[3][3] = { {1.0, 2.0, 3.0}, {4.0, 5.0, 6.0 },{7.0, 8.0, 9.0} };
torch::Tensor test_float_tensor = torch::from_blob(test_float, { 3, 3 }).to(at::kCPU).to(torch::kFloat16);
torch::sqrt(test_float_tensor);
torch::max(test_float_tensor, 1, true);
错误如下:
当把tensor改为FP32或FP64,错误就消失了,代码如下:
float test_float[3][3] = { {1.0, 2.0, 3.0}, {4.0, 5.0, 6.0 },{7.0, 8.0, 9.0} };
torch::Tensor test_float_tensor = torch::from_blob(test_float, { 3, 3 }).to(at::kCPU).to(torch::kFloat32);
torch::sqrt(test_float_tensor);
torch::max(test_float_tensor, 1, true);
为什么会造成这种错误呢?采用FP16 tensor,执行torch::sum()函数时并不会出现这个错误。是FP16 tensor有什么不适用的地方吗?求解答,谢谢!