Add support for inline_asm_elementwise in Inductor lowerings (#129846)

This doesn't actually expose `inline_asm_elementwise` from any public API, but makes it pretty easy to register a lowering for a custom op that uses it.

<img width="667" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/f125f4bb-4f8c-46e7-8e06-925f37ed2930">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129846
Approved by: https://github.com/shunting314
diff --git a/test/inductor/test_custom_lowering.py b/test/inductor/test_custom_lowering.py
index 7308ac1..90a6a01 100644
--- a/test/inductor/test_custom_lowering.py
+++ b/test/inductor/test_custom_lowering.py
@@ -1,13 +1,15 @@
 # Owner(s): ["module: inductor"]
 
 import unittest
+from functools import partial
 
 import torch
 
 from torch._inductor.ir import Pointwise
-from torch._inductor.lowering import register_lowering
+from torch._inductor.lowering import make_pointwise, register_lowering
 from torch._inductor.test_case import TestCase as InductorTestCase
 from torch._inductor.virtualized import ops
+from torch.testing._internal.common_utils import skipIfRocm
 
 from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 
@@ -27,6 +29,7 @@
             "test_inductor_ops", "IMPL", "Meta"
         )
         cls._register_jagged_to_padded_dense()
+        cls._register_asm_op()
 
     @classmethod
     def tearDown(cls):
@@ -97,6 +100,39 @@
         cls.impl_meta.impl("jagged_to_padded_dense", j2pd_meta)
         cls.impl_cuda.impl("jagged_to_padded_dense", j2pd_cuda)
 
+    @classmethod
+    def _register_asm_op(cls):
+        # Approximation of fbgemm.jagged_to_padded_dense_forward
+        cls.test_inductor_ops.define("tanh_approx(Tensor input) -> Tensor")
+
+        def tanh_approx_meta(inp):
+            return torch.tanh(inp)
+
+        cls.impl_meta.impl("tanh_approx", tanh_approx_meta)
+
+        def tanh_approx_lowering(inp):
+            fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
+            return make_pointwise(fn)(inp)
+
+        register_lowering(
+            torch.ops.test_inductor_ops.tanh_approx, type_promotion_kind=None
+        )(tanh_approx_lowering)
+
+        cls.test_inductor_ops.define("add_custom(Tensor a, Tensor b) -> Tensor")
+
+        def add_custom(a, b):
+            return a + b
+
+        cls.impl_meta.impl("add_custom", add_custom)
+
+        def add_custom_lowering(a, b):
+            fn = partial(ops.inline_asm_elementwise, asm="add.f32 $0, $1, $2;")
+            return make_pointwise(fn)(a, b)
+
+        register_lowering(
+            torch.ops.test_inductor_ops.add_custom, type_promotion_kind=None
+        )(add_custom_lowering)
+
     @unittest.skipIf(not HAS_CUDA, "CUDA needed")
     def test_jagged_to_padded_dense_sanity_cuda(self):
         def fn(inp, offsets, max_seq_len):
@@ -143,6 +179,33 @@
             fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len)
         )
 
+    @unittest.skipIf(not HAS_CUDA, "CUDA needed")
+    @skipIfRocm
+    def test_tanh_approx(self):
+        def fn(inp):
+            return torch.ops.test_inductor_ops.tanh_approx(inp)
+
+        inp = torch.randn(32, device="cuda")
+        fn_opt = torch.compile(fn)
+
+        a = torch.tanh(inp)
+        b = fn_opt(inp)
+        self.assertEqual(a, b)
+
+    @unittest.skipIf(not HAS_CUDA, "CUDA needed")
+    @skipIfRocm
+    def test_multi_inp_asm(self):
+        def fn(a, b):
+            return torch.ops.test_inductor_ops.add_custom(a, b)
+
+        a = torch.randn(32, device="cuda")
+        b = torch.randn(32, device="cuda")
+        fn_opt = torch.compile(fn)
+
+        out1 = a + b
+        out2 = fn_opt(a, b)
+        self.assertEqual(out1, out2)
+
 
 if __name__ == "__main__":
     from torch._inductor.test_case import run_tests
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
index b998c41..848d713 100644
--- a/test/inductor/test_flex_attention.py
+++ b/test/inductor/test_flex_attention.py
@@ -1009,6 +1009,42 @@
         self.run_test(bias_mod)
 
     @supported_platform
+    def test_autograd_function_in_score_mod(self):
+        class ApplyMask(torch.autograd.Function):
+            generate_vmap_rule = True
+
+            @staticmethod
+            def forward(a, mask):
+                return torch.where(mask, a, -float("inf"))
+
+            @staticmethod
+            def setup_context(ctx, inputs, output):
+                _, mask = inputs
+                ctx.mark_non_differentiable(mask)
+                pass
+
+            @staticmethod
+            def backward(ctx, i):
+                return i, None
+
+        def score_mod(score, b, h, q, kv):
+            return ApplyMask.apply(score, q <= kv)
+
+        func = torch.compile(_flex_attention, fullgraph=True)
+
+        q, k, v = (
+            torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True)
+            for _ in range(3)
+        )
+
+        # Just checking that it runs
+        func(q, k, v)
+
+        # expectedFailure
+        # This doesn't work due to vmap + autograd.Function + torch.compile not composing
+        # self.run_test(score_mod)
+
+    @supported_platform
     @common_utils.parametrize("dtype", test_dtypes)
     @common_utils.parametrize("score_mod", [_identity, _causal])
     def test_logsumexp_correctness(self, dtype, score_mod):
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 80f5d3e..a4407ae 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -722,6 +722,16 @@
         return f"tl.where({a}, {b}, {c})"
 
     @staticmethod
+    def inline_asm_elementwise(
+        *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1
+    ):
+        triton_type = triton_compute_type(dtype)
+        input_refs = ", ".join([str(i) for i in inputs])
+        if constraints is None:
+            constraints = ", ".join(["=r"] + ["r" for _ in inputs])
+        return f"tl.inline_asm_elementwise('{asm}', '{constraints}', [{input_refs}], dtype={triton_type}, is_pure={is_pure}, pack={pack})"  # noqa: B950
+
+    @staticmethod
     def cos(x):
         return f"tl_math.cos({x})"
 
diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
index c7b215a..4c716b0 100644
--- a/torch/_inductor/kernel/flex_attention.py
+++ b/torch/_inductor/kernel/flex_attention.py
@@ -337,8 +337,8 @@
         # update pointers
         indices_idx = start_n // SPARSE_KV_MULTIPLE
 
-        cur_block = tl.load(kv_indices + indices_idx)
-        next_block = tl.load(kv_indices + indices_idx + 1)
+        cur_block = tl.load(kv_indices + indices_idx, eviction_policy="evict_last")
+        next_block = tl.load(kv_indices + indices_idx + 1, eviction_policy="evict_last")
         needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
         jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N