pytorch 自带mse loss计算慢

测试发现自带的调用一次需要0.1s………

自带的mse loss计算慢,自定义一个

1
2
3
4
def self_mseloss(y_input, y_target):
num = torch.numel(y_input)
mse_loss = torch.sum((y_input - y_target) ** 2)
return mse_loss / num
请作者喝一杯咖啡☕️