| #pragma once |
| |
| #include <iostream> |
| #include <memory> |
| #include <vector> |
| #include <atomic> |
| #include <algorithm> |
| #include <unordered_set> |
| #include <functional> |
| #include <cstdint> |
| |
| #include <ATen/ATen.h> |
| |
| #include "torch/csrc/utils/object_ptr.h" |
| #include "torch/csrc/utils/auto_gpu.h" |
| #include "torch/csrc/utils/disallow_copy.h" |
| #include "torch/csrc/utils/python_stub.h" |
| |
| #include "ATen/ArrayRef.h" |
| #include "torch/csrc/jit/generic_if.h" |
| #include "torch/csrc/assertions.h" |
| #include "torch/csrc/jit/interned_strings.h" |
| #include "torch/csrc/jit/attributes.h" |
| #include "torch/csrc/jit/resource_guard.h" |
| #include "torch/csrc/jit/type.h" |
| #include "torch/csrc/jit/graph_node_list.h" |
| #include "torch/csrc/jit/variable_flags.h" |
| |
| namespace torch { namespace autograd { |
| |
| struct Function; |
| |
| }} // namespace torch::autograd |
| |
| namespace torch { namespace jit { |
| |
| // Graph represents one "function" of computation. |
| // It uses a simple ownership model where the graph owns all the nodes inside it. |
| // All references inside the graph are raw pointers. |
| // Destroying the Graph will invalidate any pointers to nodes in the graph. |
| struct Graph; |
| |
| // Node is the base class of the IR graph. It represents one computation |
| // and dependencies on a list of Values. The "prim-ops", so to speak. |
| struct Node; |
| |
| // A Value represents an input or output to node that is either a |
| // Tensor or an opaque Handle object, as determined by type(). |
| struct Value; |
| |
| // Each use is represented by this type, see Node::uses() |
| // 'user' is the consumer of the value, offset is the index into |
| // 'user's input this where the produces will be found. |
| struct Use { |
| Use(Node * user, size_t offset) |
| : user(user), offset(offset) {} |
| Node * user; |
| size_t offset; |
| }; |
| static inline bool operator==(const Use & a, const Use & b) { |
| return a.user == b.user && a.offset == b.offset; |
| } |
| |
| // SourceLocation represents source code-level debug information for a node. |
| // It contains a Python stack trace that represents the provenance of a given |
| // node in the trace. |
| struct SourceLocation { |
| SourceLocation(std::string python_traceback) |
| : python_traceback(std::move(python_traceback)) {} |
| std::string python_traceback; |
| }; |
| |
| // Scope is a node of a trie that represents the tree of nested scopes. |
| // Individual scopes are pushed and popped from Graph, which holds a |
| // pointer to the current scope. Each Node in Graph holds a pointer |
| // to the scope that was current when the node was created. |
| // The trie never needs to shrink, it only grows until it is disposed |
| // of when Graph is deallocated. Hence, pointers to scopes held by nodes |
| // will always be valid as long as Graph is alive. |
| struct Scope { |
| private: |
| Scope* parent_; |
| Symbol name_; |
| std::vector<std::unique_ptr<Scope> > children_; |
| public: |
| Scope() { |
| name_ = stringToSymbol(""); |
| parent_ = NULL; |
| } |
| Scope(Scope* parent, Symbol name) { |
| name_ = name; |
| parent_ = parent; |
| } |
| Scope* push(Symbol name) { |
| children_.push_back(std::unique_ptr<Scope>(new Scope(this, name))); |
| return children_.back().get(); |
| } |
| Scope* parent() { |
| if (parent_ == NULL) { |
| throw std::runtime_error("Cannot get parent from Scope with no parent"); |
| } |
| return parent_; |
| } |
| bool isRoot() { |
| return parent_ == NULL; |
| } |
| Scope* getRoot() { |
| Scope* current = this; |
| while (current->parent_) { |
| current = current->parent_; |
| } |
| return current; |
| } |
| Symbol name() { |
| return name_; |
| } |
| std::string namesFromRoot(const std::string& separator="/") { |
| std::string out = std::string(symbolToString(this->name_)); |
| if (this->isRoot()) { |
| return out; |
| } |
| Scope* parent = this->parent_; |
| while (!parent->isRoot()) { |
| out = std::string(symbolToString(parent->name_)) + separator + out; |
| parent = parent->parent_; |
| } |
| return out; |
| } |
| }; |
| |
| // the list types are intentionally simple, but we type-def |
| // them here so if we need to change them, refactoring will be easier |
| using node_list = std::vector<Node*>; |
| using value_list = std::vector<Value*>; |
| using use_list = std::vector<Use>; |
| using pyobj_list = std::vector<THPObjectPtr>; |
| template<typename T> |
| using ArrayRef = at::ArrayRef<T>; |
| using NodeKind = Symbol; |
| |
| struct Value { |
| TH_DISALLOW_COPY_AND_ASSIGN(Value); |
| Value(Node * node_, size_t offset_); |
| private: |
| friend struct Node; |
| friend struct Graph; |
| Node * node_; |
| size_t offset_; |
| size_t unique_ = 0; // unique id |
| size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,... |
| use_list uses_; |
| std::string unique_name_; |
| TypePtr type_; |
| public: |
| bool hasType() const { |
| return type_ != nullptr; |
| } |
| Value* setType(const TypePtr type) { |
| type_ = type; |
| return this; |
| } |
| void inferTypeFrom(const at::Tensor& output) { |
| setType(std::make_shared<TensorType>(output)); |
| } |
| const TypePtr & type() const { |
| JIT_ASSERT(type_ != nullptr); |
| return type_; |
| } |
| const TypePtr & typeOption() const { |
| return type_; |
| } |
| bool isHandle() const { |
| return hasType() && type()->kind() == TypeKind::HandleType; |
| } |
| size_t unique() const { |
| return unique_; |
| } |
| Value* setUniqueName(const std::string & name); |
| std::string uniqueName() const { |
| if (unique_name_ != "") |
| return unique_name_; |
| return std::to_string(unique()); |
| } |
| Value* setStage(size_t s) { |
| stage_ = s; |
| return this; |
| } |
| size_t stage() const { |
| return stage_; |
| } |
| Node* node() { |
| return node_; |
| } |
| size_t offset() const { |
| return offset_; |
| } |
| const Node * node() const { |
| return node_; |
| } |
| Graph * owningGraph(); |
| const Graph * owningGraph() const; |
| // TODO: make this more const correct |
| const use_list & uses() const { |
| return uses_; |
| } |
| |
| // Replaces all uses of this node with 'newValue'. |
| // |
| // Given: %3 = f(%1, %2) |
| // %4 = g(%3) |
| // %5 = h(%3, %3) |
| // Execute: %3.replaceAllUsesWith(%6) |
| // Result: %3 = f(%1, %2) |
| // %4 = g(%6) |
| // %5 = h(%6, %6) |
| void replaceAllUsesWith(Value * newValue); |
| |
| Value* copyMetadata(Value * from) { |
| if(from->hasType()) setType(from->type()); |
| if (from->unique_name_ != "") |
| setUniqueName(from->uniqueName()); |
| return this; |
| } |
| |
| }; |
| |
| struct Node : public Attributes<Node> { |
| TH_DISALLOW_COPY_AND_ASSIGN(Node); |
| friend struct Graph; |
| friend struct Value; |
| friend graph_node_list; |
| friend const_graph_node_list; |
| friend graph_node_list_iterator; |
| friend const_graph_node_list_iterator; |
| private: |
| // each node but Return/Param |
| // is associated with exactly one place in the node list... |
| // of the graph_ |
| // this circular is a doubly-linked list, the Return node is used as the sentinel for the beginning and end of the list |
| // such that the list never has null pointers |
| // next_in_graph[0] is next pointer |
| // next_in_graph[1] is prev pointer |
| // using an array to allow the same iterator class for forward and reverse node lists |
| // This list represents a topological sort |
| |
| Node* next_in_graph[2] = { nullptr, nullptr }; |
| Node* & next() { return next_in_graph[kNextDirection]; } |
| Node* & prev() { return next_in_graph[kPrevDirection]; } |
| Node* const & next() const { return next_in_graph[kNextDirection]; } |
| Node* const & prev() const { return next_in_graph[kPrevDirection]; } |
| |
| const NodeKind kind_; |
| std::vector<Value*> inputs_; |
| std::vector<Value*> outputs_; |
| Graph* graph_; |
| std::shared_ptr<SourceLocation> source_location_; |
| size_t stage_; |
| Scope* scope_; |
| protected: |
| Node(Graph * graph_, NodeKind kind_); //defined after graph |
| public: |
| NodeKind kind() const { |
| return kind_; |
| } |
| Node* setSourceLocation(std::shared_ptr<SourceLocation> sl) { |
| source_location_ = sl; |
| return this; |
| } |
| std::shared_ptr<SourceLocation> getSourceLocation() const { |
| return source_location_; |
| } |
| Graph * owningGraph() { |
| return graph_; |
| } |
| const Graph * owningGraph() const { |
| return graph_; |
| } |
| size_t stage() const { |
| return stage_; |
| } |
| Node* setStage(size_t s) { |
| stage_ = s; |
| return this; |
| } |
| Scope* scope() { |
| return scope_; |
| } |
| void setScope(Scope* scope) { |
| scope_ = scope; |
| } |
| std::string scopeName() const { |
| if (scope_ == NULL) { |
| return ""; |
| } |
| return scope_->namesFromRoot(); |
| } |
| // NB: This returns an ArrayRef; that means that it will |
| // get invalidated if you resize inputs (e.g., using addInput) |
| // We can't return a std::vector<Node*>& because there's no |
| // way to soundly cast to std::vector<const Node*> (an insane |
| // implementation of std::vector could make this representationally |
| // different.) |
| at::ArrayRef<Value*> inputs() { |
| return inputs_; |
| } |
| at::ArrayRef<const Value*> inputs() const { |
| // Vectors are not convertible in const-ness of elements, but |
| // raw pointers are. |
| return {inputs_.data(), inputs_.size()}; |
| } |
| // NB: This returns an ArrayRef; that means that it will |
| // get invalidated if you resize inputs (e.g., using addInput) |
| // We can't return a std::vector<Node*>& because there's no |
| // way to soundly cast to std::vector<const Node*> (an insane |
| // implementation of std::vector could make this representationally |
| // different.) |
| at::ArrayRef<Value*> outputs() { |
| return outputs_; |
| } |
| at::ArrayRef<const Value*> outputs() const { |
| // Vectors are not convertible in const-ness of elements, but |
| // raw pointers are. |
| return {outputs_.data(), outputs_.size()}; |
| } |
| bool hasUses() const { |
| for(auto o : outputs()) { |
| if(o->uses().size() > 0) |
| return true; |
| } |
| return false; |
| } |
| void replaceAllUsesWith(Node * n) { |
| JIT_ASSERT(outputs().size() == n->outputs().size()); |
| size_t nOutputs = outputs().size(); |
| for(size_t i = 0; i < nOutputs; i++) { |
| outputs()[i]->replaceAllUsesWith(n->outputs()[i]); |
| } |
| } |
| // lots of things like chunk have a single input or singel output, so we have a |
| // helper to make accessing it easier |
| Value * input() { |
| JIT_ASSERT(inputs_.size() == 1); |
| return inputs_.at(0); |
| } |
| Value * output() { |
| JIT_ASSERT(outputs_.size() == 1); |
| return outputs_.at(0); |
| } |
| const Value * input() const { |
| JIT_ASSERT(inputs_.size() == 1); |
| return inputs_.at(0); |
| } |
| // Access a particular input. This is a checked index. |
| Value * input(size_t i) { |
| return inputs_.at(i); |
| } |
| const Value * input(size_t i) const { |
| return inputs_.at(i); |
| } |
| |
| // Graphs |
| |
| // Note [Topological invariant] |
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // We always maintain an up-to-date topological ordering of all nodes via |
| // the next()/prev() links. All transformations to graphs must preserve |
| // this topological ordering: for example, it is only valid to 'addInput' |
| // with an input which is topologically before the current node. |
| // |
| // Usually, it is obvious whether or not topological order is maintained; |
| // for example, if you are adding nodes to the end of the topsort, it's |
| // impossible for them to refer to inputs that are not in the topsort. |
| // If it is not obvious, please comment accordingly. |
| |
| // Add 'node' as an input to 'this' at the end of existing |
| // arguments. Returns the added node for ease of chaining. |
| // |
| // Given: %3 = f(%1, %2) |
| // Execute: %3.addInput(%4) |
| // Result: %3 = f(%1, %2, %4) |
| Value* addInput(Value * node) { |
| JIT_ASSERT(graph_ == node->owningGraph()); |
| node->uses_.emplace_back(this, inputs_.size()); |
| inputs_.push_back(node); |
| return node; |
| } |
| |
| // Replace the input of 'this' at position 'i' with |
| // 'newValue', returning the old node. |
| // |
| // Given: %3 = f(%1, %2) |
| // Execute: %3.replaceInput(1, %4) |
| // Result: %3 = f(%1, %4) |
| Value * replaceInput(size_t i, Value * newValue) { |
| JIT_ASSERT(newValue->owningGraph() == graph_); |
| Value * old = dropInput(i); |
| inputs_[i] = newValue; |
| newValue->uses_.emplace_back(this, i); |
| return old; |
| } |
| |
| // Replace all occurrences of 'from' in the inputs of this |
| // node with 'to'. Corresponds to llvm's replaceUsesOfWith. |
| // |
| // Given: %3 = f(%1, %2, %1) |
| // Execute: %3.replaceInputWith(%1, %4) |
| // Result: %3 = f(%4, %2, %4) |
| void replaceInputWith(Value * from, Value * to) { |
| JIT_ASSERT(from->owningGraph() == graph_); |
| JIT_ASSERT(to->owningGraph() == graph_); |
| size_t i = 0; |
| for(auto input : inputs()) { |
| if(input == from) |
| replaceInput(i, to); |
| i++; |
| } |
| } |
| |
| Value* addOutput() { |
| outputs_.push_back(new Value(this, outputs_.size())); |
| return outputs_.back(); |
| } |
| |
| void eraseOutput(size_t i); |
| |
| // Insert unattached 'this' node after 'n' in the topological order. |
| // Returns this (for chaining). |
| // |
| // Given: %3 = f(%1, %2) |
| // %4 = g(%3) |
| // and unattached: %5 = h(%1) |
| // Execute: %5.insertBefore(%4) |
| // Result: %3 = f(%1, %2) |
| // %5 = h(%1) |
| // %4 = g(%3) |
| Node* insertBefore(Node * n) { |
| JIT_ASSERT(n->inGraphList()); |
| insertAfter(n->prev()); |
| return this; |
| } |
| |
| // Insert unattached 'this' node after 'n' in the topological order. |
| // Returns this (for chaining). |
| // |
| // Given: %3 = f(%1, %2) |
| // %4 = g(%3) |
| // and unattached: %5 = h(%1) |
| // Execute: %5.insertAfter(%4) |
| // Result: %3 = f(%1, %2) |
| // %4 = g(%3) |
| // %5 = h(%1) |
| Node* insertAfter(Node * n) { |
| JIT_ASSERT(!inGraphList() && n->inGraphList()); |
| Node * next = n->next(); |
| n->next() = this; |
| this->prev() = n; |
| this->next() = next; |
| next->prev() = this; |
| return this; |
| } |
| |
| // Move 'this' (already in the graph) after 'n' in the topological order. |
| // |
| // Given: %2 = f(%1) |
| // %3 = g(%1) |
| // Execute: %2.moveAfter(%3) |
| // Result: %3 = g(%1) |
| // %2 = f(%1) |
| // |
| void moveAfter(Node * n) { |
| removeFromList(); |
| insertAfter(n); |
| } |
| |
| // Move a node 'n' (already in the graph) before 'this' in the topological order. |
| // |
| // Given: %2 = f(%1) |
| // %3 = g(%1) |
| // Execute: %3.moveBefore(%2) |
| // Result: %3 = g(%1) |
| // %2 = f(%1) |
| void moveBefore(Node * n) { |
| removeFromList(); |
| insertBefore(n); |
| } |
| |
| // Remove the input at 'i' from this node. |
| // |
| // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling |
| // removeInput. |
| // |
| // Given: %3 = f(%1, %2) |
| // Execute: %3.removeInput(1) |
| // Result: %3 = f(%1) |
| void removeInput(size_t i) { |
| dropInput(i); |
| // everything after this input shifts left, |
| // so we need to update their use offsets to match |
| for(size_t j = i+1; j < inputs_.size(); j++) { |
| auto it = findUseForInput(j); |
| it->offset--; |
| } |
| inputs_.erase(inputs_.begin() + i); |
| } |
| |
| // Remove all inputs from a node. |
| // |
| // Given: %3 = f(%1, %2) |
| // Execute: %3.removeAllInputs() |
| // Result: %3 = f() |
| void removeAllInputs() { |
| for(size_t i = 0; i < inputs().size(); ++i) |
| dropInput(i); |
| inputs_.clear(); |
| } |
| |
| // iterators of the node list starting at this node |
| // useful for resuming a search starting at this node |
| graph_node_list_iterator iterator(); |
| graph_node_list_iterator reverseIterator(); |
| const_graph_node_list_iterator iterator() const; |
| const_graph_node_list_iterator reverseIterator() const; |
| |
| // Remove 'this' from the instruction list and deallocate it. |
| // |
| // Invariant: no outputs of 'this' may have any uses. |
| // |
| // Given: %2 = f(%1) |
| // %3 = g(%1) |
| // Execute: %2.destroy() |
| // Result: %3 = g(%1) |
| void destroy(); |
| |
| // Dynamically cast this node to the subclass indicated by the |
| // template variable, returning nullptr if the cast is invalid.. |
| // |
| // Example usage: if(auto s = n.cast<Select>()) { ... } |
| // |
| // TODO: Make this const correct |
| template<typename T> |
| T* cast() { |
| if(T::Kind == kind()) |
| return static_cast<T*>(this); |
| return nullptr; |
| } |
| template<typename T> |
| T* expect() { |
| JIT_ASSERTM(T::Kind == kind(), "expected a %s but found a %s", symbolToString(T::Kind), symbolToString(kind())); |
| return static_cast<T*>(this); |
| } |
| |
| virtual ~Node() {} |
| private: |
| // Lookup iterator in use list of _input i_ that corresponds to its use of _this_ |
| use_list::iterator findUseForInput(size_t i) { |
| auto & input_uses = inputs_[i]->uses_; |
| // O(N) on the use list, but unless we get nodes with +100 uses |
| // vector traversal still is probably faster than linked list |
| auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i)); |
| JIT_ASSERT(use_it != input_uses.end()); |
| return use_it; |
| } |
| |
| // remove the use of input i, this sets input i to nullptr, but |
| // is only used internally to Node before setting it to a new value |
| // or erasing the entry from the list. |
| Value* dropInput(size_t i) { |
| JIT_ASSERT(i < inputs_.size()); |
| auto input_node = inputs_[i]; |
| auto use_it = findUseForInput(i); |
| input_node->uses_.erase(use_it); |
| inputs_[i] = nullptr; |
| return input_node; |
| } |
| |
| bool inGraphList() const { |
| JIT_ASSERT(next() != nullptr || prev() == nullptr); |
| return next() != nullptr; |
| } |
| void removeFromList() { |
| JIT_ASSERT(inGraphList()); |
| Node * next = this->next(); |
| Node * prev = this->prev(); |
| prev->next() = next; |
| next->prev() = prev; |
| this->next() = nullptr; |
| this->prev() = nullptr; |
| } |
| void lint() const; |
| protected: |
| // subclasses must override |
| // this function is used by createClone to initialize a new version |
| // of a node in another graph. It should allocate a new instance of the same |
| // concrete type as 'this', but in graph 'g' which might be different |
| // than graph_ |
| virtual Node * allocNewInstance(Graph * g) { |
| return new Node(g, kind()); |
| } |
| // create a copy of all properties of Node s into this. |
| // subclasses should extend if they have additional information to copy. |
| // 'this' will be allocated with s->allocNewInstance(g) so it should have |
| // the same concrete type as 's' |
| // |
| // NB: This does NOT clone stages. You're expected to set the stage correctly |
| // if you are going to preserve it. |
| virtual void cloneFrom(Node * s) { |
| setSourceLocation(s->getSourceLocation()); |
| scope_ = s->scope_; |
| copyAttributes(*s); |
| } |
| }; |
| |
| struct Graph { |
| TH_DISALLOW_COPY_AND_ASSIGN(Graph); |
| friend struct Node; |
| friend struct Value; |
| private: |
| |
| // only used to keep track of allocated nodes |
| // actual representation of Graph is done with |
| // inputs, outputs, nodes |
| |
| std::unordered_set<const Node*> all_nodes; |
| std::unordered_set<const Value*> all_values; |
| size_t next_unique_; |
| |
| std::unordered_set<std::string> unique_names_; |
| |
| size_t new_node_stage_; |
| |
| std::shared_ptr<Scope> scope_root_; |
| Scope * current_scope_; |
| |
| // holds outputs in a way that can be reflected |
| // as a Use object |
| // also used as the beginning/end of the circular node list to avoid |
| // having corner cases where the list is empty. |
| Node * const output_; |
| Node * const input_; |
| |
| public: |
| |
| Graph(std::shared_ptr<Scope> scope_root) |
| : next_unique_(0) |
| , new_node_stage_(0) |
| , scope_root_(scope_root) |
| , current_scope_(scope_root_.get()) |
| , output_(initOutput(create(kReturn, 0))), input_(create(kParam, 0)) {} |
| |
| Graph() |
| : Graph( std::make_shared<Scope>()) {} |
| |
| at::ArrayRef<Value*> inputs() { |
| return input_->outputs(); |
| } |
| at::ArrayRef<const Value*> inputs() const { |
| const auto & inputs = input_->outputs(); |
| return {inputs.data(), inputs.size()}; |
| } |
| at::ArrayRef<Value*> outputs() { |
| return output_->inputs(); |
| } |
| at::ArrayRef<const Value*> outputs() const { |
| return static_cast<const Node*>(output_)->inputs(); |
| } |
| graph_node_list nodes() { |
| return graph_node_list(output_, kNextDirection); |
| } |
| const_graph_node_list nodes() const { |
| return const_graph_node_list(output_, kNextDirection); |
| } |
| // These invocations of begin() on output of function are OK |
| // because graph_node_list is non-owning, so it doesn't matter |
| // if it immediately dies after the invocation. |
| graph_node_list_iterator begin() { |
| return nodes().begin(); |
| } |
| const_graph_node_list_iterator begin() const { |
| return nodes().begin(); |
| } |
| graph_node_list_iterator end() { |
| return nodes().end(); |
| } |
| const_graph_node_list_iterator end() const { |
| return nodes().end(); |
| } |
| graph_node_list_iterator rbegin() { |
| return nodes().rbegin(); |
| } |
| const_graph_node_list_iterator rbegin() const { |
| return nodes().rbegin(); |
| } |
| graph_node_list_iterator rend() { |
| return nodes().rend(); |
| } |
| const_graph_node_list_iterator rend() const { |
| return nodes().rend(); |
| } |
| Node * return_node() { |
| return output_; |
| } |
| const Node * return_node() const { |
| return output_; |
| } |
| void push_scope(const std::string& scope_name) { |
| current_scope_ = current_scope_->push(stringToSymbol(scope_name)); |
| } |
| void pop_scope() { |
| current_scope_ = current_scope_->parent(); |
| } |
| Scope * current_scope() { |
| return current_scope_; |
| } |
| void set_current_scope(Scope* scope) { |
| if (scope->getRoot() != scope_root_.get()) { |
| throw std::runtime_error("trying to set a scope as current that does not belong to the Graph's scope trie"); |
| } |
| current_scope_ = scope; |
| } |
| ResourceGuard set_current_scope_temporary(Scope* scope) { |
| auto prev_scope = current_scope_; |
| this->set_current_scope(scope); |
| return ResourceGuard([prev_scope, this]() { this->current_scope_ = prev_scope; }); |
| } |
| std::shared_ptr<Scope> scope_root() { |
| return scope_root_; |
| } |
| Value * addInput(std::string name="") { |
| Value * v = input_->addOutput(); |
| if (name != "") v->setUniqueName(name); |
| return v; |
| } |
| void eraseInput(size_t i) { |
| input_->eraseOutput(i); |
| } |
| void advanceStage() { |
| new_node_stage_++; |
| } |
| void setStage(size_t new_stage) { |
| new_node_stage_ = new_stage; |
| } |
| size_t stage() const { |
| return new_node_stage_; |
| } |
| ResourceGuard setStageTemporary(size_t s) { |
| auto prev_stage = new_node_stage_; |
| new_node_stage_ = s; |
| return ResourceGuard([prev_stage, this]() { this->new_node_stage_ = prev_stage; }); |
| } |
| |
| size_t registerOutput(Value * n) { |
| output_->addInput(n); |
| return outputs().size() - 1; |
| } |
| |
| Node * create(NodeKind kind, size_t num_outputs=1) { |
| // NB: Node constructor adds node to all_nodes |
| auto n = new Node(this, kind); |
| for(size_t i = 0; i < num_outputs; i++) |
| n->addOutput(); |
| return n; |
| } |
| |
| Node * create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs=1) { |
| auto n = create(kind, num_outputs); |
| for(auto i : inputs) |
| n->addInput(i); |
| return n; |
| } |
| |
| Node * createUndefined() { |
| return create(kUndefined); |
| } |
| Node * createConstant(const at::Tensor& ref) { |
| JIT_ASSERT(ref.defined()); |
| AutoGPU guard(ref.type().is_cuda() ? ref.get_device() : -1); |
| auto n = create(kConstant); |
| n->t_(kvalue, ref.clone()); |
| return n; |
| } |
| Node * createFusionGroup(bool is_cuda) { |
| auto n = create(kFusionGroup, 0); |
| n->g_(kSubgraph,std::make_shared<Graph>(scope_root_)); |
| n->i_(kis_cuda, is_cuda); |
| return n; |
| } |
| Node * createPythonOp(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, std::vector<VariableFlags> && var_flags, pyobj_list&& scalar_args); |
| Node * createCppOp(const std::shared_ptr<torch::autograd::Function> & fn, std::vector<VariableFlags> && var_flags); |
| // clone n, making a new node in _this_ graph. |
| // use node_map to translate inputs of n to inputs of the cloned node |
| Node * createClone(Node * n, std::function<Value*(Value*)> value_map) { |
| //n can be from a different graph |
| Node * r = n->allocNewInstance(this); |
| for(auto o : n->outputs()) { |
| r->addOutput()->copyMetadata(o); |
| } |
| r->cloneFrom(n); |
| for(auto i : n->inputs()) { |
| r->addInput(value_map(i)); |
| } |
| return r; |
| } |
| |
| Node * appendNode(Node * n) { |
| JIT_ASSERT(n->graph_ == this && !n->inGraphList()); |
| n->insertBefore(output_); |
| return n; |
| } |
| |
| Node * prependNode(Node * n) { |
| JIT_ASSERT(n->graph_ == this && !n->inGraphList()); |
| n->insertAfter(output_); |
| return n; |
| } |
| |
| // Checks well-formedness and invariants of graph |
| void lint() const; |
| // for use in debugger |
| void dump() const; |
| |
| ~Graph() { |
| for (const Node * n : all_nodes) |
| delete n; |
| for (const Value * v : all_values) |
| delete v; |
| } |
| |
| std::string toString() const { |
| std::ostringstream oss; |
| oss << *this; |
| return oss.str(); |
| } |
| |
| friend std::ostream& operator<<(std::ostream & out, const Graph & g); |
| |
| private: |
| |
| // should only be called in the constructor |
| Node* initOutput(Node* p) { |
| p->next() = p; |
| p->prev() = p; |
| p->setStage(std::numeric_limits<size_t>::max()); |
| return p; |
| } |
| |
| void freeNode(Node * n) { |
| auto it = all_nodes.find(n); |
| JIT_ASSERT(it != all_nodes.end()); |
| delete *it; |
| all_nodes.erase(it); |
| } |
| void freeValue(Value * v) { |
| auto it = all_values.find(v); |
| JIT_ASSERT(it != all_values.end()); |
| all_values.erase(it); |
| } |
| }; |
| |
| inline Value::Value(Node * node_, size_t offset_) |
| : node_(node_), |
| offset_(offset_), |
| unique_(node_->graph_->next_unique_++), |
| stage_(node_->graph_->new_node_stage_) { |
| node_->graph_->all_values.emplace(this); |
| } |
| |
| inline Graph * Value::owningGraph() { |
| return node()->owningGraph(); |
| } |
| |
| inline const Graph * Value::owningGraph() const { |
| return node()->owningGraph(); |
| } |
| |
| inline void Value::replaceAllUsesWith(Value * newValue) { |
| JIT_ASSERT(owningGraph() == newValue->owningGraph()); |
| for(auto u : uses()) { |
| u.user->inputs_[u.offset] = newValue; |
| newValue->uses_.push_back(u); |
| } |
| uses_.clear(); |
| } |
| |
| inline Node::Node(Graph * graph_, NodeKind kind_) : |
| kind_(kind_), |
| graph_(graph_), |
| stage_(graph_->new_node_stage_), |
| scope_(graph_->current_scope_) { |
| graph_->all_nodes.emplace(this); |
| } |
| |
| inline void Node::eraseOutput(size_t i) { |
| JIT_ASSERT(i < outputs_.size()); |
| JIT_ASSERT(outputs_[i]->uses().size() == 0); |
| Value * n = outputs_[i]; |
| outputs_.erase(outputs_.begin() + i); |
| owningGraph()->freeValue(n); |
| for(size_t j = i; j < outputs_.size(); j++) { |
| outputs_[j]->offset_--; |
| } |
| } |
| |
| inline void Node::destroy() { |
| JIT_ASSERT(inGraphList()); |
| while(outputs().size() > 0) |
| eraseOutput(outputs().size() - 1); |
| removeAllInputs(); |
| removeFromList(); |
| graph_->freeNode(this); |
| } |
| |
| inline Value* Value::setUniqueName(const std::string & name) { |
| if (name.find_first_not_of("0123456789") == std::string::npos) { |
| throw std::runtime_error("names may not be integers: " + name); |
| } |
| if (node_->graph_->unique_names_.find(name) != node_->graph_->unique_names_.end()) { |
| throw std::runtime_error("name is already in use in this graph: " + name); |
| } |
| node_->graph_->unique_names_.insert(name); |
| unique_name_ = name; |
| return this; |
| } |
| |
| // Helper macros for constructing switch statements over Node types |
| // instead of heavy-weight visitors |
| // read 'between' these defines to see how they turn into a big switch |
| // statement |
| |
| // Mutable case |
| // The IFM/ELSEIFM indicate that subclass *refinement* occurs. |
| // This is only valid for node types for which we have subclasses. |
| #define IR_IFM(x,Kind) GENERIC_IF(,k##Kind,x,Kind) |
| #define IR_ELSEIFM(Kind) GENERIC_ELSEIF(,k##Kind,Kind) |
| |
| #define IR_IFM_CONST(x,Kind) GENERIC_IF(const,k##Kind,x,Kind) |
| #define IR_ELSEIFM_CONST(Kind) GENERIC_ELSEIF(const,k##Kind,Kind) |
| |
| #define IR_IF(x, Kind) \ |
| auto && __match_key = x; \ |
| switch(__match_key->kind()) { \ |
| case ::torch::jit::k##Kind: { \ |
| auto * value = __match_key; (void) value; |
| #define IR_ELSEIF(Kind) \ |
| } break; \ |
| case ::torch::jit::k##Kind: { \ |
| auto * value = __match_key; (void) value; |
| |
| #define IR_ELSE() GENERIC_ELSE() |
| #define IR_END() GENERIC_END() |
| |
| /* example: |
| Node * n = ...; |
| IR_IF(n,Select) |
| cout << "Select of" << value->input() << "\n"; |
| IR_ELSEIF(PythonOp) |
| cout << value->pyobj << "\n"; |
| IR_ELSEIF(Add) |
| cout << "Add" << \n"; |
| IR_ELSE() // optional |
| cout << "something else\n"; |
| IR_END() |
| */ |
| |
| std::ostream& operator<<(std::ostream & out, const Graph & g); |
| std::ostream& operator<<(std::ostream & out, const Type & t); |
| std::ostream& operator<<(std::ostream & out, const Node & t); |
| |
| /************* All nodes not required to be defined before Graph **************/ |
| |
| // execute a Python function, used for Ops we can't optimize but that we want to optimize around |
| struct PythonOp : public Node { |
| static const NodeKind Kind = kPythonOp; |
| PythonOp(Graph * graph) |
| : Node(graph,kPythonOp) {} |
| PythonOp* init(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, std::vector<VariableFlags> && var_flags, pyobj_list&& scalar_args) { |
| this->pyobj = std::move(pyobj); |
| this->scalar_args = std::move(scalar_args); |
| this->cconv = cconv; |
| this->var_flags = std::move(var_flags); |
| this->is_legacy = is_legacy; |
| return this; |
| } |
| virtual Node * allocNewInstance(Graph * g) override { |
| return new PythonOp(g); |
| } |
| //TODO: make this non-autograd specific |
| //remove is_legacy, avoid THPObjectPtr to avoid big PyTorch dependency |
| |
| // The Python object which contains the implementation of this function. |
| // This is either a class (non-legacy) or an object (legacy). See |
| // TraceInterpreterState for execution semantics. |
| THPObjectPtr pyobj; |
| // The calling convention for the Python function. |
| // 's' -- python scalar argument |
| // 't' -- tensor argument |
| std::string cconv; |
| bool is_legacy; |
| // Scalar arguments to the Python function. Not necessarily passed to |
| // the function in this order; see cconv for the correct order. |
| std::vector<THPObjectPtr> scalar_args; |
| std::vector<VariableFlags> var_flags; |
| std::string name() const; |
| virtual void cloneFrom(Node * other_) override; |
| }; |
| inline Node * Graph::createPythonOp(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, std::vector<VariableFlags> && var_flags, pyobj_list&& scalar_args) { |
| auto op = new PythonOp(this); |
| return op->init(std::move(pyobj),cconv,is_legacy,std::move(var_flags), std::move(scalar_args)); |
| } |
| |
| // A Cpp operator is an operator which dispatches directly to an autograd function. |
| // TODO: These are not executable without reentrant engine. |
| struct CppOp : public Node { |
| static const NodeKind Kind = kCppOp; |
| CppOp(Graph * g) |
| : Node(g,kCppOp) {} |
| std::shared_ptr<torch::autograd::Function> fn; |
| std::vector<VariableFlags> var_flags; |
| std::string name() const; |
| CppOp* init(std::shared_ptr<torch::autograd::Function> fn, std::vector<VariableFlags> && var_flags) { |
| JIT_ASSERT(fn); |
| this->fn = std::move(fn); |
| this->var_flags = std::move(var_flags); |
| return this; |
| } |
| virtual Node * allocNewInstance(Graph * g) override { |
| return new CppOp(g); |
| } |
| virtual void cloneFrom(Node * other_) override { |
| Node::cloneFrom(other_); |
| auto other = other_->cast<CppOp>(); |
| this->fn = other->fn; |
| this->var_flags = other->var_flags; |
| } |
| }; |
| inline Node * Graph::createCppOp(const std::shared_ptr<torch::autograd::Function> & fn, std::vector<VariableFlags> && var_flags) { |
| auto op = new CppOp(this); |
| return op->init(fn, std::move(var_flags)); |
| } |
| |
| inline graph_node_list_iterator Node::iterator() { |
| return graph_node_list_iterator(this, 0); |
| } |
| inline graph_node_list_iterator Node::reverseIterator() { |
| return iterator().reverse(); |
| } |
| inline const_graph_node_list_iterator Node::iterator() const { |
| return const_graph_node_list_iterator(this, 0); |
| } |
| inline const_graph_node_list_iterator Node::reverseIterator() const { |
| return iterator().reverse(); |
| } |
| |
| void LintGraph(std::shared_ptr<Graph>& graph); |
| |
| }} // namespace torch::jit |