在RepVGG官方代码的train.py中,没看懂这个函数在干什么:
def sgd_optimizer(model, lr, momentum, weight_decay, use_custwd):
params = []
for key, value in model.named_parameters(): # key是网络层名称,value是参数
if not value.requires_grad:
continue
apply_weight_decay = weight_decay
apply_lr = lr
if (use_custwd and ('rbr_dense' in key or 'rbr_1x1' in key)) or 'bias' in key or 'bn' in key:
apply_weight_decay = 0
print('set weight decay=0 for {}'.format(key))
if 'bias' in key:
apply_lr = 2 * lr # Just a Caffe-style common practice. Made no difference.
params += [{'params': [value], 'lr': apply_lr, 'weight_decay': apply_weight_decay}]
optimizer = torch.optim.SGD(params, lr, momentum=momentum)
return optimizer
```python
```