Simplify scalar_check of nll_loss. (#30669)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30669
The inputs can't be 0-d, so we don't need that check in the scalar_check.
Test Plan: Imported from OSS
Differential Revision: D18784524
Pulled By: gchanan
fbshipit-source-id: d44222dffc91880a6e8c7be69e6e146e60040d43
diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml
index 76df33d..3d67ae6 100644
--- a/aten/src/ATen/nn.yaml
+++ b/aten/src/ATen/nn.yaml
@@ -31,7 +31,7 @@
cname: ClassNLLCriterion
buffers: [total_weight]
scalar_check:
- output: reduction != at::Reduction::None || self_->dim() == 0
+ output: reduction != at::Reduction::None
total_weight: 'false'
CPU:
forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
@@ -41,7 +41,7 @@
cname: SpatialClassNLLCriterion
buffers: [total_weight]
scalar_check:
- output: reduction != at::Reduction::None || self_->dim() == 0
+ output: reduction != at::Reduction::None
total_weight: 'false'
- name: _thnn_soft_margin_loss(Tensor self, Tensor target, int64_t reduction)
diff --git a/test/test_torch.py b/test/test_torch.py
index c298056..029c490 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -6297,6 +6297,10 @@
# self.assertEqual((), torch.normal(zero_d, one_d).shape)
# self.assertEqual((), torch.normal(1, one_d).shape)
+ # nll_loss -- verify input can't be 0-dimensional.
+ self.assertRaises(ValueError, lambda: torch.nn.functional.nll_loss(zero_d, zero_d))
+ self.assertRaises(ValueError, lambda: torch.nn.functional.nll_loss(zero_d, one_d))
+
@onlyCPU
@dtypes(torch.float)
def test_diag(self, device, dtype):