blob: 1437e433941314586588c741bae34ced4421c0c9 [file] [log] [blame]
# Owner(s): ["module: inductor"]
import os
import re
import unittest
import torch
from torch import nn
from torch._dynamo.testing import reset_rng_state
from torch._inductor import config, test_operators
from torch._inductor.codegen.multi_kernel import MultiKernelCall
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_code
from torch.nn import functional as F
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
class TransformerSnippet(nn.Module):
def __init__(self):
super().__init__()
self.ln1 = nn.LayerNorm(64)
self.ln2 = nn.LayerNorm(64)
def forward(self, x1, x2):
x1 = F.dropout(x1, 0.1)
x2 = F.dropout(self.ln1(x2), 0.1)
return self.ln2(x1 + x2)
def example_inputs(self):
return (torch.randn(2, 64).cuda(), torch.randn(2, 64).cuda())
def _contains_multi_kernel_code(wrapper_code: str):
return (
re.search(r"multi_kernel_[^ ]* = async_compile.multi_kernel[(]", wrapper_code)
is not None
)
def make_cpp_wrapper_test(orig_test, **extra_args):
"""
Wrap an existing test into a new test with cpp-wrapper enabled.
Make this as a free function rather than staticmethod in MultiKernelTest.
Otherwise we get 'TypeError: 'staticmethod' object is not callable'
error in py3.8. (py3.10 works)
"""
@config.patch("cpp_wrapper", True)
def fn(self):
# The same kernel may have been compiled by previous tests with
# cpp_wrapper disabled. Clear the cache so we go ahead to re-compile
# the kernel with cpp_wrapper enabled.
from torch._inductor import codecache
codecache.PyCodeCache.cache_clear()
return orig_test(self, **extra_args)
return fn
@config.patch(
{
"triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")),
"benchmark_kernel": True,
}
)
@instantiate_parametrized_tests
class MultiKernelTest(TestCase):
def test_softmax(self, expect_multi_kernel=True):
x = torch.rand(2, 1024).cuda()
ref = torch.softmax(x, -1)
compiled_fn = torch.compile(torch.softmax)
act, wrapper_code = run_and_get_code(compiled_fn, x, -1)
# wrapper_code will contains 2 entries if cpp_wrapper=True.
# One for the first pass and one for the second pass.
# We mainly care about the wrapper for the final pass here.
wrapper_code = wrapper_code[-1]
self.assertTrue(torch.allclose(ref, act))
if expect_multi_kernel:
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
else:
# Skip verifying the wrapper_code in fbcode since we may fail
# compiling the cpp wrapper cuda code due to lacking proper setup of
# cuda compiler in fbcode environment. In that case, the last
# collected wrapper_code will corresponds to the first pass
# cpp-wrapper codegen which contains the multi-kernel.
if not config.is_fbcode():
self.assertFalse(_contains_multi_kernel_code(wrapper_code))
@parametrize("force_kernel", (0, 1))
@unittest.mock.patch.dict(
os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"}
)
def test_softmax_force_non_persistent_reduction(self, force_kernel):
"""
Force a specific sub-kernel being picked by mocking the benchmark result.
"""
x = torch.rand(2, 1024).cuda()
mock_latency = [0.2, 0.2]
mock_latency[force_kernel] = 0.1 # this make sure force_kernel will be picked
def f(x):
return torch.softmax(x, -1) + force_kernel
orig_run = MultiKernelCall.run_with_argless_kernels
picked_kernel = None
def mock_run(self, kernel_calls):
out = orig_run(self, kernel_calls)
nonlocal picked_kernel
picked_kernel = self.picked_kernel
return out
with unittest.mock.patch.object(
MultiKernelCall, "run_with_argless_kernels", mock_run
), unittest.mock.patch.object(
MultiKernelCall, "benchmark_sub_kernels", lambda *args: mock_latency
):
torch.compile(f)(x)
self.assertEqual(picked_kernel, force_kernel)
@config.patch("warn_mix_layout", True)
def test_softmax_warn_mixed_layout(self):
self.test_softmax()
test_softmax_cpp_wrapper = make_cpp_wrapper_test(
test_softmax, expect_multi_kernel=False
)
def test_layernorm(self):
ln = nn.LayerNorm(1024).cuda()
x = torch.rand(2, 1024).cuda()
ref = ln(x)
act = torch.compile(ln)(x)
self.assertTrue(
torch.allclose(ref, act, atol=1e-4, rtol=1e-4), f"ref:\n{ref}\nact:\n{act}"
)
def test_inplace_update(self):
"""
Inductor generate inplace kernel for mul.
"""
def f(x, y):
return x.sum(dim=-1, keepdims=True) * (y @ y)
x = torch.rand(1024, 1024).cuda()
y = torch.rand(1024, 1024).cuda()
ref = f(x, y)
act = torch.compile(f)(x, y)
self.assertTrue(torch.allclose(ref, act))
def test_transformer_snippet(self):
model = TransformerSnippet().cuda()
x = model.example_inputs()
def f(*x):
y = model(*x)
return y
reset_rng_state()
ref = f(*x)
opt_f = torch.compile(f)
reset_rng_state()
act = opt_f(*x)
# don't compare tensor if using inductor random number generator.
# inductor random number implementation is different to eager.
# We should fallback to eager if we want to test accuracy.
if config.fallback_random:
self.assertTrue(
torch.allclose(ref, act, atol=1e-4, rtol=1e-4),
f"ref:\n{ref}\nact:\n{act}",
)
def test_transformer_snippet_with_fallback_random(self):
"""
Same as test_transformer_snippet but fallback the random number
generator to eager so we can check accuracy.
"""
with config.patch("fallback_random", True):
self.test_transformer_snippet()
def test_batchnorm_training(self):
"""
For training, batchnorm will tracking running mean/variance during forward pass.
The kernel generated by inductor currently will pass in those tensors twice as arguments:
once for input and once for output. They are ruled out as in-out argument because
they are considered as graph inputs.
Multi-kernel previously assumes that we never pass the same argument mutli times
for a kernel. No mater if we change inductor behavior to assure that, it's better
to make multi-kernel being able to handle those cases.
"""
bn = nn.BatchNorm2d(3).to("cuda")
@torch.compile
def f(x):
bn(x).sum().backward()
_, (wrapper_code, _) = run_and_get_code(
f, torch.randn(2, 3, 8, 8, device="cuda")
)
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
def test_pass_same_arg_multi_times(self):
"""
A super simple example that simulate how BatchNorm update the running
stats.
Inductor currently pass the same tensor multiple times for the generated
kernel: once for input and once for output.
Here is a paster for the generated kernel (without multi-kernel enabled):
https://gist.github.com/shunting314/f0b446b4b9a28f4940e31dcd3e809cf9
"""
def f(x, y):
x = x.sum(dim=1, keepdim=False)
y.copy_(y * 0.9 + x * 0.1)
x = torch.randn(8, 16, device="cuda")
y = torch.randn(8, device="cuda")
y_ref = y.clone()
ref = f(x, y_ref)
act = torch.compile(f)(x, y)
self.assertTrue(torch.allclose(y_ref, y))
def test_reduction_scratch_buffer(self, force_multi_kernel=1):
"""
The explicited realized buffer in the test function will be passed in
as a scratch buffer for the non-persistent reduction kernel but
can be skipped for the persistent reduction kernel.
This causes different argument lists for non-persistent reduction kernel and
persistent reduction kernel.
Check documentation around torch._inductor.config.triton.multi_kernel about
how to interpret the force_multi_kernel argument.
"""
def f(x):
x = x.sum(dim=-1, keepdim=True) + x
x = test_operators.realize(x)
x = x.sum(dim=-1, keepdim=True) + x
return x
x = torch.rand(16, 16, device="cuda")
ref = f(x)
with config.patch("triton.multi_kernel", force_multi_kernel):
act = torch.compile(f)(x)
self.assertTrue(torch.allclose(ref, act))
# Use benchmarking to pick the faster kernel
test_reduction_scratch_buffer_cpp_wrapper = make_cpp_wrapper_test(
test_reduction_scratch_buffer, force_multi_kernel=1
)
# force pick persistent reduction. This can be a good test since this persistent
# reduction uses less call arguments than the corresponding non-persistent
# reduction.
test_reduction_scratch_buffer_cpp_wrapper_persistent_reduction = (
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=2)
)
# force pick non-persistent reduction
test_reduction_scratch_buffer_cpp_wrapper_non_persistent_reduction = (
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=3)
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CUDA:
run_tests()