[AOTI][Tooling] Add a test case where `config.debug_intermediate_value_printer=True` to check codegen (#133326)
Summary:
As title.
Add a test case in test_aot_inductor to check for codegen (i.e. `aoti_torch_print_tensor_handle` is inserted as expected for debugging printer) for both cpu and cuda based on a simple `addmm` test model.
Test Plan:
```
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_aoti_debug_printer_codegen_abi_compatible_{cuda/cpu}
```
Differential Revision: D61169068
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133326
Approved by: https://github.com/ColinPeppler
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index e9694fa..ebbf79c 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -20,6 +20,7 @@
from torch._inductor.exc import CppWrapperCodeGenError
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import TestCase
+from torch._inductor.utils import run_and_get_cpp_code
from torch.export import Dim, export
from torch.testing import FileCheck
from torch.testing._internal import common_utils
@@ -3107,6 +3108,90 @@
Model(), example_inputs, options=dict(max_autotune=max_autotune)
)
+ def test_aoti_debug_printer_codegen(self):
+ # basic addmm model to test codegen for aoti intermediate debug printer
+ class Model(torch.nn.Module):
+ def __init__(self, n, k, device):
+ super().__init__()
+ self.weight = torch.randn(n, k, device=device)
+ self.bias = torch.randn(n, device=device)
+
+ def forward(self, a):
+ return torch.nn.functional.linear(a, self.weight, self.bias)
+
+ M = 8
+ N = 6
+ K = 16
+ model = Model(N, K, self.device)
+ batch = 2
+ a = torch.randn(batch, M, K, device=self.device)
+ example_inputs = (a,)
+
+ kernel_calls = (
+ [
+ ("triton_poi_fused_0", 1),
+ ("aoti_torch_cuda_addmm_out", 2),
+ ]
+ if self.device == "cuda"
+ else [
+ ("aoti_torch_cpu_addmm_out", 2),
+ ]
+ )
+
+ # test the default debug printing codegen
+ with config.patch({"aot_inductor.debug_intermediate_value_printer": 1}):
+ result, code = run_and_get_cpp_code(
+ AOTIRunnerUtil.compile, model, example_inputs
+ )
+
+ # check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
+ self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
+
+ # check the codegen for debug printing around the actual kernel call is expected
+
+ for kernel_call, count in kernel_calls:
+ FileCheck().check_count(
+ f"before_launch - {kernel_call}",
+ count,
+ ).run(code)
+ FileCheck().check_count(
+ f"after_launch - {kernel_call}",
+ count,
+ ).run(code)
+
+ # test the filtered kernel names printing codegen
+ filtered_kernel_name = f"aoti_torch_{self.device}_addmm_out"
+ with config.patch(
+ {
+ "aot_inductor.debug_intermediate_value_printer": 1,
+ "aot_inductor.filtered_kernel_names": filtered_kernel_name,
+ }
+ ):
+ result, code = run_and_get_cpp_code(
+ AOTIRunnerUtil.compile, model, example_inputs
+ )
+ filtered_kernel_calls = [
+ (filtered_kernel_name, 2),
+ ]
+ for kernel_call, count in filtered_kernel_calls:
+ FileCheck().check_count(
+ f"before_launch - {kernel_call}",
+ count,
+ ).run(code)
+ FileCheck().check_count(
+ f"after_launch - {kernel_call}",
+ count,
+ ).run(code)
+
+ kernel_calls_not_to_print = [
+ kernel_call
+ for kernel_call in kernel_calls
+ if kernel_call[0] != filtered_kernel_name
+ ]
+ for kernel_name, _ in kernel_calls_not_to_print:
+ FileCheck().check_not(f"before_launch - {kernel_name}").run(code)
+ FileCheck().check_not(f"after_launch - {kernel_name}").run(code)
+
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
@@ -3303,6 +3388,8 @@
"test_custom_op_add": fail_non_abi_compatible_cuda(is_skip=True),
# fp8 to be re-enabled for AOTI
"test_fp8": fail_cuda(is_skip=True),
+ # non-abi compatible mode debug printer is not supported yet
+ "test_aoti_debug_printer_codegen": fail_non_abi_compatible_cuda(is_skip=True),
}
@@ -3347,6 +3434,15 @@
"test_with_offset": fail_minimal_arrayref_interface(is_skip=True),
"test_with_profiler": fail_minimal_arrayref_interface(is_skip=True),
"test_zero_size_weight": fail_minimal_arrayref_interface(is_skip=True),
+ "test_aoti_debug_printer_codegen": fail_with_and_without_stack_allocation(
+ is_skip=True
+ ),
+ }
+ ),
+ # The following test passes internally but fails in OSS CI. To be investigated.
+ CUDA_TEST_FAILURES.update(
+ {
+ "test_aoti_debug_printer_codegen": fail_cuda(is_skip=True),
}
)
@@ -3446,6 +3542,9 @@
("non_abi_compatible_cpu",), is_skip=True
),
"test_custom_op_add": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
+ "test_aoti_debug_printer_codegen": TestFailure(
+ ("non_abi_compatible_cpu",), is_skip=True
+ ),
},
)
diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py
index a2e9aff..a44fa5d 100644
--- a/torch/_inductor/codegen/debug_utils.py
+++ b/torch/_inductor/codegen/debug_utils.py
@@ -2,7 +2,6 @@
from __future__ import annotations
import functools
-import os
from typing import List, Optional
from .. import config
@@ -65,12 +64,7 @@
def get_debug_filtered_kernel_names(self) -> List[str]:
return [
x.strip()
- for x in os.environ.get(
- "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT",
- self.DEBUG_FILTER_DEFAULT_PRINT_ALL,
- )
- .lower()
- .split(",")
+ for x in config.aot_inductor.filtered_kernel_names.lower().split(",")
]
def codegen_intermediate_tensor_value_printer(
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index 7ea956b..8dc12ae 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -922,7 +922,9 @@
)
# filtered nodes to be printed for debug values. If not set, it will dump all debug tensor value info by default
- filtered_kernel_names = os.environ.get("AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", "")
+ filtered_kernel_names = os.environ.get(
+ "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", "default"
+ )
# Serialized tree spec for flattening inputs
serialized_in_spec = ""