[TensorExpr] Add a way to define target triple/cpu/attrs for llvm codegen and turn on the AOT workflow. (#66527)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66527
Differential Revision:
D31593869
D31593869
Test Plan: Imported from OSS
Reviewed By: navahgar
Pulled By: ZolotukhinM
fbshipit-source-id: e7534c11fbcf0dab5f49d01d6053caf77b833ef0
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index b66f63b..8b8d6ad 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -1304,7 +1304,7 @@
}
BackendType backendType = inferBackendTypeFromDevice(device_);
- StmtPtr stmt = transformLoops(backendType, block);
+ stmt_ = transformLoops(backendType, block);
for (auto c : constants_) {
bufferArgs_.emplace_back(BufHandle(c.buf));
@@ -1318,12 +1318,17 @@
// Generate code.
codegen_ = CreateCodeGen(
getCodeGenName(backendType),
- stmt,
+ stmt_,
bufferArgs_,
device_,
kernel_func_name_);
}
+void TensorExprKernel::recompile() {
+ codegen_ = CreateCodeGen(
+ "llvm_codegen", stmt_, bufferArgs_, device_, kernel_func_name_);
+}
+
TensorExprKernel::TensorExprKernel(
const std::shared_ptr<Graph>& subgraph,
const std::string& kernel_func_name,
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index 5355894..e607136 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -137,6 +137,7 @@
void fallback(Stack& stack) {
InterpreterState(code_).run(stack);
}
+ void recompile();
StmtPtr getCodeGenStmt();
@@ -285,6 +286,7 @@
std::vector<ConstantDescr> constants_;
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings_;
+ StmtPtr stmt_ = nullptr;
bool pre_alloc_{false};
const std::string& kernel_func_name_;
};
diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
index cc27a47..8f0b33b 100644
--- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
@@ -55,6 +55,24 @@
namespace torch {
namespace jit {
namespace tensorexpr {
+
+c10::optional<std::string>& LLVMTargetTriple() {
+ static c10::optional<std::string> triple = c10::nullopt;
+ return triple;
+}
+c10::optional<std::string>& LLVMTargetCPU() {
+ static c10::optional<std::string> cpu = c10::nullopt;
+ return cpu;
+}
+c10::optional<std::string>& LLVMTargetAttrs() {
+ static c10::optional<std::string> attrs = c10::nullopt;
+ return attrs;
+}
+bool& LLVMAOTWorkflow() {
+ static bool aot_workflow = false;
+ return aot_workflow;
+}
+
namespace {
llvm::CmpInst::Predicate llvm_comparison_predicate(
@@ -422,6 +440,15 @@
c10::optional<std::string> cpu,
c10::optional<std::string> attrs)
: context_(std::make_unique<llvm::LLVMContext>()), irb_(getContext()) {
+ if (!triple) {
+ triple = LLVMTargetTriple();
+ }
+ if (!cpu) {
+ cpu = LLVMTargetCPU();
+ }
+ if (!attrs) {
+ attrs = LLVMTargetAttrs();
+ }
// Manually map types to LLVM types.
ByteTy_ = llvm::Type::getInt8Ty(getContext());
CharTy_ = llvm::Type::getInt8Ty(getContext());
@@ -478,8 +505,10 @@
emitKernel(stmt, params);
jit_->addModule(std::move(module_), std::move(context_));
- auto sym = jit_->findSymbol("wrapper");
- kernelAddress_ = assertSuccess(sym.getAddress());
+ if (!LLVMAOTWorkflow()) {
+ auto sym = jit_->findSymbol("wrapper");
+ kernelAddress_ = assertSuccess(sym.getAddress());
+ }
}
llvm::LLVMContext& LLVMCodeGenImpl::getContext() {
diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.h b/torch/csrc/jit/tensorexpr/llvm_codegen.h
index e592b05..b155ff8 100644
--- a/torch/csrc/jit/tensorexpr/llvm_codegen.h
+++ b/torch/csrc/jit/tensorexpr/llvm_codegen.h
@@ -131,6 +131,11 @@
c10::optional<std::string> attrs_ = c10::nullopt;
};
+TORCH_API c10::optional<std::string>& LLVMTargetTriple();
+TORCH_API c10::optional<std::string>& LLVMTargetCPU();
+TORCH_API c10::optional<std::string>& LLVMTargetAttrs();
+TORCH_API bool& LLVMAOTWorkflow();
+
} // namespace tensorexpr
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
index 62e6f69..f92a4a8 100644
--- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
+++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
@@ -809,7 +809,8 @@
[](TensorExprKernel& self, const std::string& attr = "") {
return self.getCodeText(attr);
},
- py::arg("attr") = "");
+ py::arg("attr") = "")
+ .def("recompile", [](TensorExprKernel& self) { self.recompile(); });
py::class_<CodeGen>(te, "CodeGen")
.def(
@@ -886,6 +887,20 @@
});
te.def("annotate_input_shapes", &tensorexpr::annotateInputShapes);
te.def("remove_unused_self_argument", &tensorexpr::removeUnusedSelfArgument);
+#ifdef TORCH_ENABLE_LLVM
+ te.def("set_llvm_target_triple", [](const c10::optional<std::string>& val) {
+ tensorexpr::LLVMTargetTriple() = val;
+ });
+ te.def("set_llvm_target_cpu", [](const c10::optional<std::string>& val) {
+ tensorexpr::LLVMTargetCPU() = val;
+ });
+ te.def("set_llvm_target_attrs", [](const c10::optional<std::string>& val) {
+ tensorexpr::LLVMTargetAttrs() = val;
+ });
+ te.def("set_llvm_aot_workflow", [](bool val) {
+ tensorexpr::LLVMAOTWorkflow() = val;
+ });
+#endif
}
} // namespace jit
} // namespace torch