pytorch weight_init

1
2
3
4
5
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal(m.weight)

model.apply(weight_init)
请作者喝一杯咖啡☕️