Allow data size equal to 0 for SegmentReduce (#99733)

Summary:
Support special case that data size can be 0 for SegmentReduce.

Example code below:
```
x = torch.ones((0, 6)).cuda()
lengths = torch.tensor([0, 0]).cuda()
torch.segment_reduce(x, "sum", lengths=lengths, unsafe=False, initial=0)
```
Previously, error message: Expected data.numel() > 0 to be true, but got false.
Now expect to return 0.

Test Plan: contbuild & OSS CI

Differential Revision: D45133827

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99733
Approved by: https://github.com/ngimel
diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp
index 932b43c..c326bb0 100644
--- a/aten/src/ATen/native/SegmentReduce.cpp
+++ b/aten/src/ATen/native/SegmentReduce.cpp
@@ -394,7 +394,7 @@
     bool unsafe,
     const c10::optional<Scalar>& initial) {
   axis = maybe_wrap_dim(axis, data.ndimension());
-  TORCH_CHECK(data.numel() > 0);
+  TORCH_CHECK(data.numel() >= 0);
 
   // check that one of lengths or offsets is defined
   auto lengths_has_value = lengths.has_value();
diff --git a/test/test_segment_reductions.py b/test/test_segment_reductions.py
index 5e14a25..8b4f69f 100644
--- a/test/test_segment_reductions.py
+++ b/test/test_segment_reductions.py
@@ -180,6 +180,63 @@
             (torch.int, torch.int64),
         )
     )
+    def test_simple_zero_length(self, device, dtypes):
+        val_dtype, length_type = dtypes
+        lengths = [0, 0]
+        data = torch.ones((0))
+
+        for reduction in reductions:
+            for initial in [0, None]:
+                check_backward = True if initial is not None else False
+                initial_value = initial
+                default_value = get_default_value(initial_value, reduction)
+                if reduction == "max":
+                    expected_result = [default_value, default_value]
+                    expected_grad = []
+                elif reduction == "mean":
+                    expected_result = [default_value, default_value]
+                    expected_grad = []
+                elif reduction == "min":
+                    if initial is not None:
+                        initial_value = 1000  # some high number
+                        default_value = get_default_value(initial_value, reduction)
+                    expected_result = [default_value, default_value]
+                    expected_grad = []
+                elif reduction == "sum":
+                    expected_result = [default_value, default_value]
+                    expected_grad = []
+                elif reduction == "prod":
+                    if initial is not None:
+                        initial_value = 2  # 0 initial_value will zero out everything for prod
+                        default_value = get_default_value(initial_value, reduction)
+                        expected_result = [default_value, default_value]
+                        expected_grad = []
+                    else:
+                        expected_result = [default_value, default_value]
+                        expected_grad = []
+                for axis in [0]:
+                    for unsafe in [True, False]:
+                        self._test_common(
+                            reduction,
+                            device,
+                            val_dtype,
+                            unsafe,
+                            axis,
+                            initial_value,
+                            data,
+                            lengths,
+                            expected_result,
+                            expected_grad,
+                            check_backward,
+                            length_type,
+                        )
+
+    @dtypes(
+        *product(
+            (torch.half, torch.bfloat16, torch.float, torch.double),
+            (torch.int, torch.int64),
+        )
+    )
     def test_multi_d_simple(self, device, dtypes):
         val_dtype, length_type = dtypes
         axis = 0