Add weight update for DSOModel. (#110273)

Summary: Add weight update for DSOModel and AOTInductorModel

Test Plan: buck2 test accelerators/workloads/models/slimdsnn:slimdsnn_dso_test - SlimDSNN.DSO_Update_Constants

Reviewed By: mikekgfb

Differential Revision: D49748685

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110273
Approved by: https://github.com/hl475
diff --git a/torch/_inductor/codegen/aoti_runtime/interface.cpp b/torch/_inductor/codegen/aoti_runtime/interface.cpp
index 50a2c47..e074a64 100644
--- a/torch/_inductor/codegen/aoti_runtime/interface.cpp
+++ b/torch/_inductor/codegen/aoti_runtime/interface.cpp
@@ -222,4 +222,19 @@
   })
 }
 
+AOTIRuntimeError AOTInductorModelUpdateConstants(
+    AOTInductorModelHandle model_handle,
+    AOTInductorConstantMapHandle constant_map_handle) {
+  auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
+  CONVERT_EXCEPTION_TO_ERROR_CODE({
+      auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
+      auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
+
+      for (auto const& kv : *input_map) {
+        constant_map->emplace(kv.first, kv.second);
+      }
+      model->update_constants_map(std::move(constant_map));
+  })
+}
+
 } // extern "C"
diff --git a/torch/csrc/inductor/aoti_runtime/interface.h b/torch/csrc/inductor/aoti_runtime/interface.h
index 4dcc9a3..f9e9f9d 100644
--- a/torch/csrc/inductor/aoti_runtime/interface.h
+++ b/torch/csrc/inductor/aoti_runtime/interface.h
@@ -136,6 +136,12 @@
     AtenTensorHandle* input_handles,
     AtenTensorHandle* output_handles);
 
+// Replace AOTInductorModel's constant map. Note it doesn't handle concurrency
+// so be sure to handle ordering if AOTInductorModelRun is ran concurrently.
+AOTIRuntimeError AOTInductorModelUpdateConstants(
+    AOTInductorModelHandle model_handle,
+    AOTInductorConstantMapHandle constant_map_handle);
+
 // Delete an AOTInductorModel created by AOTInductorModelCreate.
 AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle);
 
diff --git a/torch/csrc/inductor/aoti_runtime/model.h b/torch/csrc/inductor/aoti_runtime/model.h
index 0d0b4ed..d4a07ed 100644
--- a/torch/csrc/inductor/aoti_runtime/model.h
+++ b/torch/csrc/inductor/aoti_runtime/model.h
@@ -224,6 +224,10 @@
     return shape(outputs_info_, idx);
   }
 
+  void update_constants_map(std::shared_ptr<ConstantMap>&& constants_map) {
+    constants_ = std::move(constants_map);
+  }
+
   /// Returns true if the model is complete.
   bool is_finished() {
     if (!run_finished_) {