测试pytorch 级联grad

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import print_function

import alog
import torch
import torch.nn as nn


class Model1(nn.Module):
def __init__(self):
super(Model1, self).__init__()

self.input_size = 10
self.output_size = 10
self.w1 = nn.Linear(self.input_size, self.output_size)

def print_grad(self):
print("Model1 w1 has grad ", self.w1.weight.requires_grad)
print("w1 grad", self.w1.weight.grad)

def forward(self, x):
y = self.w1(x)
return y


class Model2(nn.Module):
def __init__(self):
super(Model2, self).__init__()

self.input_size = 10
self.output_size = 10
self.w1 = nn.Linear(self.input_size, self.output_size)

def print_grad(self):
print("Model2 w1 has grad ", self.w1.weight.requires_grad)
print("w1 grad", self.w1.weight.grad)

def forward(self, x):
y = self.w1(x)
return y


class Model3(nn.Module):
def __init__(self):
super(Model3, self).__init__()

self.input_size = 10
self.output_size = 10
self.w1 = nn.Linear(self.input_size, self.output_size)

def print_grad(self):
print("Model3 w1 has grad ", self.w1.weight.requires_grad)
print("w1 grad", self.w1.weight.grad)

def forward(self, x):
y = self.w1(x)
return y


class TotalModel(nn.Module):
def __init__(self):
super(TotalModel, self).__init__()
self.model1 = Model1()
self.model2 = Model2()
self.model3 = Model3()

def print_grad(self):
self.model1.print_grad()
self.model2.print_grad()
self.model3.print_grad()

def forward(self, x):
y = self.model1(x)
y = self.model2(y)
y = self.model3(y)
return y


model = TotalModel()
x = torch.randn(3, 10)
y = model(x)

criterion = nn.MSELoss(size_average=True)
target = torch.randn(3, 10)
loss = criterion(y, target)
loss.backward()
model.print_grad()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
Model1 w1 has grad  True
w1 grad tensor([[ 0.0237, -0.0069, -0.0120, 0.0015, 0.0038, 0.0112, -0.0029,
0.0047, -0.0267, -0.0093],
[ 0.0483, 0.0260, -0.0237, 0.0168, -0.0000, 0.0211, 0.0017,
-0.0104, -0.0395, -0.0108],
[ 0.0204, 0.0269, 0.0060, 0.0180, -0.0555, -0.0240, -0.0061,
0.0076, 0.0231, 0.0060],
[-0.1250, -0.0052, 0.0274, -0.0343, 0.1048, 0.0142, 0.0295,
-0.0484, 0.0499, 0.0243],
[-0.0997, -0.0209, 0.0131, -0.0361, 0.1149, 0.0294, 0.0256,
-0.0409, 0.0154, 0.0121],
[ 0.0215, 0.0297, -0.0167, 0.0114, 0.0181, 0.0218, 0.0084,
-0.0219, -0.0248, -0.0041],
[ 0.0130, -0.0943, -0.0195, -0.0341, 0.0569, 0.0333, -0.0120,
0.0335, -0.0725, -0.0287],
[-0.1711, -0.0050, 0.0365, -0.0467, 0.1465, 0.0216, 0.0414,
-0.0686, 0.0669, 0.0332],
[-0.2682, 0.0491, 0.0646, -0.0515, 0.1978, 0.0183, 0.0720,
-0.1280, 0.1395, 0.0666],
[-0.0565, 0.1078, 0.0438, 0.0324, -0.0715, -0.0586, 0.0162,
-0.0397, 0.1266, 0.0469]])
Model2 w1 has grad True
w1 grad tensor([[ 0.0295, -0.0207, -0.0064, -0.0087, 0.0577, -0.0405, -0.0094,
-0.0207, 0.0490, 0.0002],
[ 0.0410, -0.0451, 0.0368, 0.1388, -0.1122, 0.1371, -0.0073,
-0.0063, -0.1194, -0.0762],
[-0.0020, 0.0121, -0.0659, -0.0638, 0.0530, -0.0535, 0.0165,
0.0036, 0.0281, 0.0592],
[ 0.0900, -0.0744, 0.0270, 0.0621, 0.0738, -0.0211, -0.0330,
-0.0549, 0.0600, -0.0556],
[ 0.0399, -0.0417, 0.0400, 0.1044, -0.0627, 0.0865, -0.0136,
-0.0141, -0.0640, -0.0663],
[ 0.1057, -0.0722, -0.0631, -0.0201, 0.1694, -0.1066, -0.0162,
-0.0615, 0.1168, 0.0197],
[ 0.0734, -0.0625, 0.0181, 0.0755, 0.0223, 0.0211, -0.0215,
-0.0382, 0.0061, -0.0514],
[ 0.1538, -0.1126, -0.1070, 0.0758, 0.0873, 0.0058, -0.0012,
-0.0620, -0.0092, 0.0021],
[ 0.0539, -0.0356, -0.0625, 0.0038, 0.0500, -0.0171, 0.0058,
-0.0205, 0.0064, 0.0226],
[-0.0055, -0.0043, 0.0645, 0.0392, -0.0307, 0.0266, -0.0171,
-0.0037, -0.0019, -0.0486]])
Model3 w1 has grad True
w1 grad tensor([[ 0.0018, -0.0045, 0.0175, -0.0154, 0.0253, -0.0375, 0.0574,
0.0155, -0.0139, 0.0022],
[ 0.1069, 0.0265, 0.0383, 0.0559, -0.1969, 0.1059, 0.0688,
0.0279, 0.1663, -0.1112],
[-0.0022, -0.0031, -0.0035, 0.0035, -0.0044, 0.0095, -0.0111,
-0.0026, 0.0031, 0.0026],
[-0.0862, -0.0261, -0.0486, -0.0220, 0.1177, -0.0284, -0.1166,
-0.0371, -0.1056, 0.0883],
[-0.0623, 0.0647, -0.1075, 0.0172, 0.0316, 0.0507, -0.3152,
-0.0999, -0.0735, 0.0283],
[-0.0432, 0.0199, -0.0191, -0.0365, 0.1060, -0.0805, -0.0350,
-0.0185, -0.0962, 0.0354],
[-0.0444, 0.0063, -0.0257, -0.0223, 0.0812, -0.0437, -0.0589,
-0.0223, -0.0753, 0.0397],
[-0.0821, 0.0236, -0.0624, -0.0313, 0.1331, -0.0579, -0.1573,
-0.0555, -0.1328, 0.0674],
[ 0.0241, 0.0421, -0.0186, 0.0224, -0.0595, 0.0435, -0.0706,
-0.0218, 0.0335, -0.0399],
[ 0.0387, 0.0043, 0.0331, 0.0014, -0.0384, -0.0070, 0.0892,
0.0273, 0.0408, -0.0358]])
请作者喝一杯咖啡☕️