Revert D26333953: [StaticRuntime] Clean up output references and remove dead code

Test Plan: revert-hammer

Differential Revision:
D26333953 (https://github.com/pytorch/pytorch/commit/0c9d72b5e11e9d3e706eaf1c07406a0cebfe27c7)

Original commit changeset: cadc0595ad6a

fbshipit-source-id: 75d0b33099342653cd8867b129139325789aee6c
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index a220a17..c8f2d5a 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -220,16 +220,8 @@
         at::Tensor output_1 = getTensor(module.forward(inputs));
 
         // run static runtime
-        c10::IValue output_ivalue = runtime.run(inputs, {});
-
-        at::Tensor output_2 = getTensor(output_ivalue);
+        at::Tensor output_2 = getTensor(runtime.run(inputs, {}));
         EXPECT_TRUE(output_1.equal(output_2));
-
-        // check for output aliasing
-        EXPECT_EQ(output_ivalue.use_count(), 1);
-        output_ivalue = IValue();
-
-        EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
       }
 
       // check for input aliasing (deep & wide does not have ops
@@ -264,16 +256,8 @@
              {"wide", wide}});
 
         // run static runtime
-        c10::IValue output_ivalue = runtime.run({}, kwargs);
-
-        at::Tensor output_2 = getTensor(output_ivalue);
+        at::Tensor output_2 = getTensor(runtime.run({}, kwargs));
         EXPECT_TRUE(output_1.equal(output_2));
-
-        // check for output aliasing
-        EXPECT_EQ(output_ivalue.use_count(), 1);
-        output_ivalue = IValue();
-
-        EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
       }
 
       EXPECT_EQ(ad_emb_packed.getIntrusivePtr().use_count(), 1);
diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp
index 6134753..b4ef21e 100644
--- a/torch/csrc/jit/runtime/static/fusion.cpp
+++ b/torch/csrc/jit/runtime/static/fusion.cpp
@@ -23,7 +23,7 @@
 Operation createStaticSubgraphRuntime(const Node* node) {
   auto g = torch::jit::PrepareForStaticRuntime(node->g(attr::Subgraph));
   auto runtime = std::make_shared<torch::jit::StaticRuntime>(g);
-  auto num_inputs = runtime->num_inputs();
+  auto num_inputs = runtime->get_inference_module()->input_regs.size();
   return [runtime, num_inputs](Stack* stack) {
     RECORD_FUNCTION("Static Runtime", std::vector<c10::IValue>());
     auto inps = torch::jit::last(stack, num_inputs);
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 87526d4..56c9fa0 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -198,12 +198,143 @@
   }
   return can_reuse;
 }
+
+size_t AssignRegisters(
+    const std::shared_ptr<torch::jit::Graph>& graph,
+    std::unordered_map<Value*, size_t>& value_to_reg,
+    std::vector<Value*>& values,
+    std::vector<size_t>& input_regs,
+    std::vector<size_t>& output_regs,
+    bool optimize_memory) {
+  auto lm = LivenessMap(graph);
+  auto optimizable_values = GetOptimizableValues(graph);
+
+  size_t num_regs = 0;
+  size_t reused_regs = 0;
+  std::unordered_map<size_t, std::set<Value*>> reg_to_val;
+  auto getReg = [&](Value* v) -> size_t {
+    if (!optimize_memory) {
+      return num_regs++;
+    }
+    TORCH_CHECK(!value_to_reg.count(v));
+    auto iter = lm.first.find(v);
+    if (iter == lm.first.end()) {
+      return num_regs++;
+    }
+    if (!optimizable_values.count(v)) {
+      return num_regs++;
+    }
+    if (lm.second.count(v)) {
+      return num_regs++;
+    }
+    const auto& live_values = iter->second;
+    // iterate through all the allocated registers
+    // and check for potential re-use, greedily
+    for (const auto& v2r : value_to_reg) {
+      auto candidate_v = v2r.first;
+
+      if (!optimizable_values.count(candidate_v)) {
+        continue;
+      }
+      if (lm.second.count(candidate_v)) {
+        continue;
+      }
+
+      // Only re-use float* tensors
+      auto t = candidate_v->type()->cast<TensorType>();
+      if (!t) {
+        continue;
+      }
+      // TODO audit this assumption (passes tests, but is scary)
+      if (t->scalarType() && *(t->scalarType()) != at::kFloat) {
+        continue;
+      }
+      // TODO
+      // if (*(t->scalarType()) != at::kFloat) {
+      //  continue;
+      //}
+      if (!live_values.count(candidate_v)) {
+        bool already_used = false;
+        for (auto use : reg_to_val.at(v2r.second)) {
+          if (live_values.count(use)) {
+            already_used = true;
+          }
+        }
+        if (already_used) {
+          continue;
+        }
+        reused_regs++;
+        return v2r.second;
+      }
+    }
+    return num_regs++;
+  };
+
+  // assign register to Value*
+  for (Value* input : graph->inputs()) {
+    TORCH_CHECK(value_to_reg.count(input) == 0);
+    auto reg = getReg(input);
+    value_to_reg[input] = reg;
+    reg_to_val[reg].insert(input);
+    input_regs.push_back(reg);
+  }
+  for (Node* node : graph->nodes()) {
+    for (Value* input : node->inputs()) {
+      TORCH_CHECK(value_to_reg.count(input) > 0);
+    }
+    for (Value* output : node->outputs()) {
+      TORCH_CHECK(
+          value_to_reg.count(output) == 0, "the graph needs to be in SSA form");
+      auto reg = getReg(output);
+      value_to_reg[output] = reg;
+      reg_to_val[reg].insert(output);
+    }
+  }
+  TORCH_CHECK(graph->outputs().size() > 0);
+  for (Value* output : graph->outputs()) {
+    TORCH_CHECK(value_to_reg.count(output) > 0);
+    output_regs.push_back(value_to_reg[output]);
+  }
+
+  values.resize(value_to_reg.size());
+  for (const auto& p : value_to_reg) {
+    values[p.second] = p.first;
+  }
+  return reused_regs;
+}
+
+// Internal values are discarded after run if
+// opts_.cleanup_activations is true.
+void DeduceInternalValues(
+    const std::shared_ptr<torch::jit::Graph>& graph,
+    const std::unordered_map<Value*, size_t>& value_to_reg,
+    std::vector<size_t>& internals) {
+  std::unordered_set<Value*> outputs{
+      graph->outputs().begin(), graph->outputs().end()};
+  for (Node* node : graph->nodes()) {
+    if (node->kind() != prim::Constant) {
+      for (Value* output : node->outputs()) {
+        if (outputs.count(output) == 0) {
+          internals.push_back(value_to_reg.at(output));
+        }
+      }
+    }
+  }
+}
 } // namespace
 
 void InferenceModule::init() {
   OptimizeGraph(graph);
   CheckGraphEligibility(graph);
   RemoveSelfFromGraphInput(graph);
+  reused_regs = AssignRegisters(
+      graph,
+      value_to_reg,
+      values,
+      input_regs,
+      output_regs,
+      opts.optimize_memory);
+  DeduceInternalValues(graph, value_to_reg, internals);
 }
 
 InferenceModule::InferenceModule(
@@ -298,6 +429,10 @@
   }
 }
 
+size_t StaticRuntime::num_outputs() const {
+  return module_->output_regs.size();
+}
+
 std::vector<at::Tensor> StaticRuntime::run(
     const std::vector<at::Tensor>& inps) {
   std::vector<c10::IValue> stack;
@@ -375,13 +510,11 @@
     std::vector<c10::IValue> outputs;
     outputs.reserve(num_outputs());
     for (auto i = 0; i < num_outputs(); ++i) {
-      // use move here. Otherwise, clean up outputs_[i] explicitly
-      outputs.emplace_back(std::move(*outputs_[i]));
+      outputs.emplace_back(Output(i));
     }
-    return c10::ivalue::Tuple::create(std::move(outputs));
+    return c10::ivalue::Tuple::create(outputs);
   }
-  // use move here. Otherwise, clean up outputs_[0] explicitly
-  return std::move(*outputs_[0]);
+  return Output(0);
 }
 
 void StaticRuntime::benchmark(
@@ -428,8 +561,8 @@
               << " bytes" << std::endl;
   }
   if (module_->opts.optimize_memory) {
-    // std::cout << "Total number of reused registers: " << module_->reused_regs
-    //           << std::endl;
+    std::cout << "Total number of reused registers: " << module_->reused_regs
+              << std::endl;
   }
 }
 
@@ -537,7 +670,6 @@
   // collect register indices of outputs of ops with out variant
   std::unordered_set<Value*> managed_values;
   std::unordered_set<IValue*> unmanaged_value_set;
-  std::unordered_map<Value*, IValue*> values_map;
   for (ProcessedNode& pnode : runtime->get_nodes()) {
     bool should_manage = pnode.has_out_variant();
     if (should_manage && isViewOp(pnode.get_node())) {
@@ -553,23 +685,31 @@
     }
     if (should_manage) {
       // Types are stored in the underlying TorchScript IR
-      for (size_t i = 0; i < pnode.outputs().size(); i++) {
-        Value* out = pnode.get_node()->output(i);
+      for (Value* out : pnode.get_node()->outputs()) {
         if (out->type()->cast<TensorType>()) {
           managed_values.insert(out);
-          values_map[out] = &pnode.Output(i);
         }
       }
     } else {
       for (auto i = 0; i < pnode.outputs().size(); ++i) {
         unmanaged_value_set.insert(&pnode.Output(i));
-        values_map[pnode.get_node()->output(i)] = &pnode.Output(i);
       }
     }
   }
 
   const InferenceModule* module = runtime->get_inference_module();
 
+  // remove model outputs from managed_values
+  for (Value* output : module->graph->outputs()) {
+    managed_values.erase(output);
+  }
+  for (IValue* output : runtime->outputs()) {
+    unmanaged_value_set.erase(output);
+  }
+  for (IValue* out : unmanaged_value_set) {
+    unmanaged_values_.emplace_back(out);
+  }
+
   // remove tensors in output List/Tuple from managed_values
   for (Value* output : module->graph->outputs()) {
     Node* output_node = output->node();
@@ -577,28 +717,10 @@
         output_node->kind() == prim::ListConstruct) {
       for (Value* input : output_node->inputs()) {
         managed_values.erase(input);
-        // Elements in Tuples and Lists are refcounted. MemoryPlanner should not
-        // hold refs of elements in output Tuples/Lists
-        if (graph_input_values.count(input) == 0) {
-          unmanaged_value_set.insert(values_map[input]);
-        }
       }
     }
   }
 
-  // remove model outputs from managed_values and unmanaged_value_set
-  for (Value* output : module->graph->outputs()) {
-    managed_values.erase(output);
-  }
-  for (IValue* output : runtime->outputs()) {
-    unmanaged_value_set.erase(output);
-  }
-
-  // unmanaged_value_set => unmanaged_values_
-  for (IValue* out : unmanaged_value_set) {
-    unmanaged_values_.emplace_back(out);
-  }
-
   // some Values should share storage, this map will
   // keep track of the index into managed_storage_
   std::unordered_map<Value*, size_t> shared;
diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h
index 4385c3f..a8e2cb3 100644
--- a/torch/csrc/jit/runtime/static/impl.h
+++ b/torch/csrc/jit/runtime/static/impl.h
@@ -71,6 +71,12 @@
   std::shared_ptr<torch::jit::Graph> graph;
   std::unique_ptr<c10::FunctionSchema> schema;
 
+  std::unordered_map<Value*, size_t> value_to_reg;
+  std::vector<Value*> values; // useful for debugging
+  std::vector<size_t> input_regs; // inputs to the graph
+  std::vector<size_t> output_regs; // outputs of the graph
+  std::vector<size_t> internals;
+  size_t reused_regs = 0;
   InferenceModuleOptions opts;
 
  private:
@@ -153,13 +159,11 @@
     return nodes_;
   }
 
-  [[nodiscard]] size_t num_inputs() const {
-    return inputs_.size();
+  const std::vector<IValue>& get_registers() {
+    return reg_;
   }
 
-  [[nodiscard]] size_t num_outputs() const {
-    return outputs_.size();
-  }
+  size_t num_outputs() const;
 
   inline const std::vector<IValue*>& outputs() const {
     return outputs_;
@@ -170,6 +174,7 @@
   std::shared_ptr<InferenceModule> module_;
   StaticRuntimeOptions opts_;
   // IValue table (including inputs, outputs, intermediates, and weights)
+  std::vector<IValue> reg_;
   std::vector<IValue> constants_;
   std::vector<IValue> inputs_;
   std::vector<IValue*> outputs_;