[export] add kwargs support for export. (#105337)
Solving #105242.
During export, the exported function's signature changes multiple times. Suppose we'd like to export f as shown in following example:
```python
def f(arg1, arg2, kw1, kw2):
pass
args = (arg1, arg2)
kwargs = {"kw2":arg3, "kw1":arg4}
torch.export(f, args, kwargs)
```
The signature changes mutiple times during export process in the following order:
1. **gm_torch_level = dynamo.export(f, *args, \*\*kwargs)**. In this step, we turn all kinds of parameters such as **postional_only**, **var_positioinal**, **kw_only**, and **var_kwargs** into **positional_or_kw**.It also preserves the positional and kword argument names in original function (i.e. f in this example) [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/export.py#L546C13-L546C27). The order of kwargs will be the **key order** of kwargs (after python 3.6, the order is the insertion of order of keys) instead of the original function signature and the order is baked into a _orig_args varaible of gm_torch_level's pytree info. So we'll have:
```python
def gm_torch_level(arg1, arg2, kw2, kw1)
```
Such difference is acceptable as it's transparent to users of export.
2. **gm_aot_export = aot_export_module(gm_torch_level, pos_or_kw_args)**. In this step, we need to turn kwargs into positional args in the order of how gm_torch_level expected, which is stored in _orig_args. The returned gm_aot_export has the graph signature of flat_args, in_spec = pytree.tree_flatten(pos_or_kw_args):
``` python
flat_args, _ = pytree.tree_flatten(pos_or_kw_args)
def gm_aot_export(*flat_args)
```
3. **exported_program(*args, \*\*kwargs)**. The epxorted artifact is exported_program, which is a wrapper over gm_aot_export and has the same calling convention as the original function "f". To do this, we need to 1. specialize the order of kwargs into pos_or_kw_args and 2. flatten the pos_or_kw_args into what gm_aot_export expected. We can combine the two steps into one with :
```python
_, in_spec = pytree.tree_flatten((args, kwargs))
# Then during exported_program.__call__(*args, **kwargs)
flat_args = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
```
, where kwargs is treated as a normal pytree whose keyorder is preserved in in_spec.
Implementation-wise, we treat _orig_args in dynamo exported graph module as single source of truth and kwags are ordered following it.
Test plan:
See added tests in test_export.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105337
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
diff --git a/docs/source/scripts/exportdb/generate_example_rst.py b/docs/source/scripts/exportdb/generate_example_rst.py
index 38f71b9..96b6aae 100644
--- a/docs/source/scripts/exportdb/generate_example_rst.py
+++ b/docs/source/scripts/exportdb/generate_example_rst.py
@@ -72,6 +72,7 @@
exported_program = export(
model,
inputs.args,
+ inputs.kwargs,
constraints=example_case.constraints,
)
graph_output = str(exported_program)
diff --git a/test/export/test_db.py b/test/export/test_db.py
index bfa57ba..9b44912 100644
--- a/test/export/test_db.py
+++ b/test/export/test_db.py
@@ -32,6 +32,7 @@
exported_program = export(
model,
inputs.args,
+ inputs.kwargs,
constraints=case.constraints,
)
exported_program.graph_module.print_readable()
@@ -61,6 +62,7 @@
exported_model = export(
model,
inputs.args,
+ inputs.kwargs,
constraints=case.constraints,
)
@@ -83,6 +85,7 @@
exported_model = export(
rewrite_case.model,
inputs.args,
+ inputs.kwargs,
constraints=rewrite_case.constraints,
)
diff --git a/test/export/test_export.py b/test/export/test_export.py
index e368499..62db418 100644
--- a/test/export/test_export.py
+++ b/test/export/test_export.py
@@ -148,13 +148,20 @@
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestExport(TestCase):
+
+ def _test_export_same_as_eager(self, f, args, kwargs=None):
+ kwargs = kwargs or {}
+ exported_program = export(f, args, kwargs)
+ reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)}
+ self.assertEqual(exported_program(*args, **kwargs), f(*args, **kwargs))
+ self.assertEqual(exported_program(*args, **reversed_kwargs), f(*args, **reversed_kwargs))
+
def test_basic(self):
def f(x, y):
return x[0] + y
inp = ([torch.ones(1, 3)], torch.ones(1, 3))
- exported_program = export(f, inp)
- self.assertTrue(torch.allclose(exported_program(*inp), f(*inp)))
+ self._test_export_same_as_eager(f, inp)
def test_raise_user_error_when_guard_on_data_dependent_operation(self):
def fn_ddo(x):
@@ -235,7 +242,7 @@
torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " +
"on a value which we evaluated to have a static value of 3. "
):
- export(f, example_inputs, constraints)
+ export(f, example_inputs, {}, constraints)
def test_not_correct_dim(self):
def f(x):
@@ -267,8 +274,71 @@
return map(body, xs, y, z)
inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
- exported_program = export(list_tensor_map, inps)
- self.assertTrue(torch.allclose(exported_program(*inps), list_tensor_map(*inps)))
+ self._test_export_same_as_eager(list_tensor_map, inps)
+
+ def test_export_func_with_kwargs(self):
+ def kw_func(arg1, arg2, kw1, kw2):
+ return arg1 + arg2, kw1 + kw2
+
+ args = (torch.ones(6, 4), torch.ones(1, 1))
+ kwargs = {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}
+ self._test_export_same_as_eager(kw_func, args, kwargs)
+
+ def test_export_func_with_pytree_kwargs(self):
+ def kw_func(arg1, arg2, a, b):
+ return arg1 + a["kw1"] + b[0], arg2 + a["kw2"] + b[1]
+
+ args = (torch.ones(2, 3), torch.ones(3, 4))
+ kwargs = {"a": {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)}, "b": [torch.ones(2, 3), torch.ones(3, 4)]}
+ self._test_export_same_as_eager(kw_func, args, kwargs)
+
+ def test_export_func_with_default_kwargs(self):
+ def kw_func(arg1, arg2, a, b=1):
+ return arg1 + arg2, a["kw1"] + a["kw2"] + b
+
+ def kw_func2(arg1, arg2, a=1, b=2):
+ return arg1 + a, arg2 + b
+
+
+ args = (torch.ones(6, 4), torch.ones(1, 1))
+ kwargs1 = {"a": {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}}
+ kwargs2 = {"a": {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}, "b": 2}
+ self._test_export_same_as_eager(kw_func, args, kwargs1)
+ self._test_export_same_as_eager(kw_func, args, kwargs2)
+ kwargs3 = {"b": 1}
+ self._test_export_same_as_eager(kw_func2, args, kwargs3)
+
+ def test_export_func_with_var_postional_args(self):
+ def kw_func(arg1, arg2, *args):
+ return arg1 + args[0], arg2 + args[1]
+
+ args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
+ self._test_export_same_as_eager(kw_func, args)
+
+ def test_export_func_with_keyword_only_args(self):
+ def kw_func(arg1, arg2, *args, kw1, kw2):
+ return arg1 + args[0] + kw1, arg2 + args[1] + kw2
+
+ args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
+ kwargs = {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)}
+ self._test_export_same_as_eager(kw_func, args, kwargs)
+
+ def test_export_func_with_var_keyword_args(self):
+ def kw_func(arg1, arg2, *args, kw1, kw2, **kwargs):
+ return arg1 + args[0] + kw1 + kwargs["kw3"], arg2 + args[1] + kw2 + kwargs["kw4"]
+
+ args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
+ kwargs = {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4), "kw3": torch.ones(2, 3), "kw4": torch.ones(3, 4)}
+ self._test_export_same_as_eager(kw_func, args, kwargs)
+
+ def test_export_func_with_var_keyword_pytree_args(self):
+ def kw_func(arg1, arg2, *args, kw1, kw2, **kwargs):
+ return arg1 + arg2[0][0] + args[0] + kw1[0] + kwargs["kw3"][0], arg2[1] + args[1] + kw2 + kwargs["kw4"]
+
+ args = (torch.ones(2, 3), [(torch.ones(2, 3), ), torch.ones(3, 4)], torch.ones(2, 3), torch.ones(3, 4))
+ kwargs = {"kw1": (torch.ones(2, 3), ), "kw2": torch.ones(3, 4),
+ "kw3": (torch.ones(2, 3), torch.ones(3, 4)), "kw4": torch.ones(3, 4)}
+ self._test_export_same_as_eager(kw_func, args, kwargs)
def test_linear_conv(self):
diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py
index c394293..7c3e877 100644
--- a/test/export/test_serialize.py
+++ b/test/export/test_serialize.py
@@ -177,7 +177,7 @@
"""Export a graph, serialize it, deserialize it, and compare the results."""
# TODO(angelayi): test better with some sort of wrapper
constraints = [] if constraints is None else constraints
- ep = export(fn, inputs, constraints)
+ ep = export(fn, inputs, {}, constraints)
serialized_struct, state_dict = serialize(ep, opset_version={"aten": 0})
deserialized_ep = deserialize(serialized_struct, state_dict, expected_opset_version={"aten": 0})
diff --git a/test/export/test_upgrade.py b/test/export/test_upgrade.py
index 80e9b48..8eec5e3 100644
--- a/test/export/test_upgrade.py
+++ b/test/export/test_upgrade.py
@@ -117,7 +117,7 @@
return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode='trunc')
inputs = (torch.ones([2, 3]) * 4, 2.)
- ep = export(fn, inputs, [], _add_runtime_assertions=False)
+ ep = export(fn, inputs, {}, [], _add_runtime_assertions=False)
compiler_opset_version = {"aten": 4}
model_opset_version = {"aten": 3}
upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS)
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index 1d4fae3..ce5abf1 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -33,6 +33,7 @@
from .exported_program import (
_process_constraints,
+ combine_args_kwargs,
CallSpec,
ExportBackwardSignature,
ExportedProgram,
@@ -124,6 +125,7 @@
def export(
f: Callable,
args: Tuple[Any],
+ kwargs: Optional[Dict[str, Any]] = None,
constraints: Optional[List[Constraint]] = None,
*,
_add_runtime_assertions=True,
@@ -136,21 +138,18 @@
Args:
m: the `nn.Module` or callable to trace.
- args: Tracing example inputs.
+ args: example positional inputs.
- constraints: A list of constraints on the dynamic arguments specifying
+ kwargs: optional example keyword inputs.
+
+ constraints: A optional list of constraints on the dynamic arguments specifying
their possible range of their shapes
Returns:
An ExportedProgram containing the traced method.
"""
- if constraints is None:
- constraints = []
-
- if not isinstance(f, torch.nn.Module):
- for parameter in inspect.signature(f).parameters.values():
- if parameter.kind == parameter.VAR_KEYWORD:
- raise UserError(UserErrorType.INVALID_INPUT, "Kwargs to torch.export is not supported")
+ constraints = constraints or []
+ kwargs = kwargs or {}
with torch._dynamo.config.patch(dataclasses.asdict(ExportDynamoConfig())): # type: ignore[attr-defined]
try:
@@ -160,6 +159,7 @@
constraints=constraints,
assume_static_by_default=True,
tracing_mode="symbolic",
+ **kwargs,
)
params_buffers: "OrderedDict[str, Union[torch.Tensor, torch.nn.Parameter]]" = OrderedDict()
@@ -210,11 +210,11 @@
fake_mode = detected_fake_mode
fake_args = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, args)
+ fake_kwargs = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, kwargs)
# Fix the graph output signature to be tuple if scalar
# because aot_export expects a tuple as return type
- return_val = f(*args)
- flat_args, in_spec = pytree.tree_flatten(args)
+ return_val = f(*args, **kwargs)
out_spec = orig_out_spec = gm_torch_level._out_spec
# this means it is scalar return value, so will make it tuple
if not isinstance(return_val, (list, tuple)):
@@ -251,7 +251,14 @@
if n.op == "get_attr":
params_buffers_to_node_meta[n.target] = meta
- gm, graph_signature = aot_export_module(gm_torch_level, fake_args, decompositions=DECOMP_TABLE, trace_joint=False)
+ # Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict
+ # to follow the order in orig_args and correctly call gm_torch_level
+ gm, graph_signature = aot_export_module(
+ gm_torch_level,
+ (*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()),
+ decompositions=DECOMP_TABLE,
+ trace_joint=False
+ )
export_backward_signature = ExportBackwardSignature(
gradients_to_parameters=graph_signature.backward_signature.gradients_to_parameters,
@@ -300,6 +307,7 @@
for k, v in params_buffers_to_node_meta[buffer_name].items():
node.meta[k] = v
+ flat_args, in_spec = pytree.tree_flatten(combine_args_kwargs(args, kwargs))
range_constraints, equality_constraints = _process_constraints(
gm,
export_graph_signature,
@@ -329,3 +337,10 @@
raise UserError(
UserErrorType.ANTI_PATTERN,
f"Consider annotating your code using constrain_as_*(). {str(e)}")
+
+def _reorder_kwargs_by_names(arg_names: List[str], args: Tuple[Any], kwargs: Dict[str, Any]):
+ assert len(arg_names) == len(args) + len(kwargs), (
+ f"Total number of arg names is expected to be {len(arg_names)} "
+ f"but got {len(args)} positional args, {len(kwargs)} kwargs."
+ )
+ return OrderedDict({kw_name: kwargs[kw_name] for kw_name in arg_names[len(args):]})
diff --git a/torch/_export/db/examples/fn_with_kwargs.py b/torch/_export/db/examples/fn_with_kwargs.py
index 86d8416..3a59ae7 100644
--- a/torch/_export/db/examples/fn_with_kwargs.py
+++ b/torch/_export/db/examples/fn_with_kwargs.py
@@ -12,17 +12,17 @@
**{"input0": torch.randn(4), "input1": torch.randn(4)}
),
tags={"python.data-structure"},
- support_level=SupportLevel.NOT_SUPPORTED_YET,
+ support_level=SupportLevel.SUPPORTED,
)
-def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs):
+def fn_with_kwargs(pos0, tuple0, *myargs, mykw0, **mykwargs):
"""
Keyword arguments are not supported at the moment.
"""
out = pos0
for arg in tuple0:
- out *= arg
+ out = out * arg
for arg in myargs:
- out *= arg
- out *= mykw0
- out *= mykwargs["input0"] * mykwargs["input1"]
+ out = out * arg
+ out = out * mykw0
+ out = out * mykwargs["input0"] * mykwargs["input1"]
return out
diff --git a/torch/_export/exported_program.py b/torch/_export/exported_program.py
index 8d9caf0..ef9ea7e 100644
--- a/torch/_export/exported_program.py
+++ b/torch/_export/exported_program.py
@@ -108,12 +108,13 @@
self.range_constraints: Dict[sympy.Symbol, RangeConstraint] = range_constraints
self.equality_constraints: List[Tuple[InputDim, InputDim]] = equality_constraints
- def __call__(self, *args: Any) -> Any:
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
if self.call_spec.in_spec is not None:
try:
- args = fx_pytree.tree_flatten_spec(args, self.call_spec.in_spec) # type: ignore[assignment]
+ user_args = combine_args_kwargs(args, kwargs)
+ args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec) # type: ignore[assignment]
except Exception:
- _, received_spec = pytree.tree_flatten(args)
+ _, received_spec = pytree.tree_flatten(user_args)
raise error.InternalError(
"Trying to flatten user inputs with exported input tree spec: \n"
f"{self.call_spec.in_spec}\n"
@@ -380,3 +381,6 @@
range_constraints[symbol] = RangeConstraint(min_val, max_val)
return range_constraints, equality_constraints
+
+def combine_args_kwargs(args, kwargs):
+ return (args, kwargs) if kwargs else args
diff --git a/torch/_export/serde/upgrade.py b/torch/_export/serde/upgrade.py
index 3e8e691..c6e2d41 100644
--- a/torch/_export/serde/upgrade.py
+++ b/torch/_export/serde/upgrade.py
@@ -196,7 +196,7 @@
upgraded_program = exported_program.transform(_pass)
# NB: we have to retrace the graph_module instead of ep because of some failure. Also, we need to turn of
# _add_runtime_assertions because dynamo is not happy with sym_size.int.
- exported_program = export(upgraded_program.graph_module, inputs, [], _add_runtime_assertions=False)
+ exported_program = export(upgraded_program.graph_module, inputs, {}, _add_runtime_assertions=False)
exported_program.call_spec = upgraded_program.call_spec
return exported_program