[aot_inductor] replace TORCH_CHECK with AOTI_CHECK in the generate cpp code (#119220)
In some cases where we have TORCH_CHECK in loops, it may cause the host
compiler to spend hours optimizing the run_impl function. This PR
mitigated the issue by replacing TORCH_CHECK with a custom AOTI_CHECK,
where we force the underneath assert function to be noinline.
If forcing noinline caused any serious perf regression, we could
either add an option to turn on/off enable noinline. Or, we could
another an option to just turn AOTI_CHECK into a no-op, similar
to the ```assert``` macro from cassert.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119220
Approved by: https://github.com/hl475, https://github.com/desertfire
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 0d14321..111e69c 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -1733,7 +1733,10 @@
@property
def assert_function(self) -> str:
- return "TORCH_CHECK"
+ if V.graph.aot_mode:
+ return "AOTI_TORCH_CHECK"
+ else:
+ return "TORCH_CHECK"
def decide_parallel_depth(self, ranges, threads):
seq = self.size_hint()
diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h
index 071bc4c..fb2f56a 100644
--- a/torch/csrc/inductor/aoti_torch/c/shim.h
+++ b/torch/csrc/inductor/aoti_torch/c/shim.h
@@ -487,6 +487,31 @@
int num_tensors,
AtenTensorHandle* flatten_tensor_args);
+AOTI_TORCH_EXPORT void aoti_torch_check(
+ bool cond,
+ const char* func,
+ const char* file,
+ uint32_t line,
+ const char* msg);
+
+#ifdef STRIP_ERROR_MESSAGES
+#define AOTI_TORCH_CHECK(cond, ...) \
+ aoti_torch_check( \
+ cond, \
+ __func__, \
+ __FILE__, \
+ static_cast<uint32_t>(__LINE__), \
+ TORCH_CHECK_MSG(cond, "", __VA_ARGS__));
+#else
+#define AOTI_TORCH_CHECK(cond, ...) \
+ aoti_torch_check( \
+ cond, \
+ __func__, \
+ __FILE__, \
+ static_cast<uint32_t>(__LINE__), \
+ TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__));
+#endif
+
#ifdef __cplusplus
} // extern "C"
diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp
index 5c0f5b6..eaf32ee 100644
--- a/torch/csrc/inductor/aoti_torch/shim_common.cpp
+++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp
@@ -808,6 +808,17 @@
});
}
+void aoti_torch_check(
+ bool cond,
+ const char* func,
+ const char* file,
+ uint32_t line,
+ const char* msg) {
+ if (C10_UNLIKELY_OR_CONST(!cond)) {
+ ::c10::detail::torchCheckFail(func, file, line, msg);
+ }
+}
+
AOTITorchError aoti_torch__alloc_from_pool(
AtenTensorHandle self,
int64_t offset_bytes,