| # Owner(s): ["module: inductor"] |
| |
| import torch |
| from torch import multiprocessing as mp |
| from torch._dynamo.test_case import run_tests, TestCase |
| from torch._inductor import config |
| from torch._inductor.graph import GraphLowering |
| from torch._inductor.ir import Buffer, FixedLayout |
| from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm |
| from torch._inductor.select_algorithm import AlgorithmSelectorCache, ChoiceCaller |
| from torch._inductor.utils import run_and_get_code |
| from torch._inductor.virtualized import V |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.testing import FileCheck |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| skipIfRocm, |
| ) |
| |
| from torch.testing._internal.inductor_utils import HAS_CUDA |
| |
| torch.set_float32_matmul_precision("high") |
| if HAS_CUDA: |
| torch.cuda.memory._set_allocator_settings("expandable_segments:False") |
| |
| |
| def benchmark_choice(choice, args, out, expected_out, timings): |
| result = choice.benchmark(*args, out=out) |
| if expected_out is not None: |
| torch.testing.assert_close(out, expected_out) |
| |
| timings.copy_(torch.tensor(result)) |
| |
| |
| class FailChoiceCaller(ChoiceCaller): |
| def benchmark(self, *args, out): |
| raise RuntimeError("This choice caller will always throw") |
| |
| |
| @instantiate_parametrized_tests |
| class TestDoBench(TestCase): |
| def _create_buffer(self, name, shape): |
| return Buffer(name, FixedLayout(torch.device("cuda:0"), torch.float32, shape)) |
| |
| def test_benchmark_choice_in_subproc(self): |
| gm = make_fx( |
| lambda: torch.zeros(2, 3) |
| )() # a dummy graph to construct the GraphLowering |
| graph = GraphLowering(gm) |
| |
| # the graph handler is neede to create benchmark example value below |
| with V.set_graph_handler(graph): |
| buf1 = self._create_buffer("mat1", (2, 3)) |
| buf2 = self._create_buffer("mat2", (3, 2)) |
| buf3 = self._create_buffer("mat3", (2, 3)) |
| buf4 = self._create_buffer("mat4", (3, 2)) |
| |
| layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2)) |
| |
| mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1) |
| mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2) |
| mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3) |
| mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4) |
| |
| out = AlgorithmSelectorCache.benchmark_example_value(layout) |
| # expected_out = (mat1 @ mat2) + (mat3 @ mat4) |
| expected_out = None |
| |
| choice = aten_mm_plus_mm.bind((buf1, buf2, buf3, buf4), layout) |
| # use a tensor since the mutation to a python list in a sub process |
| # is not synced back to the parent process |
| timings = torch.zeros(3, dtype=torch.float32) |
| ctx = mp.get_context("spawn") |
| child = ctx.Process( |
| target=benchmark_choice, |
| args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings), |
| ) |
| child.start() |
| child.join() |
| self.assertEqual(0, child.exitcode) |
| print(f"timings is {timings}, out {out}, expected_out {expected_out}") |
| |
| def test_benchmark_choice_fail_in_subproc(self): |
| gm = make_fx( |
| lambda: torch.zeros(2, 3) |
| )() # a dummy graph to construct the GraphLowering |
| graph = GraphLowering(gm) |
| |
| # the graph handler is neede to create benchmark example value below |
| with V.set_graph_handler(graph): |
| buf1 = self._create_buffer("mat1", (2, 3)) |
| buf2 = self._create_buffer("mat2", (3, 2)) |
| buf3 = self._create_buffer("mat3", (2, 3)) |
| buf4 = self._create_buffer("mat4", (3, 2)) |
| |
| layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2)) |
| |
| mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1) |
| mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2) |
| mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3) |
| mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4) |
| |
| out = AlgorithmSelectorCache.benchmark_example_value(layout) |
| expected_out = (mat1 @ mat2) + (mat3 @ mat4) |
| |
| choice = FailChoiceCaller("fail_choice_caller", [], None) |
| |
| # use a tensor since python list is not synced back |
| timings = torch.zeros(3, dtype=torch.float32) |
| ctx = mp.get_context("spawn") |
| child = ctx.Process( |
| target=benchmark_choice, |
| args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings), |
| ) |
| child.start() |
| child.join() |
| self.assertNotEqual(0, child.exitcode) |
| |
| @parametrize("autotune_in_subproc", (True, False)) |
| def test_max_autotune_mm_plus_mm(self, autotune_in_subproc): |
| """ |
| This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 . |
| With autotuning in subprocess, we don't crash anymore. |
| """ |
| m, n, k = 2048, 1536, 64 |
| |
| def mm_plus_mm(a, b, c, d): |
| return a @ b + c @ d |
| |
| a = torch.randn(m, k).cuda() |
| b = torch.randn(k, n).cuda() |
| c = torch.randn(m, k).cuda() |
| d = torch.randn(k, n).cuda() |
| |
| with config.patch( |
| {"max_autotune": True, "autotune_in_subproc": autotune_in_subproc} |
| ): |
| torch.compile(mm_plus_mm)(a, b, c, d) |
| |
| @parametrize("dynamic", (False, True)) |
| def test_max_autotune_mm_plus_mm_zero_size_input(self, dynamic): |
| """ |
| Make sure autotuning mm_plus_mm with zero-size input works without crashes. |
| """ |
| m, n, k = 0, 1536, 64 |
| |
| def mm_plus_mm(a, b, c, d): |
| return a @ b + c @ d |
| |
| a = torch.randn(m, k).cuda() |
| b = torch.randn(k, n).cuda() |
| c = torch.randn(m, k).cuda() |
| d = torch.randn(k, n).cuda() |
| |
| with config.patch({"max_autotune": True}): |
| torch.compile(mm_plus_mm, dynamic=dynamic)(a, b, c, d) |
| |
| @parametrize("dynamic", (False, True)) |
| def test_max_autotune_regular_mm(self, dynamic: bool): |
| """ |
| Make sure autotuning mm in sub processes work without crashes. |
| """ |
| |
| def mm(a, b): |
| a = torch.sin(a) |
| return a @ b |
| |
| a = torch.randn(100, 10).cuda() |
| b = torch.randn(10, 100).cuda() |
| |
| with config.patch({"max_autotune": True, "autotune_in_subproc": True}): |
| torch.compile(mm, dynamic=dynamic)(a, b) |
| |
| @parametrize("dynamic", (False, True)) |
| def test_max_autotune_regular_mm_zero_size_input(self, dynamic: bool): |
| """ |
| Make sure autotuning mm with zero-size input works without crashes. |
| """ |
| |
| def mm(a, b): |
| a = torch.sin(a) |
| return a @ b |
| |
| a = torch.randn(0, 10).cuda() |
| b = torch.randn(10, 100).cuda() |
| |
| with config.patch({"max_autotune": True}): |
| torch.compile(mm, dynamic=dynamic)(a, b) |
| |
| @parametrize("dynamic", (False, True)) |
| def test_max_autotune_addmm(self, dynamic): |
| """ |
| Make sure autotuning addmm in sub processes work without crashes. |
| """ |
| |
| def addmm(x, a, b): |
| return torch.addmm(x, a, b) |
| |
| x = torch.randn(100).cuda() |
| a = torch.randn(100, 10).cuda() |
| b = torch.randn(10, 100).cuda() |
| with config.patch({"max_autotune": True, "autotune_in_subproc": True}): |
| torch.compile(addmm, dynamic=dynamic)(x, a, b) |
| |
| @parametrize("dynamic", (False, True)) |
| def test_max_autotune_addmm_zero_size_input(self, dynamic): |
| """ |
| Make sure autotuning addmm with zero-size input works without crashes. |
| """ |
| |
| def addmm(x, a, b): |
| return torch.addmm(x, a, b) |
| |
| x = torch.randn(100).cuda() |
| a = torch.randn(0, 10).cuda() |
| b = torch.randn(10, 100).cuda() |
| with config.patch({"max_autotune": True}): |
| torch.compile(addmm, dynamic=dynamic)(x, a, b) |
| |
| @skipIfRocm |
| def test_autotune_conv1x1(self): |
| # Define the 1x1 convolutional layer |
| # Assuming input has 3 channels and we want to produce 16 channels as output |
| conv1x1 = ( |
| torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1) |
| .to(memory_format=torch.channels_last) |
| .cuda() |
| ) |
| |
| # Example input tensor: batch size = 4, channels = 3, height = 32, width = 32 |
| # The memory format is set to `channels_last` |
| input_tensor = ( |
| torch.randn(4, 3, 32, 32) |
| .contiguous(memory_format=torch.channels_last) |
| .cuda() |
| ) |
| |
| with config.patch( |
| {"max_autotune": True, "max_autotune_gemm_backends": "TRITON"} |
| ): |
| |
| @torch.compile() |
| def foo(mod, x): |
| return mod(x) |
| |
| with torch.no_grad(): |
| out, code = run_and_get_code(foo, conv1x1, input_tensor) |
| |
| FileCheck().check_not("extern_kernels.convolution").run(code[0]) |
| self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0) |
| |
| def test_cat_addmm(self): |
| def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): |
| return torch.cat( |
| [ |
| torch.addmm(a, b, c), |
| torch.addmm(b, c, a), |
| ], |
| 1, |
| ) |
| |
| args = [ |
| torch.randn(4, 4, device="cuda"), |
| torch.randn(4, 4, device="cuda"), |
| torch.randn(4, 4, device="cuda"), |
| ] |
| with config.patch( |
| { |
| "max_autotune": True, |
| "max_autotune_gemm_backends": "Triton", |
| } |
| ): |
| expected = fn(*args) |
| actual = torch.compile(fn)(*args) |
| torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2) |
| |
| def test_triton_template_with_epilogues_and_dynamic_shape(self): |
| def fn( |
| x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor, mul: torch.Tensor |
| ) -> torch.Tensor: |
| return ( |
| torch.nn.functional.relu( |
| torch.matmul(torch.transpose(x, 0, 1), torch.transpose(w, 0, 1)) |
| + bias |
| ) |
| * mul |
| ) |
| |
| M0 = 5 |
| M1 = 8 |
| K = 4 |
| N = 3 |
| w = torch.rand(N, K).cuda().half() |
| b = torch.rand(N).cuda().half() |
| |
| with config.patch( |
| { |
| "max_autotune": True, |
| "autotune_in_subproc": True, |
| "max_autotune_gemm_backends": "Triton", |
| } |
| ): |
| compiled_fn = torch.compile( |
| fn, fullgraph=True, dynamic=True, mode="max-autotune-no-cudagraphs" |
| ) |
| |
| x0 = torch.rand(K, M0).cuda().half() |
| mul0 = torch.rand(M0, N).cuda().half() |
| y0 = compiled_fn(x0, w, b, mul0) |
| y0_expected = fn(x0, w, b, mul0) |
| torch.testing.assert_close(y0, y0_expected) |
| |
| x1 = torch.rand(K, M1).cuda().half() |
| mul1 = torch.rand(M1, N).cuda().half() |
| y1 = compiled_fn(x1, w, b, mul1) |
| y1_expected = fn(x1, w, b, mul1) |
| torch.testing.assert_close(y1, y1_expected) |
| |
| |
| if __name__ == "__main__": |
| if HAS_CUDA: |
| run_tests() |