[ONNX] Improve Expand shape inference (#69264)

Extend shape inference support for `Expand`, when value of argument `shape` is unknown. Infer the rank of the output of `Expand`, and set shape to dynamic, if shape of argument `shape` is known.

Without this, shape inference aborts, and falls back to the static shape provided by tracer, which is incorrect in many cases.

Co-authored-by: BowenBao <bowbaomicrosoft.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72985
diff --git a/test/onnx/test_pytorch_onnx_shape_inference.py b/test/onnx/test_pytorch_onnx_shape_inference.py
index 9319670..3808de1 100644
--- a/test/onnx/test_pytorch_onnx_shape_inference.py
+++ b/test/onnx/test_pytorch_onnx_shape_inference.py
@@ -114,5 +114,14 @@
         slice = g.op("Slice", input, start_input, end, axis, step)
         self.run_test(g, slice.node(), expect_tensor(None, shape=(None, None)))
 
+    def test_expand(self):
+        g = self.create_empty_graph()
+        input = g.addInput()
+        constant = self.insert_tensor_constant(g, torch.ones(2, 4))
+        input.setType(constant.type().with_sizes([None, None]))
+        shape = g.op("Shape", input)
+        expand = g.op("Expand", constant, shape)
+        self.run_test(g, expand.node(), expect_tensor("Float", shape=(None, None)))
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp
index 7219d2d..167e401 100644
--- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp
+++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp
@@ -1374,6 +1374,8 @@
         if (input0_shape_size.has_value()) {
           auto input0_shape_value = input0_shape_size.value();
           if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
+            // When value of `shape` is statically known,
+            // output shape can be computed.
             auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
                 n->input(1)->debugName());
             auto final_shape =
@@ -1381,6 +1383,23 @@
             if (final_shape.has_value()) {
               UpdateShape(n->output(), final_shape.value());
             }
+          } else if (
+              auto expand_shape =
+                  ConstantValueMap::GetShapeInto1DInt64VectorWithOneUnknown(
+                      n->input(1)->debugName())) {
+            // When shape of `shape` is statically known,
+            // output rank can be computed.
+            TORCH_INTERNAL_ASSERT(
+                expand_shape.value().size() == 1,
+                "`Shape` input to `Expand` should be a 1-D tensor. Instead got rank ",
+                expand_shape.value().size());
+            if (expand_shape.value()[0] > 0) {
+              std::vector<c10::ShapeSymbol> final_shape;
+              for (const auto i : c10::irange(expand_shape.value()[0])) {
+                final_shape.emplace_back(c10::ShapeSymbol::newSymbol());
+              }
+              UpdateShape(n->output(), c10::SymbolicShape(final_shape));
+            }
           }
         }
       }