def _load_optimizer_state(self):
main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
opt_checkpoint = bf.join(
bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
)
if bf.exists(opt_checkpoint):
logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
state_dict = dist_util.load_state_dict(
opt_checkpoint, map_location=dist_util.dev(),strict=False
)
self.opt.load_state_dict(state_dict)
以上这段代码是用于从检查点中获取优化器参数,我发现在多卡并行的情况下得到的state_dict的内容和单卡情况下的内容不一样,导致我在多卡情况下恢复训练会出现错误:KeyError: 'param_groups',这是为什么?如何解决?
我尝试了各位的方案,发现还是不行,在单GPU时state_dict内是['state','param_groups'],多GPU就变成了如下:
['t_embedder.mlp.0.weight', 't_embedder.mlp.0.bias', 't_embedder.mlp.2.weight', 't_embedder.mlp.2.bias', 'encoder.pos_embedding', 'encoder.to_patch_embedding.1.weight', 'encoder.to_patch_embedding.1.bias', 'encoder.to_patch_embedding.2.weight', 'encoder.to_patch_embedding.2.bias', 'encoder.to_patch_embedding.3.weight', 'encoder.to_patch_embedding.3.bias', 'encoder.transformer.norm.weight', 'encoder.transformer.norm.bias', 'encoder.transformer.layers.0.0.norm.weight', 'encoder.transformer.layers.0.0.norm.bias', 'encoder.transformer.layers.0.0.to_qkv.weight', 'encoder.transformer.layers.0.0.to_out.0.weight', 'encoder.transformer.layers.0.0.to_out.0.bias', 'encoder.transformer.layers.0.1.net.0.weight', 'encoder.transformer.layers.0.1.net.0.bias', 'encoder.transformer.layers.0.1.net.1.weight', 'encoder.transformer.layers.0.1.net.1.bias', 'encoder.transformer.layers.0.1.net.4.weight', 'encoder.transformer.layers.0.1.net.4.bias', 'encoder.transformer.layers.1.0.norm.weight', 'encoder.transformer.layers.1.0.norm.bias', 'encoder.transformer.layers.1.0.to_qkv.weight', 'encoder.transformer.layers.1.0.to_out.0.weight', 'encoder.transformer.layers.1.0.to_out.0.bias', 'encoder.transformer.layers.1.1.net.0.weight', 'encoder.transformer.layers.1.1.net.0.bias', 'encoder.transformer.layers.1.1.net.1.weight', 'encoder.transformer.layers.1.1.net.1.bias', 'encoder.transformer.layers.1.1.net.4.weight', 'encoder.transformer.layers.1.1.net.4.bias', 'encoder.transformer.layers.2.0.norm.weight', 'encoder.transformer.layers.2.0.norm.bias', 'encoder.transformer.layers.2.0.to_qkv.weight', 'encoder.transformer.layers.2.0.to_out.0.weight', 'encoder.transformer.layers.2.0.to_out.0.bias', 'encoder.transformer.layers.2.1.net.0.weight', 'encoder.transformer.layers.2.1.net.0.bias', 'encoder.transformer.layers.2.1.net.1.weight', 'encoder.transformer.layers.2.1.net.1.bias', 'encoder.transformer.layers.2.1.net.4.weight', 'encoder.transformer.layers.2.1.net.4.bias', 'encoder.transformer.layers.3.0.norm.weight', 'encoder.transformer.layers.3.0.norm.bias', 'encoder.transformer.layers.3.0.to_qkv.weight', 'encoder.transformer.layers.3.0.to_out.0.weight', 'encoder.transformer.layers.3.0.to_out.0.bias', 'encoder.transformer.layers.3.1.net.0.weight', 'encoder.transformer.layers.3.1.net.0.bias', 'encoder.transformer.layers.3.1.net.1.weight', 'encoder.transformer.layers.3.1.net.1.bias', 'encoder.transformer.layers.3.1.net.4.weight', 'encoder.transformer.layers.3.1.net.4.bias', 'encoder.transformer.layers.4.0.norm.weight', 'encoder.transformer.layers.4.0.norm.bias', 'encoder.transformer.layers.4.0.to_qkv.weight', 'encoder.transformer.layers.4.0.to_out.0.weight', 'encoder.transformer.layers.4.0.to_out.0.bias', 'encoder.transformer.layers.4.1.net.0.weight', 'encoder.transformer.layers.4.1.net.0.bias', 'encoder.transformer.layers.4.1.net.1.weight', 'encoder.transformer.layers.4.1.net.1.bias', 'encoder.transformer.layers.4.1.net.4.weight', 'encoder.transformer.layers.4.1.net.4.bias', 'encoder.transformer.layers.5.0.norm.weight', 'encoder.transformer.layers.5.0.norm.bias', 'encoder.transformer.layers.5.0.to_qkv.weight', 'encoder.transformer.layers.5.0.to_out.0.weight', 'encoder.transformer.layers.5.0.to_out.0.bias', 'encoder.transformer.layers.5.1.net.0.weight', 'encoder.transformer.layers.5.1.net.0.bias', 'encoder.transformer.layers.5.1.net.1.weight', 'encoder.transformer.layers.5.1.net.1.bias', 'encoder.transformer.layers.5.1.net.4.weight', 'encoder.transformer.layers.5.1.net.4.bias', 'decoder.0.weight', 'decoder.0.bias', 'decoder.2.weight', 'decoder.2.bias', 'decoder.2.running_mean', 'decoder.2.running_var', 'decoder.2.num_batches_tracked', 'decoder.3.weight', 'decoder.3.bias', 'decoder.5.weight', 'decoder.5.bias', 'decoder.5.running_mean', 'decoder.5.running_var', 'decoder.5.num_batches_tracked', 'decoder.6.weight', 'decoder.6.bias', 'decoder.8.weight', 'decoder.8.bias', 'decoder.8.running_mean', 'decoder.8.running_var', 'decoder.8.num_batches_tracked', 'decoder.9.weight', 'decoder.9.bias', 'decoder.11.weight', 'decoder.11.bias', 'decoder.11.running_mean', 'decoder.11.running_var', 'decoder.11.num_batches_tracked', 'decoder.12.weight', 'decoder.12.bias', 'decoder.14.weight', 'decoder.14.bias', 'decoder.14.running_mean', 'decoder.14.running_var', 'decoder.14.num_batches_tracked', 'decoder.15.weight', 'decoder.15.bias', 'decoder.17.weight', 'decoder.17.bias', 'decoder.17.running_mean', 'decoder.17.running_var', 'decoder.17.num_batches_tracked', 'decoder.18.weight', 'decoder.18.bias', 'decoder.20.weight', 'decoder.20.bias', 'decoder.20.running_mean', 'decoder.20.running_var', 'decoder.20.num_batches_tracked', 'decoder.21.weight', 'decoder.21.bias'])
所以多卡的state_dict因为没有param_groups导致无法从检查点恢复训练,但是我不知道为什么会这样,如何解决?