blob: 4bcbcad9810f6111bf6a962dc892fc5c93a38401 [file] [log] [blame]
#pragma once
#include <algorithm>
#include <iostream>
#include <type_traits>
#include <unordered_set>
#include <vector>
#include "torch/csrc/utils/functional.h"
namespace torch { namespace jit { namespace detail {
// DynamicDAG is a simple directed acyclic graph that dynamically maintains a
// topological order as edges/vertices are added and removed.
//
// [Example applications]
// - Let's say you have a DAG where each vertex is black or red. How do we
// merge black nodes that are directly connected by contracting the
// edge between them while still maintaining the DAG and a topological order?
// Use contractEdge().
// - Let's say you have a DAG where each vertex is a Node* and the edges represent
// data dependencies. We wish to determine if adding a new Node* with certain
// data dependencies (or moving an existing one to use new dependencies) is valid.
// Use DynamicDAG::addEdge() to add the new data dependencies to the DAG:
// it will either find a valid reordering of the DAG's topological order or throw
// if the resulting DAG is invalid.
//
// The implementation is based off of the PK algorithm in the following paper:
// "A Dynamic Topsort Algorithm for Directed Acyclic Graphs"
// by David Pearce and Paul Kelly
// https://www.doc.ic.ac.uk/~phjk/Publications/DynamicTopoSortAlg-JEA-07.pdf
// It is summarized in [Edge addition] (see DynamicDAG<T>::addEdge)
template <typename T> struct Vertex;
template <typename T> struct DynamicDAG;
template <typename T> using vertex_list = std::vector<Vertex<T>*>;
template <typename T> using unique_vertex = std::unique_ptr<Vertex<T>>;
enum class DFSDirection {forward, backward};
// Used to represent adjacency lists in DynamicDAG.
// Has set semantics: stores distinct elements.
//
// Because our graphs shouldn't fan out or in very much,
// we use std::vector<Vertex<T>*> to record edges.
// In all of the complexity analysis it is assumed that
// inserting, erasing, and finding take constant time.
template <typename T>
struct vertex_set {
using iterator = typename vertex_list<T>::iterator;
using reverse_iterator = typename vertex_list<T>::reverse_iterator;
// returns if we inserted v into the set.
bool insert(Vertex<T>* v) {
if (contains(v)) {
return false;
} else {
data_.push_back(v);
return true;
}
}
void erase(Vertex<T>* v) {
data_.erase(std::find(data_.begin(), data_.end(), v));
}
bool contains(Vertex<T>* v) const {
return std::find(data_.begin(), data_.end(), v) != data_.end();
}
void sort() {
std::sort(data_.begin(), data_.end(), [](Vertex<T>* a, Vertex<T>* b) {
return a->ord < b->ord;
});
}
size_t size() const { return data_.size(); }
iterator begin() { return data_.begin(); }
iterator end() { return data_.end(); }
reverse_iterator rbegin() { return data_.rbegin(); }
reverse_iterator rend() { return data_.rend(); }
private:
std::vector<Vertex<T>*> data_;
};
template <typename T>
struct IOEdges {
vertex_set<T> in_edges;
vertex_set<T> out_edges;
};
// Simple RAII wrapper around a vertex_list<T>.
// When adding a vertex to the list, mark it as visited.
// Clears the visited flag of each vertex in the vertex_list on deletion.
template <typename T>
struct visited_list {
~visited_list() {
for (auto* v : data_) {
v->visited_ = false;
}
}
void push_back(Vertex<T>* elt) {
JIT_ASSERT(!elt->visited_);
elt->visited_ = true;
data_.push_back(elt);
}
void sort() {
std::sort(data_.begin(), data_.end(), [](Vertex<T>* a, Vertex<T>* b) {
return a->ord < b->ord;
});
}
const vertex_list<T>& vector() { return data_; }
private:
vertex_list<T> data_;
};
template <typename T>
struct Vertex {
Vertex(size_t ord, T datum)
: ord(ord), visited_(false) { data.push_back(datum); }
std::vector<T> data;
size_t ord; // unique topological index
std::string toString();
vertex_set<T>& in_edges() { return edges_.in_edges; }
vertex_set<T>& out_edges() { return edges_.out_edges; }
IOEdges<T>&& move_edges() { return std::move(edges_); }
bool visited() { return visited_; }
private:
IOEdges<T> edges_;
friend visited_list<T>;
bool visited_; // If this vertex has been visited
};
template <typename T>
struct DynamicDAG {
Vertex<T>* newVertex(T datum);
IOEdges<T> removeVertex(Vertex<T>* v);
void addEdge(Vertex<T>* producer, Vertex<T>* consumer);
void removeEdge(Vertex<T>* producer, Vertex<T>* consumer);
bool contractEdge(Vertex<T>* producer, Vertex<T>* consumer);
// max_size() >= the number of live vertices.
// for all vertices v, v.ord < max_size()
size_t max_size() const { return vertices_.size(); };
c10::optional<Vertex<T>*> at(size_t ord) const;
std::string toString();
// Use for debugging. Don't call these often.
size_t debugNumVertices() const;
void debugCheckInvariants();
private:
void mergeProducerIntoConsumer(Vertex<T>* producer, Vertex<T>* consumer);
void mergeConsumerIntoProducer(Vertex<T>* producer, Vertex<T>* consumer);
void reorder(visited_list<T> deltaF, visited_list<T> deltaB);
bool contractionProducesCycle(Vertex<T>* producer, Vertex<T>* consumer);
bool dfsSearch(
DFSDirection direction,
Vertex<T>* start,
Vertex<T>* end,
size_t bound,
visited_list<T>& visited);
// Store vertices indexed by their topological order.
// If a vertex v has ord 5, then it can be found at vertices_[5].
// There may be gaps in vertices_; this is to enable fast deletion.
std::vector<unique_vertex<T>> vertices_;
};
// O(vertices_.size()). Used for testing, don't call this often.
template <typename T>
size_t DynamicDAG<T>::debugNumVertices() const {
return std::count_if(vertices_.begin(), vertices_.end(),
[](const unique_vertex<T>& v) {
if (v) return true;
return false;
});
}
template <typename T>
Vertex<T>* DynamicDAG<T>::newVertex(T datum) {
vertices_.emplace_back(new Vertex<T>(vertices_.size(), datum));
return vertices_.back().get();
}
template <typename T>
void DynamicDAG<T>::removeEdge(Vertex<T>* producer, Vertex<T>* consumer) {
JIT_ASSERT(producer != consumer);
JIT_ASSERT(producer->out_edges().contains(consumer));
JIT_ASSERT(consumer->in_edges().contains(producer));
producer->out_edges().erase(consumer);
consumer->in_edges().erase(producer);
}
template <typename T>
void DynamicDAG<T>::debugCheckInvariants() {
for (size_t ord = 0; ord < vertices_.size(); ++ord) {
const auto& vertex = vertices_.at(ord);
if (!vertex) continue;
AT_ASSERTM(vertex->ord == ord, toString());
for (auto* v : vertex->in_edges()) {
AT_ASSERTM(v->ord < ord, toString());
}
for (auto* v : vertex->out_edges()) {
AT_ASSERTM(v->ord > ord, toString());
}
}
}
template <typename T>
c10::optional<Vertex<T>*> DynamicDAG<T>::at(size_t ord) const {
const auto& vertex = vertices_.at(ord);
if (!vertex) {
return c10::nullopt;
} else {
return vertex.get();
}
}
template <typename T>
IOEdges<T> DynamicDAG<T>::removeVertex(Vertex<T>* v) {
for (auto* parent : v->in_edges()) {
parent->out_edges().erase(v);
}
for (auto* child : v->out_edges()) {
child->in_edges().erase(v);
}
auto edges = v->move_edges();
vertices_[v->ord] = nullptr;
return edges;
}
/*
* [Edge addition]
* When adding an edge x -> y,
* - if ord(x) < ord(y), don't do anything.
* - if ord(y) < ord(x), some graph reordering must occur.
*
* Assume we are adding an edge x -> y and that ord(x) > ord(y).
* First, if there is a path y ----> x through some other vertices, then this
* edge addition would create a cycle. Figure this out via DFS and throw if necessary.
*
* Now, consider the set of all vertices v such that ord(x) > ord(v) > ord(y).
* Call this set the affected region (AR) -- these are the only vertices we
* need to consider for reordering to make the resulting graph valid.
*
* Find all children of y (through DFS) in AR (call this set deltaF and add y to it)
* Find all parents of x in AR (call this set deltaB and add x to it).
*
* Move y and all the children of y to after x and all the parents of x. The
* result topological ordering is valid.
*
* [Visual algorithm reference]
* Higher nodes come earlier in topological order.
* We are adding an edge between x -> y.
* The topological ordering is e, y, c, a, d, b, x, f.
* The affected region is {y, c, a, d, b, x}. e and f cannot be involved
* in the reorder.
*
* (e) <- ord = 0 -> (e)
* | |
* v v
* (y) <- ord = 1 -> \ (c)
* ^ \ -----\ |
* (c) | v <- ord = 2 -> -----/ (d) v
* \ | (a) <- ord = 3 -> / \->(x)
* || | /\
* (d) || | <- ord = 4 -> (y)<-/ \
* | || v \ |
* \ v| (b) <- ord = 5 -> \->(a) |
* ->(x) <- ord = 6 -> (b)<--/ v
* \->(f) <- ord = 7 -> (f)
*
* We find all children of y in the affected region. deltaF = {y, a, b}
* We find all parents of x via DFS. deltaB = {c, d, x}
*
* Now, we reorder all vertices in deltaB to come before deltaF. This is
* a little involved and happens in four steps:
*
* 1) sort all vertices in deltaB, and all vertices in deltaF.
* deltaB (sorted) = {c(2), d(4), x(6)}. deltaB ords = { 2, 4, 6 }
* deltaF (sorted) = {y(1), a(3), b(5)}. deltaF ords = { 1, 3, 5 }
*
* 2) append the two lists: the result is the order we want these vertices to have.
* L = {c(2), d(4), x(6), y(1), a(3), b(5)}.
*
* 3) Merge the sorted ords: R = { 1, 2, 3, 4, 5, 6 }.
*
* 4) Reassign the vertices in L in order with the sorted ords.
* We always use the vertices in deltaB, then deltaF, in that order.
* L = { c(1), d(2), x(3), y(4) a(5), b(6) }
*
* This produces th graph shown on the right.
*
* [Analysis]
* This is O(|AR| log |AR|). |AR| is equal to ord(consumer) - ord(producer).
* AR is the "affected region": { v s.t. ord(v) in [ord(producer), ord(consumer)] }
* consisting of the only vertices that can possibly be moved around due to this
* edge addition.
*
* NB: Pearce and Kelly give a complexity bound of <<delta>> where
* delta = union(deltaF, deltaB) and <<S>> on a set S is
* <<S>> = |S| + |edges out of vertices of S| + |edges into vertices of S|.
*/
template <typename T>
void DynamicDAG<T>::addEdge(Vertex<T>* producer, Vertex<T>* consumer) {
JIT_ASSERT(producer != consumer);
// NB: DynamicDAG is a simple graph. If an edge exists already, don't do anything.
bool is_distinct = producer->out_edges().insert(consumer);
if (!is_distinct) return;
is_distinct = consumer->in_edges().insert(producer);
JIT_ASSERT(is_distinct);
if (producer->ord <= consumer->ord) {
// topological ordering is already consistent, no need to update.
return;
}
visited_list<T> deltaF;
visited_list<T> deltaB;
// Search for vertices that are reachable from consumer that have a now incorrect
// topological ordering.
if (dfsSearch(DFSDirection::forward, consumer, producer,
/*upper_bound=*/producer->ord, deltaF)) {
// Path found! This means there's a cycle.
AT_ERROR("Cycle detected while trying to add edge.");
}
// Search for vertices that can reach producer that have a now incorrect
// topological ordering
JIT_ASSERT(!dfsSearch(DFSDirection::backward, producer, consumer,
/*lower_bound=*/consumer->ord, deltaB));
// Reorder the vertices that are reachable from consumer to occur BEFORE
// the vertices that can reach producer.
reorder(std::move(deltaF), std::move(deltaB));
}
// Define the affected region for contractEdge(producer, consumer) as
// { v s.t. ord(v) in [ord(producer), ord(consumer)] }.
// These are the only vertices that can possibly be moved around
// during edge contraction.
//
// contractEdge is O(|AR| log |AR| * min(|out_edges(producer)|, |in_edges(consumer)|))
template <typename T>
bool DynamicDAG<T>::contractEdge(Vertex<T>* producer, Vertex<T>* consumer) {
JIT_ASSERT(producer != consumer);
if (contractionProducesCycle(producer, consumer)) {
return false;
}
removeEdge(producer, consumer);
// Optimization: pick which order to merge depending on potential complexity.
if (producer->out_edges().size() > consumer->in_edges().size()) {
mergeConsumerIntoProducer(producer, consumer);
} else {
mergeProducerIntoConsumer(producer, consumer);
}
return true;
}
template <typename T>
void DynamicDAG<T>::mergeProducerIntoConsumer(Vertex<T>* producer, Vertex<T>* consumer) {
// Optimization: we want to concat lists [producer.data, consumer.data].
// Instead of inserting into the beginning of consumer.data, do a swap.
producer->data.insert(producer->data.end(), consumer->data.begin(), consumer->data.end());
std::swap(consumer->data, producer->data);
auto edges = removeVertex(producer);
// Each of these are constant b/c ord(consumer) > ord(producer) > ord(parent)
// so the edge addition still preserves the existing topological order.
for (auto* parent : edges.in_edges) {
addEdge(parent, consumer);
}
// NB: each addEdge call is linear in (ord(consumer) - ord(child)).
// This makes this function O(|out_edges(producer)| * |AR| log |AR|).
for (auto* child : edges.out_edges) {
addEdge(consumer, child);
}
}
template <typename T>
void DynamicDAG<T>::mergeConsumerIntoProducer(Vertex<T>* producer, Vertex<T>* consumer) {
producer->data.insert(producer->data.end(), consumer->data.begin(), consumer->data.end());
auto edges = removeVertex(consumer);
// Each of these are constant b/c ord(child) > ord(consumer) > ord(producer)
// so the edge addition still preserves the existing topological order.
for (auto* child : edges.out_edges) {
addEdge(producer, child);
}
// NB: each addEdge call is linear in (ord(producer) - ord(parent)).
// This makes this function O(|in_edges(consumer)| * |AR| log |AR|).
for (auto* parent : edges.in_edges) {
addEdge(parent, producer);
}
}
template <typename T>
bool DynamicDAG<T>::contractionProducesCycle(Vertex<T>* producer, Vertex<T>* consumer) {
visited_list<T> visited;
// If there are multiple paths from producer to consumer then contracting
// (merging) producer and consumer would create a cycle.
//
// Search for a path from producer to consumer while ignoring the
// producer -> consumer edge.
size_t upper_bound = consumer->ord;
for (auto* child : producer->out_edges()) {
if (child == consumer) continue;
if (child->visited()) continue; // already visited by dfs
if (dfsSearch(DFSDirection::forward, child, consumer, upper_bound, visited)) {
return true;
}
}
return false;
}
static bool is_within_bound(DFSDirection direction, size_t value, size_t bound) {
if (direction == DFSDirection::forward) {
return value < bound; // upper bound
} else {
return value > bound; // lower bound
}
}
// Searches for a path from start to end via a forward or backward dfs.
// Returns if a path exists from start to end.
// In addition, dfsSearch inserts visited vertices into the visited list.
template <typename T>
bool DynamicDAG<T>::dfsSearch(
DFSDirection direction,
Vertex<T>* start,
Vertex<T>* end,
size_t bound,
visited_list<T>& visited) {
vertex_list<T> stack;
auto visit = [&](Vertex<T>* v) {
visited.push_back(v);
stack.push_back(v);
};
visit(start);
while (!stack.empty()) {
auto* vertex = stack.back();
stack.pop_back();
auto& next_edges = (direction == DFSDirection::forward) ?
vertex->out_edges() :
vertex->in_edges();
for (Vertex<T>* next : next_edges) {
if (next == end) {
// Path found from start to end.
visit(next);
return true;
}
if (!next->visited() && is_within_bound(direction, next->ord, bound)) {
visit(next);
}
}
}
return false;
}
// Reorder deltaB vertices to occur before deltaF vertices.
template <typename T>
void DynamicDAG<T>::reorder(visited_list<T> deltaF, visited_list<T> deltaB) {
deltaB.sort();
deltaF.sort();
const auto& deltaB_ = deltaB.vector();
const auto& deltaF_ = deltaF.vector();
size_t num_affected = deltaB_.size() + deltaF_.size();
// Gather vertices in the desired order. They don't have correct ords yet.
std::vector<unique_vertex<T>> desired_vertex_ordering;
desired_vertex_ordering.reserve(num_affected);
for (auto it = deltaB_.begin(); it != deltaB_.end(); ++it) {
desired_vertex_ordering.push_back(std::move(vertices_.at((*it)->ord)));
}
for (auto it = deltaF_.begin(); it != deltaF_.end(); ++it) {
desired_vertex_ordering.push_back(std::move(vertices_.at((*it)->ord)));
}
// Sort the ords by merging two already sorted lists into a large sorted list.
// input (example): deltaB = { v(1), v(4), v(7) } , deltaF = { v(0), v(2), v(5) }.
// output: { 0, 1, 2, 4, 5, 7 }.
std::vector<size_t> gathered_ords;
gathered_ords.reserve(num_affected);
for (const auto* v : deltaB_) {
gathered_ords.push_back(v->ord);
}
auto middle = gathered_ords.size();
for (const auto* v : deltaF_) {
gathered_ords.push_back(v->ord);
}
std::inplace_merge(gathered_ords.begin(), gathered_ords.begin() + middle, gathered_ords.end());
// Return the vertices back into the vertices_ storage.
for (size_t i = 0; i < num_affected; ++i) {
desired_vertex_ordering[i]->ord = gathered_ords[i];
vertices_[gathered_ords[i]] = std::move(desired_vertex_ordering[i]);
}
}
template <typename T>
std::string DynamicDAG<T>::toString() {
std::stringstream ss;
for (auto& v : vertices_) {
if (v) {
ss << v->toString() << "\n";
}
}
return ss.str();
}
template <typename T>
std::string Vertex<T>::toString() {
std::stringstream ss;
ss << "node(" << ord << ")\n";
ss << "[";
for (auto* c : in_edges()) {
ss << c->ord << " ";
}
ss << "] -> {\n";
for (auto& d : data) {
if (std::is_pointer<T>::value) {
ss << " " << *d;
} else {
ss << " " << d;
}
}
ss << "} ("<< ord << ") -> [";
for (auto* c : out_edges()) {
ss << c->ord << " ";
}
ss << "]\n";
return ss.str();
}
}}}