pytorch tensor 套 tensor需要cat

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
a = torch.tensor(0.0)
b = torch.tensor(0.0)

a.requires_grad = True
b.requires_grad = True

cosa = torch.cos(a)
cosb = torch.cos(b)

y1 = cosa + cosb
y2 = torch.tensor([
cosa, cosb
])

y2.sum().backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

print("a has grad ", a.requires_grad)
print("a grad", a.grad)

要应该拼接起来

1
2
3
y2 = torch.cat([
cosa, cosb
], dim=0)

请作者喝一杯咖啡☕️