Dont precompile if we search_autotune_cache but not max autotune is set (#124870)
Differential Revision: [D56534950](https://our.internmc.facebook.com/intern/diff/D56534950)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124870
Approved by: https://github.com/xw285cornell
diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py
index f74fa4e..c8622de 100644
--- a/test/inductor/test_max_autotune.py
+++ b/test/inductor/test_max_autotune.py
@@ -447,6 +447,22 @@
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
+ @skipIfRocm
+ @fresh_inductor_cache()
+ @config.patch(search_autotune_cache=True)
+ def test_search_autotune_cache(self):
+ def fn(a, b, c):
+ a = (a @ b) @ c
+ a, b, c = (t.to(torch.float16) for t in [a, b, c])
+ return (a @ b) @ c
+
+ fn_c = torch.compile()(fn)
+ inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)]
+ from torch._dynamo.utils import counters
+
+ self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2)
+ self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
+
@config.patch(autotune_local_cache=False, autotune_remote_cache=False)
def test_precompilations(self):
def fn(a, b, c):
diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py
index 3a09238..056bcce 100644
--- a/torch/_inductor/select_algorithm.py
+++ b/torch/_inductor/select_algorithm.py
@@ -990,6 +990,11 @@
if timings:
return no_op
+ if config.search_autotune_cache and not (
+ config.max_autotune or config.max_autotune_gemm
+ ):
+ return no_op
+
precompile_key = (
f"{name}: {inputs_key} : {torch.get_float32_matmul_precision()}"
)