Document all public graph manipulation functions
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index ee270ad..187e320 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -91,6 +91,7 @@
// next_in_graph[0] is next pointer
// next_in_graph[1] is prev pointer
// using an array to allow the same iterator class for forward and reverse node lists
+ // This list represents a topological sort
Node * next_in_graph[2];
Node* & next() { return next_in_graph[0]; }
Node* & prev() { return next_in_graph[1]; }
@@ -131,11 +132,21 @@
// Graphs
+ // Note [Topological invariant]
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // We always maintain an up-to-date topological ordering of all nodes via
+ // the next()/prev() links. All transformations to graphs must preserve
+ // this topological ordering: for example, it is only valid to 'addInput'
+ // with an input which is topologically before the current node.
+ //
+ // Usually, it is obvious whether or not topological order is maintained;
+ // for example, if you are adding nodes to the end of the topsort, it's
+ // impossible for them to refer to inputs that are not in the topsort.
+ // If it is not obvious, please comment accordingly.
+
// Add 'node' as an input to 'this' at the end of existing
// arguments. Returns the added node for ease of chaining.
//
- // Precondition: 'node' must be topologically before 'this'.
- //
// Given: %3 = f(%1, %2)
// Execute: %3.addInput(%4)
// Result: %3 = f(%1, %2, %4)
@@ -149,8 +160,6 @@
// Replace the input of 'this' at position 'i' with
// 'newValue', returning the old node.
//
- // Precondition: 'newValue' must be topologically before 'this'.
- //
// Given: %3 = f(%1, %2)
// Execute: %3.replaceInput(1, %4)
// Result: %3 = f(%1, %4)
@@ -165,8 +174,6 @@
// Replace all occurrences of 'from' in the inputs of this
// node with 'to'. Corresponds to llvm's replaceUsesOfWith.
//
- // Precondition: 'to' must be topologically before 'this'.
- //
// Given: %3 = f(%1, %2, %1)
// Execute: %3.replaceInputWith(%1, %4)
// Result: %3 = f(%4, %2, %4)
@@ -187,10 +194,6 @@
// Replaces all uses of this node with 'newValue'.
//
- // Precondition: 'newValue' must be topologically before all uses
- // of 'this'. A sound approximation is that 'newVAlue' is topologically
- // before 'this'.
- //
// Given: %3 = f(%1, %2)
// %4 = g(%3)
// %5 = h(%3, %3)
@@ -199,7 +202,7 @@
// %4 = g(%6)
// %5 = h(%6, %6)
void replaceAllUsesWith(Node * newValue) {
- assert(graph_ == newValue->graph_);
+ JIT_ASSERT(graph_ == newValue->graph_);
for(auto u : uses()) {
u.user->inputs_[u.offset] = newValue;
newValue->uses_.push_back(u);
@@ -207,15 +210,12 @@
uses_.clear();
}
- // Insert node 'n' before this one in the topological order.
- //
- // Precondition: All inputs of 'n' must be topologically before
- // 'this'.
+ // Insert unattached 'this' node after 'n' in the topological order.
//
// Given: %3 = f(%1, %2)
// %4 = g(%3)
// and unattached: %5 = h(%1)
- // Execute: %4.insertBefore(%5)
+ // Execute: %5.insertBefore(%4)
// Result: %3 = f(%1, %2)
// %5 = h(%1)
// %4 = g(%3)
@@ -223,6 +223,16 @@
JIT_ASSERT(n->inGraphList());
insertAfter(n->prev());
}
+
+ // Insert unattached 'this' node after 'n' in the topological order.
+ //
+ // Given: %3 = f(%1, %2)
+ // %4 = g(%3)
+ // and unattached: %5 = h(%1)
+ // Execute: %5.insertAfter(%4)
+ // Result: %3 = f(%1, %2)
+ // %4 = g(%3)
+ // %5 = h(%1)
void insertAfter(Node * n) {
JIT_ASSERT(!inGraphList() && n->inGraphList());
Node * next = n->next();
@@ -231,16 +241,42 @@
this->next() = next;
next->prev() = this;
}
+
+ // Move 'this' (already in the graph) after 'n' in the topological order.
+ //
+ // Given: %2 = f(%1)
+ // %3 = g(%1)
+ // Execute: %2.moveAfter(%3)
+ // Result: %3 = g(%1)
+ // %2 = f(%1)
+ //
void moveAfter(Node * n) {
JIT_ASSERT(inGraphList());
removeFromList();
insertAfter(n);
}
+
+ // Move a node 'n' (already in the graph) before 'this' in the topological order.
+ //
+ // Given: %2 = f(%1)
+ // %3 = g(%1)
+ // Execute: %3.moveBefore(%2)
+ // Result: %3 = g(%1)
+ // %2 = f(%1)
void moveBefore(Node * n) {
JIT_ASSERT(inGraphList());
removeFromList();
insertBefore(n);
}
+
+ // Remove the input at 'i' from this node.
+ //
+ // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
+ // removeInput.
+ //
+ // Given: %3 = f(%1, %2)
+ // Execute: %3.removeInput(1)
+ // Result: %3 = f(%1)
void removeInput(size_t i) {
dropInput(i);
// everything after this input shifts left,
@@ -251,23 +287,44 @@
}
inputs_.erase(inputs_.begin() + i);
}
+
+ // Remove all inputs from a node.
+ //
+ // Given: %3 = f(%1, %2)
+ // Execute: %3.removeAllInputs()
+ // Result: %3 = f()
void removeAllInputs() {
for(size_t i = 0; i < inputs().size(); ++i)
dropInput(i);
inputs_.clear();
}
+
// iterators of the node list starting at this node
// useful for resuming a search starting at this node
graph_node_list_iterator iterator();
graph_node_list_iterator reverseIterator();
+
+ // Remove 'this' from the instruction list and deallocate it.
+ //
+ // Invariant: 'this' must not have any uses.
+ //
+ // Given: %2 = f(%1)
+ // %3 = g(%1)
+ // Execute: %2.eraseFromParent()
+ // Result: %3 = g(%1)
void eraseFromParent();
- // dynamic cast: if(auto s = n.cast<Select>()) { ... }
+
+ // Dynamically cast this node to the subclass indicated by the
+ // template variable, returning nullptr if the cast is invalid..
+ //
+ // Example usage: if(auto s = n.cast<Select>()) { ... }
template<typename T>
T* cast() {
if(T::Kind == kind())
return static_cast<T*>(this);
return nullptr;
}
+
virtual ~Node() {}
//initialize this Node by copying properties of 'other'
//translation of inputs is handled automatically in Graph::clone.