pytorch param_groups 发表于 2018-07-Sun | 阅读次数: 参数组1234567891011121314# 有两个`param_group`即,len(optim.param_groups)==2optim.SGD([ {'params': model.base.parameters()}, {'params': model.classifier.parameters(), 'lr': 1e-3} ], lr=1e-2, momentum=0.9)#一个参数组optim.SGD(model.parameters(), lr=1e-2, momentum=.9)def lr_decay(optimizer, step, lr, decay_step, gamma): lr = lr * gamma ** (step/decay_step) for param_group in optimizer.param_groups: param_group['lr'] = lr return lr请作者喝一杯咖啡☕️打赏微信支付