问题遇到的现象和发生背景
pytorch 1.12.0
载入pt的时候,原本保存的是在cpu上面,现在我想载入cuda:0上,使用map_location报错
问题相关代码
net.load_state_dict(torch.load(model_weight_path, map_location=torch.device('cuda:0')))
运行结果及报错内容
RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
我的解答思路和尝试过的方法
github上有人问出类似问题,但国内这边没有结果,pytorch官网也没给出load的特殊变化,无法定位根本问题