title: pytorch 计算图理解
date: 2018-08-06 10:21:14

tags:

因为笔者在train的时候发现梯度流被阻断了,所以学习下pytorch 计算图的原理。

version 0.3

gdown.pl hosts.txt main.py sort.sh

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from torch.autograd import Variable

x = torch.ones(1, requires_grad=True)
y = torch.ones(1)
z = x + y

# z = Variable(z.data, requires_grad=True)
w = torch.ones(1, requires_grad=True)
total = w + z

total.backward()
print(x.requires_grad, x.grad)
print(y.requires_grad, y.grad)
print(z.requires_grad, z.grad)
print(w.requires_grad, w.grad)

非叶子结点无法访问梯度

因为只有叶子结点是Variable, 它们的值可以变
非叶子结点可以通过使用.retrain_grad() 来修改梯度

只能一次backward

先前向计算得到graph, backward后graph就被释放了
然后在前向计算的时候,非叶子结点的梯度就destroy了,所以用retrain_grad进行修复

1
2
3
4
b = w1 * a
c = w2 * a
d = (w3 * b) + (w4 * c)
L = f(d)



https://towardsdatascience.com/getting-started-with-pytorch-part-1-understanding-how-automatic-differentiation-works-5008282073ec

请作者喝一杯咖啡☕️