Generate patterns in fp16 and fp32 (#109142)
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109142
Approved by: https://github.com/yanboliang
ghstack dependencies: #109663, #108894, #108917
diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py
index 7a7c454..49fd2b3 100644
--- a/test/inductor/test_fused_attention.py
+++ b/test/inductor/test_fused_attention.py
@@ -42,13 +42,15 @@
has_dropout=False,
check_train=True,
override_check_equal=False,
+ dtype=torch.float,
+ rtol=1.3e-6,
):
if args1 is None:
tensor_shape = (4, 2, 16, 32)
args1 = [
- torch.randn(tensor_shape, device=self.device),
- torch.randn(tensor_shape, device=self.device),
- torch.randn(tensor_shape, device=self.device),
+ torch.randn(tensor_shape, device=self.device, dtype=dtype),
+ torch.randn(tensor_shape, device=self.device, dtype=dtype),
+ torch.randn(tensor_shape, device=self.device, dtype=dtype),
]
else:
args1 = list(args1)
@@ -91,7 +93,7 @@
and arg1.is_floating_point()
and (not has_dropout or override_check_equal)
):
- self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=1.3e-6)
+ self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol)
@skipIfRocm
def _test_sdpa_rewriter_1(self):
@@ -106,8 +108,17 @@
.matmul(value)
)
- self._check_common(dot_prod_attention)
- self._check_common(checkpoint_wrapper(dot_prod_attention))
+ for dtype in [torch.float, torch.half]:
+ if self.device == "cpu" and dtype == torch.half:
+ continue
+ rtol = 1.3e-6 if dtype == torch.float else 0.7
+ self._check_common(dot_prod_attention, dtype=dtype, atol=0.001, rtol=rtol)
+ self._check_common(
+ checkpoint_wrapper(dot_prod_attention),
+ dtype=dtype,
+ atol=0.001,
+ rtol=rtol,
+ )
def _test_pattern_fails_with_reuse(self):
"""
diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py
index b60f53b..066b7bd 100644
--- a/test/inductor/test_pattern_matcher.py
+++ b/test/inductor/test_pattern_matcher.py
@@ -786,7 +786,7 @@
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_fuse_attention_roundtrip_pattern(self):
- # are we losing anything in serialization in patterns
+ # are we losing anything in serialization
from torch._inductor.fx_passes.fuse_attention import _get_sfdp_patterns
global_vals = {
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index ca1f436..798d1b0 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -1013,24 +1013,31 @@
super().__init__()
def forward(self, arg1, arg2, arg3):
- torch.ops.aten._scaled_dot_product_flash_attention(arg1, arg2, arg3)
+ torch.ops.aten._scaled_dot_product_flash_attention(arg1, arg2, arg3, scale=0.17677669529663687)
args_new = [
- ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
- ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
- ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
+ [
+ ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
+ ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
+ ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
+ ],
+ [
+ ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
+ ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
+ ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
+ ]
]
-
- args = [rand_strided(bsz, num_heads, seq_len, head_dim) for
- (bsz, num_heads, seq_len, head_dim) in args_new]
- try:
- with torch._subclasses.CrossRefFakeMode():
- Repro()(*args)
- except RuntimeError as e:
- # We expect the cross ref to succed for the first output to fail
- # for the rng state, see Note [Seed and Offset]
- self.assertTrue("output[0]" not in str(e))
- self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))
+ for args_list in args_new:
+ args = [rand_strided(bsz, num_heads, seq_len, head_dim) for
+ (bsz, num_heads, seq_len, head_dim) in args_list]
+ try:
+ with torch._subclasses.CrossRefFakeMode():
+ Repro()(*args)
+ except RuntimeError as e:
+ # We expect the cross ref to succed for the first output to fail
+ # for the rng state, see Note [Seed and Offset]
+ self.assertTrue("output[0]" not in str(e))
+ self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))
@skipIfRocm
@unittest.skipIf(not RUN_CUDA, "requires cuda")
diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py
index 2446d27..6bf800d 100644
--- a/torch/_inductor/fx_passes/fuse_attention.py
+++ b/torch/_inductor/fx_passes/fuse_attention.py
@@ -378,134 +378,150 @@
# sizes/values don't actually matter for initial trace
# once we get a possible match we re-trace with the actual values and verify the match still holds
- g = functools.partial(torch.empty, (2, 4, 8, 16), device=device, requires_grad=True)
- gp = functools.partial(
- torch.empty, (2, 8, 4, 16), device=device, requires_grad=True, dtype=torch.half
+ g_inp = functools.partial(
+ torch.empty, (2, 4, 8, 16), device=device, requires_grad=True
)
- b = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
- c = functools.partial(torch.tensor, 2.0, device=device)
+ b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
+ c_inp = functools.partial(torch.tensor, 2.0, device=device)
# workaround https://github.com/pytorch/pytorch/issues/97894
# 0.113377 is a "magic" value that lets us recover the lost input arg relationship
d = {"dropout_p": 0.113377}
- for pattern, replacement, args, workaround, extra_check in [
- (
- _sfdp_pattern_1,
- _sfdp_replacement_1,
- [g(), g(), g(), c()],
- {},
- _sfdp_scale_factor_check(aten.div.Tensor),
- ),
- (
- _sfdp_pattern_2,
- _sfdp_replacement_2,
- [g(), g(), g(), c()],
- {},
- _sfdp_scale_factor_check(aten.mul.Tensor),
- ),
- (
- _sfdp_pattern_3,
- _sfdp_replacement_3,
- [g(), g(), g(), c()],
- d,
- _sfdp_scale_factor_check(aten.div.Tensor),
- ),
- (
- _sfdp_pattern_4,
- _sfdp_replacement_4,
- [g(), g(), g(), c()],
- d,
- _sfdp_scale_factor_check(aten.mul.Tensor),
- ),
- (
- _sfdp_pattern_5,
- _sfdp_replacement_5,
- [g(), g(), g(), b()],
- {},
- _sfdp_params_check,
- ),
- (
- _sfdp_pattern_6,
- _sfdp_replacement_6,
- [g(), g(), g(), b()],
- d,
- _sfdp_params_check,
- ),
- (
- _sfdp_pattern_7,
- _sfdp_replacement_7,
- [gp(), gp(), gp()],
- d,
- _sfdp_params_check,
- ),
- (
- _sfdp_pattern_8,
- _sfdp_replacement_8,
- [gp(), gp(), gp()],
- {},
- _sfdp_params_check,
- ),
- (
- _sfdp_pattern_9,
- _sfdp_replacement_9,
- [gp(), gp(), gp()],
- d,
- _sfdp_params_check,
- ),
- (
- _sfdp_pattern_10,
- _sfdp_replacement_10,
- [gp(), gp(), gp()],
- {},
- _sfdp_params_check,
- ),
- (
- _sfdp_pattern_11,
- _sfdp_replacement_11,
- [g(), g(), g(), c()],
- {},
- _sfdp_scale_factor_check(aten.div.Tensor),
- ),
- (
- _sfdp_pattern_12,
- _sfdp_replacement_12,
- [g(), g(), g(), c()],
- d,
- _sfdp_scale_factor_check(aten.div.Tensor),
- ),
- ]:
- # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
- # gets serialized to a python file and does not require tracing at runtime.
- assert isinstance(workaround, dict)
- training_args = [*args, *workaround.values()]
- name = pattern.__name__
+ # softmax will generate a dtype conversion on inputs if they are in half,
+ # but will not in float, so we generate a pattern for both
+ for dtype in [torch.float, torch.half]:
+ g = functools.partial(g_inp, dtype=dtype)
+ b = functools.partial(b_inp, dtype=dtype)
+ c = functools.partial(c_inp, dtype=dtype)
- yield f"{name}_training", {
- "search_fn": pattern,
- "replace_fn": replacement,
- "example_inputs": training_args,
- "trace_fn": training_graph,
- "pass_dict": patterns,
- "extra_check": extra_check,
- "scalar_workaround": workaround,
- }
+ for pattern, replacement, args, workaround, extra_check in [
+ (
+ _sfdp_pattern_1,
+ _sfdp_replacement_1,
+ [g(), g(), g(), c()],
+ {},
+ _sfdp_scale_factor_check(aten.div.Tensor),
+ ),
+ (
+ _sfdp_pattern_2,
+ _sfdp_replacement_2,
+ [g(), g(), g(), c()],
+ {},
+ _sfdp_scale_factor_check(aten.mul.Tensor),
+ ),
+ (
+ _sfdp_pattern_3,
+ _sfdp_replacement_3,
+ [g(), g(), g(), c()],
+ d,
+ _sfdp_scale_factor_check(aten.div.Tensor),
+ ),
+ (
+ _sfdp_pattern_4,
+ _sfdp_replacement_4,
+ [g(), g(), g(), c()],
+ d,
+ _sfdp_scale_factor_check(aten.mul.Tensor),
+ ),
+ (
+ _sfdp_pattern_5,
+ _sfdp_replacement_5,
+ [g(), g(), g(), b()],
+ {},
+ _sfdp_params_check,
+ ),
+ (
+ _sfdp_pattern_6,
+ _sfdp_replacement_6,
+ [g(), g(), g(), b()],
+ d,
+ _sfdp_params_check,
+ ),
+ (
+ _sfdp_pattern_7,
+ _sfdp_replacement_7,
+ [g(), g(), g()],
+ d,
+ _sfdp_params_check,
+ ),
+ (
+ _sfdp_pattern_8,
+ _sfdp_replacement_8,
+ [g(), g(), g()],
+ {},
+ _sfdp_params_check,
+ ),
+ (
+ _sfdp_pattern_9,
+ _sfdp_replacement_9,
+ [g(), g(), g()],
+ d,
+ _sfdp_params_check,
+ ),
+ (
+ _sfdp_pattern_10,
+ _sfdp_replacement_10,
+ [g(), g(), g()],
+ {},
+ _sfdp_params_check,
+ ),
+ (
+ _sfdp_pattern_11,
+ _sfdp_replacement_11,
+ [g(), g(), g(), c()],
+ {},
+ _sfdp_scale_factor_check(aten.div.Tensor),
+ ),
+ (
+ _sfdp_pattern_12,
+ _sfdp_replacement_12,
+ [g(), g(), g(), c()],
+ d,
+ _sfdp_scale_factor_check(aten.div.Tensor),
+ ),
+ ]:
+ # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
+ # gets serialized to a python file and does not require tracing at runtime.
+ assert isinstance(workaround, dict)
+ training_args = [*args, *workaround.values()]
+ name = pattern.__name__
- if workaround:
- assert len(workaround) == 1 and "dropout_p" in workaround
- # functools.partial insufficient because we look at signature downstream
- pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
- replacement = partialize_and_update_signature(replacement, dropout_p=0.0)
- workaround = {}
+ training_name = (
+ f"{name}_training" if dtype == torch.float else f"{name}_training_half"
+ )
+ yield training_name, {
+ "search_fn": pattern,
+ "replace_fn": replacement,
+ "example_inputs": training_args,
+ "trace_fn": training_graph,
+ "pass_dict": patterns,
+ "extra_check": extra_check,
+ "scalar_workaround": workaround,
+ }
- yield f"{name}_inference", {
- "search_fn": pattern,
- "replace_fn": replacement,
- "example_inputs": args,
- "trace_fn": inference_graph,
- "pass_dict": patterns,
- "extra_check": extra_check,
- "scalar_workaround": workaround,
- }
+ if workaround:
+ assert len(workaround) == 1 and "dropout_p" in workaround
+ # functools.partial insufficient because we look at signature downstream
+ pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
+ replacement = partialize_and_update_signature(
+ replacement, dropout_p=0.0
+ )
+ workaround = {}
+
+ inference_name = (
+ f"{name}_inference"
+ if dtype == torch.float
+ else f"{name}_inference_half"
+ )
+ yield inference_name, {
+ "search_fn": pattern,
+ "replace_fn": replacement,
+ "example_inputs": args,
+ "trace_fn": inference_graph,
+ "pass_dict": patterns,
+ "extra_check": extra_check,
+ "scalar_workaround": workaround,
+ }
@functools.lru_cache(None)
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py
index 6d28df3..a684ab8 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py
@@ -93,3 +93,78 @@
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_1_training_half = MultiOutputPattern([view_default_5,
+ view_default_9,
+ permute_default_4,
+ view_default_11,
+ None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_1_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py
index 2dd4f62..f5216fe 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py
@@ -38,6 +38,92 @@
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
+view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, div_Tensor_1, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
+view_default_8 = CallFunction(aten.view.default, sub_Tensor_1, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_10_training = MultiOutputPattern([view_default_5,
+ permute_default_6,
+ permute_default_9,
+ permute_default_11
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
+expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
+expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
@@ -79,7 +165,7 @@
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
-_sfdp_pattern_10_training = MultiOutputPattern([view_default_5,
+_sfdp_pattern_10_training_half = MultiOutputPattern([view_default_5,
permute_default_6,
permute_default_9,
permute_default_11
@@ -112,4 +198,4 @@
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
-_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+_sfdp_pattern_10_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py
index a1e066a..b24a5fb 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py
@@ -108,3 +108,93 @@
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_11_training_half = MultiOutputPattern([view_default_5,
+ permute_default_6,
+ permute_default_9,
+ permute_default_11,
+ None
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_11_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py
index 17526b2..4284e4d 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py
@@ -118,3 +118,103 @@
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
+clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
+convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_12_training_half = MultiOutputPattern([view_default_5,
+ permute_default_6,
+ permute_default_9,
+ permute_default_11,
+ None,
+ None
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
+expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_12_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py
index 458da61..96dc0d4 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py
@@ -93,3 +93,78 @@
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_1, mul_Tensor_2)
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor'))
+view_default_8 = CallFunction(aten.view.default, mul_Tensor_3, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_2_training_half = MultiOutputPattern([view_default_5,
+ view_default_9,
+ permute_default_4,
+ view_default_11,
+ None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_2_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py
index ea2c4ae..0869cb7 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py
@@ -103,3 +103,88 @@
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
+clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
+convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_3_training_half = MultiOutputPattern([view_default_5,
+ view_default_9,
+ permute_default_4,
+ view_default_11,
+ None,
+ None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
+expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_3_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py
index b35c24a..bb3c617 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py
@@ -103,3 +103,88 @@
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
+clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format)
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
+mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_5, mul_Tensor_6)
+convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+mul_Tensor_7 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor'))
+view_default_8 = CallFunction(aten.view.default, mul_Tensor_7, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_4_training_half = MultiOutputPattern([view_default_5,
+ view_default_9,
+ permute_default_4,
+ view_default_11,
+ None,
+ None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
+clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
+expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_4_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py
index c83d726..def8441 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py
@@ -95,3 +95,80 @@
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
+add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored())
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_5_training_half = MultiOutputPattern([view_default_5,
+ view_default_9,
+ permute_default_4,
+ view_default_11,
+ None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
+add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_5_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py
index 304090f..075c675 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py
@@ -105,3 +105,90 @@
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
+add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
+clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
+convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored())
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_6_training_half = MultiOutputPattern([view_default_5,
+ view_default_9,
+ permute_default_4,
+ view_default_11,
+ None,
+ None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
+add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
+expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_6_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py
index 774c33d..c1ef7e6 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py
@@ -40,6 +40,102 @@
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
+mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
+view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
+clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, div_Tensor_1, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, div_Tensor_1, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
+div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_7_training = MultiOutputPattern([view_default_5,
+ permute_default_6,
+ permute_default_9,
+ permute_default_11,
+ None
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
@@ -87,7 +183,7 @@
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
-_sfdp_pattern_7_training = MultiOutputPattern([view_default_5,
+_sfdp_pattern_7_training_half = MultiOutputPattern([view_default_5,
permute_default_6,
permute_default_9,
permute_default_11,
@@ -122,4 +218,4 @@
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
-_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+_sfdp_pattern_7_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py
index 99d844f..1752202 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py
@@ -38,6 +38,92 @@
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
+view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, div_Tensor_1, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
+div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_8_training = MultiOutputPattern([view_default_5,
+ permute_default_6,
+ permute_default_9,
+ permute_default_11
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
@@ -79,7 +165,7 @@
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
-_sfdp_pattern_8_training = MultiOutputPattern([view_default_5,
+_sfdp_pattern_8_training_half = MultiOutputPattern([view_default_5,
permute_default_6,
permute_default_9,
permute_default_11
@@ -112,4 +198,4 @@
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
-_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+_sfdp_pattern_8_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py
index 9368356..abefe3b 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py
@@ -40,6 +40,102 @@
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
+mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
+view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
+clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, div_Tensor_1, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, div_Tensor_1, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
+view_default_8 = CallFunction(aten.view.default, sub_Tensor_1, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_9_training = MultiOutputPattern([view_default_5,
+ permute_default_6,
+ permute_default_9,
+ permute_default_11,
+ None
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
+expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
+expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
@@ -87,7 +183,7 @@
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
-_sfdp_pattern_9_training = MultiOutputPattern([view_default_5,
+_sfdp_pattern_9_training_half = MultiOutputPattern([view_default_5,
permute_default_6,
permute_default_9,
permute_default_11,
@@ -122,4 +218,4 @@
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
-_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+_sfdp_pattern_9_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/central_index.py b/torch/_inductor/fx_passes/serialized_patterns/central_index.py
index 9e59d0b..02eae1d 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/central_index.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/central_index.py
@@ -2,18 +2,18 @@
# To re-generate, run:
# cd ~/pytorch && python
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
-from ._sfdp_pattern_1 import (_sfdp_pattern_1_training, _sfdp_pattern_1_inference)
-from ._sfdp_pattern_2 import (_sfdp_pattern_2_training, _sfdp_pattern_2_inference)
-from ._sfdp_pattern_3 import (_sfdp_pattern_3_training, _sfdp_pattern_3_inference)
-from ._sfdp_pattern_4 import (_sfdp_pattern_4_training, _sfdp_pattern_4_inference)
-from ._sfdp_pattern_5 import (_sfdp_pattern_5_training, _sfdp_pattern_5_inference)
-from ._sfdp_pattern_6 import (_sfdp_pattern_6_training, _sfdp_pattern_6_inference)
-from ._sfdp_pattern_7 import (_sfdp_pattern_7_training, _sfdp_pattern_7_inference)
-from ._sfdp_pattern_8 import (_sfdp_pattern_8_training, _sfdp_pattern_8_inference)
-from ._sfdp_pattern_9 import (_sfdp_pattern_9_training, _sfdp_pattern_9_inference)
-from ._sfdp_pattern_10 import (_sfdp_pattern_10_training, _sfdp_pattern_10_inference)
-from ._sfdp_pattern_11 import (_sfdp_pattern_11_training, _sfdp_pattern_11_inference)
-from ._sfdp_pattern_12 import (_sfdp_pattern_12_training, _sfdp_pattern_12_inference)
+from ._sfdp_pattern_1 import (_sfdp_pattern_1_training, _sfdp_pattern_1_inference, _sfdp_pattern_1_training_half, _sfdp_pattern_1_inference_half)
+from ._sfdp_pattern_2 import (_sfdp_pattern_2_training, _sfdp_pattern_2_inference, _sfdp_pattern_2_training_half, _sfdp_pattern_2_inference_half)
+from ._sfdp_pattern_3 import (_sfdp_pattern_3_training, _sfdp_pattern_3_inference, _sfdp_pattern_3_training_half, _sfdp_pattern_3_inference_half)
+from ._sfdp_pattern_4 import (_sfdp_pattern_4_training, _sfdp_pattern_4_inference, _sfdp_pattern_4_training_half, _sfdp_pattern_4_inference_half)
+from ._sfdp_pattern_5 import (_sfdp_pattern_5_training, _sfdp_pattern_5_inference, _sfdp_pattern_5_training_half, _sfdp_pattern_5_inference_half)
+from ._sfdp_pattern_6 import (_sfdp_pattern_6_training, _sfdp_pattern_6_inference, _sfdp_pattern_6_training_half, _sfdp_pattern_6_inference_half)
+from ._sfdp_pattern_7 import (_sfdp_pattern_7_training, _sfdp_pattern_7_inference, _sfdp_pattern_7_training_half, _sfdp_pattern_7_inference_half)
+from ._sfdp_pattern_8 import (_sfdp_pattern_8_training, _sfdp_pattern_8_inference, _sfdp_pattern_8_training_half, _sfdp_pattern_8_inference_half)
+from ._sfdp_pattern_9 import (_sfdp_pattern_9_training, _sfdp_pattern_9_inference, _sfdp_pattern_9_training_half, _sfdp_pattern_9_inference_half)
+from ._sfdp_pattern_10 import (_sfdp_pattern_10_training, _sfdp_pattern_10_inference, _sfdp_pattern_10_training_half, _sfdp_pattern_10_inference_half)
+from ._sfdp_pattern_11 import (_sfdp_pattern_11_training, _sfdp_pattern_11_inference, _sfdp_pattern_11_training_half, _sfdp_pattern_11_inference_half)
+from ._sfdp_pattern_12 import (_sfdp_pattern_12_training, _sfdp_pattern_12_inference, _sfdp_pattern_12_training_half, _sfdp_pattern_12_inference_half)
central_index = {
'_sfdp_pattern_1_training': _sfdp_pattern_1_training,
@@ -40,6 +40,30 @@
'_sfdp_pattern_11_inference': _sfdp_pattern_11_inference,
'_sfdp_pattern_12_training': _sfdp_pattern_12_training,
'_sfdp_pattern_12_inference': _sfdp_pattern_12_inference,
+ '_sfdp_pattern_1_training_half': _sfdp_pattern_1_training_half,
+ '_sfdp_pattern_1_inference_half': _sfdp_pattern_1_inference_half,
+ '_sfdp_pattern_2_training_half': _sfdp_pattern_2_training_half,
+ '_sfdp_pattern_2_inference_half': _sfdp_pattern_2_inference_half,
+ '_sfdp_pattern_3_training_half': _sfdp_pattern_3_training_half,
+ '_sfdp_pattern_3_inference_half': _sfdp_pattern_3_inference_half,
+ '_sfdp_pattern_4_training_half': _sfdp_pattern_4_training_half,
+ '_sfdp_pattern_4_inference_half': _sfdp_pattern_4_inference_half,
+ '_sfdp_pattern_5_training_half': _sfdp_pattern_5_training_half,
+ '_sfdp_pattern_5_inference_half': _sfdp_pattern_5_inference_half,
+ '_sfdp_pattern_6_training_half': _sfdp_pattern_6_training_half,
+ '_sfdp_pattern_6_inference_half': _sfdp_pattern_6_inference_half,
+ '_sfdp_pattern_7_training_half': _sfdp_pattern_7_training_half,
+ '_sfdp_pattern_7_inference_half': _sfdp_pattern_7_inference_half,
+ '_sfdp_pattern_8_training_half': _sfdp_pattern_8_training_half,
+ '_sfdp_pattern_8_inference_half': _sfdp_pattern_8_inference_half,
+ '_sfdp_pattern_9_training_half': _sfdp_pattern_9_training_half,
+ '_sfdp_pattern_9_inference_half': _sfdp_pattern_9_inference_half,
+ '_sfdp_pattern_10_training_half': _sfdp_pattern_10_training_half,
+ '_sfdp_pattern_10_inference_half': _sfdp_pattern_10_inference_half,
+ '_sfdp_pattern_11_training_half': _sfdp_pattern_11_training_half,
+ '_sfdp_pattern_11_inference_half': _sfdp_pattern_11_inference_half,
+ '_sfdp_pattern_12_training_half': _sfdp_pattern_12_training_half,
+ '_sfdp_pattern_12_inference_half': _sfdp_pattern_12_inference_half,
}
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 461e8bc..def3ebe 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -4878,13 +4878,12 @@
head_dim = query.size(3)
max_seqlen_batch_k = key.size(2)
+
if device_hint(query) == "cpu":
- Nnz_q = batch_size * max_seqlen_batch_q
- query_t = query.transpose(1, 2)
- query_reshaped = query_t.reshape(Nnz_q, num_heads, head_dim)
- attention = torch.empty_like(query_reshaped, device=query.device)
- attention = attention.view(
- batch_size, max_seqlen_batch_q, num_heads, head_dim
+ attention = torch.empty(
+ (batch_size, max_seqlen_batch_q, num_heads, head_dim),
+ dtype=query.dtype,
+ device=query.device,
).transpose(1, 2)
logsumexp = torch.empty(
(
@@ -4977,6 +4976,12 @@
philox_offset: Tensor,
scale: Optional[float] = None,
):
+ if device_hint(query) != "cpu":
+ grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2)
+ grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
+ grad_v = torch.empty_like(value.transpose(1, 2)).transpose(1, 2)
+ return grad_q, grad_k, grad_v
+
batch_size = query.size(0)
num_heads = query.size(1)
head_dim = query.size(3)