Fix L1Loss when target.requires_grad is True. (#44471)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44471
L1Loss had a completely different (and incorrect, see #43228) path when target.requires_grad was True.
This PR does the following:
1) adds derivative support for target via the normal derivatives.yaml route
2) kill the different (and incorrect) path for when target.requires_grad was True
3) modify the L1Loss CriterionTests to verify that the target derivative is checked.
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D23626008
Pulled By: gchanan
fbshipit-source-id: 2828be16b56b8dabe114962223d71b0e9a85f0f5
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 6f9568f..576232b 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1202,6 +1202,7 @@
- name: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
self: l1_loss_backward(grad, self, target, reduction)
+ target: l1_loss_backward(grad, target, self, reduction)
- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
self: mse_loss_backward(grad, self, target, reduction)
@@ -1520,6 +1521,7 @@
- name: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
grad_output: l1_loss_double_backward_grad_output(grad, self, target, reduction)
self: zeros_like(grad, at::MemoryFormat::Preserve)
+ target: zeros_like(grad, at::MemoryFormat::Preserve)
- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor
grad_output: log_sigmoid_backward(grad, self, buffer)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index cee9c12..6ebed5c 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -2629,15 +2629,10 @@
stacklevel=2)
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
- if target.requires_grad:
- _Reduction.get_enum(reduction) # throw an error if reduction is invalid
- ret = torch.abs(input - target)
- if reduction != 'none':
- ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
- else:
- expanded_input, expanded_target = torch.broadcast_tensors(input, target)
- ret = torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
- return ret
+
+
+ expanded_input, expanded_target = torch.broadcast_tensors(input, target)
+ return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py
index a566493..0f066c9 100644
--- a/torch/testing/_internal/common_nn.py
+++ b/torch/testing/_internal/common_nn.py
@@ -3879,7 +3879,7 @@
dict(
module_name='L1Loss',
input_size=(2, 3, 4),
- target_size=(2, 3, 4),
+ target_fn=lambda: torch.randn((2, 3, 4), requires_grad=True),
reference_fn=lambda i, t, _: 1. / i.numel() *
sum((a - b).abs().sum() for a, b in zip(i, t)),
),
@@ -4277,7 +4277,7 @@
dict(
module_name='L1Loss',
input_size=(),
- target_size=(),
+ target_fn=lambda: torch.randn((), requires_grad=True),
reference_fn=lambda i, t, _: 1. / i.numel() * (i - t).abs().sum(),
desc='scalar',
),