| # Owner(s): ["module: inductor"] |
| # flake8: noqa: E731 |
| # Skip do not assign a lambda expression, use a def |
| import functools |
| from unittest.mock import patch |
| |
| import torch |
| import torch._dynamo.testing |
| import torch._inductor.test_case |
| from torch._higher_order_ops.triton_kernel_wrap import ( |
| generate_ttir, |
| triton_kernel_wrapper_functional, |
| triton_kernel_wrapper_mutation, |
| ) |
| from torch._inductor import metrics |
| from torch._inductor.utils import run_and_get_code |
| from torch._library import capture_triton |
| from torch.testing._internal import common_utils |
| from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu, TEST_WITH_ROCM |
| from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU |
| |
| # Defines all the kernels for tests |
| from torch.testing._internal.triton_utils import * # noqa: F403 |
| |
| |
| if HAS_GPU: |
| import triton |
| from triton import language as tl |
| |
| if not TEST_WITH_ROCM: |
| if HAS_CUDA: |
| from triton.language.extra.cuda.libdevice import ( |
| fast_dividef, |
| fast_dividef as my_fast_dividef, |
| ) |
| elif HAS_XPU: |
| from triton.language.extra.intel.libdevice import ( |
| fast_dividef, |
| fast_dividef as my_fast_dividef, |
| ) |
| |
| # Define shared triton constants here. |
| CONSTANT_C: tl.constexpr = 4 |
| STRING_CONSTANT_C: tl.constexpr = "CONSTANT_C" |
| BOOL_CONSTANT_C: tl.constexpr = True |
| |
| |
| class KernelTests(torch._inductor.test_case.TestCase): |
| @requires_gpu |
| def test_triton_kernel_with_kernel_param(self): |
| @triton.jit |
| def pass_kernel(kernel): |
| pass |
| |
| @torch.compile(backend="eager") |
| def f(x): |
| grid = (x.numel(),) |
| pass_kernel[grid](kernel=x) |
| |
| t1 = torch.rand(5, device=GPU_TYPE) |
| f(t1) |
| # No need to assert anything, the goal is to make sure dynamo does |
| # not crash |
| |
| @requires_gpu |
| def test_triton_kernel_higher_order_func(self): |
| from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table |
| |
| add_kernel_id = kernel_side_table.add_kernel(add_kernel) |
| |
| t1 = torch.rand(5, device=GPU_TYPE) |
| t2 = torch.rand(5, device=GPU_TYPE) |
| |
| torch_add = t1 + t2 |
| |
| # Test higher order function with mutation |
| output = torch.zeros_like(t1) |
| n_elements = output.numel() |
| constant_args_idx = kernel_side_table.add_constant_args( |
| {"n_elements": n_elements, "BLOCK_SIZE": 16} |
| ) |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| triton_kernel_wrapper_mutation( |
| kernel_idx=add_kernel_id, |
| constant_args_idx=constant_args_idx, |
| grid=[grid], |
| kwargs={ |
| "in_ptr0": t1, |
| "in_ptr1": t2, |
| "out_ptr": output, |
| }, |
| ) |
| self.assertEqual(output, torch_add) |
| # Make sure it is modified |
| self.assertNotEqual(output, torch.zeros_like(t1)) |
| |
| # Test higher order function without mutation |
| output = torch.zeros_like(t1) |
| out_dict = triton_kernel_wrapper_functional( |
| kernel_idx=add_kernel_id, |
| constant_args_idx=constant_args_idx, |
| grid=[grid], |
| kwargs={ |
| "in_ptr0": t1, |
| "in_ptr1": t2, |
| "out_ptr": output, |
| }, |
| tensors_to_clone=["in_ptr0", "in_ptr1", "out_ptr"], |
| ) |
| self.assertEqual(out_dict["out_ptr"], torch_add) |
| # Make sure it is NOT modified |
| self.assertEqual(output, torch.zeros_like(t1)) |
| |
| @requires_gpu |
| def test_triton_kernel_functionalize(self): |
| from functorch import make_fx |
| from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table |
| from torch._subclasses.functional_tensor import ( |
| CppFunctionalizeAPI, |
| FunctionalTensorMode, |
| PythonFunctionalizeAPI, |
| ) |
| |
| kernel_side_table.reset_table() |
| |
| def f(x, output): |
| out = triton_kernel_wrapper_functional( |
| kernel_idx=kernel_side_table.add_kernel(mul2_kernel), |
| constant_args_idx=kernel_side_table.add_constant_args( |
| {"n_elements": output.numel(), "BLOCK_SIZE": 16} |
| ), |
| grid=[(x.numel(),)], |
| kwargs={ |
| "in_ptr0": x, |
| "out_ptr": output, |
| }, |
| tensors_to_clone=["in_ptr0", "out_ptr"], |
| ) |
| return out["out_ptr"] |
| |
| t1 = torch.rand(5, device=GPU_TYPE) |
| t2 = torch.rand(5, device=GPU_TYPE) |
| with FunctionalTensorMode(): |
| gm = make_fx(PythonFunctionalizeAPI().functionalize(f))(t1, t2) |
| # Make sure t2 was not modified |
| self.assertNotEqual(gm(t1, t2), t2) |
| |
| gm = make_fx(CppFunctionalizeAPI().functionalize(f))(t1, t2) |
| # Make sure t2 was not modified |
| self.assertNotEqual(gm(t1, t2), t2) |
| |
| gm = make_fx(torch.func.functionalize(f))(t1, t2) |
| # Make sure t2 was not modified |
| self.assertNotEqual(gm(t1, t2), t2) |
| |
| gm = make_fx(f, tracing_mode="fake")(t1, t2) |
| self.assertExpectedInline( |
| gm.code.strip(), |
| """\ |
| def forward(self, x_1, output_1): |
| triton_kernel_wrapper_functional_proxy = torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(5,)], kwargs = {'in_ptr0': x_1, 'out_ptr': output_1}, tensors_to_clone = ['in_ptr0', 'out_ptr']); x_1 = output_1 = None |
| getitem = triton_kernel_wrapper_functional_proxy['in_ptr0'] |
| getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr']; triton_kernel_wrapper_functional_proxy = None |
| return getitem_1""", |
| ) |
| |
| @requires_gpu |
| def test_triton_kernel_mutation_type(self): |
| from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table |
| from torch._subclasses.fake_tensor import FakeTensorMode |
| from torch._subclasses.functional_tensor import ( |
| FunctionalTensor, |
| FunctionalTensorMode, |
| ) |
| |
| def prep(): |
| x = torch.ones(4, device=GPU_TYPE, requires_grad=True) |
| with FunctionalTensorMode(): |
| x_func = FunctionalTensor.to_functional(x) |
| self.assertTrue(torch._is_functional_tensor(x_func.elem)) |
| return x_func |
| |
| # normal mutation only |
| with FakeTensorMode(): |
| x_func = prep() |
| |
| with FunctionalTensorMode(): |
| x_func.mul_(2) |
| |
| self.assertFalse( |
| torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem) |
| ) |
| |
| # triton kernel mutation only |
| with FakeTensorMode(): |
| x_func = prep() |
| |
| with FunctionalTensorMode(): |
| triton_kernel_wrapper_mutation( |
| kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel), |
| constant_args_idx=kernel_side_table.add_constant_args( |
| {"n_elements": x_func.numel(), "BLOCK_SIZE": 16} |
| ), |
| grid=[(x_func.numel(),)], |
| kwargs={ |
| "ptr": x_func, |
| }, |
| ) |
| |
| self.assertTrue( |
| torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem) |
| ) |
| |
| # normal mutation + triton kernel mutation |
| with FakeTensorMode(): |
| x_func = prep() |
| |
| with FunctionalTensorMode(): |
| x_func.mul_(2) |
| triton_kernel_wrapper_mutation( |
| kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel), |
| constant_args_idx=kernel_side_table.add_constant_args( |
| {"n_elements": x_func.numel(), "BLOCK_SIZE": 16} |
| ), |
| grid=[(x_func.numel(),)], |
| kwargs={ |
| "ptr": x_func, |
| }, |
| ) |
| |
| self.assertFalse( |
| torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem) |
| ) |
| |
| @requires_gpu |
| @common_utils.parametrize("dynamic", [False, True]) |
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) |
| def test_triton_kernel_with_views(self, dynamic, backend): |
| def call_triton_take_view(x: torch.Tensor): |
| output = torch.zeros_like(x) |
| n_elements = output.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) |
| return output |
| |
| def call_triton_return_view(x: torch.Tensor): |
| output = torch.zeros_like(x) |
| n_elements = output.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) |
| return output.view(4, 4) |
| |
| t = torch.rand(4, 4, device=GPU_TYPE) |
| t_view = t.view(16) |
| |
| compiled_func = torch.compile( |
| call_triton_take_view, backend=backend, fullgraph=True, dynamic=dynamic |
| ) |
| self.assertEqual(2 * t_view, compiled_func(t_view)) |
| self.assertEqual(2 * t, compiled_func(t_view).view(4, 4)) |
| |
| compiled_func = torch.compile( |
| call_triton_return_view, backend=backend, fullgraph=True, dynamic=dynamic |
| ) |
| self.assertEqual(2 * t_view, compiled_func(t).view(16)) |
| self.assertEqual(2 * t, compiled_func(t)) |
| |
| @requires_gpu |
| @common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad]) |
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) |
| def test_triton_kernel_with_grad_option(self, grad_fn, backend): |
| def call_triton(x: torch.Tensor): |
| with grad_fn(): |
| output = torch.zeros_like(x) |
| n_elements = output.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) |
| return output |
| |
| t = torch.rand(5, device=GPU_TYPE) |
| compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True) |
| self.assertEqual(2 * t, compiled_func(t)) |
| |
| @requires_gpu |
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) |
| def test_triton_kernel_inner_triton_function(self, backend): |
| def f(x: torch.Tensor): |
| @triton.jit |
| def pow2_kernel( |
| in_ptr0, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| output = x * x |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| output = torch.zeros_like(x) |
| n_elements = output.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| pow2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) |
| return output |
| |
| t = torch.rand(5, device=GPU_TYPE) |
| |
| compiled_func = torch.compile(f, backend=backend, fullgraph=True) |
| # TODO(oulgen): NYI - Support this |
| # self.assertEqual(t * t, compiled_func(t)) |
| |
| @requires_gpu |
| @common_utils.parametrize("grad", [False, True]) |
| @common_utils.parametrize("dynamic", [False, True]) |
| @patch.object(torch._inductor.config, "implicit_fallbacks", False) |
| def test_triton_kernel_no_clones(self, grad, dynamic): |
| from torch._inductor.utils import run_and_get_code |
| |
| def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor): |
| n_elements = output.numel() |
| |
| tmp = torch.add(x, 1) |
| grid = (x.numel(),) |
| add_kernel.run( |
| x, y, output, n_elements, warmup=False, grid=grid, BLOCK_SIZE=16 |
| ) |
| |
| return output, tmp |
| |
| t1 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) |
| t2 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) |
| o1 = torch.zeros_like(t1, requires_grad=grad) |
| |
| torch_add = call_triton(t1, t2, o1) |
| metrics.reset() |
| o2 = torch.zeros_like(t1, requires_grad=grad) |
| test, codes = run_and_get_code( |
| torch.compile(call_triton, dynamic=dynamic), t1, t2, o2 |
| ) |
| if not grad: |
| self.assertEqual(metrics.generated_kernel_count, 1) |
| self.assertEqual(torch_add, test) |
| # These two asserts are not optimal since it requires original aten |
| # to be in the metadata, so there might be false negatives |
| self.assertTrue("aten.copy" not in codes[0]) |
| self.assertTrue("aten.clone" not in codes[0]) |
| # The following checks that there are only the tensor output is in |
| # the compiled graph |
| if dynamic and grad: |
| self.assertTrue("return (buf0, s0, )" in codes[0]) |
| else: |
| self.assertTrue("return (buf0, )" in codes[0]) |
| |
| @requires_gpu |
| def test_triton_kernel_caching(self): |
| from torch._inductor.utils import run_and_get_code |
| |
| def add_in_loop( |
| x: torch.Tensor, |
| y: torch.Tensor, |
| ): |
| output = torch.zeros_like(x) |
| n_elements = output.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| add_kernel_autotuned[grid](x, y, output, n_elements) |
| return output |
| |
| def call_triton_add( |
| x: torch.Tensor, |
| y: torch.Tensor, |
| ): |
| for i in range(4): |
| x = add_in_loop(x, y) |
| return x |
| |
| t1 = torch.ones(5, device=GPU_TYPE) |
| t2 = torch.ones(5, device=GPU_TYPE) |
| |
| test, (code,) = run_and_get_code(torch.compile(call_triton_add), t1, t2) |
| self.assertEqual(test, 5 * torch.ones(5, device=GPU_TYPE)) |
| self.assertTrue("add_kernel_autotuned_1.run" not in code) |
| |
| @requires_gpu |
| def test_triton_kernel_caching_duplicate(self): |
| from torch._inductor.utils import run_and_get_code |
| |
| class C: |
| @triton.jit |
| def pass_kernel( |
| in_ptr0, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| tl.store(out_ptr + offsets, x, mask=mask) |
| |
| class D: |
| @triton.jit |
| def pass_kernel( |
| in_ptr0, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| tl.store(out_ptr + offsets, x, mask=mask) |
| |
| def call_triton(x: torch.Tensor): |
| output1 = torch.zeros_like(x) |
| output2 = torch.zeros_like(x) |
| n_elements = output1.numel() |
| grid = (n_elements,) |
| C.pass_kernel[grid](x, output1, n_elements, BLOCK_SIZE=16) |
| D.pass_kernel[grid](x, output2, n_elements, BLOCK_SIZE=16) |
| return output1 + output2 |
| |
| t = torch.ones(5, device=GPU_TYPE) |
| test, (code,) = run_and_get_code(torch.compile(call_triton), t) |
| # Make sure we emitted two kernels here |
| self.assertTrue("pass_kernel_0.run" in code) |
| self.assertTrue("pass_kernel_1.run" in code) |
| |
| @requires_gpu |
| def test_triton_kernel_various_args(self): |
| @triton.autotune( |
| configs=[triton.Config({"BLOCK_SIZE": 128})], |
| key=[], |
| ) |
| @triton.jit |
| def pass_kernel( |
| out_ptr, |
| n_elements, |
| dummy_None, |
| dummy_empty, |
| dummy_float, |
| BLOCK_SIZE: "tl.constexpr", |
| RANDOM_SIZE: "tl.constexpr", |
| ): |
| pass |
| |
| @torch.compile |
| def call_triton(output): |
| n_elements = output.numel() |
| grid = (n_elements,) |
| pass_kernel[grid]( |
| output, |
| n_elements, |
| None, |
| torch.empty_like(output), |
| 3.1415926, |
| RANDOM_SIZE=0, |
| ) |
| return output |
| |
| output = torch.randn(5, device=GPU_TYPE) |
| # Make sure this does not crash |
| call_triton(output) |
| |
| @requires_gpu |
| @skipIfRocm |
| def test_triton_kernel_dependancies(self): |
| def call_triton( |
| x: torch.Tensor, |
| y: torch.Tensor, |
| ): |
| output = torch.zeros_like(x) |
| n_elements = output.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| add_kernel_autotuned[grid](x, y, output, n_elements) |
| output2 = torch.zeros_like(output) |
| add_kernel_autotuned[grid](output, y, output2, n_elements) |
| output3 = torch.add(output2, 1) |
| return output3 |
| |
| t1 = torch.rand(5, device=GPU_TYPE) |
| t2 = torch.rand(5, device=GPU_TYPE) |
| torch_result = call_triton(t1, t2) |
| compiled_result = torch.compile(call_triton)(t1, t2) |
| self.assertEqual(torch_result, compiled_result) |
| |
| @requires_gpu |
| def test_triton_kernel_reinplace_inplaceable_pass(self): |
| def call_triton( |
| x: torch.Tensor, |
| y: torch.Tensor, |
| ): |
| output = torch.zeros_like(x) |
| n_elements = output.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| add_kernel_autotuned[grid](x, y, output, n_elements) |
| add_kernel_autotuned[grid](output, x, output, n_elements) |
| return output |
| |
| t1 = torch.rand(5, device=GPU_TYPE) |
| t2 = torch.rand(5, device=GPU_TYPE) |
| torch_result = call_triton(t1, t2) |
| compiled_result = torch.compile(call_triton)(t1, t2) |
| self.assertEqual(torch_result, compiled_result) |
| |
| @requires_gpu |
| @common_utils.parametrize("grad", [False, True]) |
| def test_triton_kernel_multi_kernel(self, grad): |
| @triton.jit |
| def mul2_and_add_and_zero_negatives_kernel( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ACTIVATION: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| indirection_kernel( |
| in_ptr0, |
| in_ptr0, |
| n_elements, |
| BLOCK_SIZE=BLOCK_SIZE, |
| ACTIVATION="mul2_inplace_kernel", |
| ) |
| indirection_kernel( |
| in_ptr1, |
| in_ptr1, |
| n_elements, |
| BLOCK_SIZE=BLOCK_SIZE, |
| ACTIVATION="mul2_inplace_kernel", |
| ) |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output = x + y |
| if ACTIVATION == "zero_negs": |
| output = zero_negs(output) |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| @torch.compile |
| def call_triton( |
| x: torch.Tensor, |
| y: torch.Tensor, |
| xi: torch.Tensor, |
| yi: torch.Tensor, |
| output: torch.Tensor, |
| outputi: torch.Tensor, |
| ): |
| n_elements = output.numel() |
| |
| grid = (x.numel(),) |
| mul2_and_add_and_zero_negatives_kernel[grid]( |
| x, y, output, n_elements, BLOCK_SIZE=16, ACTIVATION="zero_negs" |
| ) |
| mul2_and_add_and_zero_negatives_kernel[grid]( |
| xi, yi, outputi, n_elements, BLOCK_SIZE=16, ACTIVATION=None |
| ) |
| |
| return (output, outputi) |
| |
| t1 = torch.tensor( |
| [-2.0, -1.0, 0.0, 1.0, 2.0], device=GPU_TYPE, requires_grad=grad |
| ) |
| t2 = torch.tensor( |
| [-2.0, -1.0, 0.0, 1.0, 2.0], device=GPU_TYPE, requires_grad=grad |
| ) |
| float_result = 2 * t1 + 2 * t2 |
| float_result = float_result.where(float_result >= 0, 0.0) |
| |
| t1i = torch.randint(-2, 2, (5,), device=GPU_TYPE) |
| t2i = torch.randint(-2, 2, (5,), device=GPU_TYPE) |
| o = torch.zeros_like(t1, requires_grad=grad) |
| oi = torch.zeros_like(t1i) |
| int_result = 2 * t1i + 2 * t2i |
| |
| (result, resulti) = call_triton(t1, t2, t1i, t2i, o, oi) |
| self.assertEqual(float_result, result) |
| self.assertEqual(int_result, resulti) |
| |
| @requires_gpu |
| @skipIfXpu |
| @skipIfRocm |
| def test_triton_kernel_constants(self): |
| @triton.jit |
| def mulC_kernel( |
| in_ptr0, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| CONSTANT_NAME: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| if CONSTANT_NAME == STRING_CONSTANT_C: |
| output = CONSTANT_C * x |
| if BOOL_CONSTANT_C: |
| output *= CONSTANT_C |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| def call_triton( |
| x: torch.Tensor, |
| ): |
| output = torch.zeros_like(x) |
| n_elements = output.numel() |
| |
| grid = (x.numel(),) |
| mulC_kernel[grid]( |
| x, output, n_elements, BLOCK_SIZE=16, CONSTANT_NAME="CONSTANT_C" |
| ) |
| return output |
| |
| # Triton kernels capture global constants by their parse time value |
| # not runtime value |
| global CONSTANT_C |
| prev_c = CONSTANT_C |
| # If the behavior of triton kernels change, this test will fail |
| CONSTANT_C = 10 |
| assert CONSTANT_C != prev_c |
| |
| t = torch.randn(5, device=GPU_TYPE) |
| torch_result = call_triton(t) |
| compiled_result = torch.compile(call_triton)(t) |
| |
| self.assertEqual(torch_result, compiled_result) |
| |
| # reset back |
| CONSTANT_C = prev_c |
| |
| @requires_gpu |
| @common_utils.parametrize("grad", [False, True]) |
| @common_utils.parametrize("dynamic", [False, True]) |
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) |
| @common_utils.parametrize("grid_type", [1, 2, 3]) |
| def test_triton_kernel_autotune(self, grad, dynamic, backend, grid_type): |
| def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor): |
| n_elements = output.numel() |
| |
| def grid_fn(meta): |
| return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| |
| if grid_type == 1: |
| grid = (n_elements,) |
| elif grid_type == 2: |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| elif grid_type == 3: |
| grid = grid_fn |
| |
| add_kernel_autotuned[grid](x, y, output, n_elements) |
| return output |
| |
| t1 = torch.rand(256, device=GPU_TYPE, requires_grad=grad) |
| t2 = torch.rand(256, device=GPU_TYPE, requires_grad=grad) |
| output = torch.zeros_like(t1, requires_grad=grad) |
| |
| torch_add = call_triton(t1, t2, output) |
| compiled_func = torch.compile( |
| call_triton, backend=backend, fullgraph=True, dynamic=dynamic |
| ) |
| |
| output2 = torch.zeros_like(t1, requires_grad=grad) |
| self.assertEqual(compiled_func(t1, t2, output2), torch_add) |
| |
| @requires_gpu |
| @common_utils.parametrize("grad", [False, True]) |
| @common_utils.parametrize("dynamic", [False, True]) |
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) |
| @common_utils.parametrize("grid_type", [1, 2, 3]) |
| def test_triton_kernel_2d_autotune(self, grad, dynamic, backend, grid_type): |
| def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor): |
| x_elements = output.size()[0] |
| y_elements = output.size()[1] |
| |
| def grid_fn(meta): |
| return ( |
| triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]), |
| triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]), |
| ) |
| |
| if grid_type == 1: |
| grid = (x_elements, y_elements) |
| elif grid_type == 2: |
| grid = lambda meta: ( |
| triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]), |
| triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]), |
| ) |
| elif grid_type == 3: |
| grid = grid_fn |
| |
| add_kernel_2d_autotuned[grid](x, y, output, x_elements, y_elements) |
| return output |
| |
| t1 = torch.rand((512, 256), device=GPU_TYPE, requires_grad=grad) |
| t2 = torch.rand((512, 256), device=GPU_TYPE, requires_grad=grad) |
| output = torch.zeros_like(t1, requires_grad=grad) |
| |
| torch_result = call_triton(t1, t2, output) |
| compiled_func = torch.compile( |
| call_triton, backend=backend, fullgraph=True, dynamic=dynamic |
| ) |
| output2 = torch.zeros_like(t1, requires_grad=grad) |
| self.assertEqual(compiled_func(t1, t2, output2), torch_result) |
| |
| @requires_gpu |
| @common_utils.parametrize("dynamic", [False, True]) |
| def test_triton_kernel_tracing(self, dynamic): |
| def call_triton_add( |
| x: torch.Tensor, |
| y: torch.Tensor, |
| grid_type: int, |
| num=1, |
| positional=False, |
| autotuned=False, |
| ): |
| output = torch.empty_like(x) |
| n_elements = output.numel() |
| |
| def grid_fn(meta): |
| return (triton.cdiv(num, meta["BLOCK_SIZE"]),) |
| |
| if grid_type == 0: |
| grid = (x.numel(),) |
| elif grid_type == 1: |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| elif grid_type == 2: |
| grid = grid_fn |
| else: |
| grid = [x.numel()] |
| |
| if autotuned: |
| capture_triton(add_kernel_autotuned)[grid](x, y, output, n_elements) |
| else: |
| if positional: |
| capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) |
| else: |
| capture_triton(add_kernel)[grid]( |
| x, y, output, n_elements, BLOCK_SIZE=16 |
| ) |
| |
| return output |
| |
| t0 = torch.rand(5, device=GPU_TYPE, requires_grad=True) |
| t1 = torch.rand(5, device=GPU_TYPE, requires_grad=True) |
| t2 = torch.rand(5, device=GPU_TYPE, requires_grad=True) |
| t3 = torch.rand(5, device=GPU_TYPE, requires_grad=True) |
| torch_add = t2 + t3 |
| |
| tests = [ |
| functools.partial(call_triton_add, grid_type=0), |
| functools.partial(call_triton_add, grid_type=1), |
| functools.partial(call_triton_add, grid_type=1, num=1, positional=True), |
| functools.partial(call_triton_add, grid_type=2, num=200), |
| functools.partial(call_triton_add, grid_type=3), |
| functools.partial(call_triton_add, grid_type=0, autotuned=True), |
| functools.partial(call_triton_add, grid_type=1, num=1, autotuned=True), |
| functools.partial(call_triton_add, grid_type=2, num=200, autotuned=True), |
| functools.partial(call_triton_add, grid_type=3, autotuned=True), |
| ] |
| from functorch import make_fx |
| |
| tracing_mode = "symbolic" if dynamic else "fake" |
| |
| for test in tests: |
| gm = make_fx(test, tracing_mode=tracing_mode)(t0, t1) |
| result = test(t2, t3) |
| self.assertEqual(result, torch_add) |
| |
| @requires_gpu |
| @common_utils.parametrize("grad", [False, True]) |
| @common_utils.parametrize("dynamic", [False, True]) |
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) |
| @patch.object(torch._inductor.config, "implicit_fallbacks", False) |
| def test_triton_kernel_native(self, grad, dynamic, backend): |
| def call_triton_add( |
| x: torch.Tensor, |
| y: torch.Tensor, |
| output: torch.Tensor, |
| grid_type: int, |
| num=1, |
| positional=False, |
| ): |
| n_elements = output.numel() |
| |
| def grid_fn(meta): |
| return (triton.cdiv(num, meta["BLOCK_SIZE"]),) |
| |
| if grid_type == 0: |
| grid = (x.numel(),) |
| elif grid_type == 1: |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| else: |
| grid = grid_fn |
| |
| if positional: |
| add_kernel[grid](x, y, output, n_elements, 16) |
| else: |
| add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) |
| |
| return output |
| |
| t1 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) |
| t2 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) |
| o1 = torch.zeros_like(t1, requires_grad=grad) |
| |
| torch_add = t1 + t2 |
| |
| # No Dynamo -- Make sure triton kernel works |
| self.assertEqual(call_triton_add(t1, t2, o1, 1), torch_add) |
| # No Dynamo -- Make sure triton kernel works (with positional BLOCK_SIZE) |
| o2 = torch.zeros_like(t1, requires_grad=grad) |
| self.assertEqual(call_triton_add(t1, t2, o2, 1, True), torch_add) |
| |
| # With Dynamo |
| compiled_func = torch.compile( |
| call_triton_add, backend=backend, fullgraph=True, dynamic=dynamic |
| ) |
| # With simple kernel |
| o3 = torch.zeros_like(t1, requires_grad=grad) |
| self.assertEqual(compiled_func(t1, t2, o3, 0), torch_add) |
| # With lambda kernel |
| o4 = torch.zeros_like(t1, requires_grad=grad) |
| self.assertEqual(compiled_func(t1, t2, o4, 1), torch_add) |
| # With lambda kernel (with positional BLOCK_SIZE) |
| o5 = torch.zeros_like(t1, requires_grad=grad) |
| self.assertEqual(compiled_func(t1, t2, o5, 1, 1, True), torch_add) |
| # With user defined function kernel |
| o6 = torch.zeros_like(t1, requires_grad=grad) |
| self.assertEqual(compiled_func(t1, t2, o6, 2, 200), torch_add) |
| |
| @requires_gpu |
| def test_triton_kernel_mutation_not_mark_dirty(self): |
| @torch.compile |
| def f(x): |
| n_elements = x.numel() |
| add_kernel[(n_elements,)](x, x, x, n_elements, 16) |
| return x |
| |
| x = torch.randn(5, device=GPU_TYPE, requires_grad=True) |
| x_cloned = x.clone() |
| out = x_cloned.sin() |
| f(x_cloned) |
| out.sum().backward() |
| |
| @requires_cuda |
| @patch.object(torch._inductor.config, "allow_buffer_reuse", True) |
| def test_triton_kernel_inputs_buffer_reuse(self): |
| def _mul2(x): |
| y = torch.empty_like(x) |
| mul2_kernel[(10,)]( |
| in_ptr0=x, |
| out_ptr=y, |
| n_elements=x.numel(), |
| BLOCK_SIZE=1, |
| ) |
| return y |
| |
| @torch.compile |
| def f(x): |
| for _ in range(4): |
| # The output of one kernel is the input to the next kernel, but |
| # at some point we should re-use buffers not allocate new ones. |
| x = _mul2(x) |
| return x + 1 |
| |
| x = torch.randn(10, device="cuda", dtype=torch.float32) |
| eager_out = f(x) |
| compiled_out, (code,) = run_and_get_code(torch.compile(f), x) |
| self.assertEqual(compiled_out, eager_out) |
| |
| # Check that we're allocating the minimal # of buffers. |
| num_bufs_allocated = code.count( |
| "empty_strided_cuda((10, ), (1, ), torch.float32)" |
| ) |
| self.assertEqual(num_bufs_allocated, 2) |
| |
| # Check we're re-using buffers if not allocating. |
| num_bufs_reused = code.count("# reuse") |
| self.assertEqual(num_bufs_reused, 3) |
| |
| @requires_gpu |
| def test_triton_kernel_matmul_tracking(self): |
| @triton.jit |
| def ones_kernel(x_ptr, n_elements, BLOCK_SIZE: "tl.constexpr"): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = 1.0 |
| tl.store(x_ptr + offsets, x, mask=mask) |
| |
| @torch.compile |
| def f(x): |
| out = torch.zeros_like(x) |
| ones_kernel[(4,)](out, 16, BLOCK_SIZE=16) |
| return torch.mm(out, x) + 10 |
| |
| x = torch.randn(4, 4, device=GPU_TYPE) |
| torch_out = f(x) |
| python_out = torch.mm(torch.ones(4, 4, device=GPU_TYPE), x) + 10 |
| self.assertEqual(torch_out, python_out) |
| |
| @requires_gpu |
| def test_triton_kernel_strided_input(self): |
| def f(inp): |
| # left has strides [256, 1] |
| left, right = torch.split(inp, [128, 128], dim=1) |
| out = torch.empty_like(left) |
| X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16 |
| grid = (left.size(1) // X_BLOCK_SIZE, left.size(0) // Y_BLOCK_SIZE) |
| double_strided_kernel[grid]( |
| in_ptr=left, |
| out_ptr=out, |
| in_y_stride=left.stride(0), |
| out_y_stride=out.stride(0), |
| X_BLOCK_SIZE=X_BLOCK_SIZE, |
| Y_BLOCK_SIZE=Y_BLOCK_SIZE, |
| ) |
| return out |
| |
| inp = torch.randn(64, 256, device=GPU_TYPE) |
| |
| eager_out = f(inp) |
| compiled_out = torch.compile(f)(inp) |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| def test_triton_kernel_strided_input_nonzero_offset(self): |
| def f(inp): |
| # right has strides [256, 1] and storage offset 128 |
| left, right = torch.split(inp, [128, 128], dim=1) |
| out = torch.empty_like(right) |
| X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16 |
| grid = (right.size(1) // X_BLOCK_SIZE, right.size(0) // Y_BLOCK_SIZE) |
| double_strided_kernel[grid]( |
| in_ptr=right, |
| out_ptr=out, |
| in_y_stride=right.stride(0), |
| out_y_stride=out.stride(0), |
| X_BLOCK_SIZE=X_BLOCK_SIZE, |
| Y_BLOCK_SIZE=Y_BLOCK_SIZE, |
| ) |
| return out |
| |
| inp = torch.randn(64, 256, device=GPU_TYPE) |
| |
| eager_out = f(inp) |
| compiled_out = torch.compile(f)(inp) |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| def test_triton_kernel_slice_and_view_input(self): |
| def f(inp): |
| # left has strides [256, 1] |
| left = inp[:, :128] |
| left = left.view(64, 4, 32) |
| out = torch.empty_like(left) |
| X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16 |
| grid = ( |
| (left.size(1) * left.size(2)) // X_BLOCK_SIZE, |
| left.size(0) // Y_BLOCK_SIZE, |
| ) |
| double_strided_kernel[grid]( |
| in_ptr=left, |
| out_ptr=out, |
| in_y_stride=left.stride(0), |
| out_y_stride=out.stride(0), |
| X_BLOCK_SIZE=X_BLOCK_SIZE, |
| Y_BLOCK_SIZE=Y_BLOCK_SIZE, |
| ) |
| return out + left |
| |
| inp = torch.randn(64, 256, device=GPU_TYPE) |
| |
| eager_out = f(inp) |
| compiled_out = torch.compile(f)(inp) |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| def test_triton_kernel_fallback(self): |
| def f(x, y): |
| out = torch.zeros_like(x) |
| out2 = torch.zeros_like(x) |
| # torch.mm is ExternKernelOut |
| add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16) |
| # torch.sort creates fallback kernel and hence MultiOutput |
| add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16) |
| return out, out2 |
| |
| x = torch.randn(4, 4, device=GPU_TYPE) |
| y = torch.randn(4, 4, device=GPU_TYPE) |
| eager_out = f(x, y) |
| compiled_out = torch.compile(f)(x, y) |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| def test_triton_kernel_out_of_order(self): |
| @triton.jit |
| def add_kernel( |
| in_ptr0, |
| in_ptr1, |
| BLOCK_SIZE: "tl.constexpr", |
| out_ptr, |
| n_elements, |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output = x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| def f(x, y): |
| out = torch.zeros_like(x) |
| n_elements = x.numel() |
| add_kernel[(n_elements,)](x, y, 4, out, n_elements) |
| return out |
| |
| x = torch.randn(4, device=GPU_TYPE) |
| y = torch.randn(4, device=GPU_TYPE) |
| eager_out = f(x, y) |
| compiled_out = torch.compile(f)(x, y) |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) |
| @torch._dynamo.config.patch(capture_scalar_outputs=True) |
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) |
| def test_triton_kernel_unbacked_shape_tensor(self, backend): |
| @triton.jit |
| def square( |
| in_ptr, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr + offsets, mask=mask) |
| output = x * x |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| def f(x): |
| x = x[x > 2] |
| n_elements = x.numel() |
| output = torch.zeros_like(x) |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| square[grid](x, output, n_elements, BLOCK_SIZE=16) |
| return output |
| |
| x = torch.randn(4, device=GPU_TYPE) |
| eager_out = f(x) |
| compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x) |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| @common_utils.parametrize("dynamic", [False, True]) |
| def test_triton_kernel_equal_to_1_arg(self, dynamic): |
| @triton.jit |
| def add_kernel_half_n_elements( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| half_n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < half_n_elements * 2 |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output = x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| def f(x, y): |
| out = torch.empty_like(x) |
| half_n_elements = x.numel() // 2 |
| add_kernel_half_n_elements[(half_n_elements,)]( |
| x, y, out, half_n_elements, BLOCK_SIZE=16 |
| ) |
| return out |
| |
| x = torch.randn(2, device=GPU_TYPE) |
| y = torch.randn(2, device=GPU_TYPE) |
| eager_out = f(x, y) |
| compiled_out, sources = run_and_get_code( |
| torch.compile(f, dynamic=dynamic), x, y |
| ) |
| |
| if dynamic: |
| # when half_n_elements passed to the Triton kernel is |
| # dynamic, equal_to_1 specializaiton can't be enforced |
| self.assertTrue("equal_to_1=()" in sources[0]) |
| else: |
| self.assertTrue("equal_to_1=(3,)" in sources[0]) |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| @common_utils.parametrize("dynamic", [False, True]) |
| def test_triton_kernel_equal_to_1_float_arg(self, dynamic): |
| def f(x, y): |
| out = torch.empty_like(x) |
| n_elements = x.numel() |
| scaling_factor = (n_elements**0) / 1.0 |
| add_kernel_with_scaling[(n_elements,)]( |
| x, |
| y, |
| out, |
| n_elements, |
| scaling_factor, |
| BLOCK_SIZE=16, |
| ) |
| return out |
| |
| x = torch.randn(2, device=GPU_TYPE) |
| y = torch.randn(2, device=GPU_TYPE) |
| eager_out = f(x, y) |
| compiled_out, sources = run_and_get_code( |
| torch.compile(f, dynamic=dynamic), x, y |
| ) |
| |
| # float 1.0 (both literal or symbolic) |
| # should not be added to equal_to_1 |
| self.assertTrue("equal_to_1=()" in sources[0]) |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| @skipIfRocm |
| def test_triton_kernel_with_imported_symbol(self): |
| @triton.jit |
| def add_kernel_with_imported_symbol( |
| in_ptr, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr + offsets, mask=mask) |
| output = fast_dividef(x, 3.14) |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| def f(x): |
| out = torch.empty_like(x) |
| n_elements = x.numel() |
| add_kernel_with_imported_symbol[(n_elements,)]( |
| x, out, n_elements, BLOCK_SIZE=16 |
| ) |
| return out |
| |
| x = torch.randn(4, device=GPU_TYPE) |
| eager_out = f(x) |
| compiled_out = torch.compile(f)(x) |
| |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| @skipIfRocm |
| def test_triton_kernel_with_imported_symbol_with_custom_name(self): |
| @triton.jit |
| def add_kernel_with_imported_symbol( |
| in_ptr, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr + offsets, mask=mask) |
| output = my_fast_dividef(x, 3.14) |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| def f(x): |
| out = torch.empty_like(x) |
| n_elements = x.numel() |
| add_kernel_with_imported_symbol[(n_elements,)]( |
| x, out, n_elements, BLOCK_SIZE=16 |
| ) |
| return out |
| |
| x = torch.randn(4, device=GPU_TYPE) |
| eager_out = f(x) |
| compiled_out = torch.compile(f)(x) |
| |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| @common_utils.parametrize("size", [4, 16]) |
| @common_utils.parametrize("dynamic", [False, True]) |
| def test_triton_kernel_different_shapes(self, size, dynamic): |
| from torch._inductor.utils import run_and_get_code |
| |
| def f(x, y, xx, yy): |
| n_elements = x.numel() |
| output_1 = torch.zeros_like(x) |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| add_kernel[grid](x, y, output_1, n_elements, BLOCK_SIZE=4) |
| |
| n_elements = xx.numel() |
| output_2 = torch.zeros_like(xx) |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| add_kernel[grid](xx, yy, output_2, n_elements, BLOCK_SIZE=4) |
| |
| return output_1, output_2 |
| |
| x = torch.rand(size, device=GPU_TYPE) |
| y = torch.rand(size, device=GPU_TYPE) |
| xx = torch.rand(size, size, device=GPU_TYPE) |
| yy = torch.rand(size, size, device=GPU_TYPE) |
| args = [x, y, xx, yy] |
| |
| eager_out = f(*args) |
| compiled_out, (code,) = run_and_get_code( |
| torch.compile(f, fullgraph=True, dynamic=dynamic, backend="inductor"), *args |
| ) |
| if size == 4 and not dynamic: |
| # Produce 2 kernels due to divisibility |
| self.assertTrue("add_kernel_0.run" in code) |
| self.assertTrue("add_kernel_1.run" in code) |
| else: |
| # size == 16 or dynamic |
| # Only one kernel |
| self.assertTrue("add_kernel_0.run" in code) |
| self.assertTrue("add_kernel_1.run" not in code) |
| |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| def test_triton_kernel_reset_to_zero(self): |
| @triton.autotune( |
| configs=[ |
| triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), |
| triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8), |
| ], |
| key=["n_elements"], |
| reset_to_zero=["out_ptr"], |
| ) |
| @triton.jit |
| def add_kernel_autotuned_reset( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output = x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| @torch.compile(fullgraph=True) |
| def f(x, y): |
| output = torch.zeros_like(x) |
| n_elements = output.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| add_kernel_autotuned_reset[grid](x, y, output, n_elements) |
| return output |
| |
| x = torch.randn(4, device=GPU_TYPE) |
| msg = "Only configs and keys are supported for triton.autotune" |
| with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): |
| f(x, x) |
| |
| @requires_gpu |
| @common_utils.parametrize("dynamic", [False, True]) |
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) |
| def test_triton_kernel_triton_dtype(self, dynamic, backend): |
| @triton.jit |
| def add_kernel_with_dtype( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| dtype: "tl.constexpr", |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask).to(dtype) |
| y = tl.load(in_ptr1 + offsets, mask=mask).to(dtype) |
| output = x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| def f(x, y, dtype_torch, dtype_triton): |
| output = torch.zeros_like(x).to(dtype=dtype_torch) |
| n_elements = output.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| add_kernel_with_dtype[grid]( |
| x, y, output, dtype_triton, n_elements, BLOCK_SIZE=4 |
| ) |
| return output |
| |
| x = torch.randn(4, device=GPU_TYPE) |
| y = torch.randn(4, device=GPU_TYPE) |
| args_list = ( |
| [x, y, torch.float32, tl.float32], |
| [x, y, torch.bfloat16, tl.bfloat16], |
| ) |
| for args in args_list: |
| eager_out = f(*args) |
| compiled_out = torch.compile( |
| f, fullgraph=True, backend=backend, dynamic=dynamic |
| )(*args) |
| self.assertEqual(compiled_out, eager_out) |
| |
| @requires_gpu |
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) |
| def test_triton_kernel_special_kwargs_with_autotune(self, backend): |
| @triton.autotune( |
| configs=[ |
| triton.Config({"BLOCK_SIZE": 128}), |
| triton.Config({"BLOCK_SIZE": 64}), |
| ], |
| key=["n_elements"], |
| ) |
| @triton.jit |
| def add_kernel( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output = x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| @torch.compile(fullgraph=True, backend=backend) |
| def f(x, y): |
| output = torch.zeros_like(x) |
| n_elements = output.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| add_kernel[grid]( |
| x, |
| y, |
| output, |
| n_elements, |
| num_warps=8, |
| num_stages=3, |
| ) |
| return output |
| |
| x = torch.randn(4, device=GPU_TYPE) |
| f(x, x) |
| |
| @requires_gpu |
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) |
| def test_triton_kernel_num_ctas(self, backend): |
| @triton.jit |
| def kernel(X): |
| return |
| |
| @torch.compile(backend=backend) |
| def f(x): |
| kernel[(1,)](x, num_ctas=1) |
| kernel.run(x, num_ctas=1, grid=(1,), warmup=False) |
| return x |
| |
| x = torch.randn(4, device=GPU_TYPE) |
| f(x) |
| |
| @requires_gpu |
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) |
| def test_triton_kernel_special_kwargs_without_autotune(self, backend): |
| @triton.jit |
| def add_kernel( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output = x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| @torch.compile(fullgraph=True, backend=backend) |
| def f(x, y): |
| output = torch.zeros_like(x) |
| n_elements = output.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| add_kernel[grid]( |
| x, |
| y, |
| output, |
| n_elements, |
| BLOCK_SIZE=128, |
| num_warps=8, |
| num_stages=3, |
| ) |
| return output |
| |
| x = torch.randn(4, device=GPU_TYPE) |
| f(x, x) |
| |
| |
| def make_mutation_test(fn): |
| @requires_gpu |
| def test_fn(self): |
| from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors |
| |
| kernel, inputs, outputs = fn() |
| self.assertListEqual( |
| identify_mutated_tensors(kernel, inputs), |
| outputs, |
| ) |
| |
| return test_fn |
| |
| |
| # Triton codegen suffers from scoping issues. |
| # Define helpers here |
| if HAS_GPU: |
| |
| @triton.jit |
| def helper_id(p): |
| return p |
| |
| @triton.jit |
| def helper_add_and_out(x, y, out_ptr): |
| return x + y, out_ptr |
| |
| |
| class MutationTests(torch._inductor.test_case.TestCase): |
| # Tests injected below |
| |
| @make_mutation_test |
| def test_out_of_order_kernel(): |
| @triton.jit |
| def add_kernel_out_of_order( |
| in_ptr0, |
| n_elements, |
| in_ptr1, |
| out_ptr, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output = x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| t = torch.randn(4) |
| return ( |
| add_kernel_out_of_order, |
| { |
| "in_ptr0": t, |
| "n_elements": 4, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_out_of_order_kernel_call(): |
| @triton.jit |
| def add_kernel_out_of_order_fn1( |
| in_ptr0, |
| n_elements, |
| in_ptr1, |
| out_ptr, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| add_kernel_out_of_order_fn2( |
| in_ptr0, in_ptr1, n_elements, out_ptr, BLOCK_SIZE=BLOCK_SIZE |
| ) |
| |
| t = torch.randn(4) |
| return ( |
| add_kernel_out_of_order_fn1, |
| { |
| "in_ptr0": t, |
| "n_elements": 4, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_reduce_sum(): |
| @triton.jit |
| def reduce_sum_kernel(a_ptr, c_ptr, stride_am, stride_an): |
| offs_am = tl.arange(0, 4) |
| offs_an = tl.arange(0, 4) |
| a_ptrs = a_ptr + ( |
| offs_am[:, None] * stride_am + offs_an[None, :] * stride_an |
| ) |
| a = tl.load(a_ptrs) |
| m = tl.sum(a, axis=1) |
| tl.store(c_ptr + tl.arange(0, 4), m) |
| |
| t = torch.randn(4) |
| kernel = reduce_sum_kernel |
| kwargs = { |
| "a_ptr": t, |
| "c_ptr": t, |
| "stride_am": 4, |
| "stride_an": 4, |
| } |
| |
| # TODO(aakhundov): tt.reduce is now supported, but only |
| # in the new MLIR-based Triton analysis pass (not in the |
| # old TTIR string parsing-based one). remove this gating |
| # and use ["c_ptr"] as `expected` after the new Triton |
| # pin lands both in OSS and internally. |
| ttir_module, _ = generate_ttir(kernel, kwargs) |
| if hasattr(ttir_module, "walk"): |
| # with MLIR-based Triton analysis pass |
| expected = ["c_ptr"] |
| else: |
| # with TTIR string parsing-based Triton analysis pass |
| expected = ["a_ptr", "c_ptr"] |
| |
| return ( |
| kernel, |
| kwargs, |
| expected, |
| ) |
| |
| @make_mutation_test |
| def test_argmax(): |
| @triton.jit |
| def argmax_kernel(a_ptr, c_ptr, stride_am, stride_an): |
| offs_am = tl.arange(0, 4) |
| offs_an = tl.arange(0, 4) |
| a_ptrs = a_ptr + ( |
| offs_am[:, None] * stride_am + offs_an[None, :] * stride_an |
| ) |
| a = tl.load(a_ptrs) |
| m = tl.argmax(a, axis=1) |
| tl.store(c_ptr + tl.arange(0, 4), m) |
| |
| t = torch.randn(4) |
| kernel = argmax_kernel |
| kwargs = { |
| "a_ptr": t, |
| "c_ptr": t, |
| "stride_am": 4, |
| "stride_an": 4, |
| } |
| |
| # TODO(aakhundov): tt.reduce is now supported, but only |
| # in the new MLIR-based Triton analysis pass (not in the |
| # old TTIR string parsing-based one). remove this gating |
| # and use ["c_ptr"] as `expected` after the new Triton |
| # pin lands both in OSS and internally. |
| ttir_module, _ = generate_ttir(kernel, kwargs) |
| if hasattr(ttir_module, "walk"): |
| # with MLIR-based Triton analysis pass |
| expected = ["c_ptr"] |
| else: |
| # with TTIR string parsing-based Triton analysis pass |
| expected = ["a_ptr", "c_ptr"] |
| |
| return ( |
| kernel, |
| kwargs, |
| expected, |
| ) |
| |
| @requires_cuda |
| @skipIfRocm |
| def test_triton_kernel_inference_mode(self): |
| def f(x, y, out): |
| n_elements = x.numel() |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=4) |
| |
| with torch.inference_mode(): |
| x = torch.ones(32, device="cuda") |
| y = torch.ones(32, device="cuda") |
| out_ref = torch.zeros_like(x) |
| out_test = torch.zeros_like(x) |
| f(x, y, out_ref) |
| torch.compile(f)(x, y, out_test) |
| self.assertEqual(out_ref, out_test) |
| |
| @make_mutation_test |
| def test_cumsum(): |
| @triton.jit |
| def cumsum_kernel(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): |
| rindex = tl.arange(0, RBLOCK)[None, :] |
| xindex = tl.arange(0, XBLOCK)[:, None] |
| data = tl.load(in_ptr + rindex) |
| scan = tl.cumsum(data, 1) |
| expected_max = tl.sum(data, 1) |
| tl.device_assert(scan <= expected_max) |
| tl.store(out_ptr + xindex * RBLOCK + rindex, scan) |
| |
| t = torch.randn(4) |
| kernel = cumsum_kernel |
| kwargs = { |
| "in_ptr": t, |
| "out_ptr": t, |
| "XBLOCK": 4, |
| "RBLOCK": 16, |
| } |
| |
| # TODO(aakhundov): tt.scan is now supported, but only |
| # in the new MLIR-based Triton analysis pass (not in the |
| # old TTIR string parsing-based one). remove this gating |
| # and use ["out_ptr"] as `expected` after the new Triton |
| # pin lands both in OSS and internally. |
| ttir_module, _ = generate_ttir(kernel, kwargs) |
| if hasattr(ttir_module, "walk"): |
| # with MLIR-based Triton analysis pass |
| expected = ["out_ptr"] |
| else: |
| # with TTIR string parsing-based Triton analysis pass |
| expected = ["in_ptr", "out_ptr"] |
| |
| return ( |
| kernel, |
| kwargs, |
| expected, |
| ) |
| |
| @make_mutation_test |
| def test_fn_call_one_return(): |
| @triton.jit |
| def add_kernel_with_fn_call( |
| in_ptr0, |
| in_ptr1, |
| n_elements, |
| out_ptr, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output = x + y |
| out = helper_id(out_ptr) |
| tl.store(out + offsets, output, mask=mask) |
| |
| t = torch.randn(4) |
| return ( |
| add_kernel_with_fn_call, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "n_elements": 4, |
| "out_ptr": t, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_fn_call_multi_return(): |
| @triton.jit |
| def add_kernel_with_fn_call( |
| in_ptr0, |
| in_ptr1, |
| n_elements, |
| out_ptr, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output, out = helper_add_and_out(x, y, out_ptr) |
| tl.store(out + offsets, output, mask=mask) |
| |
| t = torch.randn(4) |
| return ( |
| add_kernel_with_fn_call, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "n_elements": 4, |
| "out_ptr": t, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_nested_cond_op_kernel(): |
| @triton.jit |
| def nested_cond_op_kernel( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| if tl.program_id(0) == 0: |
| if tl.program_id(1) == 0: |
| output = x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| else: |
| pass |
| |
| t = torch.randn(4) |
| return ( |
| nested_cond_op_kernel, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_add_for_loop(): |
| @triton.jit |
| def add_4_times_kernel( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output = tl.zeros((n_elements,), dtype=tl.float32) |
| for i in range(4): |
| output += x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| t = torch.randn(4) |
| return ( |
| add_4_times_kernel, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_add_for_loop2(): |
| @triton.jit |
| def add_1_time_kernel( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| for i in range(0, BLOCK_SIZE): |
| i = tl.multiple_of(i, 1) |
| output = x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| t = torch.randn(4) |
| return ( |
| add_1_time_kernel, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_add_nested_for_loop(): |
| @triton.jit |
| def add_4_times_kernel( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output = tl.zeros((n_elements,), dtype=tl.float32) |
| for i in range(2): |
| for j in range(2): |
| output += x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| t = torch.randn(4) |
| return ( |
| add_4_times_kernel, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_add_nested_for_loop_multi_return(): |
| @triton.jit |
| def add_4_times_kernel( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output1 = tl.zeros((n_elements,), dtype=tl.float32) |
| output2 = tl.zeros((n_elements,), dtype=tl.float32) |
| for i in range(2): |
| for j in range(2): |
| output1 += y |
| output2 += x |
| output = output1 + output2 |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| t = torch.randn(4) |
| return ( |
| add_4_times_kernel, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_labels(): |
| @triton.jit |
| def kernel_with_label( |
| in_ptr0, |
| in_ptr1, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: "tl.constexpr", |
| ): |
| pid = tl.program_id(axis=0) |
| if pid > 1: |
| return |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| x = tl.load(in_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr1 + offsets, mask=mask) |
| output = x + y |
| tl.store(out_ptr + offsets, output, mask=mask) |
| |
| t = torch.randn(4) |
| return ( |
| kernel_with_label, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_for_loop_arg(): |
| @triton.jit |
| def fwd_kernel( |
| X_ptr, |
| W1_ptr, |
| b1_ptr, |
| O_ptr, |
| M: tl.constexpr, |
| C1: tl.constexpr, |
| C2: tl.constexpr, |
| BLOCK_SIZE_M: tl.constexpr, |
| BLOCK_SIZE_C2: tl.constexpr, |
| ): |
| # Get program ids |
| pid_m = tl.program_id(0) |
| |
| # Compute offsets |
| offs_c1 = tl.arange(0, C1) |
| offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| |
| # Load input data |
| x_block_ptr = X_ptr + offs_m[:, None] * C1 + offs_c1[None, :] |
| x = tl.load(x_block_ptr) |
| |
| # Compute gating |
| for c2 in range(0, tl.cdiv(C2, BLOCK_SIZE_C2)): |
| # Compute block pointers |
| offs_c2 = c2 * BLOCK_SIZE_C2 + tl.arange(0, BLOCK_SIZE_C2) |
| o_block_ptr = O_ptr + offs_m[:, None] * C2 + offs_c2[None, :] |
| w1_block_ptr = W1_ptr + offs_c1[:, None] * C2 + offs_c2[None, :] |
| b1_block_ptr = b1_ptr + offs_c2 |
| |
| # Compute output |
| w = tl.load(w1_block_ptr) |
| b = tl.load(b1_block_ptr) |
| o = tl.dot(x, w, allow_tf32=False) |
| o += b[None, :] |
| |
| # Store output |
| tl.store(o_block_ptr, o) |
| |
| t = torch.randn(64) |
| return ( |
| fwd_kernel, |
| { |
| "X_ptr": t, |
| "W1_ptr": t, |
| "b1_ptr": t, |
| "O_ptr": t, |
| "M": 64, |
| "C1": 64, |
| "C2": 64, |
| "BLOCK_SIZE_M": 64, |
| "BLOCK_SIZE_C2": 64, |
| }, |
| ["O_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_for_loop_arg_2(): |
| @triton.jit |
| def fwd_kernel( |
| x_ptr, |
| o_ptr, |
| M, |
| N, |
| stride_m, |
| stride_n, |
| BLOCK_B: tl.constexpr, |
| BLOCK_M: tl.constexpr, |
| BLOCK_N: tl.constexpr, |
| ): |
| # Get program ids |
| pid_m = tl.program_id(0) |
| X_block_ptr = tl.make_block_ptr( |
| base=x_ptr, |
| shape=(M, N), |
| strides=(stride_m, stride_n), |
| offsets=(0, 0), |
| block_shape=(BLOCK_M, BLOCK_N), |
| order=(1, 0), |
| ) |
| O_block_ptr = tl.make_block_ptr( |
| base=o_ptr, |
| shape=(M, N), |
| strides=(stride_m, stride_n), |
| offsets=(0, 0), |
| block_shape=(BLOCK_M, BLOCK_N), |
| order=(1, 0), |
| ) |
| |
| for _ in range(BLOCK_B): |
| x = tl.load(X_block_ptr) |
| tl.store(O_block_ptr, x) |
| |
| X_block_ptr = tl.advance(X_block_ptr, (BLOCK_M, 0)) |
| O_block_ptr = tl.advance(O_block_ptr, (BLOCK_M, 0)) |
| |
| t = torch.randn((32, 64, 128)) |
| o = torch.empty_like(t) |
| B, M, N = t.shape |
| return ( |
| fwd_kernel, |
| { |
| "x_ptr": t, |
| "o_ptr": o, |
| "M": M, |
| "N": N, |
| "stride_m": N, |
| "stride_n": 1, |
| "BLOCK_B": B, |
| "BLOCK_M": M, |
| "BLOCK_N": N, |
| }, |
| ["o_ptr"], |
| ) |
| |
| @make_mutation_test |
| def test_while_loop(): |
| @triton.jit |
| def fwd_kernel( |
| x_ptr, |
| o_ptr, |
| M, |
| N, |
| stride_m, |
| stride_n, |
| BLOCK_B: tl.constexpr, |
| BLOCK_M: tl.constexpr, |
| BLOCK_N: tl.constexpr, |
| ): |
| # Get program ids |
| pid_m = tl.program_id(0) |
| X_block_ptr = tl.make_block_ptr( |
| base=x_ptr, |
| shape=(M, N), |
| strides=(stride_m, stride_n), |
| offsets=(0, 0), |
| block_shape=(BLOCK_M, BLOCK_N), |
| order=(1, 0), |
| ) |
| O_block_ptr = tl.make_block_ptr( |
| base=o_ptr, |
| shape=(M, N), |
| strides=(stride_m, stride_n), |
| offsets=(0, 0), |
| block_shape=(BLOCK_M, BLOCK_N), |
| order=(1, 0), |
| ) |
| |
| i = 0 |
| while i < BLOCK_B: |
| x = tl.load(X_block_ptr) |
| tl.store(O_block_ptr, x) |
| |
| X_block_ptr = tl.advance(X_block_ptr, (BLOCK_M, 0)) |
| O_block_ptr = tl.advance(O_block_ptr, (BLOCK_M, 0)) |
| i += 1 |
| |
| t = torch.randn((32, 64, 128)) |
| o = torch.empty_like(t) |
| B, M, N = t.shape |
| return ( |
| fwd_kernel, |
| { |
| "x_ptr": t, |
| "o_ptr": o, |
| "M": M, |
| "N": N, |
| "stride_m": N, |
| "stride_n": 1, |
| "BLOCK_B": B, |
| "BLOCK_M": M, |
| "BLOCK_N": N, |
| }, |
| ["o_ptr"], |
| ) |
| |
| |
| if HAS_GPU: |
| t = torch.randn(4) |
| tt = torch.randn(4, 1) |
| tests = [ |
| [ |
| add_kernel, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ], |
| [ |
| add_kernel_2d_autotuned, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "x_elements": 4, |
| "y_elements": 4, |
| }, |
| ["out_ptr"], |
| ], |
| [ |
| indirection_kernel, |
| { |
| "in_ptr0": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| "ACTIVATION": "mul2_inplace_kernel", |
| }, |
| ["in_ptr0", "out_ptr"], |
| ], |
| [ |
| indirection_kernel, |
| { |
| "in_ptr0": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| "ACTIVATION": "add_kernel", |
| }, |
| ["out_ptr"], |
| ], |
| [ |
| mul2_inplace_kernel, |
| {"ptr": t, "n_elements": 4, "BLOCK_SIZE": 4}, |
| ["ptr"], |
| ], |
| # Cant optimize since the kernel contains a tl.inline_asm_elementwise |
| [ |
| inline_asm_kernel, |
| {"X": t, "Y": t, "Z": t, "n": 4, "BLOCK": 4}, |
| ["X", "Y", "Z"], |
| ], |
| [ |
| add_kernel_with_block_ptr, |
| { |
| "x_ptr": t, |
| "y_ptr": t, |
| "output_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["output_ptr"], |
| ], |
| [ |
| kernel_with_block_ptr_2d, |
| { |
| "x_ptr": tt, |
| "output_ptr": tt, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["output_ptr"], |
| ], |
| [ |
| add_kernel_with_import, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ], |
| [ |
| atomic_add_kernel, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ], |
| [ |
| add_4_times_kernel, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ], |
| [ |
| cond_op_kernel, |
| { |
| "in_ptr0": t, |
| "in_ptr1": t, |
| "out_ptr": t, |
| "n_elements": 4, |
| "BLOCK_SIZE": 4, |
| }, |
| ["out_ptr"], |
| ], |
| ] |
| for kernel, inputs, outputs in tests: |
| fn = make_mutation_test( |
| # Add default arguments to avoid Python lambda capture pitfall |
| # This forces the capture at lambda creation |
| lambda kernel=kernel, inputs=inputs, outputs=outputs: ( |
| kernel, |
| inputs, |
| outputs, |
| ) |
| ) |
| name = f"test_mutations_{kernel.fn.__name__}" |
| # Poor way to make test names be unique |
| while name in MutationTests.__dict__: |
| name += "1" |
| |
| setattr(MutationTests, name, fn) |
| |
| |
| class CustomOpTests(torch._inductor.test_case.TestCase): |
| """Tests for custom ops wrapping triton kernels""" |
| |
| @requires_gpu |
| @common_utils.parametrize("autotuned", [False, True]) |
| def test_add_kernel(self, autotuned): |
| from torch._inductor.utils import run_and_get_code |
| |
| libname = "my_cool_namespace" |
| opname = "my_triton_operator" |
| |
| @torch._library.triton_op(f"{libname}::{opname}", mutates_args={}) |
| def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| output = torch.empty_like(x) |
| n_elements = output.numel() |
| |
| def grid(meta): |
| return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| |
| if autotuned: |
| capture_triton(add_kernel_autotuned)[grid](x, y, output, n_elements) |
| else: |
| capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) |
| return output |
| |
| def f(x, y): |
| return add(x, y) |
| |
| x = torch.randn(3, device=GPU_TYPE) |
| y = torch.randn(3, device=GPU_TYPE) |
| |
| out = f(x, y) |
| expected = x + y |
| self.assertEqual(out, expected) |
| out_compiled, codes = run_and_get_code(torch.compile(f), x, y) |
| self.assertEqual(out_compiled, expected) |
| self.assertEqual(len(codes), 1) |
| |
| # Check that we decomposed the operator away |
| code = "\n".join(codes[0]) |
| self.assertNotIn(libname, code) |
| self.assertNotIn(opname, code) |
| |
| |
| common_utils.instantiate_parametrized_tests(KernelTests) |
| common_utils.instantiate_parametrized_tests(CustomOpTests) |
| |
| |
| if __name__ == "__main__": |
| from torch._inductor.test_case import run_tests |
| |
| run_tests() |