[ONNX] Fix concat with empty tensors (#87620)

Fixes #54410

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87620
Approved by: https://github.com/BowenBao
diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py
index 89526c7..aca26e0 100644
--- a/test/onnx/test_pytorch_onnx_no_runtime.py
+++ b/test/onnx/test_pytorch_onnx_no_runtime.py
@@ -1068,6 +1068,28 @@
         ]
         self.assertEqual(len(all_aten_nodes), 0)
 
+    def test_cat_with_empty_tensor(self):
+        class NoopConcat(torch.nn.Module):
+            def forward(self, x):
+                return torch.cat((torch.Tensor([]), x))
+
+        x = torch.randn(4, 5, 6)
+        # TODO: Parametrize this test for opset_version
+        for opset_version in {9, 11}:
+            f = io.BytesIO()
+            torch.onnx.export(NoopConcat(), (x,), f, opset_version=opset_version)
+            loaded_model = onnx.load_from_string(f.getvalue())
+            self.assertEqual(
+                len(loaded_model.graph.output[0].type.tensor_type.shape.dim), 3
+            )
+            for idx, dim in enumerate(x.shape):
+                self.assertEqual(
+                    loaded_model.graph.output[0]
+                    .type.tensor_type.shape.dim[idx]
+                    .dim_value,
+                    dim,
+                )
+
 
 class TestQuantizeEagerONNXExport(common_utils.TestCase):
     def _test_lower_graph_impl(self, model, data):
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 9984f60..e8fd99e 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -527,6 +527,30 @@
 @_beartype.beartype
 def cat(g: jit_utils.GraphContext, tensor_list, dim):
     tensors = symbolic_helper._unpack_list(tensor_list)
+    # torch.cat ignores empty tensors such as `torch.Tensor([])`
+    # These needs to be removed as input from ONNX's concat too, otherwise shape inference
+    # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else)
+    nonempty_tensors = []
+    for t in tensors:
+        if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size(
+            t, 0
+        ):
+            continue
+        nonempty_tensors.append(t)
+    assert len(nonempty_tensors) > 0
+    assert all(
+        [
+            symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None
+            or symbolic_helper._get_tensor_rank(t)
+            == symbolic_helper._get_tensor_rank(nonempty_tensors[0])
+            for t in nonempty_tensors
+        ]
+    )
+    tensor_list.node().removeAllInputs()
+    for t in nonempty_tensors:
+        tensor_list.node().addInput(t)
+
+    tensors = symbolic_helper._unpack_list(tensor_list)
     return g.op("Concat", *tensors, axis_i=dim)