代码:
concat_mask = True if 'MST_shanghaitech' in args.PATH else False
model = MST(config, concat_mask)
model.load()
model.inference(args.image_path, args.mask_path, config.valid_th, config.mask_th,
not_obj_remove=args.not_obj_remove)
报错结果为:
Traceback (most recent call last):
File "test_single.py", line 52, in <module>
model.load()
File "E:\code\MST_inpainting-main\src\MST_model.py", line 102, in load
self.inpaint_decoder.generator.load_state_dict(
File "D:\Anaconda3\envs\torch18\lib\site-packages\torch\nn\modules\module.py", line 1223, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for InpaintGateGenerator:
size mismatch for encoder.1.gate_conv.weight: copying a param with shape torch.Size([128, 6, 7, 7]) from checkpoint, the shape in current model is torch.Size([128, 7, 7, 7]).
这个该怎么去改它的参数呢?