pytorch那些坑

1. network 不存在

用了非推荐的save方式,然后load的模型的名字和自己写的文件名不一样

2. 记录loss信息的时候直接使用了输出的Variable

1
2
3
4
5
6
for data, label in trainloader:
......
out = model(data)
loss = criterion(out, label)
#loss_sum += loss # <--- 这里
loss_sum += loss.data[0]

3. model.eval

在训练每个batch之前记得加model.train(),训练完若干个iteration之后在验证前记得加model.eval()
否则会影响dropout和BN.

4. F.dropout()

用F.dropout()时一定要手动设参数self.training,正确用法:F.dropout(x, 0.2, self.training)

5. zero_grad

6. zero

如果是tensor:

1
tensor.new(tensor.size()).zero_()

如果是Variable,得是:

1
Variable(tensor.new(tensor.size()).zero_())

7. Inference

在做inference时,千万要记住对输入的Variable设置volatile 为true.而不能设置requires_grad 为false.血的教训。。。

请作者喝一杯咖啡☕️