[TensorExpr] Add a way to define target triple/cpu/attrs for llvm codegen and turn on the AOT workflow. (#66527)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66527

Differential Revision:
D31593869
D31593869

Test Plan: Imported from OSS

Reviewed By: navahgar

Pulled By: ZolotukhinM

fbshipit-source-id: e7534c11fbcf0dab5f49d01d6053caf77b833ef0
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index b66f63b..8b8d6ad 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -1304,7 +1304,7 @@
   }
 
   BackendType backendType = inferBackendTypeFromDevice(device_);
-  StmtPtr stmt = transformLoops(backendType, block);
+  stmt_ = transformLoops(backendType, block);
 
   for (auto c : constants_) {
     bufferArgs_.emplace_back(BufHandle(c.buf));
@@ -1318,12 +1318,17 @@
   // Generate code.
   codegen_ = CreateCodeGen(
       getCodeGenName(backendType),
-      stmt,
+      stmt_,
       bufferArgs_,
       device_,
       kernel_func_name_);
 }
 
+void TensorExprKernel::recompile() {
+  codegen_ = CreateCodeGen(
+      "llvm_codegen", stmt_, bufferArgs_, device_, kernel_func_name_);
+}
+
 TensorExprKernel::TensorExprKernel(
     const std::shared_ptr<Graph>& subgraph,
     const std::string& kernel_func_name,
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index 5355894..e607136 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -137,6 +137,7 @@
   void fallback(Stack& stack) {
     InterpreterState(code_).run(stack);
   }
+  void recompile();
 
   StmtPtr getCodeGenStmt();
 
@@ -285,6 +286,7 @@
   std::vector<ConstantDescr> constants_;
 
   std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings_;
+  StmtPtr stmt_ = nullptr;
   bool pre_alloc_{false};
   const std::string& kernel_func_name_;
 };
diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
index cc27a47..8f0b33b 100644
--- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
@@ -55,6 +55,24 @@
 namespace torch {
 namespace jit {
 namespace tensorexpr {
+
+c10::optional<std::string>& LLVMTargetTriple() {
+  static c10::optional<std::string> triple = c10::nullopt;
+  return triple;
+}
+c10::optional<std::string>& LLVMTargetCPU() {
+  static c10::optional<std::string> cpu = c10::nullopt;
+  return cpu;
+}
+c10::optional<std::string>& LLVMTargetAttrs() {
+  static c10::optional<std::string> attrs = c10::nullopt;
+  return attrs;
+}
+bool& LLVMAOTWorkflow() {
+  static bool aot_workflow = false;
+  return aot_workflow;
+}
+
 namespace {
 
 llvm::CmpInst::Predicate llvm_comparison_predicate(
@@ -422,6 +440,15 @@
     c10::optional<std::string> cpu,
     c10::optional<std::string> attrs)
     : context_(std::make_unique<llvm::LLVMContext>()), irb_(getContext()) {
+  if (!triple) {
+    triple = LLVMTargetTriple();
+  }
+  if (!cpu) {
+    cpu = LLVMTargetCPU();
+  }
+  if (!attrs) {
+    attrs = LLVMTargetAttrs();
+  }
   // Manually map types to LLVM types.
   ByteTy_ = llvm::Type::getInt8Ty(getContext());
   CharTy_ = llvm::Type::getInt8Ty(getContext());
@@ -478,8 +505,10 @@
   emitKernel(stmt, params);
 
   jit_->addModule(std::move(module_), std::move(context_));
-  auto sym = jit_->findSymbol("wrapper");
-  kernelAddress_ = assertSuccess(sym.getAddress());
+  if (!LLVMAOTWorkflow()) {
+    auto sym = jit_->findSymbol("wrapper");
+    kernelAddress_ = assertSuccess(sym.getAddress());
+  }
 }
 
 llvm::LLVMContext& LLVMCodeGenImpl::getContext() {
diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.h b/torch/csrc/jit/tensorexpr/llvm_codegen.h
index e592b05..b155ff8 100644
--- a/torch/csrc/jit/tensorexpr/llvm_codegen.h
+++ b/torch/csrc/jit/tensorexpr/llvm_codegen.h
@@ -131,6 +131,11 @@
   c10::optional<std::string> attrs_ = c10::nullopt;
 };
 
+TORCH_API c10::optional<std::string>& LLVMTargetTriple();
+TORCH_API c10::optional<std::string>& LLVMTargetCPU();
+TORCH_API c10::optional<std::string>& LLVMTargetAttrs();
+TORCH_API bool& LLVMAOTWorkflow();
+
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
index 62e6f69..f92a4a8 100644
--- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
+++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
@@ -809,7 +809,8 @@
           [](TensorExprKernel& self, const std::string& attr = "") {
             return self.getCodeText(attr);
           },
-          py::arg("attr") = "");
+          py::arg("attr") = "")
+      .def("recompile", [](TensorExprKernel& self) { self.recompile(); });
 
   py::class_<CodeGen>(te, "CodeGen")
       .def(
@@ -886,6 +887,20 @@
       });
   te.def("annotate_input_shapes", &tensorexpr::annotateInputShapes);
   te.def("remove_unused_self_argument", &tensorexpr::removeUnusedSelfArgument);
+#ifdef TORCH_ENABLE_LLVM
+  te.def("set_llvm_target_triple", [](const c10::optional<std::string>& val) {
+    tensorexpr::LLVMTargetTriple() = val;
+  });
+  te.def("set_llvm_target_cpu", [](const c10::optional<std::string>& val) {
+    tensorexpr::LLVMTargetCPU() = val;
+  });
+  te.def("set_llvm_target_attrs", [](const c10::optional<std::string>& val) {
+    tensorexpr::LLVMTargetAttrs() = val;
+  });
+  te.def("set_llvm_aot_workflow", [](bool val) {
+    tensorexpr::LLVMAOTWorkflow() = val;
+  });
+#endif
 }
 } // namespace jit
 } // namespace torch