Fix the SDPA AOT export issue (#130164)
Summary:
## Context
TL;DR: aot_export failed for SDPA memory efficient backend when using `inference_mode`
The CMF AOTI lowering started to fail on the trunk. We have the script (https://fburl.com/code/kfk64i5s) to reproduce the issue quickly (log: P1469307638). By bisecting the stack, we found the issue starting from the D58701607
## Root Cause
In the `inference_mode()`,
the `aten::scaled_dot_product_attention` was not decomposed before the `functionalization` and the op it-self was an out-place op, so the `functionalization` doesn't make change and then was decomposed into `masked_fill_.`, then decomposed to the `copy_`
So it's `aten::sdpa` --- (functionalization) ---> `aten::sdpa` --- (decompose) ---> `masked_fill_` --- (decompose) ---> `copy_` ---> failure
In the `torch.no_grad()`,
`aten::sdpa` was decomposed before `functionalization`, so the story is
`aten::sdpa` --- (decompose) ---> `masked_fill_` --- (functionalization) ---> `masked_fill` --- (decompose) ---> `out-place ops` ---> good
## How to fix
Long-term:
The issue was tracked in the ticket (https://github.com/pytorch/pytorch/issues/129418). The long-term fix could be we do one more round of `functionalization` after the `decompose`, like
`aten::sdpa` --- (functionalization) ---> `aten::sdpa` --- (decompose) ---> `masked_fill_` --- (functionalization) ---> `masked_fill` ---> good
Short-term:
It would be a big change I guess. To unblock the production use-case, I marked the `aten::sdpa` should be decomposed in this diff
Test Plan:
local repro works now
buck run mode/opt scripts/sijiac/prototypes:sdpa_aoti
Differential Revision: D59385876
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130164
Approved by: https://github.com/zou3519
diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py
index 878c675..e75c7a8 100644
--- a/torch/_subclasses/functional_tensor.py
+++ b/torch/_subclasses/functional_tensor.py
@@ -106,6 +106,11 @@
torch.ops.aten.feature_dropout.default, # type: ignore[has-type]
torch.ops.aten.feature_alpha_dropout.default, # type: ignore[has-type]
torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type]
+ # `scaled_dot_product_attention` is not aliasing or mutating, but it would
+ # decompose into in-place ops, so we're adding it to this list to force decomposingdecomposing for it.
+ # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L530-L531
+ # Related to https://github.com/pytorch/pytorch/issues/129418
+ torch.ops.aten.scaled_dot_product_attention.default, # type: ignore[has-type]
]
def __new__(cls, elem):