| # Owner(s): ["module: inductor"] |
| |
| import torch |
| from torch import multiprocessing as mp |
| from torch._dynamo.test_case import run_tests, TestCase |
| from torch._inductor import config |
| from torch._inductor.graph import GraphLowering |
| from torch._inductor.ir import Buffer, FixedLayout |
| from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm |
| from torch._inductor.select_algorithm import AlgorithmSelectorCache, ChoiceCaller |
| from torch._inductor.virtualized import V |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| ) |
| from torch.testing._internal.inductor_utils import HAS_CUDA |
| |
| torch.set_float32_matmul_precision("high") |
| if HAS_CUDA: |
| torch.cuda.memory._set_allocator_settings("expandable_segments:False") |
| |
| |
| def benchmark_choice(choice, args, out, expected_out, timings): |
| result = choice.benchmark(*args, out=out) |
| if expected_out is not None: |
| torch.testing.assert_close(out, expected_out) |
| |
| timings.copy_(torch.tensor(result)) |
| |
| |
| class FailChoiceCaller(ChoiceCaller): |
| def benchmark(self, *args, out): |
| raise RuntimeError("This choice caller will always throw") |
| |
| |
| @instantiate_parametrized_tests |
| class TestDoBench(TestCase): |
| def _create_buffer(self, name, shape): |
| return Buffer(name, FixedLayout(torch.device("cuda:0"), torch.float32, shape)) |
| |
| def test_benchmark_choice_in_subproc(self): |
| gm = make_fx( |
| lambda: torch.zeros(2, 3) |
| )() # a dummy graph to construct the GraphLowering |
| graph = GraphLowering(gm) |
| |
| # the graph handler is neede to create benchmark example value below |
| with V.set_graph_handler(graph): |
| buf1 = self._create_buffer("mat1", (2, 3)) |
| buf2 = self._create_buffer("mat2", (3, 2)) |
| buf3 = self._create_buffer("mat3", (2, 3)) |
| buf4 = self._create_buffer("mat4", (3, 2)) |
| |
| layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2)) |
| |
| mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1) |
| mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2) |
| mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3) |
| mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4) |
| |
| out = AlgorithmSelectorCache.benchmark_example_value(layout) |
| # expected_out = (mat1 @ mat2) + (mat3 @ mat4) |
| expected_out = None |
| |
| choice = aten_mm_plus_mm.bind((buf1, buf2, buf3, buf4), layout) |
| # use a tensor since the mutation to a python list in a sub process |
| # is not synced back to the parent process |
| timings = torch.zeros(3, dtype=torch.float32) |
| ctx = mp.get_context("spawn") |
| child = ctx.Process( |
| target=benchmark_choice, |
| args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings), |
| ) |
| child.start() |
| child.join() |
| self.assertEqual(0, child.exitcode) |
| print(f"timings is {timings}, out {out}, expected_out {expected_out}") |
| |
| def test_benchmark_choice_fail_in_subproc(self): |
| gm = make_fx( |
| lambda: torch.zeros(2, 3) |
| )() # a dummy graph to construct the GraphLowering |
| graph = GraphLowering(gm) |
| |
| # the graph handler is neede to create benchmark example value below |
| with V.set_graph_handler(graph): |
| buf1 = self._create_buffer("mat1", (2, 3)) |
| buf2 = self._create_buffer("mat2", (3, 2)) |
| buf3 = self._create_buffer("mat3", (2, 3)) |
| buf4 = self._create_buffer("mat4", (3, 2)) |
| |
| layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2)) |
| |
| mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1) |
| mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2) |
| mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3) |
| mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4) |
| |
| out = AlgorithmSelectorCache.benchmark_example_value(layout) |
| expected_out = (mat1 @ mat2) + (mat3 @ mat4) |
| |
| choice = FailChoiceCaller("fail_choice_caller", [], None) |
| |
| # use a tensor since python list is not synced back |
| timings = torch.zeros(3, dtype=torch.float32) |
| ctx = mp.get_context("spawn") |
| child = ctx.Process( |
| target=benchmark_choice, |
| args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings), |
| ) |
| child.start() |
| child.join() |
| self.assertNotEqual(0, child.exitcode) |
| |
| @parametrize("autotune_in_subproc", (True, False)) |
| def test_max_autotune_mm_plus_mm(self, autotune_in_subproc): |
| """ |
| This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 . |
| With autotuning in subprocess, we don't crash anymore. |
| """ |
| m, n, k = 2048, 1536, 64 |
| |
| def mm_plus_mm(a, b, c, d): |
| return a @ b + c @ d |
| |
| a = torch.randn(m, k).cuda() |
| b = torch.randn(k, n).cuda() |
| c = torch.randn(m, k).cuda() |
| d = torch.randn(k, n).cuda() |
| |
| with config.patch( |
| {"max_autotune": True, "autotune_in_subproc": autotune_in_subproc} |
| ): |
| torch.compile(mm_plus_mm)(a, b, c, d) |
| |
| def test_max_autotune_regular_mm(self): |
| """ |
| Make sure autotuning mm in sub processes work without crashes. |
| """ |
| |
| def mm(a, b): |
| a = torch.sin(a) |
| return a @ b |
| |
| a = torch.randn(100, 10).cuda() |
| b = torch.randn(10, 100).cuda() |
| |
| with config.patch({"max_autotune": True, "autotune_in_subproc": True}): |
| torch.compile(mm)(a, b) |
| |
| def test_max_autotune_addmm(self): |
| """ |
| Make sure autotuning addmm in sub processes work without crashes. |
| """ |
| |
| def addmm(x, a, b): |
| return torch.addmm(x, a, b) |
| |
| x = torch.randn(100).cuda() |
| a = torch.randn(100, 10).cuda() |
| b = torch.randn(10, 100).cuda() |
| with config.patch({"max_autotune": True, "autotune_in_subproc": True}): |
| torch.compile(addmm)(x, a, b) |
| |
| |
| if __name__ == "__main__": |
| if HAS_CUDA: |
| run_tests() |