pytorch dropout自动调整 发表于 2018-08-Sat | 阅读次数: 123456789101112131415161718192021222324252627# -*- coding: utf-8 -*-from __future__ import absolute_importfrom __future__ import print_functionimport torch.nn as nnimport torchclass Linear(nn.Module): def __init__(self, p_dropout=1): super(Linear, self).__init__() self.dropout = nn.Dropout(p_dropout) def forward(self, x): y = self.dropout(x) return ynet = Linear()a = torch.ones(4)net.train()print(net(a))net.eval()b = torch.ones(4)print(net(b))12345678910111213Variable containing: 0 0 0 0[torch.FloatTensor of size 4]Variable containing: 1 1 1 1[torch.FloatTensor of size 4]请作者喝一杯咖啡☕️打赏微信支付