title: pytorch 计算图理解
date: 2018-08-06 10:21:14
tags:
因为笔者在train的时候发现梯度流被阻断了,所以学习下pytorch 计算图的原理。
version 0.3
gdown.pl hosts.txt main.py sort.sh
1 | import torch |
非叶子结点无法访问梯度
因为只有叶子结点是Variable, 它们的值可以变
非叶子结点可以通过使用.retrain_grad() 来修改梯度
只能一次backward
先前向计算得到graph, backward后graph就被释放了
然后在前向计算的时候,非叶子结点的梯度就destroy了,所以用retrain_grad进行修复
1 | b = w1 * a |