Generate patterns in fp16 and fp32 (#109142)

aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109142
Approved by: https://github.com/yanboliang
ghstack dependencies: #109663, #108894, #108917
diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py
index 7a7c454..49fd2b3 100644
--- a/test/inductor/test_fused_attention.py
+++ b/test/inductor/test_fused_attention.py
@@ -42,13 +42,15 @@
         has_dropout=False,
         check_train=True,
         override_check_equal=False,
+        dtype=torch.float,
+        rtol=1.3e-6,
     ):
         if args1 is None:
             tensor_shape = (4, 2, 16, 32)
             args1 = [
-                torch.randn(tensor_shape, device=self.device),
-                torch.randn(tensor_shape, device=self.device),
-                torch.randn(tensor_shape, device=self.device),
+                torch.randn(tensor_shape, device=self.device, dtype=dtype),
+                torch.randn(tensor_shape, device=self.device, dtype=dtype),
+                torch.randn(tensor_shape, device=self.device, dtype=dtype),
             ]
         else:
             args1 = list(args1)
@@ -91,7 +93,7 @@
                         and arg1.is_floating_point()
                         and (not has_dropout or override_check_equal)
                     ):
-                        self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=1.3e-6)
+                        self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol)
 
     @skipIfRocm
     def _test_sdpa_rewriter_1(self):
@@ -106,8 +108,17 @@
                 .matmul(value)
             )
 
-        self._check_common(dot_prod_attention)
-        self._check_common(checkpoint_wrapper(dot_prod_attention))
+        for dtype in [torch.float, torch.half]:
+            if self.device == "cpu" and dtype == torch.half:
+                continue
+            rtol = 1.3e-6 if dtype == torch.float else 0.7
+            self._check_common(dot_prod_attention, dtype=dtype, atol=0.001, rtol=rtol)
+            self._check_common(
+                checkpoint_wrapper(dot_prod_attention),
+                dtype=dtype,
+                atol=0.001,
+                rtol=rtol,
+            )
 
     def _test_pattern_fails_with_reuse(self):
         """
diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py
index b60f53b..066b7bd 100644
--- a/test/inductor/test_pattern_matcher.py
+++ b/test/inductor/test_pattern_matcher.py
@@ -786,7 +786,7 @@
         FileCheck().check_not("extern_kernels.addmm(").run(code[0])
 
     def test_fuse_attention_roundtrip_pattern(self):
-        # are we losing anything in serialization in patterns
+        # are we losing anything in serialization
         from torch._inductor.fx_passes.fuse_attention import _get_sfdp_patterns
 
         global_vals = {
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index ca1f436..798d1b0 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -1013,24 +1013,31 @@
                 super().__init__()
 
             def forward(self, arg1, arg2, arg3):
-                torch.ops.aten._scaled_dot_product_flash_attention(arg1, arg2, arg3)
+                torch.ops.aten._scaled_dot_product_flash_attention(arg1, arg2, arg3, scale=0.17677669529663687)
 
         args_new = [
-            ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
-            ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
-            ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
+            [
+                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
+                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
+                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
+            ],
+            [
+                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
+                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
+                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
+            ]
         ]
-
-        args = [rand_strided(bsz, num_heads, seq_len, head_dim) for
-                (bsz, num_heads, seq_len, head_dim) in args_new]
-        try:
-            with torch._subclasses.CrossRefFakeMode():
-                Repro()(*args)
-        except RuntimeError as e:
-            # We expect the cross ref to succed for the first output to fail
-            # for the rng state, see Note [Seed and Offset]
-            self.assertTrue("output[0]" not in str(e))
-            self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))
+        for args_list in args_new:
+            args = [rand_strided(bsz, num_heads, seq_len, head_dim) for
+                    (bsz, num_heads, seq_len, head_dim) in args_list]
+            try:
+                with torch._subclasses.CrossRefFakeMode():
+                    Repro()(*args)
+            except RuntimeError as e:
+                # We expect the cross ref to succed for the first output to fail
+                # for the rng state, see Note [Seed and Offset]
+                self.assertTrue("output[0]" not in str(e))
+                self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))
 
     @skipIfRocm
     @unittest.skipIf(not RUN_CUDA, "requires cuda")
diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py
index 2446d27..6bf800d 100644
--- a/torch/_inductor/fx_passes/fuse_attention.py
+++ b/torch/_inductor/fx_passes/fuse_attention.py
@@ -378,134 +378,150 @@
 
     # sizes/values don't actually matter for initial trace
     # once we get a possible match we re-trace with the actual values and verify the match still holds
-    g = functools.partial(torch.empty, (2, 4, 8, 16), device=device, requires_grad=True)
-    gp = functools.partial(
-        torch.empty, (2, 8, 4, 16), device=device, requires_grad=True, dtype=torch.half
+    g_inp = functools.partial(
+        torch.empty, (2, 4, 8, 16), device=device, requires_grad=True
     )
-    b = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
-    c = functools.partial(torch.tensor, 2.0, device=device)
+    b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
+    c_inp = functools.partial(torch.tensor, 2.0, device=device)
     # workaround https://github.com/pytorch/pytorch/issues/97894
     # 0.113377 is a "magic" value that lets us recover the lost input arg relationship
     d = {"dropout_p": 0.113377}
 
-    for pattern, replacement, args, workaround, extra_check in [
-        (
-            _sfdp_pattern_1,
-            _sfdp_replacement_1,
-            [g(), g(), g(), c()],
-            {},
-            _sfdp_scale_factor_check(aten.div.Tensor),
-        ),
-        (
-            _sfdp_pattern_2,
-            _sfdp_replacement_2,
-            [g(), g(), g(), c()],
-            {},
-            _sfdp_scale_factor_check(aten.mul.Tensor),
-        ),
-        (
-            _sfdp_pattern_3,
-            _sfdp_replacement_3,
-            [g(), g(), g(), c()],
-            d,
-            _sfdp_scale_factor_check(aten.div.Tensor),
-        ),
-        (
-            _sfdp_pattern_4,
-            _sfdp_replacement_4,
-            [g(), g(), g(), c()],
-            d,
-            _sfdp_scale_factor_check(aten.mul.Tensor),
-        ),
-        (
-            _sfdp_pattern_5,
-            _sfdp_replacement_5,
-            [g(), g(), g(), b()],
-            {},
-            _sfdp_params_check,
-        ),
-        (
-            _sfdp_pattern_6,
-            _sfdp_replacement_6,
-            [g(), g(), g(), b()],
-            d,
-            _sfdp_params_check,
-        ),
-        (
-            _sfdp_pattern_7,
-            _sfdp_replacement_7,
-            [gp(), gp(), gp()],
-            d,
-            _sfdp_params_check,
-        ),
-        (
-            _sfdp_pattern_8,
-            _sfdp_replacement_8,
-            [gp(), gp(), gp()],
-            {},
-            _sfdp_params_check,
-        ),
-        (
-            _sfdp_pattern_9,
-            _sfdp_replacement_9,
-            [gp(), gp(), gp()],
-            d,
-            _sfdp_params_check,
-        ),
-        (
-            _sfdp_pattern_10,
-            _sfdp_replacement_10,
-            [gp(), gp(), gp()],
-            {},
-            _sfdp_params_check,
-        ),
-        (
-            _sfdp_pattern_11,
-            _sfdp_replacement_11,
-            [g(), g(), g(), c()],
-            {},
-            _sfdp_scale_factor_check(aten.div.Tensor),
-        ),
-        (
-            _sfdp_pattern_12,
-            _sfdp_replacement_12,
-            [g(), g(), g(), c()],
-            d,
-            _sfdp_scale_factor_check(aten.div.Tensor),
-        ),
-    ]:
-        # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
-        # gets serialized to a python file and does not require tracing at runtime.
-        assert isinstance(workaround, dict)
-        training_args = [*args, *workaround.values()]
-        name = pattern.__name__
+    # softmax will generate a dtype conversion on inputs if they are in half,
+    # but will not in float, so we generate a pattern for both
+    for dtype in [torch.float, torch.half]:
+        g = functools.partial(g_inp, dtype=dtype)
+        b = functools.partial(b_inp, dtype=dtype)
+        c = functools.partial(c_inp, dtype=dtype)
 
-        yield f"{name}_training", {
-            "search_fn": pattern,
-            "replace_fn": replacement,
-            "example_inputs": training_args,
-            "trace_fn": training_graph,
-            "pass_dict": patterns,
-            "extra_check": extra_check,
-            "scalar_workaround": workaround,
-        }
+        for pattern, replacement, args, workaround, extra_check in [
+            (
+                _sfdp_pattern_1,
+                _sfdp_replacement_1,
+                [g(), g(), g(), c()],
+                {},
+                _sfdp_scale_factor_check(aten.div.Tensor),
+            ),
+            (
+                _sfdp_pattern_2,
+                _sfdp_replacement_2,
+                [g(), g(), g(), c()],
+                {},
+                _sfdp_scale_factor_check(aten.mul.Tensor),
+            ),
+            (
+                _sfdp_pattern_3,
+                _sfdp_replacement_3,
+                [g(), g(), g(), c()],
+                d,
+                _sfdp_scale_factor_check(aten.div.Tensor),
+            ),
+            (
+                _sfdp_pattern_4,
+                _sfdp_replacement_4,
+                [g(), g(), g(), c()],
+                d,
+                _sfdp_scale_factor_check(aten.mul.Tensor),
+            ),
+            (
+                _sfdp_pattern_5,
+                _sfdp_replacement_5,
+                [g(), g(), g(), b()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_6,
+                _sfdp_replacement_6,
+                [g(), g(), g(), b()],
+                d,
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_7,
+                _sfdp_replacement_7,
+                [g(), g(), g()],
+                d,
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_8,
+                _sfdp_replacement_8,
+                [g(), g(), g()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_9,
+                _sfdp_replacement_9,
+                [g(), g(), g()],
+                d,
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_10,
+                _sfdp_replacement_10,
+                [g(), g(), g()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_11,
+                _sfdp_replacement_11,
+                [g(), g(), g(), c()],
+                {},
+                _sfdp_scale_factor_check(aten.div.Tensor),
+            ),
+            (
+                _sfdp_pattern_12,
+                _sfdp_replacement_12,
+                [g(), g(), g(), c()],
+                d,
+                _sfdp_scale_factor_check(aten.div.Tensor),
+            ),
+        ]:
+            # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
+            # gets serialized to a python file and does not require tracing at runtime.
+            assert isinstance(workaround, dict)
+            training_args = [*args, *workaround.values()]
+            name = pattern.__name__
 
-        if workaround:
-            assert len(workaround) == 1 and "dropout_p" in workaround
-            # functools.partial insufficient because we look at signature downstream
-            pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
-            replacement = partialize_and_update_signature(replacement, dropout_p=0.0)
-            workaround = {}
+            training_name = (
+                f"{name}_training" if dtype == torch.float else f"{name}_training_half"
+            )
+            yield training_name, {
+                "search_fn": pattern,
+                "replace_fn": replacement,
+                "example_inputs": training_args,
+                "trace_fn": training_graph,
+                "pass_dict": patterns,
+                "extra_check": extra_check,
+                "scalar_workaround": workaround,
+            }
 
-        yield f"{name}_inference", {
-            "search_fn": pattern,
-            "replace_fn": replacement,
-            "example_inputs": args,
-            "trace_fn": inference_graph,
-            "pass_dict": patterns,
-            "extra_check": extra_check,
-            "scalar_workaround": workaround,
-        }
+            if workaround:
+                assert len(workaround) == 1 and "dropout_p" in workaround
+                # functools.partial insufficient because we look at signature downstream
+                pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
+                replacement = partialize_and_update_signature(
+                    replacement, dropout_p=0.0
+                )
+                workaround = {}
+
+            inference_name = (
+                f"{name}_inference"
+                if dtype == torch.float
+                else f"{name}_inference_half"
+            )
+            yield inference_name, {
+                "search_fn": pattern,
+                "replace_fn": replacement,
+                "example_inputs": args,
+                "trace_fn": inference_graph,
+                "pass_dict": patterns,
+                "extra_check": extra_check,
+                "scalar_workaround": workaround,
+            }
 
 
 @functools.lru_cache(None)
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py
index 6d28df3..a684ab8 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py
@@ -93,3 +93,78 @@
 view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
 _sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_1_training_half = MultiOutputPattern([view_default_5,
+  view_default_9,
+  permute_default_4,
+  view_default_11,
+  None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_1_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py
index 2dd4f62..f5216fe 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py
@@ -38,6 +38,92 @@
 clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
 view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
 bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
+view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, div_Tensor_1, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
+view_default_8 = CallFunction(aten.view.default, sub_Tensor_1, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_10_training = MultiOutputPattern([view_default_5,
+  permute_default_6,
+  permute_default_9,
+  permute_default_11
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
+expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
+expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
 view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
 convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
 amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
@@ -79,7 +165,7 @@
 bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
 view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
 permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
-_sfdp_pattern_10_training = MultiOutputPattern([view_default_5,
+_sfdp_pattern_10_training_half = MultiOutputPattern([view_default_5,
   permute_default_6,
   permute_default_9,
   permute_default_11
@@ -112,4 +198,4 @@
 clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
 view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
-_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+_sfdp_pattern_10_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py
index a1e066a..b24a5fb 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py
@@ -108,3 +108,93 @@
 view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
 _sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_11_training_half = MultiOutputPattern([view_default_5,
+  permute_default_6,
+  permute_default_9,
+  permute_default_11,
+  None
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_11_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py
index 17526b2..4284e4d 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py
@@ -118,3 +118,103 @@
 view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
 _sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
+clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
+convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_12_training_half = MultiOutputPattern([view_default_5,
+  permute_default_6,
+  permute_default_9,
+  permute_default_11,
+  None,
+  None
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
+expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_12_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py
index 458da61..96dc0d4 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py
@@ -93,3 +93,78 @@
 view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
 _sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_1, mul_Tensor_2)
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor'))
+view_default_8 = CallFunction(aten.view.default, mul_Tensor_3, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_2_training_half = MultiOutputPattern([view_default_5,
+  view_default_9,
+  permute_default_4,
+  view_default_11,
+  None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_2_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py
index ea2c4ae..0869cb7 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py
@@ -103,3 +103,88 @@
 view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
 _sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
+clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
+convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_3_training_half = MultiOutputPattern([view_default_5,
+  view_default_9,
+  permute_default_4,
+  view_default_11,
+  None,
+  None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
+expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_3_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py
index b35c24a..bb3c617 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py
@@ -103,3 +103,88 @@
 view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
 _sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
+clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format)
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
+mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_5, mul_Tensor_6)
+convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+mul_Tensor_7 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor'))
+view_default_8 = CallFunction(aten.view.default, mul_Tensor_7, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_4_training_half = MultiOutputPattern([view_default_5,
+  view_default_9,
+  permute_default_4,
+  view_default_11,
+  None,
+  None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
+clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
+expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_4_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py
index c83d726..def8441 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py
@@ -95,3 +95,80 @@
 view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
 _sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
+add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored())
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_5_training_half = MultiOutputPattern([view_default_5,
+  view_default_9,
+  permute_default_4,
+  view_default_11,
+  None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
+add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_5_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py
index 304090f..075c675 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py
@@ -105,3 +105,90 @@
 view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
 _sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
+add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
+mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
+view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
+clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
+convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
+convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored())
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+_sfdp_pattern_6_training_half = MultiOutputPattern([view_default_5,
+  view_default_9,
+  permute_default_4,
+  view_default_11,
+  None,
+  None
+])
+
+
+expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
+view_default = CallFunction(aten.view.default, expand_default, Ignored())
+permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
+view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
+add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
+convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
+expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
+view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_6_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py
index 774c33d..c1ef7e6 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py
@@ -40,6 +40,102 @@
 view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
 bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
 view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
+mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
+view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
+clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, div_Tensor_1, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, div_Tensor_1, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
+div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_7_training = MultiOutputPattern([view_default_5,
+  permute_default_6,
+  permute_default_9,
+  permute_default_11,
+  None
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
 div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
 convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
 amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
@@ -87,7 +183,7 @@
 bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
 view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
 permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
-_sfdp_pattern_7_training = MultiOutputPattern([view_default_5,
+_sfdp_pattern_7_training_half = MultiOutputPattern([view_default_5,
   permute_default_6,
   permute_default_9,
   permute_default_11,
@@ -122,4 +218,4 @@
 clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
 view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
-_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+_sfdp_pattern_7_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py
index 99d844f..1752202 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py
@@ -38,6 +38,92 @@
 view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
 bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
 view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
+view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, div_Tensor_1, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
+div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
+view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_8_training = MultiOutputPattern([view_default_5,
+  permute_default_6,
+  permute_default_9,
+  permute_default_11
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
 div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
 convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
 amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
@@ -79,7 +165,7 @@
 bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
 view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
 permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
-_sfdp_pattern_8_training = MultiOutputPattern([view_default_5,
+_sfdp_pattern_8_training_half = MultiOutputPattern([view_default_5,
   permute_default_6,
   permute_default_9,
   permute_default_11
@@ -112,4 +198,4 @@
 clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
 view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
-_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+_sfdp_pattern_8_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py
index 9368356..abefe3b 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py
@@ -40,6 +40,102 @@
 clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
 view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
 bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
+mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
+mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
+convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
+view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
+permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
+bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
+convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
+view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
+convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
+convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
+mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
+mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
+clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
+mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, div_Tensor_1, _users=2)
+sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
+mul_Tensor_5 = CallFunction(aten.mul.Tensor, div_Tensor_1, sum_dim_IntList_1)
+sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
+view_default_8 = CallFunction(aten.view.default, sub_Tensor_1, Ignored(), _users=2)
+permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
+bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
+view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
+div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
+permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
+permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
+bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
+view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
+permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
+permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
+permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
+bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
+view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
+permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
+_sfdp_pattern_9_training = MultiOutputPattern([view_default_5,
+  permute_default_6,
+  permute_default_9,
+  permute_default_11,
+  None
+])
+
+
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
+expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored())
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
+view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
+amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
+sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
+exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
+sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
+div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
+clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
+convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
+expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
+view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
+permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
+expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
+clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
+view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
+bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
+_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+
+
+rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
+gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
+permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
+div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
+expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
+clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
+view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
+permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
+permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
+expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
+clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
+view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
+bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
 view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
 convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
 amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
@@ -87,7 +183,7 @@
 bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
 view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
 permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
-_sfdp_pattern_9_training = MultiOutputPattern([view_default_5,
+_sfdp_pattern_9_training_half = MultiOutputPattern([view_default_5,
   permute_default_6,
   permute_default_9,
   permute_default_11,
@@ -122,4 +218,4 @@
 clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
 view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
 bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
-_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
+_sfdp_pattern_9_inference_half = CallFunction(aten.view.default, bmm_default_1, Ignored())
diff --git a/torch/_inductor/fx_passes/serialized_patterns/central_index.py b/torch/_inductor/fx_passes/serialized_patterns/central_index.py
index 9e59d0b..02eae1d 100644
--- a/torch/_inductor/fx_passes/serialized_patterns/central_index.py
+++ b/torch/_inductor/fx_passes/serialized_patterns/central_index.py
@@ -2,18 +2,18 @@
 # To re-generate, run:
 # cd ~/pytorch && python
 # torchgen/fuse_attention_patterns/gen_attention_patterns.py
-from ._sfdp_pattern_1 import (_sfdp_pattern_1_training, _sfdp_pattern_1_inference)
-from ._sfdp_pattern_2 import (_sfdp_pattern_2_training, _sfdp_pattern_2_inference)
-from ._sfdp_pattern_3 import (_sfdp_pattern_3_training, _sfdp_pattern_3_inference)
-from ._sfdp_pattern_4 import (_sfdp_pattern_4_training, _sfdp_pattern_4_inference)
-from ._sfdp_pattern_5 import (_sfdp_pattern_5_training, _sfdp_pattern_5_inference)
-from ._sfdp_pattern_6 import (_sfdp_pattern_6_training, _sfdp_pattern_6_inference)
-from ._sfdp_pattern_7 import (_sfdp_pattern_7_training, _sfdp_pattern_7_inference)
-from ._sfdp_pattern_8 import (_sfdp_pattern_8_training, _sfdp_pattern_8_inference)
-from ._sfdp_pattern_9 import (_sfdp_pattern_9_training, _sfdp_pattern_9_inference)
-from ._sfdp_pattern_10 import (_sfdp_pattern_10_training, _sfdp_pattern_10_inference)
-from ._sfdp_pattern_11 import (_sfdp_pattern_11_training, _sfdp_pattern_11_inference)
-from ._sfdp_pattern_12 import (_sfdp_pattern_12_training, _sfdp_pattern_12_inference)
+from ._sfdp_pattern_1 import (_sfdp_pattern_1_training, _sfdp_pattern_1_inference, _sfdp_pattern_1_training_half, _sfdp_pattern_1_inference_half)
+from ._sfdp_pattern_2 import (_sfdp_pattern_2_training, _sfdp_pattern_2_inference, _sfdp_pattern_2_training_half, _sfdp_pattern_2_inference_half)
+from ._sfdp_pattern_3 import (_sfdp_pattern_3_training, _sfdp_pattern_3_inference, _sfdp_pattern_3_training_half, _sfdp_pattern_3_inference_half)
+from ._sfdp_pattern_4 import (_sfdp_pattern_4_training, _sfdp_pattern_4_inference, _sfdp_pattern_4_training_half, _sfdp_pattern_4_inference_half)
+from ._sfdp_pattern_5 import (_sfdp_pattern_5_training, _sfdp_pattern_5_inference, _sfdp_pattern_5_training_half, _sfdp_pattern_5_inference_half)
+from ._sfdp_pattern_6 import (_sfdp_pattern_6_training, _sfdp_pattern_6_inference, _sfdp_pattern_6_training_half, _sfdp_pattern_6_inference_half)
+from ._sfdp_pattern_7 import (_sfdp_pattern_7_training, _sfdp_pattern_7_inference, _sfdp_pattern_7_training_half, _sfdp_pattern_7_inference_half)
+from ._sfdp_pattern_8 import (_sfdp_pattern_8_training, _sfdp_pattern_8_inference, _sfdp_pattern_8_training_half, _sfdp_pattern_8_inference_half)
+from ._sfdp_pattern_9 import (_sfdp_pattern_9_training, _sfdp_pattern_9_inference, _sfdp_pattern_9_training_half, _sfdp_pattern_9_inference_half)
+from ._sfdp_pattern_10 import (_sfdp_pattern_10_training, _sfdp_pattern_10_inference, _sfdp_pattern_10_training_half, _sfdp_pattern_10_inference_half)
+from ._sfdp_pattern_11 import (_sfdp_pattern_11_training, _sfdp_pattern_11_inference, _sfdp_pattern_11_training_half, _sfdp_pattern_11_inference_half)
+from ._sfdp_pattern_12 import (_sfdp_pattern_12_training, _sfdp_pattern_12_inference, _sfdp_pattern_12_training_half, _sfdp_pattern_12_inference_half)
 
 central_index = {
     '_sfdp_pattern_1_training': _sfdp_pattern_1_training,
@@ -40,6 +40,30 @@
     '_sfdp_pattern_11_inference': _sfdp_pattern_11_inference,
     '_sfdp_pattern_12_training': _sfdp_pattern_12_training,
     '_sfdp_pattern_12_inference': _sfdp_pattern_12_inference,
+    '_sfdp_pattern_1_training_half': _sfdp_pattern_1_training_half,
+    '_sfdp_pattern_1_inference_half': _sfdp_pattern_1_inference_half,
+    '_sfdp_pattern_2_training_half': _sfdp_pattern_2_training_half,
+    '_sfdp_pattern_2_inference_half': _sfdp_pattern_2_inference_half,
+    '_sfdp_pattern_3_training_half': _sfdp_pattern_3_training_half,
+    '_sfdp_pattern_3_inference_half': _sfdp_pattern_3_inference_half,
+    '_sfdp_pattern_4_training_half': _sfdp_pattern_4_training_half,
+    '_sfdp_pattern_4_inference_half': _sfdp_pattern_4_inference_half,
+    '_sfdp_pattern_5_training_half': _sfdp_pattern_5_training_half,
+    '_sfdp_pattern_5_inference_half': _sfdp_pattern_5_inference_half,
+    '_sfdp_pattern_6_training_half': _sfdp_pattern_6_training_half,
+    '_sfdp_pattern_6_inference_half': _sfdp_pattern_6_inference_half,
+    '_sfdp_pattern_7_training_half': _sfdp_pattern_7_training_half,
+    '_sfdp_pattern_7_inference_half': _sfdp_pattern_7_inference_half,
+    '_sfdp_pattern_8_training_half': _sfdp_pattern_8_training_half,
+    '_sfdp_pattern_8_inference_half': _sfdp_pattern_8_inference_half,
+    '_sfdp_pattern_9_training_half': _sfdp_pattern_9_training_half,
+    '_sfdp_pattern_9_inference_half': _sfdp_pattern_9_inference_half,
+    '_sfdp_pattern_10_training_half': _sfdp_pattern_10_training_half,
+    '_sfdp_pattern_10_inference_half': _sfdp_pattern_10_inference_half,
+    '_sfdp_pattern_11_training_half': _sfdp_pattern_11_training_half,
+    '_sfdp_pattern_11_inference_half': _sfdp_pattern_11_inference_half,
+    '_sfdp_pattern_12_training_half': _sfdp_pattern_12_training_half,
+    '_sfdp_pattern_12_inference_half': _sfdp_pattern_12_inference_half,
 }
 
 
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 461e8bc..def3ebe 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -4878,13 +4878,12 @@
     head_dim = query.size(3)
 
     max_seqlen_batch_k = key.size(2)
+
     if device_hint(query) == "cpu":
-        Nnz_q = batch_size * max_seqlen_batch_q
-        query_t = query.transpose(1, 2)
-        query_reshaped = query_t.reshape(Nnz_q, num_heads, head_dim)
-        attention = torch.empty_like(query_reshaped, device=query.device)
-        attention = attention.view(
-            batch_size, max_seqlen_batch_q, num_heads, head_dim
+        attention = torch.empty(
+            (batch_size, max_seqlen_batch_q, num_heads, head_dim),
+            dtype=query.dtype,
+            device=query.device,
         ).transpose(1, 2)
         logsumexp = torch.empty(
             (
@@ -4977,6 +4976,12 @@
     philox_offset: Tensor,
     scale: Optional[float] = None,
 ):
+    if device_hint(query) != "cpu":
+        grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2)
+        grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
+        grad_v = torch.empty_like(value.transpose(1, 2)).transpose(1, 2)
+        return grad_q, grad_k, grad_v
+
     batch_size = query.size(0)
     num_heads = query.size(1)
     head_dim = query.size(3)