More name refactoring of memory planning codes to make it more readable (#54272)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54272

Test Plan: Imported from OSS

Reviewed By: bwasti

Differential Revision: D27233881

fbshipit-source-id: f257f16ac0684df055961e539f17d002cb8f1bfe
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 3fc4869..5930a78 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -125,35 +125,38 @@
 LivenessInformation GetLivenessInformation(
     const std::shared_ptr<torch::jit::Graph>& graph,
     AliasDb& db) {
+  // map a Value to a set of Values that overlap live-ranges with the Value's
   std::unordered_map<const Value*, std::set<const Value*>> liveness_map;
+  // a set of Values whose live-range exceed current inference
   std::unordered_set<const Value*> always_alive;
 
+  // map Values to its creation order in graph (Note: only traverse top-level
+  // nodes such that nodes under control-flows are represented by top-level
+  // block nodes)
   std::vector<const Value*> values_in_creation_order;
-  std::unordered_map<const Value*, size_t> values_in_creation_order_idx;
+  std::unordered_map<const Value*, size_t> values_to_idx_in_creation_order;
   for (const auto* node : graph->nodes()) {
     for (const auto* v : node->outputs()) {
-      values_in_creation_order_idx[v] = values_in_creation_order.size();
+      values_to_idx_in_creation_order[v] = values_in_creation_order.size();
       values_in_creation_order.emplace_back(v);
     }
   }
 
-  // maps values to any nodes that consume or produce them
-  //
-  // updated as we traverse the graph. the presence of a key in `live_values`
-  // means that the value is currently alive.
-  //
-  // invariant: set.size() > 0
-  std::unordered_map<const Value*, std::set<const Node*>> live_values;
-  std::unordered_map<const Node*, std::set<const Value*>> live_nodes;
+  // presence of a Value in live_values_use_chain means the Value alive
+  // Value mapped to set of Nodes that may use the Value (i.e., use-chain of
+  // Value)
+  std::unordered_map<const Value*, std::set<const Node*>> live_values_use_chain;
+  // Node mapped to set of Values that the Node may use (i.e., def-chain of node
+  // inputs)
+  std::unordered_map<const Node*, std::set<const Value*>> live_nodes_def_chain;
 
-  // inputs and outputs are marked permanently alive
+  // mark inputs, constants, outputs as always_alive
   for (const auto* input : graph->inputs()) {
     always_alive.insert(input);
   }
   for (const auto* output : graph->outputs()) {
     always_alive.insert(output);
   }
-
   for (const auto* node : graph->nodes()) {
     if (node->kind() == prim::Constant) {
       for (const auto* output : node->outputs()) {
@@ -162,13 +165,14 @@
     }
   }
 
+  // add v to the current liveness_map
   std::function<void(const Value* v)> add_live_value_fn = [&](const Value* v) {
     if (liveness_map.count(v)) {
       return;
     }
     liveness_map[v] = {};
 
-    for (const auto& live_v : live_values) {
+    for (const auto& live_v : live_values_use_chain) {
       liveness_map.at(v).insert(live_v.first);
       liveness_map.at(live_v.first).insert(v);
     }
@@ -176,45 +180,53 @@
     // only add values to the live set if they
     // have deps, otherwise they die immediately
     if (v->uses().size()) {
-      live_values[v] = {};
+      live_values_use_chain[v] = {};
     }
 
+    // record the relationship between v (Value) and its uses (Node)
     for (const auto& u : v->uses()) {
       const auto* node = u.user;
-      // track deps of this value
-      live_values.at(v).insert(node);
-      live_nodes[node].insert(v);
+      live_values_use_chain.at(v).insert(node);
+      live_nodes_def_chain[node].insert(v);
     }
 
-    // values created after this one that alias it
-    std::vector<const Value*> aliased_vs;
-    auto idx = values_in_creation_order_idx[v];
+    // FIXME(penguin): the following alias refinement seems to assume
+    // that `v` refers to a new  tensor created by the node that defines
+    // v, thus other Values "before" the node that defines `v` cannot
+    // possibly be aliased to `v`.
+    // TODO(penguin): Is it a limitation of TS alias analysis
+    // so that we need to do such refinement? If so, better improve
+    // alias analysis so that we dont need this special handling here
+    //
+    // Refine aliases of v by include only those created after v
+    std::vector<const Value*> refined_aliases;
+    auto idx = values_to_idx_in_creation_order[v];
     for (; idx < values_in_creation_order.size(); ++idx) {
       auto* alias_v = values_in_creation_order[idx];
       if (mayContainAlias(db, v, alias_v)) {
-        aliased_vs.emplace_back(alias_v);
+        refined_aliases.emplace_back(alias_v);
       }
     }
     // for all the values in the alias set,
     // we set them "alive"
-    for (auto* aliased_v : aliased_vs) {
+    for (auto* aliased_v : refined_aliases) {
       add_live_value_fn(aliased_v);
       for (const auto& u : aliased_v->uses()) {
         const auto* node = u.user;
         // track deps of the aliased values is if they
         // are our own
-        live_values.at(v).insert(node);
-        live_nodes[node].insert(v);
+        live_values_use_chain.at(v).insert(node);
+        live_nodes_def_chain[node].insert(v);
       }
     }
   };
 
   auto traverse_node_fn = [&](const Node* node,
                               std::vector<const Value*>& dead) {
-    if (live_nodes.count(node)) {
-      for (const auto* v : live_nodes.at(node)) {
-        live_values.at(v).erase(node);
-        if (!live_values.at(v).size()) {
+    if (live_nodes_def_chain.count(node)) {
+      for (const auto* v : live_nodes_def_chain.at(node)) {
+        live_values_use_chain.at(v).erase(node);
+        if (!live_values_use_chain.at(v).size()) {
           dead.emplace_back(v);
         }
       }
@@ -233,11 +245,11 @@
     std::vector<const Value*> dead;
     traverse_node_fn(node, dead);
     for (const auto* dead_value : dead) {
-      live_values.erase(dead_value);
+      live_values_use_chain.erase(dead_value);
     }
   }
 
-  for (const auto& v : live_values) {
+  for (const auto& v : live_values_use_chain) {
     TORCH_CHECK(always_alive.count(v.first));
   }
 
@@ -255,16 +267,16 @@
   return std::make_pair(liveness_map, always_alive);
 }
 
-// Implementation specific pruning of values
-// from "optimzable" set.  GetLivenessInformation and FindSameStorageValues
-// work with any graph, but we prune out values
-// that aren't produced by "_out" variants here.
+// Collect the set of Values that are candidates for memory planning:
+//   - Values that are used in in-place operators (i.e., _out variants), and
+//   - excluding those that are either inputs or outputs of
+//     non in-place operators
 //
 // Returns
-//   first: Values that can be optimized
+//   first: Values that are candidates for memory planning
 //   second: A deterministc order of all values
 std::pair<std::vector<const Value*>, std::vector<const Value*>>
-GetOptimizableValues(const std::shared_ptr<torch::jit::Graph>& graph) {
+GetMemoryPlanningCandidates(const std::shared_ptr<torch::jit::Graph>& graph) {
   // for determinism
   std::unordered_set<const Value*> seen_values;
   std::vector<const Value*> all_values;
@@ -334,7 +346,7 @@
 // NB: This is a deterministic implementation, which makes it easier to tune
 // and debug.
 std::unordered_map<const Value*, std::vector<const Value*>>
-FindSameStorageValues(
+GenerateSameStorageValues(
     const LivenessInformation& lm,
     const std::pair<std::vector<const Value*>, std::vector<const Value*>>&
         optimizable,
@@ -399,33 +411,44 @@
   // to preserve determinism
   std::vector<const Value*> seen;
 
+  auto compute_liveset_fn =
+      [&always_alive, &alive_during, &same_storage_values](
+          std::set<const Value*>& live, const Value* v) {
+        for (const auto* sv : same_storage_values.at(v)) {
+          const auto& l = alive_during.count(sv) ? alive_during.at(sv)
+                                                 : std::set<const Value*>{};
+          live.insert(l.begin(), l.end());
+        }
+        live.insert(always_alive.begin(), always_alive.end());
+      };
+
+  // check if same_storage_values[s] intersects with live
+  auto intersect_fn = [&same_storage_values](
+                          std::set<const Value*>& live, const Value* s) {
+    bool intersect = false;
+    for (const auto* v : same_storage_values.at(s)) {
+      if (live.count(v)) {
+        intersect = true;
+        break;
+      }
+    }
+    return intersect;
+  };
+
   for (const auto* v : optimizable_values) {
     if (always_alive.count(v)) {
       continue;
     }
     // get values that are live during the lifetime of v
     std::set<const Value*> live;
-    for (const auto* sv : same_storage_values.at(v)) {
-      const auto& l = alive_during.count(sv) ? alive_during.at(sv)
-                                             : std::set<const Value*>{};
-      live.insert(l.begin(), l.end());
-    }
-    live.insert(always_alive.begin(), always_alive.end());
-
+    compute_liveset_fn(live, v);
     for (const auto* s : seen) {
-      // check if any values in this set of same_storage_values
-      // are alive at the time of v
-      // effectively finding | set_intersection(live, set_of_shared(s)) | > 0
-      bool intersects = false;
-      for (const auto* candidate_v : same_storage_values.at(s)) {
-        if (live.count(candidate_v)) {
-          intersects = true;
-          break;
-        }
-      }
-      // we can share memory if there's no overlap
-      if (!intersects) {
+      // if live(same_storage_values[v]) and same_storage_values[s]
+      // do not overlap, then s and v can share the same storage
+      if (!intersect_fn(live, s)) {
         share_storage_fn(v, s);
+        // since s is added to same_storage_values[v], live needs
+        // to be recomputed, so bail out here
         break;
       }
     }
@@ -556,11 +579,12 @@
   auto lm = GetLivenessInformation(graph_, alias_db);
   external_values_ = lm.second;
   if (opts_.optimize_memory) {
-    auto values = GetOptimizableValues(graph_);
+    auto values = GetMemoryPlanningCandidates(graph_);
     if (!opts_.enable_out_variant) {
       values.first = {};
     }
-    value_to_same_storage_values_ = FindSameStorageValues(lm, values, alias_db);
+    value_to_same_storage_values_ =
+        GenerateSameStorageValues(lm, values, alias_db);
   }
 }
 
diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h
index 90e6173..2116016 100644
--- a/torch/csrc/jit/runtime/static/impl.h
+++ b/torch/csrc/jit/runtime/static/impl.h
@@ -16,8 +16,8 @@
   bool cleanup_activations{true};
   bool enable_out_variant{true};
   bool optimize_memory{true};
-  bool optimize_output_memory{
-      false}; // to enable MemoryPlanner on output tensors
+  // to enable MemoryPlanner on output tensors
+  bool optimize_output_memory{false};
 };
 
 /// The static runime supports two execution modes.
@@ -81,7 +81,6 @@
   typedef enum {
     CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant
     INPUT_VALUE = -1, // VALUE nodes representing graph inputs
-    OTHER_VALUE = 0 // other VALUE nodes (use non-negative index)
   } VALUE_KIND;
 
  private: