小行星

热爱平淡,向往未知


  • 首页

  • 分类

  • 归档

  • 关于

  • 阅读排行

  • 搜索

未命名

发表于 2018-08-Mon | 阅读次数:

title: pytorch 计算图理解
date: 2018-08-06 10:21:14

tags:

因为笔者在train的时候发现梯度流被阻断了,所以学习下pytorch 计算图的原理。

version 0.3

gdown.pl hosts.txt main.py sort.sh

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from torch.autograd import Variable

x = torch.ones(1, requires_grad=True)
y = torch.ones(1)
z = x + y

# z = Variable(z.data, requires_grad=True)
w = torch.ones(1, requires_grad=True)
total = w + z

total.backward()
print(x.requires_grad, x.grad)
print(y.requires_grad, y.grad)
print(z.requires_grad, z.grad)
print(w.requires_grad, w.grad)

非叶子结点无法访问梯度

因为只有叶子结点是Variable, 它们的值可以变
非叶子结点可以通过使用.retrain_grad() 来修改梯度

只能一次backward

先前向计算得到graph, backward后graph就被释放了
然后在前向计算的时候,非叶子结点的梯度就destroy了,所以用retrain_grad进行修复

1
2
3
4
b = w1 * a
c = w2 * a
d = (w3 * b) + (w4 * c)
L = f(d)



https://towardsdatascience.com/getting-started-with-pytorch-part-1-understanding-how-automatic-differentiation-works-5008282073ec

tensorflow一些代码

发表于 2018-08-Mon | 阅读次数:

load多个图:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class ImportGraph():
""" Importing and running isolated TF graph """
def __init__(self, loc):
# Create local graph and use it in the session
self.graph = tf.Graph()
config = tf.ConfigProto(log_device_placement=False)
config.gpu_options.allow_growth = True
self.sess = tf.Session(graph=self.graph, config=config)
with self.graph.as_default():
# Import saved model from location 'loc' into local graph
saver = tf.train.import_meta_graph(loc + '.meta',
clear_devices=True)
saver.restore(self.sess, loc)
# There are TWO options how to get activation operation:
# FROM SAVED COLLECTION:
self.logits = self.graph.get_operation_by_name('proj/Reshape_1').outputs[0]
# self.activation = tf.get_collection('activation')[0]
# BY NAME:
# self.activation = self.graph.get_operation_by_name('activation_opt').outputs[0]

def run(self, fd):
""" Running the activation operation previously imported """
return self.sess.run([self.logits],
feed_dict=fd)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import tensorflow as tf
class Model:
def __init__(self,param):
self.param = param

# create & build graph
self.graph = tf.Graph()
self.build_graph()

# create session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
gpu_num = random.choice(cuda_gpu_count())
config.gpu_options.visible_device_list= str(gpu_num)
self.sess = tf.Session(config=config,graph=self.graph)

def build_graph(self):
with self.graph.as_default():
...
def __del__(self):
# explicitly collect resources by closing and deleting session and graph
self.sess.close()
del self.sess
del self.graph
del self.param

# train models and return the test accuracy
def train_test(self,train_data,train_label,test_data,test_label):
...

https://blog.csdn.net/silent56_th/article/details/81415940

squared error and softmax

发表于 2018-08-Sun | 阅读次数:

squared error 下,梯度是 y * (1-y) *(-2 * (t - y)) 如果y=0,那么梯度不见了

softmax error下,梯度是 y * (1-y) / (y) 如果y=0,那么梯度还蛮大

pytorch 容器

发表于 2018-08-Sat | 阅读次数:
1
2
3
4
if result is None:
result = temp
else:
result = torch.cat([result, temp])

pytorch variable out[out > 0] = 1

发表于 2018-08-Sat | 分类于 长文 | 阅读次数:
1
2
3
4
5
# out[out > 0] = 1
# out[out <= 0] = -1
out = (out >= 0.5).float()
# map 1,0 to 1,-1
out = out * 2 - 1

pytorch比较weight

发表于 2018-08-Sat | 阅读次数:
1
2
3
4
5
a = list(self.parameters())[0].clone()
loss.backward()
self.optimizer.step()
b = list(self.parameters())[0].clone()
torch.equal(a.data, b.data)

pytorch dropout自动调整

发表于 2018-08-Sat | 阅读次数:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import print_function

import torch.nn as nn
import torch


class Linear(nn.Module):
def __init__(self, p_dropout=1):
super(Linear, self).__init__()
self.dropout = nn.Dropout(p_dropout)

def forward(self, x):
y = self.dropout(x)
return y


net = Linear()
a = torch.ones(4)

net.train()
print(net(a))

net.eval()
b = torch.ones(4)
print(net(b))
1
2
3
4
5
6
7
8
9
10
11
12
13
Variable containing:
0
0
0
0
[torch.FloatTensor of size 4]

Variable containing:
1
1
1
1
[torch.FloatTensor of size 4]

pytorch net save

发表于 2018-08-Fri | 阅读次数:
1
2
3
4
class NetALL:
__init__(self):
self.net1 = xx
self.net2 = xx
1
2
torch.save(self.net1.dict()) w1:
torch.save(self.dict()) net1.w1:

pytorch 查看net状态

发表于 2018-08-Fri | 阅读次数:
1
print(module.training)

TypeError: must be type, not classobj

发表于 2018-08-Thu | 阅读次数:
1
2
class A():
xxx
1
2
class A(object):
xxx
1…222324…59
fangyh

fangyh

最爱的是那苍穹之外的浩渺宇宙

588 日志
4 分类
66 标签
© 2020 fangyh
由 Hexo 强力驱动
|
主题 — NexT.Mist v5.1.3
|本站总访问量45045次