blob: 7b227051b0f3d329673a9b607852f78d357ef87f [file] [log] [blame]
# Owner(s): ["module: dynamo"]
# flake8: noqa: E731
# Skip do not assign a lambda expression, use a def
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import config
from torch._dynamo.testing import make_test_cls_with_patches
from torch._higher_order_ops.triton_kernel_wrap import (
triton_kernel_wrapper_functional,
triton_kernel_wrapper_mutation,
)
from torch._inductor import metrics
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfRocm
# Defines all the kernels for tests
from torch.testing._internal.triton_utils import * # noqa: F403
if HAS_CUDA:
import triton
from triton import language as tl
# Define shared triton constants here.
CONSTANT_C = 4
STRING_CONSTANT_C = "CONSTANT_C"
BOOL_CONSTANT_C = True
class KernelTests(torch._dynamo.test_case.TestCase):
@requires_cuda
def test_triton_kernel_with_kernel_param(self):
@triton.jit
def pass_kernel(kernel):
pass
@torch.compile(backend="eager")
def f(x):
grid = (x.numel(),)
pass_kernel[grid](kernel=x)
t1 = torch.rand(5, device="cuda")
f(t1)
# No need to assert anything, the goal is to make sure dynamo does
# not crash
@requires_cuda
def test_triton_kernel_higher_order_func(self):
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
add_kernel_id = kernel_side_table.add_kernel(add_kernel)
t1 = torch.rand(5, device="cuda")
t2 = torch.rand(5, device="cuda")
torch_add = t1 + t2
# Test higher order function with mutation
output = torch.zeros_like(t1)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
triton_kernel_wrapper_mutation(
kernel_idx=add_kernel_id,
grid=[grid],
kwargs={
"in_ptr0": t1,
"in_ptr1": t2,
"out_ptr": output,
"n_elements": n_elements,
"BLOCK_SIZE": 16,
},
)
self.assertEqual(output, torch_add)
# Make sure it is modified
self.assertNotEqual(output, torch.zeros_like(t1))
# Test higher order function without mutation
output = torch.zeros_like(t1)
out_dict = triton_kernel_wrapper_functional(
kernel_idx=add_kernel_id,
grid=[grid],
kwargs={
"in_ptr0": t1,
"in_ptr1": t2,
"out_ptr": output,
"n_elements": n_elements,
"BLOCK_SIZE": 16,
},
tensors_to_clone=["in_ptr0", "in_ptr1", "out_ptr"],
)
self.assertEqual(out_dict["out_ptr"], torch_add)
# Make sure it is NOT modified
self.assertEqual(output, torch.zeros_like(t1))
@requires_cuda
@skipIfRocm
def test_triton_kernel_functionalize(self):
from functorch import make_fx
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
from torch._subclasses.functional_tensor import (
CppFunctionalizeAPI,
FunctionalTensorMode,
PythonFunctionalizeAPI,
)
kernel_side_table.reset_table()
def f(x, output):
out = triton_kernel_wrapper_functional(
kernel_idx=kernel_side_table.add_kernel(mul2_kernel),
grid=[(x.numel(),)],
kwargs={
"in_ptr0": x,
"out_ptr": output,
"n_elements": output.numel(),
"BLOCK_SIZE": 16,
},
tensors_to_clone=["in_ptr0", "out_ptr"],
)
return out["out_ptr"]
t1 = torch.rand(5, device="cuda")
t2 = torch.rand(5, device="cuda")
with FunctionalTensorMode():
gm = make_fx(PythonFunctionalizeAPI().functionalize(f))(t1, t2)
# Make sure t2 was not modified
self.assertNotEqual(gm(t1, t2), t2)
gm = make_fx(CppFunctionalizeAPI().functionalize(f))(t1, t2)
# Make sure t2 was not modified
self.assertNotEqual(gm(t1, t2), t2)
gm = make_fx(torch.func.functionalize(f))(t1, t2)
# Make sure t2 was not modified
self.assertNotEqual(gm(t1, t2), t2)
gm = make_fx(f, tracing_mode="fake")(t1, t2)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1, output_1):
triton_kernel_wrapper_functional_proxy = torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional(kernel_idx = 0, grid = [(5,)], kwargs = {'in_ptr0': x_1, 'out_ptr': output_1, 'n_elements': 5, 'BLOCK_SIZE': 16}, tensors_to_clone = ['in_ptr0', 'out_ptr']); x_1 = output_1 = None
getitem = triton_kernel_wrapper_functional_proxy['in_ptr0']
getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr']; triton_kernel_wrapper_functional_proxy = None
return getitem_1""",
)
@requires_cuda
@skipIfRocm
def test_triton_kernel_mutation_type(self):
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import (
FunctionalTensor,
FunctionalTensorMode,
)
def prep():
x = torch.ones(4, device="cuda", requires_grad=True)
with FunctionalTensorMode():
x_func = FunctionalTensor.to_functional(x)
self.assertTrue(torch._is_functional_tensor(x_func.elem))
return x_func
# normal mutation only
with FakeTensorMode():
x_func = prep()
with FunctionalTensorMode():
x_func.mul_(2)
self.assertFalse(
torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem)
)
# triton kernel mutation only
with FakeTensorMode():
x_func = prep()
with FunctionalTensorMode():
triton_kernel_wrapper_mutation(
kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel),
grid=[(x_func.numel(),)],
kwargs={
"ptr": x_func,
"n_elements": x_func.numel(),
"BLOCK_SIZE": 16,
},
)
self.assertTrue(
torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem)
)
# normal mutation + triton kernel mutation
with FakeTensorMode():
x_func = prep()
with FunctionalTensorMode():
x_func.mul_(2)
triton_kernel_wrapper_mutation(
kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel),
grid=[(x_func.numel(),)],
kwargs={
"ptr": x_func,
"n_elements": x_func.numel(),
"BLOCK_SIZE": 16,
},
)
self.assertFalse(
torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem)
)
@requires_cuda
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_triton_kernel_with_views(self, dynamic, backend):
def call_triton_take_view(x: torch.Tensor):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
return output
def call_triton_return_view(x: torch.Tensor):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
return output.view(4, 4)
t = torch.rand(4, 4, device="cuda")
t_view = t.view(16)
compiled_func = torch.compile(
call_triton_take_view, backend=backend, fullgraph=True, dynamic=dynamic
)
self.assertEqual(2 * t_view, compiled_func(t_view))
self.assertEqual(2 * t, compiled_func(t_view).view(4, 4))
compiled_func = torch.compile(
call_triton_return_view, backend=backend, fullgraph=True, dynamic=dynamic
)
self.assertEqual(2 * t_view, compiled_func(t).view(16))
self.assertEqual(2 * t, compiled_func(t))
@requires_cuda
@common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_triton_kernel_with_grad_option(self, grad_fn, backend):
def call_triton(x: torch.Tensor):
with grad_fn():
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
return output
t = torch.rand(5, device="cuda")
compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
self.assertEqual(2 * t, compiled_func(t))
@requires_cuda
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_triton_kernel_inner_triton_function(self, backend):
def f(x: torch.Tensor):
@triton.jit
def pow2_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = x * x
tl.store(out_ptr + offsets, output, mask=mask)
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
pow2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)
return output
t = torch.rand(5, device="cuda")
compiled_func = torch.compile(f, backend=backend, fullgraph=True)
# TODO(oulgen): NYI - Support this
# self.assertEqual(t * t, compiled_func(t))
@requires_cuda
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])
@patch.object(torch._inductor.config, "implicit_fallbacks", False)
def test_triton_kernel_no_clones(self, grad, dynamic):
from torch._inductor.utils import run_and_get_code
def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor):
n_elements = output.numel()
tmp = torch.add(x, 1)
grid = (x.numel(),)
add_kernel.run(
x, y, output, n_elements, warmup=False, grid=grid, BLOCK_SIZE=16
)
return output, tmp
t1 = torch.rand(5, device="cuda", requires_grad=grad)
t2 = torch.rand(5, device="cuda", requires_grad=grad)
o1 = torch.zeros_like(t1, requires_grad=grad)
torch_add = call_triton(t1, t2, o1)
metrics.reset()
o2 = torch.zeros_like(t1, requires_grad=grad)
test, codes = run_and_get_code(
torch.compile(call_triton, dynamic=dynamic), t1, t2, o2
)
if not grad:
self.assertEqual(metrics.generated_kernel_count, 1)
self.assertEqual(torch_add, test)
# These two asserts are not optimal since it requires original aten
# to be in the metadata, so there might be false negatives
self.assertTrue("aten.copy" not in codes[0])
self.assertTrue("aten.clone" not in codes[0])
# The following checks that there are only the tensor output is in
# the compiled graph
if dynamic and grad:
self.assertTrue("return (buf0, s0, )" in codes[0])
else:
self.assertTrue("return (buf0, )" in codes[0])
@requires_cuda
@skipIfRocm
def test_triton_kernel_caching(self):
from torch._inductor.utils import run_and_get_code
def add_in_loop(
x: torch.Tensor,
y: torch.Tensor,
):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotuned[grid](x, y, output, n_elements)
return output
def call_triton_add(
x: torch.Tensor,
y: torch.Tensor,
):
for i in range(4):
x = add_in_loop(x, y)
return x
t1 = torch.ones(5, device="cuda")
t2 = torch.ones(5, device="cuda")
test, (code,) = run_and_get_code(torch.compile(call_triton_add), t1, t2)
self.assertEqual(test, 5 * torch.ones(5, device="cuda"))
self.assertTrue("add_kernel_autotuned_1.run" not in code)
@requires_cuda
@skipIfRocm
def test_triton_kernel_caching_duplicate(self):
from torch._inductor.utils import run_and_get_code
class C:
@triton.jit
def pass_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
tl.store(out_ptr + offsets, x, mask=mask)
class D:
@triton.jit
def pass_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
tl.store(out_ptr + offsets, x, mask=mask)
def call_triton(x: torch.Tensor):
output1 = torch.zeros_like(x)
output2 = torch.zeros_like(x)
n_elements = output1.numel()
grid = (n_elements,)
C.pass_kernel[grid](x, output1, n_elements, BLOCK_SIZE=16)
D.pass_kernel[grid](x, output2, n_elements, BLOCK_SIZE=16)
return output1 + output2
t = torch.ones(5, device="cuda")
test, (code,) = run_and_get_code(torch.compile(call_triton), t)
# Make sure we emitted two kernels here
self.assertTrue("pass_kernel_0.run" in code)
self.assertTrue("pass_kernel_1.run" in code)
@requires_cuda
@skipIfRocm
def test_triton_kernel_various_args(self):
@triton.autotune(
configs=[triton.Config({"BLOCK_SIZE": 128})],
key=[],
)
@triton.jit
def pass_kernel(
out_ptr,
n_elements,
dummy_None,
dummy_empty,
dummy_float,
BLOCK_SIZE: "tl.constexpr",
RANDOM_SIZE: "tl.constexpr",
):
pass
@torch.compile
def call_triton(output):
n_elements = output.numel()
grid = (n_elements,)
pass_kernel[grid](
output,
n_elements,
None,
torch.empty_like(output),
3.1415926,
RANDOM_SIZE=0,
)
return output
output = torch.randn(5, device="cuda")
# Make sure this does not crash
call_triton(output)
@requires_cuda
@skipIfRocm
def test_triton_kernel_dependancies(self):
def call_triton(
x: torch.Tensor,
y: torch.Tensor,
):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotuned[grid](x, y, output, n_elements)
output2 = torch.zeros_like(output)
add_kernel_autotuned[grid](output, y, output2, n_elements)
output3 = torch.add(output2, 1)
return output3
t1 = torch.rand(5, device="cuda")
t2 = torch.rand(5, device="cuda")
torch_result = call_triton(t1, t2)
compiled_result = torch.compile(call_triton)(t1, t2)
self.assertEqual(torch_result, compiled_result)
@requires_cuda
@skipIfRocm
def test_triton_kernel_reinplace_inplaceable_pass(self):
def call_triton(
x: torch.Tensor,
y: torch.Tensor,
):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotuned[grid](x, y, output, n_elements)
add_kernel_autotuned[grid](output, x, output, n_elements)
return output
t1 = torch.rand(5, device="cuda")
t2 = torch.rand(5, device="cuda")
torch_result = call_triton(t1, t2)
compiled_result = torch.compile(call_triton)(t1, t2)
self.assertEqual(torch_result, compiled_result)
@requires_cuda
@common_utils.parametrize("grad", [False, True])
def test_triton_kernel_multi_kernel(self, grad):
@triton.jit
def mul2_and_add_and_zero_negatives_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
ACTIVATION: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
indirection_kernel(
in_ptr0,
in_ptr0,
n_elements,
BLOCK_SIZE=BLOCK_SIZE,
ACTIVATION="mul2_inplace_kernel",
)
indirection_kernel(
in_ptr1,
in_ptr1,
n_elements,
BLOCK_SIZE=BLOCK_SIZE,
ACTIVATION="mul2_inplace_kernel",
)
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
if ACTIVATION == "zero_negs":
output = zero_negs(output)
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile
def call_triton(
x: torch.Tensor,
y: torch.Tensor,
xi: torch.Tensor,
yi: torch.Tensor,
output: torch.Tensor,
outputi: torch.Tensor,
):
n_elements = output.numel()
grid = (x.numel(),)
mul2_and_add_and_zero_negatives_kernel[grid](
x, y, output, n_elements, BLOCK_SIZE=16, ACTIVATION="zero_negs"
)
mul2_and_add_and_zero_negatives_kernel[grid](
xi, yi, outputi, n_elements, BLOCK_SIZE=16, ACTIVATION=None
)
return (output, outputi)
t1 = torch.tensor(
[-2.0, -1.0, 0.0, 1.0, 2.0], device="cuda", requires_grad=grad
)
t2 = torch.tensor(
[-2.0, -1.0, 0.0, 1.0, 2.0], device="cuda", requires_grad=grad
)
float_result = 2 * t1 + 2 * t2
float_result = float_result.where(float_result >= 0, 0.0)
t1i = torch.randint(-2, 2, (5,), device="cuda")
t2i = torch.randint(-2, 2, (5,), device="cuda")
o = torch.zeros_like(t1, requires_grad=grad)
oi = torch.zeros_like(t1i)
int_result = 2 * t1i + 2 * t2i
(result, resulti) = call_triton(t1, t2, t1i, t2i, o, oi)
self.assertEqual(float_result, result)
self.assertEqual(int_result, resulti)
@requires_cuda
def test_triton_kernel_constants(self):
@triton.jit
def mulC_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
CONSTANT_NAME: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
if CONSTANT_NAME.value == STRING_CONSTANT_C:
output = CONSTANT_C * x
if BOOL_CONSTANT_C:
output *= CONSTANT_C
tl.store(out_ptr + offsets, output, mask=mask)
def call_triton(
x: torch.Tensor,
):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = (x.numel(),)
mulC_kernel[grid](
x, output, n_elements, BLOCK_SIZE=16, CONSTANT_NAME="CONSTANT_C"
)
return output
# Triton kernels capture global constants by their parse time value
# not runtime value
global CONSTANT_C
prev_c = CONSTANT_C
# If the behavior of triton kernels change, this test will fail
CONSTANT_C = 10
assert CONSTANT_C != prev_c
t = torch.randn(5, device="cuda")
torch_result = call_triton(t)
compiled_result = torch.compile(call_triton)(t)
self.assertEqual(torch_result, compiled_result)
# reset back
CONSTANT_C = prev_c
@requires_cuda
@skipIfRocm
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@common_utils.parametrize("grid_type", [1, 2, 3])
def test_triton_kernel_autotune(self, grad, dynamic, backend, grid_type):
def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor):
n_elements = output.numel()
def grid_fn(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
if grid_type == 1:
grid = (n_elements,)
elif grid_type == 2:
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
elif grid_type == 3:
grid = grid_fn
add_kernel_autotuned[grid](x, y, output, n_elements)
return output
t1 = torch.rand(256, device="cuda", requires_grad=grad)
t2 = torch.rand(256, device="cuda", requires_grad=grad)
output = torch.zeros_like(t1, requires_grad=grad)
torch_add = call_triton(t1, t2, output)
compiled_func = torch.compile(
call_triton, backend=backend, fullgraph=True, dynamic=dynamic
)
output2 = torch.zeros_like(t1, requires_grad=grad)
self.assertEqual(compiled_func(t1, t2, output2), torch_add)
@requires_cuda
@skipIfRocm
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@common_utils.parametrize("grid_type", [1, 2, 3])
def test_triton_kernel_2d_autotune(self, grad, dynamic, backend, grid_type):
def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor):
x_elements = output.size()[0]
y_elements = output.size()[1]
def grid_fn(meta):
return (
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
)
if grid_type == 1:
grid = (x_elements, y_elements)
elif grid_type == 2:
grid = lambda meta: (
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
)
elif grid_type == 3:
grid = grid_fn
add_kernel_2d_autotuned[grid](x, y, output, x_elements, y_elements)
return output
t1 = torch.rand((512, 256), device="cuda", requires_grad=grad)
t2 = torch.rand((512, 256), device="cuda", requires_grad=grad)
output = torch.zeros_like(t1, requires_grad=grad)
torch_result = call_triton(t1, t2, output)
compiled_func = torch.compile(
call_triton, backend=backend, fullgraph=True, dynamic=dynamic
)
output2 = torch.zeros_like(t1, requires_grad=grad)
self.assertEqual(compiled_func(t1, t2, output2), torch_result)
@requires_cuda
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@patch.object(torch._inductor.config, "implicit_fallbacks", False)
def test_triton_kernel_native(self, grad, dynamic, backend):
def call_triton_add(
x: torch.Tensor,
y: torch.Tensor,
output: torch.Tensor,
grid_type: int,
num=1,
positional=False,
):
n_elements = output.numel()
def grid_fn(meta):
return (triton.cdiv(num, meta["BLOCK_SIZE"]),)
if grid_type == 0:
grid = (x.numel(),)
elif grid_type == 1:
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
else:
grid = grid_fn
if positional:
add_kernel[grid](x, y, output, n_elements, 16)
else:
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
return output
t1 = torch.rand(5, device="cuda", requires_grad=grad)
t2 = torch.rand(5, device="cuda", requires_grad=grad)
o1 = torch.zeros_like(t1, requires_grad=grad)
torch_add = t1 + t2
# No Dynamo -- Make sure triton kernel works
self.assertEqual(call_triton_add(t1, t2, o1, 1), torch_add)
# No Dynamo -- Make sure triton kernel works (with positional BLOCK_SIZE)
o2 = torch.zeros_like(t1, requires_grad=grad)
self.assertEqual(call_triton_add(t1, t2, o2, 1, True), torch_add)
# With Dynamo
compiled_func = torch.compile(
call_triton_add, backend=backend, fullgraph=True, dynamic=dynamic
)
# With simple kernel
o3 = torch.zeros_like(t1, requires_grad=grad)
self.assertEqual(compiled_func(t1, t2, o3, 0), torch_add)
# With lambda kernel
o4 = torch.zeros_like(t1, requires_grad=grad)
self.assertEqual(compiled_func(t1, t2, o4, 1), torch_add)
# With lambda kernel (with positional BLOCK_SIZE)
o5 = torch.zeros_like(t1, requires_grad=grad)
self.assertEqual(compiled_func(t1, t2, o5, 1, 1, True), torch_add)
# With user defined function kernel
o6 = torch.zeros_like(t1, requires_grad=grad)
self.assertEqual(compiled_func(t1, t2, o6, 2, 200), torch_add)
@requires_cuda
def test_triton_kernel_mutation_not_mark_dirty(self):
@torch.compile
def f(x):
n_elements = x.numel()
add_kernel[(n_elements,)](x, x, x, n_elements, 16)
return x
x = torch.randn(5, device="cuda", requires_grad=True)
x_cloned = x.clone()
out = x_cloned.sin()
f(x_cloned)
out.sum().backward()
@requires_cuda
def test_triton_kernel_matmul_tracking(self):
@triton.jit
def ones_kernel(x_ptr, n_elements, BLOCK_SIZE: "tl.constexpr"):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = 1.0
tl.store(x_ptr + offsets, x, mask=mask)
@torch.compile
def f(x):
out = torch.zeros_like(x)
ones_kernel[(4,)](out, 16, BLOCK_SIZE=16)
return torch.mm(out, x) + 10
x = torch.randn(4, 4, device="cuda")
torch_out = f(x)
python_out = torch.mm(torch.ones(4, 4, device="cuda"), x) + 10
self.assertEqual(torch_out, python_out)
@requires_cuda
def test_triton_kernel_strided_input(self):
def f(inp):
# left has strides [256, 1]
left, right = torch.split(inp, [128, 128], dim=1)
out = torch.empty_like(left)
X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16
grid = (left.size(1) // X_BLOCK_SIZE, left.size(0) // Y_BLOCK_SIZE)
double_strided_kernel[grid](
in_ptr=left,
out_ptr=out,
in_y_stride=left.stride(0),
out_y_stride=out.stride(0),
X_BLOCK_SIZE=X_BLOCK_SIZE,
Y_BLOCK_SIZE=Y_BLOCK_SIZE,
)
return out
inp = torch.randn(64, 256, device="cuda")
eager_out = f(inp)
compiled_out = torch.compile(f)(inp)
self.assertEqual(compiled_out, eager_out)
@requires_cuda
def test_triton_kernel_strided_input_nonzero_offset(self):
def f(inp):
# right has strides [256, 1] and storage offset 128
left, right = torch.split(inp, [128, 128], dim=1)
out = torch.empty_like(right)
X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16
grid = (right.size(1) // X_BLOCK_SIZE, right.size(0) // Y_BLOCK_SIZE)
double_strided_kernel[grid](
in_ptr=right,
out_ptr=out,
in_y_stride=right.stride(0),
out_y_stride=out.stride(0),
X_BLOCK_SIZE=X_BLOCK_SIZE,
Y_BLOCK_SIZE=Y_BLOCK_SIZE,
)
return out
inp = torch.randn(64, 256, device="cuda")
eager_out = f(inp)
compiled_out = torch.compile(f)(inp)
self.assertEqual(compiled_out, eager_out)
@requires_cuda
def test_triton_kernel_slice_and_view_input(self):
def f(inp):
# left has strides [256, 1]
left = inp[:, :128]
left = left.view(64, 4, 32)
out = torch.empty_like(left)
X_BLOCK_SIZE, Y_BLOCK_SIZE = 32, 16
grid = (
(left.size(1) * left.size(2)) // X_BLOCK_SIZE,
left.size(0) // Y_BLOCK_SIZE,
)
double_strided_kernel[grid](
in_ptr=left,
out_ptr=out,
in_y_stride=left.stride(0),
out_y_stride=out.stride(0),
X_BLOCK_SIZE=X_BLOCK_SIZE,
Y_BLOCK_SIZE=Y_BLOCK_SIZE,
)
return out + left
inp = torch.randn(64, 256, device="cuda")
eager_out = f(inp)
compiled_out = torch.compile(f)(inp)
self.assertEqual(compiled_out, eager_out)
@requires_cuda
def test_triton_kernel_fallback(self):
def f(x, y):
out = torch.zeros_like(x)
out2 = torch.zeros_like(x)
# torch.mm is ExternKernelOut
add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16)
# torch.sort creates fallback kernel and hence MultiOutput
add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16)
return out, out2
x = torch.randn(4, 4, device="cuda")
y = torch.randn(4, 4, device="cuda")
eager_out = f(x, y)
compiled_out = torch.compile(f)(x, y)
self.assertEqual(compiled_out, eager_out)
def make_mutation_test(fn):
@requires_cuda
@requires_lark
@skipIfRocm
def test_fn(self):
from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
kernel, inputs, outputs = fn()
self.assertListEqual(
identify_mutated_tensors(kernel, inputs),
outputs,
)
return test_fn
# Triton codegen suffers from scoping issues.
# Define helpers here
if HAS_CUDA:
@triton.jit
def helper_id(p):
return p
@triton.jit
def helper_add_and_out(x, y, out_ptr):
return x + y, out_ptr
class MutationTests(torch._dynamo.test_case.TestCase):
# Tests injected below
@make_mutation_test
def test_out_of_order_kernel():
@triton.jit
def add_kernel_out_of_order(
in_ptr0,
n_elements,
in_ptr1,
out_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
t = torch.randn(4)
return (
add_kernel_out_of_order,
{
"in_ptr0": t,
"n_elements": 4,
"in_ptr1": t,
"out_ptr": t,
"BLOCK_SIZE": 4,
},
["out_ptr"],
)
@make_mutation_test
def test_out_of_order_kernel_call():
@triton.jit
def add_kernel_out_of_order_fn1(
in_ptr0,
n_elements,
in_ptr1,
out_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
add_kernel_out_of_order_fn2(
in_ptr0, in_ptr1, n_elements, out_ptr, BLOCK_SIZE=BLOCK_SIZE
)
t = torch.randn(4)
return (
add_kernel_out_of_order_fn1,
{
"in_ptr0": t,
"n_elements": 4,
"in_ptr1": t,
"out_ptr": t,
"BLOCK_SIZE": 4,
},
["out_ptr"],
)
@make_mutation_test
def test_argmax():
@triton.jit
def argmax_kernel(a_ptr, c_ptr, stride_am, stride_an):
offs_am = tl.arange(0, 4)
offs_an = tl.arange(0, 4)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_an[None, :] * stride_an
)
a = tl.load(a_ptrs)
m = tl.argmax(a, axis=1)
tl.store(c_ptr + tl.arange(0, 4), m)
t = torch.randn(4)
return (
argmax_kernel,
{
"a_ptr": t,
"c_ptr": t,
"stride_am": 4,
"stride_an": 4,
},
# TODO(oulgen): tt.reduce closures are not implemented yet
["a_ptr", "c_ptr"],
)
@make_mutation_test
def test_fn_call_one_return():
@triton.jit
def add_kernel_with_fn_call(
in_ptr0,
in_ptr1,
n_elements,
out_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
out = helper_id(out_ptr)
tl.store(out + offsets, output, mask=mask)
t = torch.randn(4)
return (
add_kernel_with_fn_call,
{
"in_ptr0": t,
"in_ptr1": t,
"n_elements": 4,
"out_ptr": t,
"BLOCK_SIZE": 4,
},
# TODO(oulgen): helper return values not implemented yet
["in_ptr0", "in_ptr1", "out_ptr"],
)
@make_mutation_test
def test_fn_call_multi_return():
@triton.jit
def add_kernel_with_fn_call(
in_ptr0,
in_ptr1,
n_elements,
out_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output, out = helper_add_and_out(x, y, out_ptr)
tl.store(out + offsets, output, mask=mask)
t = torch.randn(4)
return (
add_kernel_with_fn_call,
{
"in_ptr0": t,
"in_ptr1": t,
"n_elements": 4,
"out_ptr": t,
"BLOCK_SIZE": 4,
},
# TODO(oulgen): multiple return values not implemented yet
["in_ptr0", "in_ptr1", "out_ptr"],
)
if HAS_CUDA and HAS_LARK:
t = torch.randn(4)
tests = [
[
add_kernel,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
["out_ptr"],
],
[
add_kernel_2d_autotuned,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"x_elements": 4,
"y_elements": 4,
},
["out_ptr"],
],
[
indirection_kernel,
{
"in_ptr0": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
"ACTIVATION": "mul2_inplace_kernel",
},
["in_ptr0", "out_ptr"],
],
[
indirection_kernel,
{
"in_ptr0": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
"ACTIVATION": "add_kernel",
},
# TODO(oulgen): Multiple functions is not implemented yet
["in_ptr0", "out_ptr"],
],
[
mul2_inplace_kernel,
{"ptr": t, "n_elements": 4, "BLOCK_SIZE": 4},
["ptr"],
],
# Cant optimize since the kernel contains a tl.inline_asm_elementwise
[
inline_asm_kernel,
{"X": t, "Y": t, "Z": t, "n": 4, "BLOCK": 4},
["X", "Y", "Z"],
],
[
add_kernel_with_block_ptr,
{
"x_ptr": t,
"y_ptr": t,
"output_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
["output_ptr"],
],
[
add_kernel_with_import,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
["out_ptr"],
],
[
atomic_add_kernel,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
["out_ptr"],
],
[
add_4_times_kernel,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
# TODO(oulgen): For loops not implemented yet
["in_ptr0", "in_ptr1", "out_ptr"],
],
[
cond_op_kernel,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
# TODO(oulgen): Dynamic control flow is not implemented yet
["in_ptr0", "in_ptr1", "out_ptr"],
],
]
for kernel, inputs, outputs in tests:
fn = make_mutation_test(
# Add default arguments to avoid Python lambda capture pitfall
# This forces the capture at lambda creation
lambda kernel=kernel, inputs=inputs, outputs=outputs: (
kernel,
inputs,
outputs,
)
)
name = f"test_mutations_{kernel.fn.__name__}"
# Poor way to make test names be unique
while name in MutationTests.__dict__:
name += "1"
setattr(MutationTests, name, fn)
common_utils.instantiate_parametrized_tests(KernelTests)
no_opt_test_class = make_test_cls_with_patches(
KernelTests,
"NoOptimization",
"_no_optimizations",
(config, "optimize_user_defined_triton_kernels", False),
)
globals()[no_opt_test_class.__name__] = no_opt_test_class
no_opt_test_class.__module__ = __name__
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()