Add support for inline_asm_elementwise in Inductor lowerings (#129846)
This doesn't actually expose `inline_asm_elementwise` from any public API, but makes it pretty easy to register a lowering for a custom op that uses it.
<img width="667" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/f125f4bb-4f8c-46e7-8e06-925f37ed2930">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129846
Approved by: https://github.com/shunting314
diff --git a/test/inductor/test_custom_lowering.py b/test/inductor/test_custom_lowering.py
index 7308ac1..90a6a01 100644
--- a/test/inductor/test_custom_lowering.py
+++ b/test/inductor/test_custom_lowering.py
@@ -1,13 +1,15 @@
# Owner(s): ["module: inductor"]
import unittest
+from functools import partial
import torch
from torch._inductor.ir import Pointwise
-from torch._inductor.lowering import register_lowering
+from torch._inductor.lowering import make_pointwise, register_lowering
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.virtualized import ops
+from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
@@ -27,6 +29,7 @@
"test_inductor_ops", "IMPL", "Meta"
)
cls._register_jagged_to_padded_dense()
+ cls._register_asm_op()
@classmethod
def tearDown(cls):
@@ -97,6 +100,39 @@
cls.impl_meta.impl("jagged_to_padded_dense", j2pd_meta)
cls.impl_cuda.impl("jagged_to_padded_dense", j2pd_cuda)
+ @classmethod
+ def _register_asm_op(cls):
+ # Approximation of fbgemm.jagged_to_padded_dense_forward
+ cls.test_inductor_ops.define("tanh_approx(Tensor input) -> Tensor")
+
+ def tanh_approx_meta(inp):
+ return torch.tanh(inp)
+
+ cls.impl_meta.impl("tanh_approx", tanh_approx_meta)
+
+ def tanh_approx_lowering(inp):
+ fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
+ return make_pointwise(fn)(inp)
+
+ register_lowering(
+ torch.ops.test_inductor_ops.tanh_approx, type_promotion_kind=None
+ )(tanh_approx_lowering)
+
+ cls.test_inductor_ops.define("add_custom(Tensor a, Tensor b) -> Tensor")
+
+ def add_custom(a, b):
+ return a + b
+
+ cls.impl_meta.impl("add_custom", add_custom)
+
+ def add_custom_lowering(a, b):
+ fn = partial(ops.inline_asm_elementwise, asm="add.f32 $0, $1, $2;")
+ return make_pointwise(fn)(a, b)
+
+ register_lowering(
+ torch.ops.test_inductor_ops.add_custom, type_promotion_kind=None
+ )(add_custom_lowering)
+
@unittest.skipIf(not HAS_CUDA, "CUDA needed")
def test_jagged_to_padded_dense_sanity_cuda(self):
def fn(inp, offsets, max_seq_len):
@@ -143,6 +179,33 @@
fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len)
)
+ @unittest.skipIf(not HAS_CUDA, "CUDA needed")
+ @skipIfRocm
+ def test_tanh_approx(self):
+ def fn(inp):
+ return torch.ops.test_inductor_ops.tanh_approx(inp)
+
+ inp = torch.randn(32, device="cuda")
+ fn_opt = torch.compile(fn)
+
+ a = torch.tanh(inp)
+ b = fn_opt(inp)
+ self.assertEqual(a, b)
+
+ @unittest.skipIf(not HAS_CUDA, "CUDA needed")
+ @skipIfRocm
+ def test_multi_inp_asm(self):
+ def fn(a, b):
+ return torch.ops.test_inductor_ops.add_custom(a, b)
+
+ a = torch.randn(32, device="cuda")
+ b = torch.randn(32, device="cuda")
+ fn_opt = torch.compile(fn)
+
+ out1 = a + b
+ out2 = fn_opt(a, b)
+ self.assertEqual(out1, out2)
+
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
index b998c41..848d713 100644
--- a/test/inductor/test_flex_attention.py
+++ b/test/inductor/test_flex_attention.py
@@ -1009,6 +1009,42 @@
self.run_test(bias_mod)
@supported_platform
+ def test_autograd_function_in_score_mod(self):
+ class ApplyMask(torch.autograd.Function):
+ generate_vmap_rule = True
+
+ @staticmethod
+ def forward(a, mask):
+ return torch.where(mask, a, -float("inf"))
+
+ @staticmethod
+ def setup_context(ctx, inputs, output):
+ _, mask = inputs
+ ctx.mark_non_differentiable(mask)
+ pass
+
+ @staticmethod
+ def backward(ctx, i):
+ return i, None
+
+ def score_mod(score, b, h, q, kv):
+ return ApplyMask.apply(score, q <= kv)
+
+ func = torch.compile(_flex_attention, fullgraph=True)
+
+ q, k, v = (
+ torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True)
+ for _ in range(3)
+ )
+
+ # Just checking that it runs
+ func(q, k, v)
+
+ # expectedFailure
+ # This doesn't work due to vmap + autograd.Function + torch.compile not composing
+ # self.run_test(score_mod)
+
+ @supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", [_identity, _causal])
def test_logsumexp_correctness(self, dtype, score_mod):
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 80f5d3e..a4407ae 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -722,6 +722,16 @@
return f"tl.where({a}, {b}, {c})"
@staticmethod
+ def inline_asm_elementwise(
+ *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1
+ ):
+ triton_type = triton_compute_type(dtype)
+ input_refs = ", ".join([str(i) for i in inputs])
+ if constraints is None:
+ constraints = ", ".join(["=r"] + ["r" for _ in inputs])
+ return f"tl.inline_asm_elementwise('{asm}', '{constraints}', [{input_refs}], dtype={triton_type}, is_pure={is_pure}, pack={pack})" # noqa: B950
+
+ @staticmethod
def cos(x):
return f"tl_math.cos({x})"
diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
index c7b215a..4c716b0 100644
--- a/torch/_inductor/kernel/flex_attention.py
+++ b/torch/_inductor/kernel/flex_attention.py
@@ -337,8 +337,8 @@
# update pointers
indices_idx = start_n // SPARSE_KV_MULTIPLE
- cur_block = tl.load(kv_indices + indices_idx)
- next_block = tl.load(kv_indices + indices_idx + 1)
+ cur_block = tl.load(kv_indices + indices_idx, eviction_policy="evict_last")
+ next_block = tl.load(kv_indices + indices_idx + 1, eviction_policy="evict_last")
needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N