问题具体:
在读入pytorch训练好的pth模型时遇到问题:
Exception in thread Thread-1:
Traceback (most recent call last):
File "C:\Users\YMeng\.conda\envs\pytorchgpu\lib\threading.py", line 926, in _bootstrap_inner
self.run()
File "C:\Users\YMeng\.conda\envs\pytorchgpu\lib\threading.py", line 870, in run
self._target(*self._args, **self._kwargs)
File "F:/Zhaohaocen/HandPose/HandPose/GUI/real_time_show_demo.py", line 148, in detection
self.detector, self.classifer = Application._model_init()
File "F:/Zhaohaocen/HandPose/HandPose/GUI/real_time_show_demo.py", line 114, in _model_init
detector.model.load_state_dict(torch.load(CKPT, map_location=torch.device('cpu')))###########指定路径下读取训练模型进行训练
File "C:\Users\YMeng\.conda\envs\pytorchgpu\lib\site-packages\torch\serialization.py", line 592, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "C:\Users\YMeng\.conda\envs\pytorchgpu\lib\site-packages\torch\serialization.py", line 851, in _load
result = unpickler.load()
ModuleNotFoundError: No module named 'models'
之前也看过一些解释,但我也使用了load_state_dict函数了
训练时的函数:
checkpoint = {'model': model.state_dict(),
'model_state_dict': model.state_dict(),
#'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch}
torch.save(checkpoint, os.path.join(save_folder, 'epoch_{}.pth'.format(epoch)))
读入模型函数:
detector.model.load_state_dict(torch.load(CKPT, map_location=torch.device('cpu')))