[aot_inductor] replace TORCH_CHECK with AOTI_CHECK in the generate cpp code (#119220)

In some cases where we have TORCH_CHECK in loops, it may cause the host
compiler to spend hours optimizing the run_impl function. This PR
mitigated the issue by replacing TORCH_CHECK with a custom AOTI_CHECK,
where we force the underneath assert function to be noinline.

If forcing noinline caused any serious perf regression, we could
either add an option to turn on/off enable noinline. Or, we could
another an option to just turn AOTI_CHECK into a no-op, similar
to the ```assert``` macro from cassert.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119220
Approved by: https://github.com/hl475, https://github.com/desertfire
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 0d14321..111e69c 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -1733,7 +1733,10 @@
 
     @property
     def assert_function(self) -> str:
-        return "TORCH_CHECK"
+        if V.graph.aot_mode:
+            return "AOTI_TORCH_CHECK"
+        else:
+            return "TORCH_CHECK"
 
     def decide_parallel_depth(self, ranges, threads):
         seq = self.size_hint()
diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h
index 071bc4c..fb2f56a 100644
--- a/torch/csrc/inductor/aoti_torch/c/shim.h
+++ b/torch/csrc/inductor/aoti_torch/c/shim.h
@@ -487,6 +487,31 @@
     int num_tensors,
     AtenTensorHandle* flatten_tensor_args);
 
+AOTI_TORCH_EXPORT void aoti_torch_check(
+    bool cond,
+    const char* func,
+    const char* file,
+    uint32_t line,
+    const char* msg);
+
+#ifdef STRIP_ERROR_MESSAGES
+#define AOTI_TORCH_CHECK(cond, ...)    \
+  aoti_torch_check(                    \
+      cond,                            \
+      __func__,                        \
+      __FILE__,                        \
+      static_cast<uint32_t>(__LINE__), \
+      TORCH_CHECK_MSG(cond, "", __VA_ARGS__));
+#else
+#define AOTI_TORCH_CHECK(cond, ...)    \
+  aoti_torch_check(                    \
+      cond,                            \
+      __func__,                        \
+      __FILE__,                        \
+      static_cast<uint32_t>(__LINE__), \
+      TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__));
+#endif
+
 #ifdef __cplusplus
 } // extern "C"
 
diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp
index 5c0f5b6..eaf32ee 100644
--- a/torch/csrc/inductor/aoti_torch/shim_common.cpp
+++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp
@@ -808,6 +808,17 @@
   });
 }
 
+void aoti_torch_check(
+    bool cond,
+    const char* func,
+    const char* file,
+    uint32_t line,
+    const char* msg) {
+  if (C10_UNLIKELY_OR_CONST(!cond)) {
+    ::c10::detail::torchCheckFail(func, file, line, msg);
+  }
+}
+
 AOTITorchError aoti_torch__alloc_from_pool(
     AtenTensorHandle self,
     int64_t offset_bytes,