[ez][inductor] show kernel category in kernel benchmark result (#96991)
I feel it's useful to show if an kernel is pointwise/reduction/persistent_reduction in the benchmark output. Only print the upper case of the first 3 letters to avoid wrap the line:
- POI for pointwise
- RED for reduction
- PER for persistent_reduction
<img width="1091" alt="Screenshot 2023-03-16 at 5 10 21 PM" src="https://user-images.githubusercontent.com/52589240/225780546-07b8d345-2bbe-40bd-9e65-185e9294743e.png">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96991
Approved by: https://github.com/Chillee
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 63f6d4d..3362469 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -695,6 +695,29 @@
return arg[len("--only=") :]
+def get_kernel_category(kernel_mod):
+ """
+ Given the module defining a triton kernel, return the category of the kernel.
+ Cateogry can be one of:
+ - pointwise
+ - reduction
+ - persistent_reduction
+
+ Currently we simply decide the cateory depending on what decorator is imported
+ by the kernel.
+ """
+ choices = [
+ "pointwise",
+ "reduction",
+ "persistent_reduction",
+ ]
+ choices = [ch for ch in choices if ch in kernel_mod.__dict__]
+ if len(choices) == 1:
+ return choices[0]
+ else:
+ return "unknown"
+
+
def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
"""
An experimental API used only when config.benchmark_kernel is true.
@@ -711,6 +734,8 @@
for kernel_key, kernel_mod in PyCodeCache.cache.items():
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
continue
+
+ kernel_category = get_kernel_category(kernel_mod)
args = kernel_mod.get_args()
num_gb = get_num_bytes(*args) / 1e9
@@ -728,10 +753,13 @@
)
bench_result = []
+ kernel_desc = (
+ f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}"
+ )
if benchmark_all_configs:
assert hasattr(kernel_mod, "benchmark_all_configs")
bench_result = kernel_mod.benchmark_all_configs(args)
- print(f"{benchmark_name:20} {kernel_key[:10]}")
+ print(kernel_desc)
for launcher, ms in bench_result.items():
print(
f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
@@ -748,7 +776,7 @@
launcher.n_regs,
launcher.n_spills,
launcher.shared,
- prefix=f"{benchmark_name:20} {kernel_key[:10]} ",
+ prefix=f"{kernel_desc} ",
)
)