[1/N] Non-Tensor: Scalar Support: Enable aot compile to support aten operations with scalar input like alpha (#124177)
Some operations have a scalar input parameter, like `torch.add(a, b, alpha=2.0)`. Currently, the aot compile does not support such a case because it requires the signature of the captured graph to align with the operation's signature. This means that some inputs in the captured graph may be scalar(float, int, bool, etc.). It breaks the assumption of `compile_fx_aot` as it assumes all the example inputs are tensor - https://github.com/pytorch/pytorch/blob/0f6ce45bcbd7026c00da43db0317ede10830378b/torch/_inductor/compile_fx.py#L1048
This PR intends to support such cases by allowing not-aligned signature and filtering out the non-Tensor parameters.
Captured graph for `torch.add(a, b, alpha=2.0)`
```
opcode name target args kwargs
------------- -------- --------------- ---------------- --------------
placeholder arg0_1 arg0_1 () {}
placeholder arg1_1 arg1_1 () {}
call_function add aten.add.Tensor (arg0_1, arg1_1) {'alpha': 2.0}
output output_1 output ((add,),) {}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124177
Approved by: https://github.com/jansel, https://github.com/desertfire, https://github.com/jgong5
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index fe7b48f..49afa50 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -9,6 +9,7 @@
from unittest import skip
import torch
+import torch._export
import torch._inductor
import torch.nn as nn
from torch._dynamo.testing import rand_strided, same
@@ -1210,6 +1211,24 @@
torch._export.aot_compile(Model(self.device), example_inputs)
self.check_model(Model(self.device), example_inputs)
+ def test_non_tensor_input(self):
+ def fn(a, b, alpha=1.0):
+ return torch.add(a, b, alpha=alpha)
+
+ a = torch.randn(10, device=self.device)
+ b = torch.randn(10, device=self.device)
+ with self.assertRaises(RuntimeError):
+ torch._export.aot_compile(fn, args=(a, b), kwargs={"alpha": 2.0})
+
+ so_path = torch._export.aot_compile(
+ torch.ops.aten.add, args=(a, b), kwargs={"alpha": 2.0}, same_signature=False
+ )
+ kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path)
+ res = kernel_runner.run([a, b])
+ self.assertTrue(isinstance(res, list))
+ self.assertTrue(len(res) == 1)
+ self.assertEqual(fn(a, b, alpha=2.0), res[0])
+
def test_buffer_mutation_2(self):
class Model(torch.nn.Module):
def __init__(self, device):
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index a1c233a..05aee5f 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -343,6 +343,7 @@
options: Optional[Dict[str, Any]] = None,
remove_runtime_assertions: bool = False,
disable_constraint_solver: bool = False,
+ same_signature: bool = True,
) -> str:
"""
Note: this function is not stable yet
@@ -393,6 +394,7 @@
kwargs,
dynamic_shapes,
disable_constraint_solver=disable_constraint_solver,
+ same_signature=same_signature,
# Disabling this flag, because instead we can rely on the mapping
# dynamo_flat_name_to_original_fqn which is coming from Dynamo.
restore_fqn=False,
diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py
index 5ea494f..0516fc5 100644
--- a/torch/_inductor/__init__.py
+++ b/torch/_inductor/__init__.py
@@ -76,7 +76,9 @@
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
(args, kwargs or {})
)
- flat_example_inputs = tuple(x[1] for x in flat_args_with_path)
+ flat_example_inputs = tuple(
+ x[1] for x in flat_args_with_path if isinstance(x[1], torch.Tensor)
+ )
if in_spec is not None and received_spec != in_spec:
raise ValueError( # noqa: TRY200
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 092f054..08db8d1 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -1642,6 +1642,10 @@
options=options,
remove_runtime_assertions=remove_runtime_assertions,
disable_constraint_solver=disable_constraint_solver,
+ # Some operations may have non-Tensor parameters like int, float, bool. These
+ # non-Tensor parameters will not be the input of the graph. Therefore, we do
+ # need to keep the same signature.
+ same_signature=False,
)
kernel_metadata_items = []
diff --git a/torch/export/_trace.py b/torch/export/_trace.py
index 728bdf2..72ae886 100644
--- a/torch/export/_trace.py
+++ b/torch/export/_trace.py
@@ -415,6 +415,7 @@
disable_constraint_solver: bool = False,
restore_fqn: bool = True,
_log_export_usage: bool = True,
+ same_signature: bool = True,
) -> torch.fx.GraphModule:
"""
Traces either an nn.Module's forward function or just a callable with PyTorch
@@ -445,6 +446,7 @@
tracing_mode="symbolic",
disable_constraint_solver=disable_constraint_solver,
_log_export_usage=_log_export_usage,
+ same_signature=same_signature,
)(
*args,
**kwargs,