pytorch param_groups

参数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 有两个`param_group`即,len(optim.param_groups)==2
optim.SGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

#一个参数组
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)

def lr_decay(optimizer, step, lr, decay_step, gamma):
lr = lr * gamma ** (step/decay_step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr

请作者喝一杯咖啡☕️