More name refactoring of memory planning codes to make it more readable (#54272)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54272
Test Plan: Imported from OSS
Reviewed By: bwasti
Differential Revision: D27233881
fbshipit-source-id: f257f16ac0684df055961e539f17d002cb8f1bfe
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 3fc4869..5930a78 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -125,35 +125,38 @@
LivenessInformation GetLivenessInformation(
const std::shared_ptr<torch::jit::Graph>& graph,
AliasDb& db) {
+ // map a Value to a set of Values that overlap live-ranges with the Value's
std::unordered_map<const Value*, std::set<const Value*>> liveness_map;
+ // a set of Values whose live-range exceed current inference
std::unordered_set<const Value*> always_alive;
+ // map Values to its creation order in graph (Note: only traverse top-level
+ // nodes such that nodes under control-flows are represented by top-level
+ // block nodes)
std::vector<const Value*> values_in_creation_order;
- std::unordered_map<const Value*, size_t> values_in_creation_order_idx;
+ std::unordered_map<const Value*, size_t> values_to_idx_in_creation_order;
for (const auto* node : graph->nodes()) {
for (const auto* v : node->outputs()) {
- values_in_creation_order_idx[v] = values_in_creation_order.size();
+ values_to_idx_in_creation_order[v] = values_in_creation_order.size();
values_in_creation_order.emplace_back(v);
}
}
- // maps values to any nodes that consume or produce them
- //
- // updated as we traverse the graph. the presence of a key in `live_values`
- // means that the value is currently alive.
- //
- // invariant: set.size() > 0
- std::unordered_map<const Value*, std::set<const Node*>> live_values;
- std::unordered_map<const Node*, std::set<const Value*>> live_nodes;
+ // presence of a Value in live_values_use_chain means the Value alive
+ // Value mapped to set of Nodes that may use the Value (i.e., use-chain of
+ // Value)
+ std::unordered_map<const Value*, std::set<const Node*>> live_values_use_chain;
+ // Node mapped to set of Values that the Node may use (i.e., def-chain of node
+ // inputs)
+ std::unordered_map<const Node*, std::set<const Value*>> live_nodes_def_chain;
- // inputs and outputs are marked permanently alive
+ // mark inputs, constants, outputs as always_alive
for (const auto* input : graph->inputs()) {
always_alive.insert(input);
}
for (const auto* output : graph->outputs()) {
always_alive.insert(output);
}
-
for (const auto* node : graph->nodes()) {
if (node->kind() == prim::Constant) {
for (const auto* output : node->outputs()) {
@@ -162,13 +165,14 @@
}
}
+ // add v to the current liveness_map
std::function<void(const Value* v)> add_live_value_fn = [&](const Value* v) {
if (liveness_map.count(v)) {
return;
}
liveness_map[v] = {};
- for (const auto& live_v : live_values) {
+ for (const auto& live_v : live_values_use_chain) {
liveness_map.at(v).insert(live_v.first);
liveness_map.at(live_v.first).insert(v);
}
@@ -176,45 +180,53 @@
// only add values to the live set if they
// have deps, otherwise they die immediately
if (v->uses().size()) {
- live_values[v] = {};
+ live_values_use_chain[v] = {};
}
+ // record the relationship between v (Value) and its uses (Node)
for (const auto& u : v->uses()) {
const auto* node = u.user;
- // track deps of this value
- live_values.at(v).insert(node);
- live_nodes[node].insert(v);
+ live_values_use_chain.at(v).insert(node);
+ live_nodes_def_chain[node].insert(v);
}
- // values created after this one that alias it
- std::vector<const Value*> aliased_vs;
- auto idx = values_in_creation_order_idx[v];
+ // FIXME(penguin): the following alias refinement seems to assume
+ // that `v` refers to a new tensor created by the node that defines
+ // v, thus other Values "before" the node that defines `v` cannot
+ // possibly be aliased to `v`.
+ // TODO(penguin): Is it a limitation of TS alias analysis
+ // so that we need to do such refinement? If so, better improve
+ // alias analysis so that we dont need this special handling here
+ //
+ // Refine aliases of v by include only those created after v
+ std::vector<const Value*> refined_aliases;
+ auto idx = values_to_idx_in_creation_order[v];
for (; idx < values_in_creation_order.size(); ++idx) {
auto* alias_v = values_in_creation_order[idx];
if (mayContainAlias(db, v, alias_v)) {
- aliased_vs.emplace_back(alias_v);
+ refined_aliases.emplace_back(alias_v);
}
}
// for all the values in the alias set,
// we set them "alive"
- for (auto* aliased_v : aliased_vs) {
+ for (auto* aliased_v : refined_aliases) {
add_live_value_fn(aliased_v);
for (const auto& u : aliased_v->uses()) {
const auto* node = u.user;
// track deps of the aliased values is if they
// are our own
- live_values.at(v).insert(node);
- live_nodes[node].insert(v);
+ live_values_use_chain.at(v).insert(node);
+ live_nodes_def_chain[node].insert(v);
}
}
};
auto traverse_node_fn = [&](const Node* node,
std::vector<const Value*>& dead) {
- if (live_nodes.count(node)) {
- for (const auto* v : live_nodes.at(node)) {
- live_values.at(v).erase(node);
- if (!live_values.at(v).size()) {
+ if (live_nodes_def_chain.count(node)) {
+ for (const auto* v : live_nodes_def_chain.at(node)) {
+ live_values_use_chain.at(v).erase(node);
+ if (!live_values_use_chain.at(v).size()) {
dead.emplace_back(v);
}
}
@@ -233,11 +245,11 @@
std::vector<const Value*> dead;
traverse_node_fn(node, dead);
for (const auto* dead_value : dead) {
- live_values.erase(dead_value);
+ live_values_use_chain.erase(dead_value);
}
}
- for (const auto& v : live_values) {
+ for (const auto& v : live_values_use_chain) {
TORCH_CHECK(always_alive.count(v.first));
}
@@ -255,16 +267,16 @@
return std::make_pair(liveness_map, always_alive);
}
-// Implementation specific pruning of values
-// from "optimzable" set. GetLivenessInformation and FindSameStorageValues
-// work with any graph, but we prune out values
-// that aren't produced by "_out" variants here.
+// Collect the set of Values that are candidates for memory planning:
+// - Values that are used in in-place operators (i.e., _out variants), and
+// - excluding those that are either inputs or outputs of
+// non in-place operators
//
// Returns
-// first: Values that can be optimized
+// first: Values that are candidates for memory planning
// second: A deterministc order of all values
std::pair<std::vector<const Value*>, std::vector<const Value*>>
-GetOptimizableValues(const std::shared_ptr<torch::jit::Graph>& graph) {
+GetMemoryPlanningCandidates(const std::shared_ptr<torch::jit::Graph>& graph) {
// for determinism
std::unordered_set<const Value*> seen_values;
std::vector<const Value*> all_values;
@@ -334,7 +346,7 @@
// NB: This is a deterministic implementation, which makes it easier to tune
// and debug.
std::unordered_map<const Value*, std::vector<const Value*>>
-FindSameStorageValues(
+GenerateSameStorageValues(
const LivenessInformation& lm,
const std::pair<std::vector<const Value*>, std::vector<const Value*>>&
optimizable,
@@ -399,33 +411,44 @@
// to preserve determinism
std::vector<const Value*> seen;
+ auto compute_liveset_fn =
+ [&always_alive, &alive_during, &same_storage_values](
+ std::set<const Value*>& live, const Value* v) {
+ for (const auto* sv : same_storage_values.at(v)) {
+ const auto& l = alive_during.count(sv) ? alive_during.at(sv)
+ : std::set<const Value*>{};
+ live.insert(l.begin(), l.end());
+ }
+ live.insert(always_alive.begin(), always_alive.end());
+ };
+
+ // check if same_storage_values[s] intersects with live
+ auto intersect_fn = [&same_storage_values](
+ std::set<const Value*>& live, const Value* s) {
+ bool intersect = false;
+ for (const auto* v : same_storage_values.at(s)) {
+ if (live.count(v)) {
+ intersect = true;
+ break;
+ }
+ }
+ return intersect;
+ };
+
for (const auto* v : optimizable_values) {
if (always_alive.count(v)) {
continue;
}
// get values that are live during the lifetime of v
std::set<const Value*> live;
- for (const auto* sv : same_storage_values.at(v)) {
- const auto& l = alive_during.count(sv) ? alive_during.at(sv)
- : std::set<const Value*>{};
- live.insert(l.begin(), l.end());
- }
- live.insert(always_alive.begin(), always_alive.end());
-
+ compute_liveset_fn(live, v);
for (const auto* s : seen) {
- // check if any values in this set of same_storage_values
- // are alive at the time of v
- // effectively finding | set_intersection(live, set_of_shared(s)) | > 0
- bool intersects = false;
- for (const auto* candidate_v : same_storage_values.at(s)) {
- if (live.count(candidate_v)) {
- intersects = true;
- break;
- }
- }
- // we can share memory if there's no overlap
- if (!intersects) {
+ // if live(same_storage_values[v]) and same_storage_values[s]
+ // do not overlap, then s and v can share the same storage
+ if (!intersect_fn(live, s)) {
share_storage_fn(v, s);
+ // since s is added to same_storage_values[v], live needs
+ // to be recomputed, so bail out here
break;
}
}
@@ -556,11 +579,12 @@
auto lm = GetLivenessInformation(graph_, alias_db);
external_values_ = lm.second;
if (opts_.optimize_memory) {
- auto values = GetOptimizableValues(graph_);
+ auto values = GetMemoryPlanningCandidates(graph_);
if (!opts_.enable_out_variant) {
values.first = {};
}
- value_to_same_storage_values_ = FindSameStorageValues(lm, values, alias_db);
+ value_to_same_storage_values_ =
+ GenerateSameStorageValues(lm, values, alias_db);
}
}
diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h
index 90e6173..2116016 100644
--- a/torch/csrc/jit/runtime/static/impl.h
+++ b/torch/csrc/jit/runtime/static/impl.h
@@ -16,8 +16,8 @@
bool cleanup_activations{true};
bool enable_out_variant{true};
bool optimize_memory{true};
- bool optimize_output_memory{
- false}; // to enable MemoryPlanner on output tensors
+ // to enable MemoryPlanner on output tensors
+ bool optimize_output_memory{false};
};
/// The static runime supports two execution modes.
@@ -81,7 +81,6 @@
typedef enum {
CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant
INPUT_VALUE = -1, // VALUE nodes representing graph inputs
- OTHER_VALUE = 0 // other VALUE nodes (use non-negative index)
} VALUE_KIND;
private: