Refactor AotCompile to return a pair (#65707)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65707
Refactoring aotCompile to return a pair of compiled function and the LLVM assembly instead of updating an incoming string with assembly code
Testing: Gives expected results when compiled and run
```
(pytorch) ~/local/pytorch refactor_aot
└─ $ build/bin/aot_model_compiler --model mobilenetv3.pt --model_name=pytorch_dev_mobilenetv3 --model_version=v1 --input_dims="2,2,2"
The compiled model was saved to mobilenetv3.compiled.pt
```
Test Plan: Imported from OSS
Reviewed By: qihqi
Differential Revision: D31220452
Pulled By: priyaramani
fbshipit-source-id: f957c53ba83f876a2e7dbdd4b4571a760b3b6a9a
diff --git a/binaries/aot_model_compiler.cc b/binaries/aot_model_compiler.cc
index d757af1..7ca7f18 100644
--- a/binaries/aot_model_compiler.cc
+++ b/binaries/aot_model_compiler.cc
@@ -112,10 +112,10 @@
auto sizes = getInputSizesForMethod(method_compile_spec, method_name);
std::string llvm_asm_code;
- auto func =
- torch::jit::mobile::nnc::aotCompile(method_name, graph, sizes, &llvm_asm_code);
- writeOutputLlvmAssembly(llvm_asm_code);
+ auto compiled = torch::jit::mobile::nnc::aotCompile(method_name, graph, sizes);
+ writeOutputLlvmAssembly(compiled.second);
+ auto func = std::move(compiled.first);
func->set_nnc_kernel_id(getNncKernelId(method_name));
torch::jit::mobile::nnc::CompilationUnit cu;
diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp
index 0790fdf..de7db04 100644
--- a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp
+++ b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp
@@ -33,7 +33,7 @@
return r;
}
-void getCompiledFunction(
+void compileFunction(
std::shared_ptr<tensorexpr::TensorExprKernel> kernel,
Function* func) {
std::vector<at::Tensor> parameters;
@@ -66,11 +66,10 @@
func->set_output_specs(out_spec);
}
-std::unique_ptr<Function> aotCompile(
+std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
const std::string& method_name,
std::shared_ptr<Graph>& g,
- const std::vector<int64_t>& sizes,
- std::string* compiled_assembly) {
+ const std::vector<int64_t>& sizes) {
auto g2 = g->copy();
GRAPH_DEBUG("Input sizes ", sizes);
@@ -90,7 +89,7 @@
std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
std::make_shared<tensorexpr::TensorExprKernel>(g);
- *compiled_assembly = kernel->getCodeText();
+ const std::string compiled_assembly = kernel->getCodeText();
g = g2;
@@ -102,8 +101,8 @@
input.dtype_ = c10::ScalarType::Float;
func->set_input_specs({input});
- getCompiledFunction(kernel, func.get());
- return func;
+ compileFunction(kernel, func.get());
+ return std::make_pair(std::move(func), compiled_assembly);
}
} // namespace nnc
diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.h b/torch/csrc/jit/mobile/nnc/aot_compiler.h
index 71f6d92..966337e 100644
--- a/torch/csrc/jit/mobile/nnc/aot_compiler.h
+++ b/torch/csrc/jit/mobile/nnc/aot_compiler.h
@@ -11,11 +11,10 @@
// Performs Ahead Of Time compilation of a given method in a model
// returning the compiled function and LLVM assembly code
-TORCH_API std::unique_ptr<Function> aotCompile(
+TORCH_API std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
const std::string& method_name,
std::shared_ptr<Graph>& subgraph,
- const std::vector<int64_t>& sizes,
- std::string* compiled_assembly);
+ const std::vector<int64_t>& sizes);
} // namespace nnc
} // namespace mobile