测试pytorch 级联grad 发表于 2018-08-Mon | 阅读次数: 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889#!/usr/bin/env python# -*- coding: utf-8 -*-from __future__ import absolute_importfrom __future__ import print_functionimport alogimport torchimport torch.nn as nnclass Model1(nn.Module): def __init__(self): super(Model1, self).__init__() self.input_size = 10 self.output_size = 10 self.w1 = nn.Linear(self.input_size, self.output_size) def print_grad(self): print("Model1 w1 has grad ", self.w1.weight.requires_grad) print("w1 grad", self.w1.weight.grad) def forward(self, x): y = self.w1(x) return yclass Model2(nn.Module): def __init__(self): super(Model2, self).__init__() self.input_size = 10 self.output_size = 10 self.w1 = nn.Linear(self.input_size, self.output_size) def print_grad(self): print("Model2 w1 has grad ", self.w1.weight.requires_grad) print("w1 grad", self.w1.weight.grad) def forward(self, x): y = self.w1(x) return yclass Model3(nn.Module): def __init__(self): super(Model3, self).__init__() self.input_size = 10 self.output_size = 10 self.w1 = nn.Linear(self.input_size, self.output_size) def print_grad(self): print("Model3 w1 has grad ", self.w1.weight.requires_grad) print("w1 grad", self.w1.weight.grad) def forward(self, x): y = self.w1(x) return yclass TotalModel(nn.Module): def __init__(self): super(TotalModel, self).__init__() self.model1 = Model1() self.model2 = Model2() self.model3 = Model3() def print_grad(self): self.model1.print_grad() self.model2.print_grad() self.model3.print_grad() def forward(self, x): y = self.model1(x) y = self.model2(y) y = self.model3(y) return ymodel = TotalModel()x = torch.randn(3, 10)y = model(x)criterion = nn.MSELoss(size_average=True)target = torch.randn(3, 10)loss = criterion(y, target)loss.backward()model.print_grad()123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263Model1 w1 has grad Truew1 grad tensor([[ 0.0237, -0.0069, -0.0120, 0.0015, 0.0038, 0.0112, -0.0029, 0.0047, -0.0267, -0.0093], [ 0.0483, 0.0260, -0.0237, 0.0168, -0.0000, 0.0211, 0.0017, -0.0104, -0.0395, -0.0108], [ 0.0204, 0.0269, 0.0060, 0.0180, -0.0555, -0.0240, -0.0061, 0.0076, 0.0231, 0.0060], [-0.1250, -0.0052, 0.0274, -0.0343, 0.1048, 0.0142, 0.0295, -0.0484, 0.0499, 0.0243], [-0.0997, -0.0209, 0.0131, -0.0361, 0.1149, 0.0294, 0.0256, -0.0409, 0.0154, 0.0121], [ 0.0215, 0.0297, -0.0167, 0.0114, 0.0181, 0.0218, 0.0084, -0.0219, -0.0248, -0.0041], [ 0.0130, -0.0943, -0.0195, -0.0341, 0.0569, 0.0333, -0.0120, 0.0335, -0.0725, -0.0287], [-0.1711, -0.0050, 0.0365, -0.0467, 0.1465, 0.0216, 0.0414, -0.0686, 0.0669, 0.0332], [-0.2682, 0.0491, 0.0646, -0.0515, 0.1978, 0.0183, 0.0720, -0.1280, 0.1395, 0.0666], [-0.0565, 0.1078, 0.0438, 0.0324, -0.0715, -0.0586, 0.0162, -0.0397, 0.1266, 0.0469]])Model2 w1 has grad Truew1 grad tensor([[ 0.0295, -0.0207, -0.0064, -0.0087, 0.0577, -0.0405, -0.0094, -0.0207, 0.0490, 0.0002], [ 0.0410, -0.0451, 0.0368, 0.1388, -0.1122, 0.1371, -0.0073, -0.0063, -0.1194, -0.0762], [-0.0020, 0.0121, -0.0659, -0.0638, 0.0530, -0.0535, 0.0165, 0.0036, 0.0281, 0.0592], [ 0.0900, -0.0744, 0.0270, 0.0621, 0.0738, -0.0211, -0.0330, -0.0549, 0.0600, -0.0556], [ 0.0399, -0.0417, 0.0400, 0.1044, -0.0627, 0.0865, -0.0136, -0.0141, -0.0640, -0.0663], [ 0.1057, -0.0722, -0.0631, -0.0201, 0.1694, -0.1066, -0.0162, -0.0615, 0.1168, 0.0197], [ 0.0734, -0.0625, 0.0181, 0.0755, 0.0223, 0.0211, -0.0215, -0.0382, 0.0061, -0.0514], [ 0.1538, -0.1126, -0.1070, 0.0758, 0.0873, 0.0058, -0.0012, -0.0620, -0.0092, 0.0021], [ 0.0539, -0.0356, -0.0625, 0.0038, 0.0500, -0.0171, 0.0058, -0.0205, 0.0064, 0.0226], [-0.0055, -0.0043, 0.0645, 0.0392, -0.0307, 0.0266, -0.0171, -0.0037, -0.0019, -0.0486]])Model3 w1 has grad Truew1 grad tensor([[ 0.0018, -0.0045, 0.0175, -0.0154, 0.0253, -0.0375, 0.0574, 0.0155, -0.0139, 0.0022], [ 0.1069, 0.0265, 0.0383, 0.0559, -0.1969, 0.1059, 0.0688, 0.0279, 0.1663, -0.1112], [-0.0022, -0.0031, -0.0035, 0.0035, -0.0044, 0.0095, -0.0111, -0.0026, 0.0031, 0.0026], [-0.0862, -0.0261, -0.0486, -0.0220, 0.1177, -0.0284, -0.1166, -0.0371, -0.1056, 0.0883], [-0.0623, 0.0647, -0.1075, 0.0172, 0.0316, 0.0507, -0.3152, -0.0999, -0.0735, 0.0283], [-0.0432, 0.0199, -0.0191, -0.0365, 0.1060, -0.0805, -0.0350, -0.0185, -0.0962, 0.0354], [-0.0444, 0.0063, -0.0257, -0.0223, 0.0812, -0.0437, -0.0589, -0.0223, -0.0753, 0.0397], [-0.0821, 0.0236, -0.0624, -0.0313, 0.1331, -0.0579, -0.1573, -0.0555, -0.1328, 0.0674], [ 0.0241, 0.0421, -0.0186, 0.0224, -0.0595, 0.0435, -0.0706, -0.0218, 0.0335, -0.0399], [ 0.0387, 0.0043, 0.0331, 0.0014, -0.0384, -0.0070, 0.0892, 0.0273, 0.0408, -0.0358]])请作者喝一杯咖啡☕️打赏微信支付