[ONNX] Sum empty tensor could not be exported to ONNX successfully. (#58141) (#59537)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59537
PyTorch sum over empty tensor gives 0, while ONNX produces an error.
torch.sum will be translated into onnx::ReduceSum op. Per the definition of ReduceSum, update the keepdims attribute for this scenario.
Test Plan: Imported from OSS
Reviewed By: nikithamalgifb, ansley
Differential Revision: D29046604
Pulled By: SplitInfinity
fbshipit-source-id: 6f5f3a66cb8eda8b5114b8474dda6fcdbae73469
Co-authored-by: fatcat-z <jiz@microsoft.com>
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index bf7ee66..28d36e2 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -9036,6 +9036,17 @@
self.run_test(M(2, 1), (x,))
self.run_test(M([-1, 3], [-2, -1]), (x,))
+ def test_sum_empty_tensor(self):
+ class M(torch.nn.Module):
+ def forward(self, x):
+ return x[0:0].sum()
+
+ x = torch.ones(12)
+ self.run_test(M(), (x,))
+
+ x = torch.ones(2, 0, 3)
+ self.run_test(M(), (x,))
+
def make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout, script_test_min_opset_version,
**extra_kwargs):
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 2af83c7..0bab2b2 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -738,6 +738,13 @@
n.setType(OptionalType.ofTensor())
return n
+def _handle_reduce_dim_none(g, self, op_name):
+ dim_size = _get_tensor_dim_size(self, 0)
+ if dim_size is None or dim_size == 0:
+ # If input tensor is empty, according to ONNX ReduceSum definition,
+ # set keepdims=1 so that the resulted tensor has the same rank as the input.
+ return g.op(op_name, self, keepdims_i=1)
+ return g.op(op_name, self, keepdims_i=0)
# ---------------------------------------------------------------------
# ONNX operator version
diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py
index a5e47da..7f20833 100644
--- a/torch/onnx/symbolic_opset13.py
+++ b/torch/onnx/symbolic_opset13.py
@@ -147,9 +147,9 @@
self = _maybe_cast_reduce_op_input(g, self)
if dim is None:
# all-reduce path
- return g.op(onnx_op_name, self, keepdims_i=0)
+ return sym_help._handle_reduce_dim_none(g, self, onnx_op_name)
else:
- keepdim = sym_help._get_const(keepdim, "i", "keepdim")
+ keepdim = sym_help._get_const(keepdim, 'i', 'keepdim')
return g.op(onnx_op_name, self, dim, keepdims_i=keepdim)
return symbolic
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index f238427..d6177d1 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -368,7 +368,7 @@
self = _maybe_cast_reduce_op_input(g, self)
if dim is None:
# all-reduce path
- return g.op(onnx_op_name, self, keepdims_i=0)
+ return sym_help._handle_reduce_dim_none(g, self, onnx_op_name)
else:
# dim-reduce path
desc = "is" if allow_multi_dim_support else "i"