| # Owner(s): ["module: inductor"] |
| import json |
| import os |
| import unittest |
| |
| from typing import Callable, List, Optional |
| |
| import torch |
| from torch import multiprocessing as mp, nn |
| from torch._dynamo import reset |
| from torch._dynamo.exc import BackendCompilerFailed |
| from torch._dynamo.testing import rand_strided, reset_rng_state |
| from torch._inductor import config |
| from torch._inductor.autotune_process import ( |
| BenchmarkRequest, |
| CUDA_VISIBLE_DEVICES, |
| TuningProcessPool, |
| ) |
| from torch._inductor.graph import GraphLowering |
| from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout |
| from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm |
| from torch._inductor.select_algorithm import ( |
| AlgorithmSelectorCache, |
| TritonTemplateCaller, |
| ) |
| from torch._inductor.test_case import run_tests, TestCase |
| |
| from torch._inductor.utils import fresh_inductor_cache, run_and_get_code |
| from torch._inductor.virtualized import V |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.testing import FileCheck |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| skipIfRocm, |
| ) |
| |
| from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA |
| |
| torch.set_float32_matmul_precision("high") |
| if HAS_CUDA: |
| torch.cuda.memory._set_allocator_settings("expandable_segments:False") |
| |
| _CUTLASS_DIR = os.path.join(os.path.dirname(__file__), "../../third_party/cutlass/") |
| |
| |
| def _get_path_without_sccache() -> str: |
| """ |
| Get the PATH environment variable without sccache. |
| """ |
| path_envs = os.environ.get("PATH", "").split(":") |
| path_envs = [env for env in path_envs if "/opt/cache/bin" not in env] |
| return ":".join(path_envs) |
| |
| |
| def benchmark_choice(choice, args, out, expected_out, timings): |
| result = choice.benchmark(*args, out=out) |
| if expected_out is not None: |
| torch.testing.assert_close(out, expected_out) |
| |
| timings.copy_(torch.tensor(result)) |
| |
| |
| class FailChoiceCaller(ChoiceCaller): |
| def benchmark(self, *args, out): |
| raise RuntimeError("This choice caller will always throw") |
| |
| |
| @instantiate_parametrized_tests |
| class TestMaxAutotune(TestCase): |
| def _create_buffer(self, name, shape): |
| return Buffer(name, FixedLayout(torch.device("cuda:0"), torch.float32, shape)) |
| |
| def test_benchmark_choice_in_subproc(self): |
| gm = make_fx( |
| lambda: torch.zeros(2, 3) |
| )() # a dummy graph to construct the GraphLowering |
| graph = GraphLowering(gm) |
| |
| # the graph handler is neede to create benchmark example value below |
| with V.set_graph_handler(graph): |
| buf1 = self._create_buffer("mat1", (2, 3)) |
| buf2 = self._create_buffer("mat2", (3, 2)) |
| buf3 = self._create_buffer("mat3", (2, 3)) |
| buf4 = self._create_buffer("mat4", (3, 2)) |
| |
| layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2)) |
| |
| mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1) |
| mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2) |
| mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3) |
| mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4) |
| |
| out = AlgorithmSelectorCache.benchmark_example_value(layout) |
| # expected_out = (mat1 @ mat2) + (mat3 @ mat4) |
| expected_out = None |
| |
| choice = aten_mm_plus_mm.bind((buf1, buf2, buf3, buf4), layout) |
| # use a tensor since the mutation to a python list in a sub process |
| # is not synced back to the parent process |
| timings = torch.zeros(3, dtype=torch.float32) |
| ctx = mp.get_context("spawn") |
| child = ctx.Process( |
| target=benchmark_choice, |
| args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings), |
| ) |
| child.start() |
| child.join() |
| self.assertEqual(0, child.exitcode) |
| print(f"timings is {timings}, out {out}, expected_out {expected_out}") |
| |
| def test_benchmark_choice_fail_in_subproc(self): |
| gm = make_fx( |
| lambda: torch.zeros(2, 3) |
| )() # a dummy graph to construct the GraphLowering |
| graph = GraphLowering(gm) |
| |
| # the graph handler is neede to create benchmark example value below |
| with V.set_graph_handler(graph): |
| buf1 = self._create_buffer("mat1", (2, 3)) |
| buf2 = self._create_buffer("mat2", (3, 2)) |
| buf3 = self._create_buffer("mat3", (2, 3)) |
| buf4 = self._create_buffer("mat4", (3, 2)) |
| |
| layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2)) |
| |
| mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1) |
| mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2) |
| mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3) |
| mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4) |
| |
| out = AlgorithmSelectorCache.benchmark_example_value(layout) |
| expected_out = (mat1 @ mat2) + (mat3 @ mat4) |
| |
| choice = FailChoiceCaller("fail_choice_caller", [], None) |
| |
| # use a tensor since python list is not synced back |
| timings = torch.zeros(3, dtype=torch.float32) |
| ctx = mp.get_context("spawn") |
| child = ctx.Process( |
| target=benchmark_choice, |
| args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings), |
| ) |
| child.start() |
| child.join() |
| self.assertNotEqual(0, child.exitcode) |
| |
| @parametrize("autotune_in_subproc", (True, False)) |
| @parametrize("autotune_multi_device", (True, False)) |
| def test_max_autotune_mm_plus_mm(self, autotune_in_subproc, autotune_multi_device): |
| """ |
| This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 . |
| With autotuning in subprocess, we don't crash anymore. |
| """ |
| m, n, k = 2048, 1536, 64 |
| |
| def mm_plus_mm(a, b, c, d): |
| return a @ b + c @ d |
| |
| a = torch.randn(m, k).cuda() |
| b = torch.randn(k, n).cuda() |
| c = torch.randn(m, k).cuda() |
| d = torch.randn(k, n).cuda() |
| |
| with config.patch( |
| { |
| "max_autotune": True, |
| "autotune_in_subproc": autotune_in_subproc, |
| "autotune_multi_device": autotune_multi_device, |
| } |
| ): |
| torch.compile(mm_plus_mm)(a, b, c, d) |
| |
| @parametrize("dynamic", (False, True)) |
| def test_max_autotune_mm_plus_mm_zero_size_input(self, dynamic): |
| """ |
| Make sure autotuning mm_plus_mm with zero-size input works without crashes. |
| """ |
| m, n, k = 0, 1536, 64 |
| |
| def mm_plus_mm(a, b, c, d): |
| return a @ b + c @ d |
| |
| a = torch.randn(m, k).cuda() |
| b = torch.randn(k, n).cuda() |
| c = torch.randn(m, k).cuda() |
| d = torch.randn(k, n).cuda() |
| |
| with config.patch({"max_autotune": True}): |
| torch.compile(mm_plus_mm, dynamic=dynamic)(a, b, c, d) |
| |
| @parametrize("dynamic", (False, True)) |
| def test_max_autotune_regular_mm(self, dynamic: bool): |
| """ |
| Make sure autotuning mm in sub processes work without crashes. |
| """ |
| |
| def mm(a, b): |
| a = torch.sin(a) |
| return a @ b |
| |
| a = torch.randn(100, 10).cuda() |
| b = torch.randn(10, 100).cuda() |
| |
| with config.patch({"max_autotune": True, "autotune_in_subproc": True}): |
| torch.compile(mm, dynamic=dynamic)(a, b) |
| |
| @parametrize("dynamic", (False, True)) |
| def test_max_autotune_regular_mm_zero_size_input(self, dynamic: bool): |
| """ |
| Make sure autotuning mm with zero-size input works without crashes. |
| """ |
| |
| def mm(a, b): |
| a = torch.sin(a) |
| return a @ b |
| |
| a = torch.randn(0, 10).cuda() |
| b = torch.randn(10, 100).cuda() |
| |
| with config.patch({"max_autotune": True}): |
| torch.compile(mm, dynamic=dynamic)(a, b) |
| |
| @skipIfRocm |
| @parametrize("dynamic", (False, True)) |
| def test_max_autotune_remote_caching(self, dynamic: bool): |
| from unittest.mock import patch |
| |
| def mm(a, b): |
| a = torch.sin(a) |
| return a @ b |
| |
| a = torch.randn(100, 10).cuda() |
| b = torch.randn(10, 100).cuda() |
| |
| class Model(torch.nn.Module): |
| def forward(self, x, y): |
| return x + y |
| |
| def f(x, y): |
| return Model()(x, y) |
| |
| x = torch.randn(100, 100).cuda() |
| y = torch.randn(100, 100).cuda() |
| |
| cache = {} |
| num_get = 0 |
| num_put = 0 |
| |
| class MyCache: |
| def __init__(self, key, is_autotune=False): |
| pass |
| |
| def get(self, filename): |
| nonlocal cache |
| nonlocal num_get |
| if filename not in cache: |
| return None |
| ret = json.loads(cache[filename]) |
| num_get += 1 |
| return ret |
| |
| def put(self, filename, data): |
| nonlocal cache |
| nonlocal num_put |
| cache[filename] = json.dumps(data) |
| num_put += 1 |
| |
| cache_module = ( |
| "triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend" |
| if config.is_fbcode() |
| else "torch._inductor.remote_cache.RedisRemoteCacheBackend" |
| ) |
| |
| with config.patch( |
| { |
| "autotune_local_cache": False, |
| "autotune_remote_cache": True, |
| } |
| ), patch.dict(os.environ), patch(cache_module, MyCache, create=True): |
| os.environ.pop("TRITON_CACHE_MANAGER", None) |
| with config.patch({"max_autotune": True}): |
| for _ in range(4): |
| with fresh_inductor_cache(): |
| torch.compile(mm, dynamic=dynamic)(a, b) |
| reset() |
| self.assertEqual(num_get, 3) |
| self.assertEqual(num_put, 1) |
| num_get = 0 |
| num_put = 0 |
| for _ in range(4): |
| with fresh_inductor_cache(): |
| torch.compile(f, dynamic=dynamic)(x, y) |
| reset() |
| self.assertEqual(num_get, 3) |
| self.assertEqual(num_put, 1) |
| |
| @skipIfRocm |
| def test_precompilation_threads(self): |
| import threading |
| from typing import Any, Dict |
| from unittest.mock import Mock, patch |
| |
| class FakeChoiceCaller(ChoiceCaller): |
| def __init__(self): |
| super().__init__("none", [], Mock()) |
| self.thread_id = None |
| |
| def precompile(self): |
| self.thread_id = threading.get_ident() |
| |
| def call_name(self) -> str: |
| return None |
| |
| def to_callable(self): |
| return None |
| |
| def hash_key(self) -> str: |
| return None |
| |
| def output_node(self) -> "TensorBox": # noqa: F821 |
| return None |
| |
| fake_choices = [FakeChoiceCaller() for i in range(10)] |
| fake_lookup_result = {choice: 0.123 for choice in fake_choices} |
| |
| def no_lookup( |
| choices: List[ChoiceCaller], |
| op: str, |
| inputs: str, |
| benchmark: Callable[[Any], Dict[ChoiceCaller, float]], |
| ) -> Dict[ChoiceCaller, float]: |
| if benchmark is not None: |
| return benchmark(choices) |
| |
| asc = AlgorithmSelectorCache() |
| |
| def fake_benchmark_fn(*args, **kwargs): |
| return fake_lookup_result |
| |
| main_thread_id = threading.get_ident() |
| mock_debug_handler = Mock() |
| old_debug_handler = V.debug |
| try: |
| V.set_debug_handler(mock_debug_handler) |
| with patch.object(asc, "lookup", new=no_lookup): |
| with patch.object( |
| asc, "make_benchmark_fn", return_value=fake_benchmark_fn |
| ): |
| with config.patch( |
| { |
| "autotune_in_subproc": False, |
| "compile_threads": len(fake_choices), |
| } |
| ): |
| asc("test_call", fake_choices, [], Mock()) |
| for fake_choice in fake_choices: |
| assert ( |
| fake_choice.thread_id is not None |
| ), "Expected all ChoiceCaller's precompile method to have been called" |
| assert ( |
| fake_choice.thread_id != main_thread_id |
| ), "Expected all ChoiceCaller's precompile method to have been called on separate thread" |
| finally: |
| V.set_debug_handler(old_debug_handler) |
| |
| @parametrize("dynamic", (False, True)) |
| def test_max_autotune_addmm(self, dynamic=False): |
| """ |
| Make sure autotuning addmm in sub processes work without crashes. |
| """ |
| |
| torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False |
| |
| def addmm(x, a, b): |
| return torch.addmm(x, a, b) |
| |
| x = torch.randn(100).cuda() |
| a = torch.randn(100, 10).cuda() |
| b = torch.randn(10, 100).cuda() |
| with config.patch({"max_autotune": True, "autotune_in_subproc": True}): |
| Y_compiled = torch.compile(addmm, dynamic=dynamic)(x, a, b) |
| Y = addmm(x, a, b) |
| torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2) |
| |
| @parametrize("dynamic", (False, True)) |
| def test_max_autotune_addmm_zero_size_input(self, dynamic): |
| """ |
| Make sure autotuning addmm with zero-size input works without crashes. |
| """ |
| |
| def addmm(x, a, b): |
| return torch.addmm(x, a, b) |
| |
| x = torch.randn(100).cuda() |
| a = torch.randn(0, 10).cuda() |
| b = torch.randn(10, 100).cuda() |
| with config.patch({"max_autotune": True}): |
| torch.compile(addmm, dynamic=dynamic)(x, a, b) |
| |
| @skipIfRocm |
| def test_autotune_conv1x1(self): |
| # Assuming input has 3 channels and we want to produce 16 channels as output |
| conv1x1 = ( |
| torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1) |
| .to(memory_format=torch.channels_last) |
| .cuda() |
| ) |
| |
| # Example input tensor: batch size = 4, channels = 3, height = 32, width = 32 |
| # The memory format is set to `channels_last` |
| input_tensor = ( |
| torch.randn(4, 3, 32, 32) |
| .contiguous(memory_format=torch.channels_last) |
| .cuda() |
| ) |
| |
| with config.patch( |
| {"max_autotune": True, "max_autotune_gemm_backends": "TRITON"} |
| ): |
| |
| @torch.compile() |
| def foo(mod, x): |
| return mod(x) |
| |
| with torch.no_grad(): |
| out, code = run_and_get_code(foo, conv1x1, input_tensor) |
| |
| FileCheck().check_not("extern_kernels.convolution").run(code[0]) |
| self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0) |
| |
| @skipIfRocm |
| def test_filled_cache_precompile(self): |
| def fn(a, b, c): |
| a = (a @ b) @ c |
| a, b, c = (t.to(torch.float16) for t in [a, b, c]) |
| return (a @ b) @ c |
| |
| fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) |
| inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)] |
| from torch._dynamo.utils import counters |
| |
| self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) |
| |
| torch._dynamo.reset() |
| counters.clear() |
| |
| fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) |
| self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) |
| |
| @skipIfRocm |
| @fresh_inductor_cache() |
| @config.patch(search_autotune_cache=True) |
| def test_search_autotune_cache(self): |
| def fn(a, b, c): |
| a = (a @ b) @ c |
| a, b, c = (t.to(torch.float16) for t in [a, b, c]) |
| return (a @ b) @ c |
| |
| fn_c = torch.compile()(fn) |
| inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)] |
| from torch._dynamo.utils import counters |
| |
| self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) |
| self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) |
| |
| @skipIfRocm |
| @fresh_inductor_cache() |
| @config.patch(max_autotune=True, max_fusion_size=2) |
| def test_jit_fusion_matches_aot_fusion(self): |
| # In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due |
| # to proximity, we want to make sure AOT-compile pass does the same. |
| # AOT could do fuse(buf2, buf4) instead if buf3 was pushed to the end |
| # of the V.graph.buffers list because fuse(buf2, buf4) would have a |
| # better proximity score than fuse(buf1, buf2). This scenario is possible |
| # since finalizing MultiTemplateBuffers needs to replace buffers. |
| def fn(x, number): |
| buf0 = x + x |
| buf1 = number.item() |
| buf2 = x * x |
| buf3 = x @ x # MultiTemplateBuffer |
| buf4 = x**2 |
| return buf0, buf1, buf2, buf3, buf4 |
| |
| inputs = (torch.rand([256, 256], device="cuda"), torch.tensor(3, device="cuda")) |
| torch._export.aot_compile(fn, args=inputs) |
| |
| @config.patch(autotune_local_cache=False, autotune_remote_cache=False) |
| @skipIfRocm |
| def test_precompilations(self): |
| def fn(a, b, c): |
| a = (a @ b) @ c |
| a, b, c = (t.to(torch.float16) for t in [a, b, c]) |
| return (a @ b) @ c |
| |
| fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) |
| inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)] |
| |
| self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) |
| |
| from torch._dynamo.utils import counters |
| |
| self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 2) |
| |
| def test_cat_addmm(self): |
| def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): |
| return torch.cat( |
| [ |
| torch.addmm(a, b, c), |
| torch.addmm(b, c, a), |
| ], |
| 1, |
| ) |
| |
| args = [ |
| torch.randn(4, 4, device="cuda"), |
| torch.randn(4, 4, device="cuda"), |
| torch.randn(4, 4, device="cuda"), |
| ] |
| with config.patch( |
| { |
| "max_autotune": True, |
| "max_autotune_gemm_backends": "Triton", |
| } |
| ): |
| expected = fn(*args) |
| actual = torch.compile(fn)(*args) |
| torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2) |
| |
| def test_triton_template_with_epilogues_and_dynamic_shape(self): |
| def fn( |
| x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor, mul: torch.Tensor |
| ) -> torch.Tensor: |
| return ( |
| torch.nn.functional.relu( |
| torch.matmul(torch.transpose(x, 0, 1), torch.transpose(w, 0, 1)) |
| + bias |
| ) |
| * mul |
| ) |
| |
| M0 = 5 |
| M1 = 8 |
| K = 4 |
| N = 3 |
| w = torch.rand(N, K).cuda().half() |
| b = torch.rand(N).cuda().half() |
| |
| with config.patch( |
| { |
| "max_autotune": True, |
| "autotune_in_subproc": True, |
| "max_autotune_gemm_backends": "Triton", |
| } |
| ): |
| compiled_fn = torch.compile( |
| fn, fullgraph=True, dynamic=True, mode="max-autotune-no-cudagraphs" |
| ) |
| |
| x0 = torch.rand(K, M0).cuda().half() |
| mul0 = torch.rand(M0, N).cuda().half() |
| y0 = compiled_fn(x0, w, b, mul0) |
| y0_expected = fn(x0, w, b, mul0) |
| torch.testing.assert_close(y0, y0_expected) |
| |
| x1 = torch.rand(K, M1).cuda().half() |
| mul1 = torch.rand(M1, N).cuda().half() |
| y1 = compiled_fn(x1, w, b, mul1) |
| y1_expected = fn(x1, w, b, mul1) |
| torch.testing.assert_close(y1, y1_expected) |
| |
| @config.patch( |
| benchmark_kernel=True, |
| fallback_random=True, |
| max_autotune_gemm=True, |
| ) |
| @parametrize("device", ("cpu", "cuda")) |
| def test_matmul_dropout(self, device): |
| def fwd(a, b): |
| x = a @ b |
| x = torch.nn.functional.dropout(x, 0.1) |
| return x |
| |
| def fn(a, b): |
| x = fwd(a, b).sum() |
| x.backward() |
| return a.grad |
| |
| N = 128 |
| a = torch.randn(N, N, device=device, requires_grad=True) |
| b = torch.randn(N, N, device=device) |
| |
| opt_fn = torch.compile(fn) |
| reset_rng_state() |
| ref = fn(a, b) |
| reset_rng_state() |
| act = opt_fn(a, b) |
| |
| if N <= 8: |
| print(f"ref\n{ref}\nact\n{act}") |
| torch.testing.assert_close(ref, act, atol=1e-1, rtol=1e-1) |
| |
| @config.patch( |
| max_autotune_gemm=True, |
| ) |
| @unittest.skipIf( |
| torch.cuda.device_count() < 2, "Need at least 2 devices for this test" |
| ) |
| def test_autotune_device_guard(self): |
| x = torch.randn(1024, 1024, device="cuda:1") |
| y = torch.randn(1024, 1024, device="cuda:1") |
| |
| def f(x, y): |
| return x @ y |
| |
| with fresh_inductor_cache(): |
| act = torch.compile(f)(x, y) |
| ref = f(x, y) |
| self.assertTrue(torch.allclose(act, ref, atol=4 * 1e-3, rtol=4 * 1e-3)) |
| |
| @config.patch(max_autotune=True) |
| def test_empty_conv_input(self, kernel_size=3): |
| x = torch.randn(0, 256, 14, 14, device="cuda") |
| weight = torch.randn(256, 256, kernel_size, kernel_size, device="cuda") |
| |
| def f(x, weight): |
| return torch.convolution( |
| x, |
| weight, |
| bias=None, |
| stride=[1, 1], |
| padding=[0, 0], |
| dilation=[1, 1], |
| transposed=False, |
| output_padding=[0, 0], |
| groups=1, |
| ) |
| |
| opt_f = torch.compile(f) |
| ref = f(x, weight) |
| act = opt_f(x, weight) |
| self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3)) |
| |
| @config.patch(max_autotune=True) |
| def test_empty_conv_input_with_1x1_kernel(self): |
| self.test_empty_conv_input(kernel_size=1) |
| |
| @config.patch(max_autotune=True) |
| def test_conv1x1_with_free_symbols(self): |
| """ |
| Make sure there is no exception due to free symbols. |
| """ |
| conv = nn.Conv2d( |
| 3, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False |
| ).to(device="cuda") |
| |
| @torch.compile |
| def f(x, y, z): |
| h = y.nonzero().size(0) |
| w = z.nonzero().size(0) |
| x = x[:, :, :h, :w] |
| x = conv(x) |
| return x |
| |
| x = torch.randn(4, 3, 224, 224).to( |
| memory_format=torch.channels_last, device="cuda" |
| ) |
| for _ in range(2): |
| y = torch.randint(0, 10, (224,)).to(device="cuda") |
| z = torch.randint(0, 10, (224,)).to(device="cuda") |
| f(x, y, z) |
| |
| def test_non_contiguous_input_mm(self): |
| """ |
| Make sure the triton template can work with non-contiguous inputs without crash. |
| Check https://github.com/pytorch/pytorch/issues/125437 for more details. |
| """ |
| x = rand_strided( |
| (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" |
| ) |
| y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda") |
| |
| @torch.compile(mode="max-autotune") |
| def f(x, y): |
| return x @ y |
| |
| ref = x @ y |
| act = f(x, y) |
| self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) |
| |
| def test_non_contiguous_input_addmm(self): |
| b = torch.randn((768), dtype=torch.bfloat16, device="cuda") |
| x = rand_strided( |
| (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" |
| ) |
| y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda") |
| |
| @torch.compile(mode="max-autotune") |
| def f(x, y): |
| return torch.addmm(b, x, y) |
| |
| ref = torch.addmm(b, x, y) |
| act = f(x, y) |
| self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) |
| |
| def test_non_contiguous_input_bmm(self): |
| x = rand_strided( |
| (1, 50257, 32768), (0, 1, 50304), dtype=torch.bfloat16, device="cuda" |
| ) |
| y = rand_strided( |
| (1, 32768, 768), (0, 768, 1), dtype=torch.bfloat16, device="cuda" |
| ) |
| |
| @torch.compile(mode="max-autotune") |
| def f(x, y): |
| return torch.bmm(x, y) |
| |
| ref = torch.bmm(x, y) |
| act = f(x, y) |
| self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) |
| |
| def test_non_contiguous_input_mm_plus_mm(self): |
| x1 = rand_strided((50257, 32768), (1, 50304), device="cuda") |
| y1 = rand_strided((32768, 768), (768, 1), device="cuda") |
| |
| x2 = rand_strided((50257, 32768), (1, 50304), device="cuda") |
| y2 = rand_strided((32768, 768), (768, 1), device="cuda") |
| |
| @torch.compile(mode="max-autotune") |
| def f(x1, y1, x2, y2): |
| return x1 @ y1 + x2 @ y2 |
| |
| ref = x1 @ y1 + x2 @ y2 |
| act = f(x1, y1, x2, y2) |
| self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) |
| |
| @config.patch( |
| max_autotune=True, |
| max_autotune_gemm_backends="", |
| autotune_fallback_to_aten=False, |
| ) |
| def test_no_valid_choices(self): |
| a = torch.zeros([2, 2], device="cuda") |
| b = torch.zeros([2, 2], device="cuda") |
| with self.assertRaises(BackendCompilerFailed) as context: |
| torch.compile(lambda a, b: a.matmul(b))(a, b) |
| self.assertIn("NoValidChoicesError", str(context.exception)) |
| |
| @parametrize("multi_template", (True, False)) |
| @config.patch( |
| max_autotune=True, |
| max_autotune_gemm_backends="TRITON", |
| autotune_fallback_to_aten=False, |
| ) |
| def test_inf_timing(self, multi_template): |
| from unittest.mock import patch |
| |
| lookup = AlgorithmSelectorCache.lookup |
| |
| def mock_lookup(self, *args, **kwargs): |
| timings = lookup(self, *args, **kwargs) |
| return {choice: float("inf") for choice in timings.keys()} |
| |
| a = torch.zeros([16, 16], device="cuda") |
| b = torch.zeros([16, 16], device="cuda") |
| with patch.object(AlgorithmSelectorCache, "lookup", mock_lookup), config.patch( |
| benchmark_epilogue_fusion=multi_template |
| ): |
| with self.assertRaises(BackendCompilerFailed) as context: |
| torch.compile(lambda a, b: a.matmul(b))(a, b) |
| self.assertIn("NoValidChoicesError", str(context.exception)) |
| |
| |
| class TestBenchmarkRequest(BenchmarkRequest): |
| def __init__( |
| self, value: float, multi_device: bool, parent_visible_devices: Optional[str] |
| ) -> None: |
| self.value = value |
| self.multi_device = multi_device |
| self.parent_visible_devices = parent_visible_devices |
| |
| def benchmark( |
| self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None |
| ) -> float: |
| # Verify that the visible devices env var is set correctly. If multi-device |
| # auto-tuning is disabled, the visible devices should be unmanipulated from |
| # the parent process. If multi-device auto-tuning is enabled, the visible |
| # devices should be a _single_ valid device number. Note that we can't perform |
| # this validation directly from the test body because benchmarks execute in a |
| # separate process. If the check fails, however, the test will detect the |
| # failure by virtue of not receiving the expected result back. |
| visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES) |
| if not self.multi_device: |
| assert visible_devices == self.parent_visible_devices |
| else: |
| valid_devices = self.parent_visible_devices.split(",") |
| assert visible_devices in valid_devices |
| |
| return self.value |
| |
| |
| class TestTritonTemplateCaller(TritonTemplateCaller): |
| def __init__(self, bmreq: TestBenchmarkRequest): |
| self.bmreq = bmreq |
| |
| def __str__(self) -> str: |
| return "test" |
| |
| |
| class TestTuningProcess(TestCase): |
| def test_tuning_pool_crash(self): |
| # Use only one device/subprocess so we test the process restarts |
| # and is usable after a "crash". |
| with config.patch({"autotune_multi_device": False}): |
| tuning_pool = TuningProcessPool() |
| tuning_pool.initialize() |
| |
| # First force the tuning process to "crash" by setting a bogus |
| # string for the expected visible devices. |
| bmreq = TestBenchmarkRequest(3.14, False, "invalid") |
| choice = TestTritonTemplateCaller(bmreq) |
| |
| timings = tuning_pool.benchmark([choice]) |
| self.assertTrue(choice in timings) |
| self.assertEqual(timings[choice], float("inf")) |
| |
| # Then send another request and make sure the sub-process |
| # has restarted and is operational. 'valid_devices' expected |
| # to be None because autotune_multi_device is off. |
| choice.bmreq.parent_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES) |
| |
| timings = tuning_pool.benchmark([choice]) |
| self.assertTrue(choice in timings) |
| self.assertEqual(timings[choice], bmreq.value) |
| |
| tuning_pool.terminate() |
| |
| def test_tuning_pool_multiple_devices(self): |
| with config.patch({"autotune_multi_device": True}): |
| # Adapt the test to the available devices (and whether CUDA_VISIBLE_DEVICES |
| # is already set in the environment); use a subset of the available devices |
| # to ensure only the subset are visible to the sub-processes. |
| if CUDA_VISIBLE_DEVICES in os.environ: |
| visible_devices = os.environ[CUDA_VISIBLE_DEVICES].split(",") |
| else: |
| visible_devices = [str(d) for d in range(torch.cuda.device_count())] |
| |
| parent_visible_devices = ",".join(visible_devices[-2:]) |
| os.environ[CUDA_VISIBLE_DEVICES] = parent_visible_devices |
| |
| tuning_pool = TuningProcessPool() |
| tuning_pool.initialize() |
| |
| choice1 = TestTritonTemplateCaller( |
| TestBenchmarkRequest(3.14, True, parent_visible_devices), |
| ) |
| choice2 = TestTritonTemplateCaller( |
| TestBenchmarkRequest(2.718, True, parent_visible_devices), |
| ) |
| |
| timings = tuning_pool.benchmark([choice1, choice2]) |
| self.assertEqual(timings[choice1], choice1.bmreq.value) |
| self.assertEqual(timings[choice2], choice2.bmreq.value) |
| |
| tuning_pool.terminate() |
| |
| |
| if __name__ == "__main__": |
| from torch._inductor.utils import is_big_gpu |
| |
| # Set env to make it work in CI. |
| if HAS_CUDA and HAS_CPU and is_big_gpu(0): |
| run_tests() |