[export] add kwargs support for export. (#105337)

Solving #105242.

During export, the exported function's signature changes multiple times. Suppose we'd like to export f as shown in following example:
```python
def f(arg1, arg2, kw1, kw2):
  pass

args = (arg1, arg2)
kwargs =  {"kw2":arg3, "kw1":arg4}

torch.export(f, args, kwargs)
```
The signature changes mutiple times during export process in the following order:
1. **gm_torch_level = dynamo.export(f, *args, \*\*kwargs)**. In this step, we turn all  kinds of parameters such as **postional_only**, **var_positioinal**, **kw_only**, and **var_kwargs** into **positional_or_kw**.It also preserves the positional and kword argument names in original function (i.e. f in this example) [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/export.py#L546C13-L546C27). The order of kwargs will be the **key order** of kwargs (after python 3.6, the order is the insertion of order of keys) instead of the original function signature and the order is baked into a _orig_args varaible of gm_torch_level's pytree info. So we'll have:
```python
def gm_torch_level(arg1, arg2, kw2, kw1)
```
Such difference is acceptable as it's transparent to users of export.

2. **gm_aot_export = aot_export_module(gm_torch_level, pos_or_kw_args)**. In this step, we need to turn kwargs into positional args in the order of how gm_torch_level expected, which is stored in _orig_args. The returned gm_aot_export has the graph signature of flat_args, in_spec = pytree.tree_flatten(pos_or_kw_args):
``` python
flat_args, _ = pytree.tree_flatten(pos_or_kw_args)
def gm_aot_export(*flat_args)
```

3. **exported_program(*args, \*\*kwargs)**. The epxorted artifact is exported_program, which is a wrapper over gm_aot_export and has the same calling convention as the original function "f". To do this, we need to 1. specialize the order of kwargs into pos_or_kw_args and 2. flatten the pos_or_kw_args into what gm_aot_export expected.  We can combine the two steps into one with :
```python
_, in_spec = pytree.tree_flatten((args, kwargs))

# Then during exported_program.__call__(*args, **kwargs)
flat_args  = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
```
, where kwargs is treated as a normal pytree whose keyorder is preserved in in_spec.

Implementation-wise, we treat _orig_args in dynamo exported graph module as single source of truth and kwags are ordered following it.

Test plan:
See added tests in test_export.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105337
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
diff --git a/docs/source/scripts/exportdb/generate_example_rst.py b/docs/source/scripts/exportdb/generate_example_rst.py
index 38f71b9..96b6aae 100644
--- a/docs/source/scripts/exportdb/generate_example_rst.py
+++ b/docs/source/scripts/exportdb/generate_example_rst.py
@@ -72,6 +72,7 @@
         exported_program = export(
             model,
             inputs.args,
+            inputs.kwargs,
             constraints=example_case.constraints,
         )
         graph_output = str(exported_program)
diff --git a/test/export/test_db.py b/test/export/test_db.py
index bfa57ba..9b44912 100644
--- a/test/export/test_db.py
+++ b/test/export/test_db.py
@@ -32,6 +32,7 @@
         exported_program = export(
             model,
             inputs.args,
+            inputs.kwargs,
             constraints=case.constraints,
         )
         exported_program.graph_module.print_readable()
@@ -61,6 +62,7 @@
             exported_model = export(
                 model,
                 inputs.args,
+                inputs.kwargs,
                 constraints=case.constraints,
             )
 
@@ -83,6 +85,7 @@
         exported_model = export(
             rewrite_case.model,
             inputs.args,
+            inputs.kwargs,
             constraints=rewrite_case.constraints,
         )
 
diff --git a/test/export/test_export.py b/test/export/test_export.py
index e368499..62db418 100644
--- a/test/export/test_export.py
+++ b/test/export/test_export.py
@@ -148,13 +148,20 @@
 
 @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
 class TestExport(TestCase):
+
+    def _test_export_same_as_eager(self, f, args, kwargs=None):
+        kwargs = kwargs or {}
+        exported_program = export(f, args, kwargs)
+        reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)}
+        self.assertEqual(exported_program(*args, **kwargs), f(*args, **kwargs))
+        self.assertEqual(exported_program(*args, **reversed_kwargs), f(*args, **reversed_kwargs))
+
     def test_basic(self):
         def f(x, y):
             return x[0] + y
 
         inp = ([torch.ones(1, 3)], torch.ones(1, 3))
-        exported_program = export(f, inp)
-        self.assertTrue(torch.allclose(exported_program(*inp), f(*inp)))
+        self._test_export_same_as_eager(f, inp)
 
     def test_raise_user_error_when_guard_on_data_dependent_operation(self):
         def fn_ddo(x):
@@ -235,7 +242,7 @@
             torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " +
             "on a value which we evaluated to have a static value of 3. "
         ):
-            export(f, example_inputs, constraints)
+            export(f, example_inputs, {}, constraints)
 
     def test_not_correct_dim(self):
         def f(x):
@@ -267,8 +274,71 @@
             return map(body, xs, y, z)
 
         inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
-        exported_program = export(list_tensor_map, inps)
-        self.assertTrue(torch.allclose(exported_program(*inps), list_tensor_map(*inps)))
+        self._test_export_same_as_eager(list_tensor_map, inps)
+
+    def test_export_func_with_kwargs(self):
+        def kw_func(arg1, arg2, kw1, kw2):
+            return arg1 + arg2, kw1 + kw2
+
+        args = (torch.ones(6, 4), torch.ones(1, 1))
+        kwargs = {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}
+        self._test_export_same_as_eager(kw_func, args, kwargs)
+
+    def test_export_func_with_pytree_kwargs(self):
+        def kw_func(arg1, arg2, a, b):
+            return arg1 + a["kw1"] + b[0], arg2 + a["kw2"] + b[1]
+
+        args = (torch.ones(2, 3), torch.ones(3, 4))
+        kwargs = {"a": {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)}, "b": [torch.ones(2, 3), torch.ones(3, 4)]}
+        self._test_export_same_as_eager(kw_func, args, kwargs)
+
+    def test_export_func_with_default_kwargs(self):
+        def kw_func(arg1, arg2, a, b=1):
+            return arg1 + arg2, a["kw1"] + a["kw2"] + b
+
+        def kw_func2(arg1, arg2, a=1, b=2):
+            return arg1 + a, arg2 + b
+
+
+        args = (torch.ones(6, 4), torch.ones(1, 1))
+        kwargs1 = {"a": {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}}
+        kwargs2 = {"a": {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}, "b": 2}
+        self._test_export_same_as_eager(kw_func, args, kwargs1)
+        self._test_export_same_as_eager(kw_func, args, kwargs2)
+        kwargs3 = {"b": 1}
+        self._test_export_same_as_eager(kw_func2, args, kwargs3)
+
+    def test_export_func_with_var_postional_args(self):
+        def kw_func(arg1, arg2, *args):
+            return arg1 + args[0], arg2 + args[1]
+
+        args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
+        self._test_export_same_as_eager(kw_func, args)
+
+    def test_export_func_with_keyword_only_args(self):
+        def kw_func(arg1, arg2, *args, kw1, kw2):
+            return arg1 + args[0] + kw1, arg2 + args[1] + kw2
+
+        args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
+        kwargs = {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)}
+        self._test_export_same_as_eager(kw_func, args, kwargs)
+
+    def test_export_func_with_var_keyword_args(self):
+        def kw_func(arg1, arg2, *args, kw1, kw2, **kwargs):
+            return arg1 + args[0] + kw1 + kwargs["kw3"], arg2 + args[1] + kw2 + kwargs["kw4"]
+
+        args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
+        kwargs = {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4), "kw3": torch.ones(2, 3), "kw4": torch.ones(3, 4)}
+        self._test_export_same_as_eager(kw_func, args, kwargs)
+
+    def test_export_func_with_var_keyword_pytree_args(self):
+        def kw_func(arg1, arg2, *args, kw1, kw2, **kwargs):
+            return arg1 + arg2[0][0] + args[0] + kw1[0] + kwargs["kw3"][0], arg2[1] + args[1] + kw2 + kwargs["kw4"]
+
+        args = (torch.ones(2, 3), [(torch.ones(2, 3), ), torch.ones(3, 4)], torch.ones(2, 3), torch.ones(3, 4))
+        kwargs = {"kw1": (torch.ones(2, 3), ), "kw2": torch.ones(3, 4),
+                  "kw3": (torch.ones(2, 3), torch.ones(3, 4)), "kw4": torch.ones(3, 4)}
+        self._test_export_same_as_eager(kw_func, args, kwargs)
 
     def test_linear_conv(self):
 
diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py
index c394293..7c3e877 100644
--- a/test/export/test_serialize.py
+++ b/test/export/test_serialize.py
@@ -177,7 +177,7 @@
         """Export a graph, serialize it, deserialize it, and compare the results."""
         # TODO(angelayi): test better with some sort of wrapper
         constraints = [] if constraints is None else constraints
-        ep = export(fn, inputs, constraints)
+        ep = export(fn, inputs, {}, constraints)
         serialized_struct, state_dict = serialize(ep, opset_version={"aten": 0})
         deserialized_ep = deserialize(serialized_struct, state_dict, expected_opset_version={"aten": 0})
 
diff --git a/test/export/test_upgrade.py b/test/export/test_upgrade.py
index 80e9b48..8eec5e3 100644
--- a/test/export/test_upgrade.py
+++ b/test/export/test_upgrade.py
@@ -117,7 +117,7 @@
             return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode='trunc')
 
         inputs = (torch.ones([2, 3]) * 4, 2.)
-        ep = export(fn, inputs, [], _add_runtime_assertions=False)
+        ep = export(fn, inputs, {}, [], _add_runtime_assertions=False)
         compiler_opset_version = {"aten": 4}
         model_opset_version = {"aten": 3}
         upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS)
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index 1d4fae3..ce5abf1 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -33,6 +33,7 @@
 
 from .exported_program import (
     _process_constraints,
+    combine_args_kwargs,
     CallSpec,
     ExportBackwardSignature,
     ExportedProgram,
@@ -124,6 +125,7 @@
 def export(
     f: Callable,
     args: Tuple[Any],
+    kwargs: Optional[Dict[str, Any]] = None,
     constraints: Optional[List[Constraint]] = None,
     *,
     _add_runtime_assertions=True,
@@ -136,21 +138,18 @@
     Args:
         m: the `nn.Module` or callable to trace.
 
-        args: Tracing example inputs.
+        args: example positional inputs.
 
-        constraints: A list of constraints on the dynamic arguments specifying
+        kwargs: optional example keyword inputs.
+
+        constraints: A optional list of constraints on the dynamic arguments specifying
             their possible range of their shapes
 
     Returns:
         An ExportedProgram containing the traced method.
     """
-    if constraints is None:
-        constraints = []
-
-    if not isinstance(f, torch.nn.Module):
-        for parameter in inspect.signature(f).parameters.values():
-            if parameter.kind == parameter.VAR_KEYWORD:
-                raise UserError(UserErrorType.INVALID_INPUT, "Kwargs to torch.export is not supported")
+    constraints = constraints or []
+    kwargs = kwargs or {}
 
     with torch._dynamo.config.patch(dataclasses.asdict(ExportDynamoConfig())):  # type: ignore[attr-defined]
         try:
@@ -160,6 +159,7 @@
                 constraints=constraints,
                 assume_static_by_default=True,
                 tracing_mode="symbolic",
+                **kwargs,
             )
 
             params_buffers: "OrderedDict[str, Union[torch.Tensor, torch.nn.Parameter]]" = OrderedDict()
@@ -210,11 +210,11 @@
                 fake_mode = detected_fake_mode
 
             fake_args = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, args)
+            fake_kwargs = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, kwargs)
 
             # Fix the graph output signature to be tuple if scalar
             # because aot_export expects a tuple as return type
-            return_val = f(*args)
-            flat_args, in_spec = pytree.tree_flatten(args)
+            return_val = f(*args, **kwargs)
             out_spec = orig_out_spec = gm_torch_level._out_spec
             # this means it is scalar return value, so will make it tuple
             if not isinstance(return_val, (list, tuple)):
@@ -251,7 +251,14 @@
                         if n.op == "get_attr":
                             params_buffers_to_node_meta[n.target] = meta
 
-            gm, graph_signature = aot_export_module(gm_torch_level, fake_args, decompositions=DECOMP_TABLE, trace_joint=False)
+            # Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict
+            # to follow the order in orig_args and correctly call gm_torch_level
+            gm, graph_signature = aot_export_module(
+                gm_torch_level,
+                (*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()),
+                decompositions=DECOMP_TABLE,
+                trace_joint=False
+            )
 
             export_backward_signature = ExportBackwardSignature(
                 gradients_to_parameters=graph_signature.backward_signature.gradients_to_parameters,
@@ -300,6 +307,7 @@
                             for k, v in params_buffers_to_node_meta[buffer_name].items():
                                 node.meta[k] = v
 
+            flat_args, in_spec = pytree.tree_flatten(combine_args_kwargs(args, kwargs))
             range_constraints, equality_constraints = _process_constraints(
                 gm,
                 export_graph_signature,
@@ -329,3 +337,10 @@
             raise UserError(
                 UserErrorType.ANTI_PATTERN,
                 f"Consider annotating your code using constrain_as_*(). {str(e)}")
+
+def _reorder_kwargs_by_names(arg_names: List[str], args: Tuple[Any], kwargs: Dict[str, Any]):
+    assert len(arg_names) == len(args) + len(kwargs), (
+        f"Total number of arg names is expected to be {len(arg_names)} "
+        f"but got {len(args)} positional args, {len(kwargs)} kwargs."
+    )
+    return OrderedDict({kw_name: kwargs[kw_name] for kw_name in arg_names[len(args):]})
diff --git a/torch/_export/db/examples/fn_with_kwargs.py b/torch/_export/db/examples/fn_with_kwargs.py
index 86d8416..3a59ae7 100644
--- a/torch/_export/db/examples/fn_with_kwargs.py
+++ b/torch/_export/db/examples/fn_with_kwargs.py
@@ -12,17 +12,17 @@
         **{"input0": torch.randn(4), "input1": torch.randn(4)}
     ),
     tags={"python.data-structure"},
-    support_level=SupportLevel.NOT_SUPPORTED_YET,
+    support_level=SupportLevel.SUPPORTED,
 )
-def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs):
+def fn_with_kwargs(pos0, tuple0, *myargs, mykw0, **mykwargs):
     """
     Keyword arguments are not supported at the moment.
     """
     out = pos0
     for arg in tuple0:
-        out *= arg
+        out = out * arg
     for arg in myargs:
-        out *= arg
-    out *= mykw0
-    out *= mykwargs["input0"] * mykwargs["input1"]
+        out = out * arg
+    out = out * mykw0
+    out = out * mykwargs["input0"] * mykwargs["input1"]
     return out
diff --git a/torch/_export/exported_program.py b/torch/_export/exported_program.py
index 8d9caf0..ef9ea7e 100644
--- a/torch/_export/exported_program.py
+++ b/torch/_export/exported_program.py
@@ -108,12 +108,13 @@
         self.range_constraints: Dict[sympy.Symbol, RangeConstraint] = range_constraints
         self.equality_constraints: List[Tuple[InputDim, InputDim]] = equality_constraints
 
-    def __call__(self, *args: Any) -> Any:
+    def __call__(self, *args: Any, **kwargs: Any) -> Any:
         if self.call_spec.in_spec is not None:
             try:
-                args = fx_pytree.tree_flatten_spec(args, self.call_spec.in_spec)  # type: ignore[assignment]
+                user_args = combine_args_kwargs(args, kwargs)
+                args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec)  # type: ignore[assignment]
             except Exception:
-                _, received_spec = pytree.tree_flatten(args)
+                _, received_spec = pytree.tree_flatten(user_args)
                 raise error.InternalError(
                     "Trying to flatten user inputs with exported input tree spec: \n"
                     f"{self.call_spec.in_spec}\n"
@@ -380,3 +381,6 @@
         range_constraints[symbol] = RangeConstraint(min_val, max_val)
 
     return range_constraints, equality_constraints
+
+def combine_args_kwargs(args, kwargs):
+    return (args, kwargs) if kwargs else args
diff --git a/torch/_export/serde/upgrade.py b/torch/_export/serde/upgrade.py
index 3e8e691..c6e2d41 100644
--- a/torch/_export/serde/upgrade.py
+++ b/torch/_export/serde/upgrade.py
@@ -196,7 +196,7 @@
             upgraded_program = exported_program.transform(_pass)
             # NB: we have to retrace the graph_module instead of ep because of some failure. Also, we need to turn of
             # _add_runtime_assertions because dynamo is not happy with sym_size.int.
-            exported_program = export(upgraded_program.graph_module, inputs, [], _add_runtime_assertions=False)
+            exported_program = export(upgraded_program.graph_module, inputs, {}, _add_runtime_assertions=False)
             exported_program.call_spec = upgraded_program.call_spec
 
         return exported_program