| # Owner(s): ["module: inductor"] |
| |
| import os |
| import re |
| import unittest |
| |
| import torch |
| from torch import nn |
| from torch._dynamo.testing import reset_rng_state |
| |
| from torch._inductor import config, test_operators |
| from torch._inductor.codegen.multi_kernel import MultiKernelCall |
| from torch._inductor.test_case import TestCase |
| from torch._inductor.utils import run_and_get_code |
| from torch.nn import functional as F |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| ) |
| from torch.testing._internal.inductor_utils import HAS_CUDA |
| |
| |
| class TransformerSnippet(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(64) |
| self.ln2 = nn.LayerNorm(64) |
| |
| def forward(self, x1, x2): |
| x1 = F.dropout(x1, 0.1) |
| x2 = F.dropout(self.ln1(x2), 0.1) |
| |
| return self.ln2(x1 + x2) |
| |
| def example_inputs(self): |
| return (torch.randn(2, 64).cuda(), torch.randn(2, 64).cuda()) |
| |
| |
| def _contains_multi_kernel_code(wrapper_code: str): |
| return ( |
| re.search(r"multi_kernel_[^ ]* = async_compile.multi_kernel[(]", wrapper_code) |
| is not None |
| ) |
| |
| |
| def make_cpp_wrapper_test(orig_test, **extra_args): |
| """ |
| Wrap an existing test into a new test with cpp-wrapper enabled. |
| |
| Make this as a free function rather than staticmethod in MultiKernelTest. |
| Otherwise we get 'TypeError: 'staticmethod' object is not callable' |
| error in py3.8. (py3.10 works) |
| """ |
| |
| @config.patch("cpp_wrapper", True) |
| def fn(self): |
| # The same kernel may have been compiled by previous tests with |
| # cpp_wrapper disabled. Clear the cache so we go ahead to re-compile |
| # the kernel with cpp_wrapper enabled. |
| from torch._inductor import codecache |
| |
| codecache.PyCodeCache.cache_clear() |
| return orig_test(self, **extra_args) |
| |
| return fn |
| |
| |
| @config.patch( |
| { |
| "triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")), |
| "benchmark_kernel": True, |
| } |
| ) |
| @instantiate_parametrized_tests |
| class MultiKernelTest(TestCase): |
| def test_softmax(self, expect_multi_kernel=True): |
| x = torch.rand(2, 1024).cuda() |
| ref = torch.softmax(x, -1) |
| compiled_fn = torch.compile(torch.softmax) |
| act, wrapper_code = run_and_get_code(compiled_fn, x, -1) |
| |
| # wrapper_code will contains 2 entries if cpp_wrapper=True. |
| # One for the first pass and one for the second pass. |
| # We mainly care about the wrapper for the final pass here. |
| wrapper_code = wrapper_code[-1] |
| self.assertTrue(torch.allclose(ref, act)) |
| if expect_multi_kernel: |
| self.assertTrue(_contains_multi_kernel_code(wrapper_code)) |
| else: |
| # Skip verifying the wrapper_code in fbcode since we may fail |
| # compiling the cpp wrapper cuda code due to lacking proper setup of |
| # cuda compiler in fbcode environment. In that case, the last |
| # collected wrapper_code will corresponds to the first pass |
| # cpp-wrapper codegen which contains the multi-kernel. |
| if not config.is_fbcode(): |
| self.assertFalse(_contains_multi_kernel_code(wrapper_code)) |
| |
| @parametrize("force_kernel", (0, 1)) |
| @unittest.mock.patch.dict( |
| os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"} |
| ) |
| def test_softmax_force_non_persistent_reduction(self, force_kernel): |
| """ |
| Force a specific sub-kernel being picked by mocking the benchmark result. |
| """ |
| x = torch.rand(2, 1024).cuda() |
| mock_latency = [0.2, 0.2] |
| mock_latency[force_kernel] = 0.1 # this make sure force_kernel will be picked |
| |
| def f(x): |
| return torch.softmax(x, -1) + force_kernel |
| |
| orig_run = MultiKernelCall.run_with_argless_kernels |
| picked_kernel = None |
| |
| def mock_run(self, kernel_calls): |
| out = orig_run(self, kernel_calls) |
| nonlocal picked_kernel |
| picked_kernel = self.picked_kernel |
| return out |
| |
| with unittest.mock.patch.object( |
| MultiKernelCall, "run_with_argless_kernels", mock_run |
| ), unittest.mock.patch.object( |
| MultiKernelCall, "benchmark_sub_kernels", lambda *args: mock_latency |
| ): |
| torch.compile(f)(x) |
| self.assertEqual(picked_kernel, force_kernel) |
| |
| @config.patch("warn_mix_layout", True) |
| def test_softmax_warn_mixed_layout(self): |
| self.test_softmax() |
| |
| test_softmax_cpp_wrapper = make_cpp_wrapper_test( |
| test_softmax, expect_multi_kernel=False |
| ) |
| |
| def test_layernorm(self): |
| ln = nn.LayerNorm(1024).cuda() |
| x = torch.rand(2, 1024).cuda() |
| ref = ln(x) |
| act = torch.compile(ln)(x) |
| self.assertTrue( |
| torch.allclose(ref, act, atol=1e-4, rtol=1e-4), f"ref:\n{ref}\nact:\n{act}" |
| ) |
| |
| def test_inplace_update(self): |
| """ |
| Inductor generate inplace kernel for mul. |
| """ |
| |
| def f(x, y): |
| return x.sum(dim=-1, keepdims=True) * (y @ y) |
| |
| x = torch.rand(1024, 1024).cuda() |
| y = torch.rand(1024, 1024).cuda() |
| ref = f(x, y) |
| act = torch.compile(f)(x, y) |
| self.assertTrue(torch.allclose(ref, act)) |
| |
| def test_transformer_snippet(self): |
| model = TransformerSnippet().cuda() |
| x = model.example_inputs() |
| |
| def f(*x): |
| y = model(*x) |
| return y |
| |
| reset_rng_state() |
| ref = f(*x) |
| |
| opt_f = torch.compile(f) |
| reset_rng_state() |
| act = opt_f(*x) |
| |
| # don't compare tensor if using inductor random number generator. |
| # inductor random number implementation is different to eager. |
| # We should fallback to eager if we want to test accuracy. |
| if config.fallback_random: |
| self.assertTrue( |
| torch.allclose(ref, act, atol=1e-4, rtol=1e-4), |
| f"ref:\n{ref}\nact:\n{act}", |
| ) |
| |
| def test_transformer_snippet_with_fallback_random(self): |
| """ |
| Same as test_transformer_snippet but fallback the random number |
| generator to eager so we can check accuracy. |
| """ |
| with config.patch("fallback_random", True): |
| self.test_transformer_snippet() |
| |
| def test_batchnorm_training(self): |
| """ |
| For training, batchnorm will tracking running mean/variance during forward pass. |
| The kernel generated by inductor currently will pass in those tensors twice as arguments: |
| once for input and once for output. They are ruled out as in-out argument because |
| they are considered as graph inputs. |
| |
| Multi-kernel previously assumes that we never pass the same argument mutli times |
| for a kernel. No mater if we change inductor behavior to assure that, it's better |
| to make multi-kernel being able to handle those cases. |
| """ |
| bn = nn.BatchNorm2d(3).to("cuda") |
| |
| @torch.compile |
| def f(x): |
| bn(x).sum().backward() |
| |
| _, (wrapper_code, _) = run_and_get_code( |
| f, torch.randn(2, 3, 8, 8, device="cuda") |
| ) |
| self.assertTrue(_contains_multi_kernel_code(wrapper_code)) |
| |
| def test_pass_same_arg_multi_times(self): |
| """ |
| A super simple example that simulate how BatchNorm update the running |
| stats. |
| |
| Inductor currently pass the same tensor multiple times for the generated |
| kernel: once for input and once for output. |
| |
| Here is a paster for the generated kernel (without multi-kernel enabled): |
| https://gist.github.com/shunting314/f0b446b4b9a28f4940e31dcd3e809cf9 |
| """ |
| |
| def f(x, y): |
| x = x.sum(dim=1, keepdim=False) |
| y.copy_(y * 0.9 + x * 0.1) |
| |
| x = torch.randn(8, 16, device="cuda") |
| y = torch.randn(8, device="cuda") |
| y_ref = y.clone() |
| |
| ref = f(x, y_ref) |
| act = torch.compile(f)(x, y) |
| self.assertTrue(torch.allclose(y_ref, y)) |
| |
| def test_reduction_scratch_buffer(self, force_multi_kernel=1): |
| """ |
| The explicited realized buffer in the test function will be passed in |
| as a scratch buffer for the non-persistent reduction kernel but |
| can be skipped for the persistent reduction kernel. |
| |
| This causes different argument lists for non-persistent reduction kernel and |
| persistent reduction kernel. |
| |
| Check documentation around torch._inductor.config.triton.multi_kernel about |
| how to interpret the force_multi_kernel argument. |
| """ |
| |
| def f(x): |
| x = x.sum(dim=-1, keepdim=True) + x |
| x = test_operators.realize(x) |
| x = x.sum(dim=-1, keepdim=True) + x |
| return x |
| |
| x = torch.rand(16, 16, device="cuda") |
| ref = f(x) |
| with config.patch("triton.multi_kernel", force_multi_kernel): |
| act = torch.compile(f)(x) |
| self.assertTrue(torch.allclose(ref, act)) |
| |
| # Use benchmarking to pick the faster kernel |
| test_reduction_scratch_buffer_cpp_wrapper = make_cpp_wrapper_test( |
| test_reduction_scratch_buffer, force_multi_kernel=1 |
| ) |
| # force pick persistent reduction. This can be a good test since this persistent |
| # reduction uses less call arguments than the corresponding non-persistent |
| # reduction. |
| test_reduction_scratch_buffer_cpp_wrapper_persistent_reduction = ( |
| make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=2) |
| ) |
| # force pick non-persistent reduction |
| test_reduction_scratch_buffer_cpp_wrapper_non_persistent_reduction = ( |
| make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=3) |
| ) |
| |
| |
| if __name__ == "__main__": |
| from torch._inductor.test_case import run_tests |
| |
| if HAS_CUDA: |
| run_tests() |