pytorch黑科技

1
2
3
import torch
# after import torch, add this
torch.backends.cudnn.benchmark=True

保存模型

1
2
3
4
5
6
torch.save(the_model.state_dict(), PATH)
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

torch.save(the_model, PATH)
the_model = torch.load(PATH)

Variable to tensor

1
2
3
4
# for gpu
t.data.cpu().numpy()
# for cpu
t.data.numpy()

类型转换

1
2
a = a.type('torch.DoubleTensor') # for converting to double tensor (cpu)
a.double()

pytorch版本

1
2
import torch
print(torch.__version__)
请作者喝一杯咖啡☕️