[ONNX] Add dynamic axes support to torchscript exporter with dynamo=True (#128371)

This PR enables specific axe to be dynamic with calling torch.export.export and torch.export.Dim.

Features:
(1) Turn dynamic_axes to dynamic_shapes
(2) Dim constraints remain the same (see test case with hitting constraints). This might give different user experience, since we didn't have any constraints in torchscript-onnx exporting.
(3) If input_names is used in dynamic_axes, ValueError will be raised, as input_names is currently not supported.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128371
Approved by: https://github.com/justinchuby
diff --git a/test/onnx/dynamo/test_exporter_api.py b/test/onnx/dynamo/test_exporter_api.py
index 30bfd27..1834016 100644
--- a/test/onnx/dynamo/test_exporter_api.py
+++ b/test/onnx/dynamo/test_exporter_api.py
@@ -33,6 +33,11 @@
         return (y, z)
 
 
+class SampleModelForDynamicShapes(torch.nn.Module):
+    def forward(self, x, b):
+        return x.relu(), b.sigmoid()
+
+
 class _LargeModel(torch.nn.Module):
     def __init__(self):
         super().__init__()
@@ -230,8 +235,15 @@
 
 class TestONNXExportWithDynamo(common_utils.TestCase):
     def test_args_normalization_with_no_kwargs(self):
+        exported_program = torch.export.export(
+            SampleModelTwoInputs(),
+            (
+                torch.randn(1, 1, 2),
+                torch.randn(1, 1, 2),
+            ),
+        )
         onnx_program_from_new_exporter = torch.onnx.dynamo_export(
-            SampleModelTwoInputs(), torch.randn(1, 1, 2), torch.randn(1, 1, 2)
+            exported_program, torch.randn(1, 1, 2), torch.randn(1, 1, 2)
         )
         onnx_program_from_old_exporter = torch.onnx.export(
             SampleModelTwoInputs(),
@@ -243,9 +255,25 @@
             onnx_program_from_old_exporter.model_proto,
         )
 
-    def test_args_normalization_with_kwargs(self):
+    def test_args_is_tensor_not_tuple(self):
+        exported_program = torch.export.export(SampleModel(), (torch.randn(1, 1, 2),))
         onnx_program_from_new_exporter = torch.onnx.dynamo_export(
-            SampleModelTwoInputs(), torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
+            exported_program, torch.randn(1, 1, 2)
+        )
+        onnx_program_from_old_exporter = torch.onnx.export(
+            SampleModel(), torch.randn(1, 1, 2), dynamo=True
+        )
+        self.assertEqual(
+            onnx_program_from_new_exporter.model_proto,
+            onnx_program_from_old_exporter.model_proto,
+        )
+
+    def test_args_normalization_with_kwargs(self):
+        exported_program = torch.export.export(
+            SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)}
+        )
+        onnx_program_from_new_exporter = torch.onnx.dynamo_export(
+            exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
         )
         onnx_program_from_old_exporter = torch.onnx.export(
             SampleModelTwoInputs(),
@@ -258,8 +286,11 @@
         )
 
     def test_args_normalization_with_empty_dict_at_the_tail(self):
+        exported_program = torch.export.export(
+            SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)}
+        )
         onnx_program_from_new_exporter = torch.onnx.dynamo_export(
-            SampleModelTwoInputs(), torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
+            exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
         )
         onnx_program_from_old_exporter = torch.onnx.export(
             SampleModelTwoInputs(),
@@ -271,17 +302,111 @@
             onnx_program_from_old_exporter.model_proto,
         )
 
-    def test_dynamic_axes_enable_dynamic_shape(self):
+    def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
+        exported_program = torch.export.export(
+            SampleModelForDynamicShapes(),
+            (
+                torch.randn(2, 2, 3),
+                torch.randn(2, 2, 3),
+            ),
+            dynamic_shapes={
+                "x": {
+                    0: torch.export.Dim("customx_dim_0"),
+                    1: torch.export.Dim("customx_dim_1"),
+                    2: torch.export.Dim("customx_dim_2"),
+                },
+                "b": {
+                    0: torch.export.Dim("customb_dim_0"),
+                    1: torch.export.Dim("customb_dim_1"),
+                    2: torch.export.Dim("customb_dim_2"),
+                },
+            },
+        )
         onnx_program_from_new_exporter = torch.onnx.dynamo_export(
-            SampleModelTwoInputs(),
-            torch.randn(1, 1, 2),
-            b=torch.randn(1, 1, 2),
-            export_options=ExportOptions(dynamic_shapes=True),
+            exported_program,
+            torch.randn(2, 2, 3),
+            b=torch.randn(2, 2, 3),
         )
         onnx_program_from_old_exporter = torch.onnx.export(
-            SampleModelTwoInputs(),
-            (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}, {}),
-            dynamic_axes={"b": [0, 1, 2]},
+            SampleModelForDynamicShapes(),
+            (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}, {}),
+            dynamic_axes={
+                "x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"},
+                "b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"},
+            },
+            dynamo=True,
+        )
+        self.assertEqual(
+            onnx_program_from_new_exporter.model_proto,
+            onnx_program_from_old_exporter.model_proto,
+        )
+
+    def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self):
+        exported_program = torch.export.export(
+            SampleModelForDynamicShapes(),
+            (
+                torch.randn(2, 2, 3),
+                torch.randn(2, 2, 3),
+            ),
+            dynamic_shapes={
+                "x": {
+                    0: torch.export.Dim("customx_dim_0"),
+                    1: torch.export.Dim("customx_dim_1"),
+                    2: torch.export.Dim("customx_dim_2"),
+                },
+                "b": {
+                    0: torch.export.Dim("customb_dim_0"),
+                    1: torch.export.Dim("customb_dim_1"),
+                    2: torch.export.Dim("customb_dim_2"),
+                },
+            },
+        )
+        onnx_program_from_new_exporter = torch.onnx.dynamo_export(
+            exported_program,
+            torch.randn(2, 2, 3),
+            b=torch.randn(2, 2, 3),
+        )
+        onnx_program_from_old_exporter = torch.onnx.export(
+            SampleModelForDynamicShapes(),
+            (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}, {}),
+            dynamic_axes={
+                "x": [0, 1, 2],
+                "b": [0, 1, 2],
+            },
+            dynamo=True,
+        )
+        self.assertEqual(
+            onnx_program_from_new_exporter.model_proto,
+            onnx_program_from_old_exporter.model_proto,
+        )
+
+    def test_dynamic_axes_supports_partial_dynamic_shapes(self):
+        exported_program = torch.export.export(
+            SampleModelForDynamicShapes(),
+            (
+                torch.randn(2, 2, 3),
+                torch.randn(2, 2, 3),
+            ),
+            dynamic_shapes={
+                "x": None,
+                "b": {
+                    0: torch.export.Dim("customb_dim_0"),
+                    1: torch.export.Dim("customb_dim_1"),
+                    2: torch.export.Dim("customb_dim_2"),
+                },
+            },
+        )
+        onnx_program_from_new_exporter = torch.onnx.dynamo_export(
+            exported_program,
+            torch.randn(2, 2, 3),
+            b=torch.randn(2, 2, 3),
+        )
+        onnx_program_from_old_exporter = torch.onnx.export(
+            SampleModelForDynamicShapes(),
+            (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}, {}),
+            dynamic_axes={
+                "b": [0, 1, 2],
+            },
             dynamo=True,
         )
         self.assertEqual(
@@ -303,16 +428,37 @@
                 dynamo=True,
             )
 
-    def test_raises_unsupported_specific_dynamic_axes_warning(self):
-        message = (
-            "Specified dynamic axes is not supported for dynamo export at the moment."
-        )
-
-        with self.assertWarnsOnceRegex(UserWarning, message):
+    def test_input_names_are_not_yet_supported_in_dynamic_axes(self):
+        with self.assertRaisesRegex(
+            ValueError,
+            "Assinging new input names is not supported yet. Please use model forward signature "
+            "to specify input names in dynamix_axes.",
+        ):
             _ = torch.onnx.export(
-                SampleModel(),
-                (torch.randn(1, 1, 2),),
-                dynamic_axes={"input": [0, 1, 2]},
+                SampleModelForDynamicShapes(),
+                (
+                    torch.randn(2, 2, 3),
+                    torch.randn(2, 2, 3),
+                ),
+                input_names=["input"],
+                dynamic_axes={"input": [0, 1]},
+                dynamo=True,
+            )
+
+    def test_dynamic_shapes_hit_constraints_in_dynamo(self):
+        # SampleModelTwoInputs has constraints becuse of add of two inputs,
+        # so the two input shapes are related.
+        with self.assertRaisesRegex(
+            torch._dynamo.exc.UserError,
+            "Constraints violated",
+        ):
+            _ = torch.onnx.export(
+                SampleModelTwoInputs(),
+                (torch.randn(2, 2, 3), torch.randn(2, 2, 3)),
+                dynamic_axes={
+                    "x": {0: "x_dim_0", 1: "x_dim_1", 2: "x_dim_2"},
+                    "b": {0: "b_dim_0", 1: "b_dim_1", 2: "b_dim_2"},
+                },
                 dynamo=True,
             )
 
@@ -323,6 +469,17 @@
             )
             self.assertTrue(os.path.exists(path))
 
+    def test_raises_error_when_input_is_script_module(self):
+        class ScriptModule(torch.jit.ScriptModule):
+            def forward(self, x):
+                return x
+
+        with self.assertRaisesRegex(
+            TypeError,
+            "Dynamo export does not support ScriptModule or ScriptFunction.",
+        ):
+            _ = torch.onnx.export(ScriptModule(), torch.randn(1, 1, 2), dynamo=True)
+
 
 if __name__ == "__main__":
     common_utils.run_tests()
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index 94a5778..0d02fab 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -512,6 +512,10 @@
     """
 
     if dynamo:
+        if isinstance(model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)):
+            raise TypeError(
+                "Dynamo export does not support ScriptModule or ScriptFunction."
+            )
         # Unsupported parameters for dynamo export
         # TODO: These are not supported AT THE TIME
         warnings.warn(
@@ -519,7 +523,6 @@
             "do_constant_folding, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, and "
             "autograd_inlining are not supported for dynamo export at the moment."
         )
-        # TODO: check args normalization
         args = _decide_input_format(model, args)
         kwargs = {}
         if args is not None and isinstance(args[-1], dict):
@@ -527,18 +530,14 @@
             args = args[:-1]
         # TODO: refactor this when we have migrated ExportedProgam and
         # needs users to specify dynamic_axes
-        if dynamic_axes is None or not isinstance(dynamic_axes, dict):
-            dynamic_shapes = False
-        else:
-            dynamic_shapes = True
-            warnings.warn(
-                "Specified dynamic axes is not supported for dynamo export at the moment."
-            )
-        # TODO: expose more ExportOptions?
-        export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_shapes)
-        onnx_program = torch.onnx.dynamo_export(
-            model, *args, **kwargs, export_options=export_options
+        dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes(
+            model, dynamic_axes, input_names
         )
+        exported_program = torch.export.export(
+            model, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes  # type: ignore[arg-type]
+        )
+        # TODO: expose ExportOptions?
+        onnx_program = torch.onnx.dynamo_export(exported_program, *args, **kwargs)
         if f is not None:
             onnx_program.save(f)
         return onnx_program
@@ -916,6 +915,65 @@
 
 
 @_beartype.beartype
+def _from_dynamic_axes_to_dynamic_shapes(
+    model,
+    dynamic_axes: Optional[
+        Union[Mapping[str, Mapping[int, str]], Mapping[str, Sequence[int]]]
+    ] = None,
+    input_names: Optional[Sequence[str]] = None,
+) -> Optional[Dict[str, Any]]:
+    """
+
+    dynamic_axes examples:
+    (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
+    (2) dynamic_axes = {"x": [0], "y": [1]}
+
+    these will be converted to dynamic_shapes respectively:
+    (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}}
+    (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}}  # auto-generated dim names
+
+    """
+    if dynamic_axes is None:
+        return None
+
+    if input_names is None:
+        input_names_set = set()
+    else:
+        input_names_set = set(input_names)
+
+    dynamic_shapes: Dict[str, Optional[Any]] = {}
+    for input_name, axes in dynamic_axes.items():
+        if input_name in input_names_set:
+            raise ValueError(
+                "Assinging new input names is not supported yet. Please use model forward signature "
+                "to specify input names in dynamix_axes."
+            )
+        if isinstance(axes, dict):
+            dynamic_shapes[input_name] = {
+                k: torch.export.Dim(v) for k, v in axes.items()
+            }
+        elif isinstance(axes, list):
+            dynamic_shapes[input_name] = {
+                k: torch.export.Dim(f"{input_name}_dim_{k}") for k in axes
+            }
+        else:
+            raise TypeError(
+                f"dynamic_axes value must be either a dict or a list, but got {type(axes)}"
+            )
+    # torch.export.export needs static dim to present in dynamic_shapes
+    # for all input tensors, so we need to add them with None
+    try:
+        sig = _signature(model)
+    except ValueError as e:
+        warnings.warn(f"{e}, skipping auto filling None on static axes...")
+        return dynamic_shapes
+    for input_name in sig.parameters.keys():
+        if input_name not in dynamic_shapes:
+            dynamic_shapes[input_name] = None
+    return dynamic_shapes
+
+
+@_beartype.beartype
 def _trace(func, args, operator_export_type, return_outs=False):
     # Special case for common case of passing a single Tensor
     if isinstance(args, torch.Tensor):