Add weight update for DSOModel. (#110273)
Summary: Add weight update for DSOModel and AOTInductorModel
Test Plan: buck2 test accelerators/workloads/models/slimdsnn:slimdsnn_dso_test - SlimDSNN.DSO_Update_Constants
Reviewed By: mikekgfb
Differential Revision: D49748685
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110273
Approved by: https://github.com/hl475
diff --git a/torch/_inductor/codegen/aoti_runtime/interface.cpp b/torch/_inductor/codegen/aoti_runtime/interface.cpp
index 50a2c47..e074a64 100644
--- a/torch/_inductor/codegen/aoti_runtime/interface.cpp
+++ b/torch/_inductor/codegen/aoti_runtime/interface.cpp
@@ -222,4 +222,19 @@
})
}
+AOTIRuntimeError AOTInductorModelUpdateConstants(
+ AOTInductorModelHandle model_handle,
+ AOTInductorConstantMapHandle constant_map_handle) {
+ auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
+ auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
+ auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
+
+ for (auto const& kv : *input_map) {
+ constant_map->emplace(kv.first, kv.second);
+ }
+ model->update_constants_map(std::move(constant_map));
+ })
+}
+
} // extern "C"
diff --git a/torch/csrc/inductor/aoti_runtime/interface.h b/torch/csrc/inductor/aoti_runtime/interface.h
index 4dcc9a3..f9e9f9d 100644
--- a/torch/csrc/inductor/aoti_runtime/interface.h
+++ b/torch/csrc/inductor/aoti_runtime/interface.h
@@ -136,6 +136,12 @@
AtenTensorHandle* input_handles,
AtenTensorHandle* output_handles);
+// Replace AOTInductorModel's constant map. Note it doesn't handle concurrency
+// so be sure to handle ordering if AOTInductorModelRun is ran concurrently.
+AOTIRuntimeError AOTInductorModelUpdateConstants(
+ AOTInductorModelHandle model_handle,
+ AOTInductorConstantMapHandle constant_map_handle);
+
// Delete an AOTInductorModel created by AOTInductorModelCreate.
AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle);
diff --git a/torch/csrc/inductor/aoti_runtime/model.h b/torch/csrc/inductor/aoti_runtime/model.h
index 0d0b4ed..d4a07ed 100644
--- a/torch/csrc/inductor/aoti_runtime/model.h
+++ b/torch/csrc/inductor/aoti_runtime/model.h
@@ -224,6 +224,10 @@
return shape(outputs_info_, idx);
}
+ void update_constants_map(std::shared_ptr<ConstantMap>&& constants_map) {
+ constants_ = std::move(constants_map);
+ }
+
/// Returns true if the model is complete.
bool is_finished() {
if (!run_finished_) {