[ONNX] Fix concat with empty tensors (#87620)
Fixes #54410
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87620
Approved by: https://github.com/BowenBao
diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py
index 89526c7..aca26e0 100644
--- a/test/onnx/test_pytorch_onnx_no_runtime.py
+++ b/test/onnx/test_pytorch_onnx_no_runtime.py
@@ -1068,6 +1068,28 @@
]
self.assertEqual(len(all_aten_nodes), 0)
+ def test_cat_with_empty_tensor(self):
+ class NoopConcat(torch.nn.Module):
+ def forward(self, x):
+ return torch.cat((torch.Tensor([]), x))
+
+ x = torch.randn(4, 5, 6)
+ # TODO: Parametrize this test for opset_version
+ for opset_version in {9, 11}:
+ f = io.BytesIO()
+ torch.onnx.export(NoopConcat(), (x,), f, opset_version=opset_version)
+ loaded_model = onnx.load_from_string(f.getvalue())
+ self.assertEqual(
+ len(loaded_model.graph.output[0].type.tensor_type.shape.dim), 3
+ )
+ for idx, dim in enumerate(x.shape):
+ self.assertEqual(
+ loaded_model.graph.output[0]
+ .type.tensor_type.shape.dim[idx]
+ .dim_value,
+ dim,
+ )
+
class TestQuantizeEagerONNXExport(common_utils.TestCase):
def _test_lower_graph_impl(self, model, data):
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 9984f60..e8fd99e 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -527,6 +527,30 @@
@_beartype.beartype
def cat(g: jit_utils.GraphContext, tensor_list, dim):
tensors = symbolic_helper._unpack_list(tensor_list)
+ # torch.cat ignores empty tensors such as `torch.Tensor([])`
+ # These needs to be removed as input from ONNX's concat too, otherwise shape inference
+ # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else)
+ nonempty_tensors = []
+ for t in tensors:
+ if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size(
+ t, 0
+ ):
+ continue
+ nonempty_tensors.append(t)
+ assert len(nonempty_tensors) > 0
+ assert all(
+ [
+ symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None
+ or symbolic_helper._get_tensor_rank(t)
+ == symbolic_helper._get_tensor_rank(nonempty_tensors[0])
+ for t in nonempty_tensors
+ ]
+ )
+ tensor_list.node().removeAllInputs()
+ for t in nonempty_tensors:
+ tensor_list.node().addInput(t)
+
+ tensors = symbolic_helper._unpack_list(tensor_list)
return g.op("Concat", *tensors, axis_i=dim)