[ONNX] Sum empty tensor could not be exported to ONNX successfully. (#58141) (#59537)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59537

PyTorch sum over empty tensor gives 0, while ONNX produces an error.

torch.sum will be translated into onnx::ReduceSum op. Per the definition of ReduceSum, update the keepdims attribute for this scenario.

Test Plan: Imported from OSS

Reviewed By: nikithamalgifb, ansley

Differential Revision: D29046604

Pulled By: SplitInfinity

fbshipit-source-id: 6f5f3a66cb8eda8b5114b8474dda6fcdbae73469

Co-authored-by: fatcat-z <jiz@microsoft.com>
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index bf7ee66..28d36e2 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -9036,6 +9036,17 @@
         self.run_test(M(2, 1), (x,))
         self.run_test(M([-1, 3], [-2, -1]), (x,))
 
+    def test_sum_empty_tensor(self):
+        class M(torch.nn.Module):
+            def forward(self, x):
+                return x[0:0].sum()
+
+        x = torch.ones(12)
+        self.run_test(M(), (x,))
+
+        x = torch.ones(2, 0, 3)
+        self.run_test(M(), (x,))
+
 def make_test(name, base, layer, bidirectional, initial_state,
               variable_length, dropout, script_test_min_opset_version,
               **extra_kwargs):
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 2af83c7..0bab2b2 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -738,6 +738,13 @@
     n.setType(OptionalType.ofTensor())
     return n
 
+def _handle_reduce_dim_none(g, self, op_name):
+    dim_size = _get_tensor_dim_size(self, 0)
+    if dim_size is None or dim_size == 0:
+        # If input tensor is empty, according to ONNX ReduceSum definition,
+        # set keepdims=1 so that the resulted tensor has the same rank as the input.
+        return g.op(op_name, self, keepdims_i=1)
+    return g.op(op_name, self, keepdims_i=0)
 
 # ---------------------------------------------------------------------
 # ONNX operator version
diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py
index a5e47da..7f20833 100644
--- a/torch/onnx/symbolic_opset13.py
+++ b/torch/onnx/symbolic_opset13.py
@@ -147,9 +147,9 @@
         self = _maybe_cast_reduce_op_input(g, self)
         if dim is None:
             # all-reduce path
-            return g.op(onnx_op_name, self, keepdims_i=0)
+            return sym_help._handle_reduce_dim_none(g, self, onnx_op_name)
         else:
-            keepdim = sym_help._get_const(keepdim, "i", "keepdim")
+            keepdim = sym_help._get_const(keepdim, 'i', 'keepdim')
             return g.op(onnx_op_name, self, dim, keepdims_i=keepdim)
     return symbolic
 
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index f238427..d6177d1 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -368,7 +368,7 @@
         self = _maybe_cast_reduce_op_input(g, self)
         if dim is None:
             # all-reduce path
-            return g.op(onnx_op_name, self, keepdims_i=0)
+            return sym_help._handle_reduce_dim_none(g, self, onnx_op_name)
         else:
             # dim-reduce path
             desc = "is" if allow_multi_dim_support else "i"