Allow data size equal to 0 for SegmentReduce (#99733)
Summary:
Support special case that data size can be 0 for SegmentReduce.
Example code below:
```
x = torch.ones((0, 6)).cuda()
lengths = torch.tensor([0, 0]).cuda()
torch.segment_reduce(x, "sum", lengths=lengths, unsafe=False, initial=0)
```
Previously, error message: Expected data.numel() > 0 to be true, but got false.
Now expect to return 0.
Test Plan: contbuild & OSS CI
Differential Revision: D45133827
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99733
Approved by: https://github.com/ngimel
diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp
index 932b43c..c326bb0 100644
--- a/aten/src/ATen/native/SegmentReduce.cpp
+++ b/aten/src/ATen/native/SegmentReduce.cpp
@@ -394,7 +394,7 @@
bool unsafe,
const c10::optional<Scalar>& initial) {
axis = maybe_wrap_dim(axis, data.ndimension());
- TORCH_CHECK(data.numel() > 0);
+ TORCH_CHECK(data.numel() >= 0);
// check that one of lengths or offsets is defined
auto lengths_has_value = lengths.has_value();
diff --git a/test/test_segment_reductions.py b/test/test_segment_reductions.py
index 5e14a25..8b4f69f 100644
--- a/test/test_segment_reductions.py
+++ b/test/test_segment_reductions.py
@@ -180,6 +180,63 @@
(torch.int, torch.int64),
)
)
+ def test_simple_zero_length(self, device, dtypes):
+ val_dtype, length_type = dtypes
+ lengths = [0, 0]
+ data = torch.ones((0))
+
+ for reduction in reductions:
+ for initial in [0, None]:
+ check_backward = True if initial is not None else False
+ initial_value = initial
+ default_value = get_default_value(initial_value, reduction)
+ if reduction == "max":
+ expected_result = [default_value, default_value]
+ expected_grad = []
+ elif reduction == "mean":
+ expected_result = [default_value, default_value]
+ expected_grad = []
+ elif reduction == "min":
+ if initial is not None:
+ initial_value = 1000 # some high number
+ default_value = get_default_value(initial_value, reduction)
+ expected_result = [default_value, default_value]
+ expected_grad = []
+ elif reduction == "sum":
+ expected_result = [default_value, default_value]
+ expected_grad = []
+ elif reduction == "prod":
+ if initial is not None:
+ initial_value = 2 # 0 initial_value will zero out everything for prod
+ default_value = get_default_value(initial_value, reduction)
+ expected_result = [default_value, default_value]
+ expected_grad = []
+ else:
+ expected_result = [default_value, default_value]
+ expected_grad = []
+ for axis in [0]:
+ for unsafe in [True, False]:
+ self._test_common(
+ reduction,
+ device,
+ val_dtype,
+ unsafe,
+ axis,
+ initial_value,
+ data,
+ lengths,
+ expected_result,
+ expected_grad,
+ check_backward,
+ length_type,
+ )
+
+ @dtypes(
+ *product(
+ (torch.half, torch.bfloat16, torch.float, torch.double),
+ (torch.int, torch.int64),
+ )
+ )
def test_multi_d_simple(self, device, dtypes):
val_dtype, length_type = dtypes
axis = 0