blob: 7aa6ac045749a8a0731136265286894d00a9a42f [file] [log] [blame]
#include <torch/csrc/jit/passes/liveness.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <iostream>
#include <memory>
namespace torch {
namespace jit {
// LivenessAnalyzer computes "bailout" liveness which is equivalent to
// "{LIVE_IN} or {GEN}" or "{LIVE_OUT} - {KILL}"
struct LivenessAnalyzer {
explicit LivenessAnalyzer(std::shared_ptr<Graph> graph)
: graph_(std::move(graph)), changed_(false) {}
std::unordered_map<Node*, std::vector<Value*>> run() {
std::vector<Node*> counters;
insertExplicitUsesOfLoopCounters(graph_->block(), counters);
// we implement the canonical fixed-point liveness
// the analysis is run until there are no more changes
// to liveness sets for each node
do {
changed_ = false;
processBlock(graph_->block(), SparseBitVector{});
} while (changed_);
removeCounterNodes(counters);
std::unordered_map<Node*, std::vector<Value*>> result;
for (const auto& e : liveness_sets_) {
result.insert({e.first, toValueVector(e.second)});
}
return result;
}
// temporary make loop counts live for the duration of the loop
// as they are needed by BailOuts in the loop
void insertExplicitUsesOfLoopCounters(
Block* b,
std::vector<Node*>& counters) {
for (auto it : b->nodes()) {
if (it->kind() == prim::Loop) {
LoopView lv(it);
WithInsertPoint guard(lv.bodyBlock());
auto ctc = graph_->create(prim::Store, {lv.currentTripCount()}, 0);
graph_->insertNode(ctc);
counters.push_back(ctc);
auto mtc = graph_->create(prim::Store, {lv.maxTripCount()}, 0);
graph_->insertNode(mtc);
counters.push_back(mtc);
}
for (auto ib : it->blocks()) {
insertExplicitUsesOfLoopCounters(ib, counters);
}
}
}
void removeCounterNodes(std::vector<Node*>& counters) {
for (auto n : counters) {
n->destroy();
}
}
void dump(
const std::unordered_map<Node*, std::vector<Value*>>& liveness_sets) {
std::cout << "Liveness info:\n";
for (auto e : liveness_sets) {
if (!e.first->outputs().empty()) {
std::cout << e.first->outputs()[0]->debugName();
}
std::cout << " " << e.first->kind().toQualString();
std::cout << " = ";
dump(e.second);
std::cout << std::endl;
}
std::cout << "graph :\n";
graph_->dump();
}
void dump(const std::vector<Value*>& set) {
bool first = true;
std::cout << "[";
for (auto el : set) {
if (first) {
first = false;
} else {
std::cout << ", ";
}
std::cout << el->debugName() << "(" << el->unique() << ")";
}
std::cout << "]";
}
private:
SparseBitVector toSparseBitVector(at::ArrayRef<Value*> values) {
SparseBitVector sbv;
for (auto v : values) {
ids_to_values_[v->unique()] = v;
sbv.set(v->unique());
}
return sbv;
}
std::vector<Value*> toValueVector(const SparseBitVector& sbv) {
std::vector<Value*> vec;
for (auto id : sbv) {
vec.push_back(ids_to_values_[id]);
}
return vec;
}
SparseBitVector processBlock(Block* b, SparseBitVector liveness) {
// block outputs are the uses
auto block_outputs = toSparseBitVector(b->outputs());
liveness |= block_outputs;
SparseBitVector defs;
for (Node* it : b->nodes().reverse()) {
// kill outputs
liveness -= toSparseBitVector(it->outputs());
if (it->kind() == prim::Loop) {
LoopView lv(it);
// N.B. merge in changes from the loop header
auto loop_header = *lv.bodyBlock()->nodes().begin();
auto loop_block = liveness | liveness_sets_[loop_header];
loop_block = processBlock(lv.bodyBlock(), loop_block);
// loop block's inputs die outside loop's block
loop_block -= toSparseBitVector(lv.bodyBlock()->inputs());
liveness |= loop_block;
} else if (it->kind() == prim::If) {
IfView iv(it);
auto true_liveness = processBlock(iv.thenBlock(), liveness);
auto false_liveness = processBlock(iv.elseBlock(), liveness);
liveness |= true_liveness;
liveness |= false_liveness;
}
liveness |= toSparseBitVector(it->inputs());
// `|=` returns true if new bits were set in LHS
// after or/union with `liveness`
auto changed = liveness_sets_[it] |= liveness;
changed_ = changed_ | changed;
}
return liveness;
}
std::shared_ptr<Graph> graph_;
bool changed_;
std::map<Node*, SparseBitVector> liveness_sets_;
std::map<size_t, Value*> ids_to_values_;
};
std::unordered_map<Node*, std::vector<Value*>> BuildLivenessSets(
std::shared_ptr<Graph> graph) {
LivenessAnalyzer la(std::move(graph));
return la.run();
}
} // namespace jit
} // namespace torch