[ONNX] Add dynamic axes support to torchscript exporter with dynamo=True (#128371)
This PR enables specific axe to be dynamic with calling torch.export.export and torch.export.Dim.
Features:
(1) Turn dynamic_axes to dynamic_shapes
(2) Dim constraints remain the same (see test case with hitting constraints). This might give different user experience, since we didn't have any constraints in torchscript-onnx exporting.
(3) If input_names is used in dynamic_axes, ValueError will be raised, as input_names is currently not supported.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128371
Approved by: https://github.com/justinchuby
diff --git a/test/onnx/dynamo/test_exporter_api.py b/test/onnx/dynamo/test_exporter_api.py
index 30bfd27..1834016 100644
--- a/test/onnx/dynamo/test_exporter_api.py
+++ b/test/onnx/dynamo/test_exporter_api.py
@@ -33,6 +33,11 @@
return (y, z)
+class SampleModelForDynamicShapes(torch.nn.Module):
+ def forward(self, x, b):
+ return x.relu(), b.sigmoid()
+
+
class _LargeModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -230,8 +235,15 @@
class TestONNXExportWithDynamo(common_utils.TestCase):
def test_args_normalization_with_no_kwargs(self):
+ exported_program = torch.export.export(
+ SampleModelTwoInputs(),
+ (
+ torch.randn(1, 1, 2),
+ torch.randn(1, 1, 2),
+ ),
+ )
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
- SampleModelTwoInputs(), torch.randn(1, 1, 2), torch.randn(1, 1, 2)
+ exported_program, torch.randn(1, 1, 2), torch.randn(1, 1, 2)
)
onnx_program_from_old_exporter = torch.onnx.export(
SampleModelTwoInputs(),
@@ -243,9 +255,25 @@
onnx_program_from_old_exporter.model_proto,
)
- def test_args_normalization_with_kwargs(self):
+ def test_args_is_tensor_not_tuple(self):
+ exported_program = torch.export.export(SampleModel(), (torch.randn(1, 1, 2),))
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
- SampleModelTwoInputs(), torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
+ exported_program, torch.randn(1, 1, 2)
+ )
+ onnx_program_from_old_exporter = torch.onnx.export(
+ SampleModel(), torch.randn(1, 1, 2), dynamo=True
+ )
+ self.assertEqual(
+ onnx_program_from_new_exporter.model_proto,
+ onnx_program_from_old_exporter.model_proto,
+ )
+
+ def test_args_normalization_with_kwargs(self):
+ exported_program = torch.export.export(
+ SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)}
+ )
+ onnx_program_from_new_exporter = torch.onnx.dynamo_export(
+ exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
)
onnx_program_from_old_exporter = torch.onnx.export(
SampleModelTwoInputs(),
@@ -258,8 +286,11 @@
)
def test_args_normalization_with_empty_dict_at_the_tail(self):
+ exported_program = torch.export.export(
+ SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)}
+ )
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
- SampleModelTwoInputs(), torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
+ exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
)
onnx_program_from_old_exporter = torch.onnx.export(
SampleModelTwoInputs(),
@@ -271,17 +302,111 @@
onnx_program_from_old_exporter.model_proto,
)
- def test_dynamic_axes_enable_dynamic_shape(self):
+ def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
+ exported_program = torch.export.export(
+ SampleModelForDynamicShapes(),
+ (
+ torch.randn(2, 2, 3),
+ torch.randn(2, 2, 3),
+ ),
+ dynamic_shapes={
+ "x": {
+ 0: torch.export.Dim("customx_dim_0"),
+ 1: torch.export.Dim("customx_dim_1"),
+ 2: torch.export.Dim("customx_dim_2"),
+ },
+ "b": {
+ 0: torch.export.Dim("customb_dim_0"),
+ 1: torch.export.Dim("customb_dim_1"),
+ 2: torch.export.Dim("customb_dim_2"),
+ },
+ },
+ )
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
- SampleModelTwoInputs(),
- torch.randn(1, 1, 2),
- b=torch.randn(1, 1, 2),
- export_options=ExportOptions(dynamic_shapes=True),
+ exported_program,
+ torch.randn(2, 2, 3),
+ b=torch.randn(2, 2, 3),
)
onnx_program_from_old_exporter = torch.onnx.export(
- SampleModelTwoInputs(),
- (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}, {}),
- dynamic_axes={"b": [0, 1, 2]},
+ SampleModelForDynamicShapes(),
+ (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}, {}),
+ dynamic_axes={
+ "x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"},
+ "b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"},
+ },
+ dynamo=True,
+ )
+ self.assertEqual(
+ onnx_program_from_new_exporter.model_proto,
+ onnx_program_from_old_exporter.model_proto,
+ )
+
+ def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self):
+ exported_program = torch.export.export(
+ SampleModelForDynamicShapes(),
+ (
+ torch.randn(2, 2, 3),
+ torch.randn(2, 2, 3),
+ ),
+ dynamic_shapes={
+ "x": {
+ 0: torch.export.Dim("customx_dim_0"),
+ 1: torch.export.Dim("customx_dim_1"),
+ 2: torch.export.Dim("customx_dim_2"),
+ },
+ "b": {
+ 0: torch.export.Dim("customb_dim_0"),
+ 1: torch.export.Dim("customb_dim_1"),
+ 2: torch.export.Dim("customb_dim_2"),
+ },
+ },
+ )
+ onnx_program_from_new_exporter = torch.onnx.dynamo_export(
+ exported_program,
+ torch.randn(2, 2, 3),
+ b=torch.randn(2, 2, 3),
+ )
+ onnx_program_from_old_exporter = torch.onnx.export(
+ SampleModelForDynamicShapes(),
+ (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}, {}),
+ dynamic_axes={
+ "x": [0, 1, 2],
+ "b": [0, 1, 2],
+ },
+ dynamo=True,
+ )
+ self.assertEqual(
+ onnx_program_from_new_exporter.model_proto,
+ onnx_program_from_old_exporter.model_proto,
+ )
+
+ def test_dynamic_axes_supports_partial_dynamic_shapes(self):
+ exported_program = torch.export.export(
+ SampleModelForDynamicShapes(),
+ (
+ torch.randn(2, 2, 3),
+ torch.randn(2, 2, 3),
+ ),
+ dynamic_shapes={
+ "x": None,
+ "b": {
+ 0: torch.export.Dim("customb_dim_0"),
+ 1: torch.export.Dim("customb_dim_1"),
+ 2: torch.export.Dim("customb_dim_2"),
+ },
+ },
+ )
+ onnx_program_from_new_exporter = torch.onnx.dynamo_export(
+ exported_program,
+ torch.randn(2, 2, 3),
+ b=torch.randn(2, 2, 3),
+ )
+ onnx_program_from_old_exporter = torch.onnx.export(
+ SampleModelForDynamicShapes(),
+ (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}, {}),
+ dynamic_axes={
+ "b": [0, 1, 2],
+ },
dynamo=True,
)
self.assertEqual(
@@ -303,16 +428,37 @@
dynamo=True,
)
- def test_raises_unsupported_specific_dynamic_axes_warning(self):
- message = (
- "Specified dynamic axes is not supported for dynamo export at the moment."
- )
-
- with self.assertWarnsOnceRegex(UserWarning, message):
+ def test_input_names_are_not_yet_supported_in_dynamic_axes(self):
+ with self.assertRaisesRegex(
+ ValueError,
+ "Assinging new input names is not supported yet. Please use model forward signature "
+ "to specify input names in dynamix_axes.",
+ ):
_ = torch.onnx.export(
- SampleModel(),
- (torch.randn(1, 1, 2),),
- dynamic_axes={"input": [0, 1, 2]},
+ SampleModelForDynamicShapes(),
+ (
+ torch.randn(2, 2, 3),
+ torch.randn(2, 2, 3),
+ ),
+ input_names=["input"],
+ dynamic_axes={"input": [0, 1]},
+ dynamo=True,
+ )
+
+ def test_dynamic_shapes_hit_constraints_in_dynamo(self):
+ # SampleModelTwoInputs has constraints becuse of add of two inputs,
+ # so the two input shapes are related.
+ with self.assertRaisesRegex(
+ torch._dynamo.exc.UserError,
+ "Constraints violated",
+ ):
+ _ = torch.onnx.export(
+ SampleModelTwoInputs(),
+ (torch.randn(2, 2, 3), torch.randn(2, 2, 3)),
+ dynamic_axes={
+ "x": {0: "x_dim_0", 1: "x_dim_1", 2: "x_dim_2"},
+ "b": {0: "b_dim_0", 1: "b_dim_1", 2: "b_dim_2"},
+ },
dynamo=True,
)
@@ -323,6 +469,17 @@
)
self.assertTrue(os.path.exists(path))
+ def test_raises_error_when_input_is_script_module(self):
+ class ScriptModule(torch.jit.ScriptModule):
+ def forward(self, x):
+ return x
+
+ with self.assertRaisesRegex(
+ TypeError,
+ "Dynamo export does not support ScriptModule or ScriptFunction.",
+ ):
+ _ = torch.onnx.export(ScriptModule(), torch.randn(1, 1, 2), dynamo=True)
+
if __name__ == "__main__":
common_utils.run_tests()
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index 94a5778..0d02fab 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -512,6 +512,10 @@
"""
if dynamo:
+ if isinstance(model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)):
+ raise TypeError(
+ "Dynamo export does not support ScriptModule or ScriptFunction."
+ )
# Unsupported parameters for dynamo export
# TODO: These are not supported AT THE TIME
warnings.warn(
@@ -519,7 +523,6 @@
"do_constant_folding, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, and "
"autograd_inlining are not supported for dynamo export at the moment."
)
- # TODO: check args normalization
args = _decide_input_format(model, args)
kwargs = {}
if args is not None and isinstance(args[-1], dict):
@@ -527,18 +530,14 @@
args = args[:-1]
# TODO: refactor this when we have migrated ExportedProgam and
# needs users to specify dynamic_axes
- if dynamic_axes is None or not isinstance(dynamic_axes, dict):
- dynamic_shapes = False
- else:
- dynamic_shapes = True
- warnings.warn(
- "Specified dynamic axes is not supported for dynamo export at the moment."
- )
- # TODO: expose more ExportOptions?
- export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_shapes)
- onnx_program = torch.onnx.dynamo_export(
- model, *args, **kwargs, export_options=export_options
+ dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes(
+ model, dynamic_axes, input_names
)
+ exported_program = torch.export.export(
+ model, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes # type: ignore[arg-type]
+ )
+ # TODO: expose ExportOptions?
+ onnx_program = torch.onnx.dynamo_export(exported_program, *args, **kwargs)
if f is not None:
onnx_program.save(f)
return onnx_program
@@ -916,6 +915,65 @@
@_beartype.beartype
+def _from_dynamic_axes_to_dynamic_shapes(
+ model,
+ dynamic_axes: Optional[
+ Union[Mapping[str, Mapping[int, str]], Mapping[str, Sequence[int]]]
+ ] = None,
+ input_names: Optional[Sequence[str]] = None,
+) -> Optional[Dict[str, Any]]:
+ """
+
+ dynamic_axes examples:
+ (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
+ (2) dynamic_axes = {"x": [0], "y": [1]}
+
+ these will be converted to dynamic_shapes respectively:
+ (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}}
+ (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names
+
+ """
+ if dynamic_axes is None:
+ return None
+
+ if input_names is None:
+ input_names_set = set()
+ else:
+ input_names_set = set(input_names)
+
+ dynamic_shapes: Dict[str, Optional[Any]] = {}
+ for input_name, axes in dynamic_axes.items():
+ if input_name in input_names_set:
+ raise ValueError(
+ "Assinging new input names is not supported yet. Please use model forward signature "
+ "to specify input names in dynamix_axes."
+ )
+ if isinstance(axes, dict):
+ dynamic_shapes[input_name] = {
+ k: torch.export.Dim(v) for k, v in axes.items()
+ }
+ elif isinstance(axes, list):
+ dynamic_shapes[input_name] = {
+ k: torch.export.Dim(f"{input_name}_dim_{k}") for k in axes
+ }
+ else:
+ raise TypeError(
+ f"dynamic_axes value must be either a dict or a list, but got {type(axes)}"
+ )
+ # torch.export.export needs static dim to present in dynamic_shapes
+ # for all input tensors, so we need to add them with None
+ try:
+ sig = _signature(model)
+ except ValueError as e:
+ warnings.warn(f"{e}, skipping auto filling None on static axes...")
+ return dynamic_shapes
+ for input_name in sig.parameters.keys():
+ if input_name not in dynamic_shapes:
+ dynamic_shapes[input_name] = None
+ return dynamic_shapes
+
+
+@_beartype.beartype
def _trace(func, args, operator_export_type, return_outs=False):
# Special case for common case of passing a single Tensor
if isinstance(args, torch.Tensor):