pytorch按batch softmax

1
2
3
x = Variable(torch.cat((torch.ones(1, 1, 10, 10), torch.ones(1, 1, 10, 10)*2), dim=0))
softmax = nn.Softmax(dim=1)
y = softmax(x.view(2, -1)).view(2, 1, 10, 10)
请作者喝一杯咖啡☕️