| #pragma once |
| |
| // TODO: Remove Python dependency with layer of indirection |
| |
| #include <iostream> |
| |
| #include <Python.h> |
| #include <memory> |
| #include <vector> |
| #include <atomic> |
| #include <algorithm> |
| #include <unordered_set> |
| #include <list> |
| #include <cstdint> |
| |
| #include <ATen/ATen.h> |
| |
| #include "torch/csrc/utils/object_ptr.h" |
| |
| #include "torch/csrc/jit/DisallowCopy.h" |
| #include "ATen/ArrayRef.h" |
| #include "torch/csrc/jit/assert.h" |
| |
| namespace torch { namespace jit { |
| |
| // Graph represents one "function" of computation. |
| // It uses a simple ownership model where the graph owns all the nodes inside it. |
| // All references inside the graph are raw pointers. |
| // Destroying the Graph will invalidate any pointers to nodes in the graph. |
| struct Graph; |
| |
| // Node is the base class of the IR graph. It represents one computation |
| // and dependencies on a list of values. The "prim-ops", so to speak. |
| struct Node; |
| |
| #define TH_FORALL_TYPES(_) \ |
| _(Multi) \ |
| _(Single) |
| |
| enum class TypeKind { |
| #define DEFINE_TYPE(T) T, |
| TH_FORALL_TYPES(DEFINE_TYPE) |
| #undef DEFINE_TYPE |
| }; |
| |
| struct Type { |
| private: |
| TypeKind kind_; |
| |
| protected: |
| Type(TypeKind kind) |
| : kind_(kind) {} |
| |
| public: |
| TypeKind kind() const { |
| return kind_; |
| } |
| |
| // Dynamically cast this object to the subclass indicated by the |
| // template variable, returning nullptr if the cast is invalid.. |
| template<typename T> |
| T* cast() { |
| if (T::Kind == kind()) |
| return static_cast<T*>(this); |
| return nullptr; |
| } |
| |
| std::unique_ptr<Type> clone(); |
| |
| static std::unique_ptr<Type> newWithKind(TypeKind kind); |
| }; |
| |
| // Type of nodes with a single output |
| struct TypeSingle : public Type { |
| private: |
| friend struct Type; |
| |
| std::vector<std::int64_t> sizes_; |
| std::vector<std::int64_t> strides_; |
| |
| TypeSingle() |
| : Type(TypeKind::Single) {} |
| |
| public: |
| static const TypeKind Kind = TypeKind::Single; |
| |
| void inferFrom(const at::Tensor& tensor) { |
| auto ndim = tensor.dim(); |
| sizes_.resize(ndim); |
| strides_.resize(ndim); |
| // NOTE: This is not memcpy! These are assignments. |
| std::copy(tensor.sizes().begin(), tensor.sizes().end(), sizes_.begin()); |
| std::copy(tensor.strides().begin(), tensor.strides().end(), strides_.begin()); |
| } |
| |
| const std::vector<std::int64_t>& sizes() { |
| return sizes_; |
| } |
| const std::vector<std::int64_t>& strides() { |
| return strides_; |
| } |
| }; |
| |
| // Type of multireturn nodes. Note that it doesn't mean that they must always |
| // have multiple outputs. |
| struct TypeMulti : public Type { |
| private: |
| friend struct Type; |
| |
| TypeMulti() |
| : Type(TypeKind::Multi) {} |
| |
| public: |
| static const TypeKind Kind = TypeKind::Multi; |
| }; |
| |
| inline std::unique_ptr<Type> Type::newWithKind(TypeKind kind) { |
| switch (kind) { |
| #define HANDLE_KIND(KIND) \ |
| case TypeKind::KIND: \ |
| return std::unique_ptr<Type>(static_cast<Type*>(new Type##KIND())); |
| TH_FORALL_TYPES(HANDLE_KIND) |
| } |
| #undef HANDLE_KIND |
| __builtin_unreachable(); |
| } |
| |
| inline std::unique_ptr<Type> Type::clone() { |
| #define HANDLE_KIND(KIND) \ |
| case TypeKind::KIND: \ |
| return std::unique_ptr<Type>(static_cast<Type*>( \ |
| new Type##KIND(*(static_cast<Type##KIND*>(this))))); |
| switch (kind_) { |
| TH_FORALL_TYPES(HANDLE_KIND) |
| } |
| #undef HANDLE_KIND |
| __builtin_unreachable(); |
| } |
| |
| // Each use is represented by this type, see Node::uses() |
| // 'user' is the consumer of the node, offset is the index into |
| // 'user's input this where the produces will be found. |
| struct Use { |
| Use(Node * user, size_t offset) |
| : user(user), offset(offset) {} |
| Node * user; |
| size_t offset; |
| }; |
| static inline bool operator==(const Use & a, const Use & b) { |
| return a.user == b.user && a.offset == b.offset; |
| } |
| |
| // Param represents an input to the Graph, it has no inputs itself. |
| // Graph holds a list of parameters. |
| struct Param; |
| |
| // the list types are intentionally simple, but we type-def |
| // them here so if we need to change them, refactoring will be easier |
| using node_list = std::vector<Node*>; |
| using param_list = node_list; |
| using use_list = std::vector<Use>; |
| using pyobj_list = std::vector<THPObjectPtr>; |
| template<typename T> |
| using ArrayRef = at::ArrayRef<T>; |
| |
| |
| // defined using x-macros so that we can generate toString easily |
| #define TH_FORALL_NODES(_) \ |
| _(PythonOp) \ |
| _(Param) \ |
| _(Select) \ |
| _(Return) \ |
| _(Add) \ |
| _(Mul) \ |
| _(Negate) \ |
| _(Sigmoid) \ |
| _(Tanh) \ |
| _(FusionGroup) |
| |
| enum class NodeKind { |
| #define DEFINE_NODE(n) n, |
| TH_FORALL_NODES(DEFINE_NODE) |
| #undef DEFINE_NODE |
| }; |
| |
| using graph_node_list = std::list<Node*>; |
| |
| struct Node { |
| TH_DISALLOW_COPY_AND_ASSIGN(Node); |
| friend struct Graph; |
| private: |
| graph_node_list::iterator nodes_iter_; |
| graph_node_list::iterator next() { return std::next(nodes_iter_); } |
| graph_node_list::iterator prev() { return std::prev(nodes_iter_); } |
| |
| const NodeKind kind_; |
| std::vector<Node*> inputs_; |
| use_list uses_; |
| Graph* graph_; |
| size_t unique_ = 0; // unique id |
| size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,... |
| protected: |
| std::unique_ptr<Type> type_; |
| Node(NodeKind kind_, TypeKind type_kind) |
| : kind_(kind_), type_(Type::newWithKind(type_kind)) {} |
| public: |
| NodeKind kind() { |
| return kind_; |
| } |
| const Type* type() { |
| return type_.get(); |
| } |
| void inferTypeFrom(const at::Tensor& output) { |
| auto single_type = type_->cast<TypeSingle>(); |
| JIT_ASSERT(single_type); |
| single_type->inferFrom(output); |
| } |
| Graph * owningGraph() { |
| return graph_; |
| } |
| size_t unique() { |
| return unique_; |
| } |
| void setStage(size_t s) { |
| stage_ = s; |
| } |
| size_t stage() { |
| return stage_; |
| } |
| const std::vector<Node*>& inputs() { |
| return inputs_; |
| } |
| const use_list & uses() { |
| return uses_; |
| } |
| |
| // Graphs |
| |
| // Note [Topological invariant] |
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // We always maintain an up-to-date topological ordering of all nodes via |
| // the next()/prev() links. All transformations to graphs must preserve |
| // this topological ordering: for example, it is only valid to 'addInput' |
| // with an input which is topologically before the current node. |
| // |
| // Usually, it is obvious whether or not topological order is maintained; |
| // for example, if you are adding nodes to the end of the topsort, it's |
| // impossible for them to refer to inputs that are not in the topsort. |
| // If it is not obvious, please comment accordingly. |
| |
| // Add 'node' as an input to 'this' at the end of existing |
| // arguments. Returns the added node for ease of chaining. |
| // |
| // Given: %3 = f(%1, %2) |
| // Execute: %3.addInput(%4) |
| // Result: %3 = f(%1, %2, %4) |
| Node* addInput(Node * node) { |
| JIT_ASSERT(graph_ == node->graph_); |
| node->uses_.emplace_back(this, inputs_.size()); |
| inputs_.push_back(node); |
| return node; |
| } |
| |
| // Replace the input of 'this' at position 'i' with |
| // 'newValue', returning the old node. |
| // |
| // Given: %3 = f(%1, %2) |
| // Execute: %3.replaceInput(1, %4) |
| // Result: %3 = f(%1, %4) |
| Node * replaceInput(size_t i, Node * newValue) { |
| JIT_ASSERT(newValue->graph_ == graph_); |
| Node * old = dropInput(i); |
| inputs_[i] = newValue; |
| newValue->uses_.emplace_back(this, i); |
| return old; |
| } |
| |
| // Replace all occurrences of 'from' in the inputs of this |
| // node with 'to'. Corresponds to llvm's replaceUsesOfWith. |
| // |
| // Given: %3 = f(%1, %2, %1) |
| // Execute: %3.replaceInputWith(%1, %4) |
| // Result: %3 = f(%4, %2, %4) |
| void replaceInputWith(Node * from, Node * to) { |
| JIT_ASSERT(from->graph_ == graph_); |
| JIT_ASSERT(to->graph_ == graph_); |
| size_t i = 0; |
| for(auto input : inputs()) { |
| if(input == from) |
| replaceInput(i, to); |
| i++; |
| } |
| } |
| |
| // Replaces all uses of this node with 'newValue'. |
| // |
| // Given: %3 = f(%1, %2) |
| // %4 = g(%3) |
| // %5 = h(%3, %3) |
| // Execute: %3.replaceAllUsesWith(%6) |
| // Result: %3 = f(%1, %2) |
| // %4 = g(%6) |
| // %5 = h(%6, %6) |
| void replaceAllUsesWith(Node * newValue) { |
| JIT_ASSERT(graph_ == newValue->graph_); |
| for(auto u : uses()) { |
| u.user->inputs_[u.offset] = newValue; |
| newValue->uses_.push_back(u); |
| } |
| uses_.clear(); |
| } |
| |
| // Insert unattached 'this' node after 'n' in the topological order. |
| // |
| // Given: %3 = f(%1, %2) |
| // %4 = g(%3) |
| // and unattached: %5 = h(%1) |
| // Execute: %5.insertBefore(%4) |
| // Result: %3 = f(%1, %2) |
| // %5 = h(%1) |
| // %4 = g(%3) |
| void insertBefore(Node * n) { |
| JIT_ASSERT(n->inGraphList()); |
| insertAt(n->nodes_iter_); |
| } |
| |
| // Insert unattached 'this' node after 'n' in the topological order. |
| // |
| // Given: %3 = f(%1, %2) |
| // %4 = g(%3) |
| // and unattached: %5 = h(%1) |
| // Execute: %5.insertAfter(%4) |
| // Result: %3 = f(%1, %2) |
| // %4 = g(%3) |
| // %5 = h(%1) |
| void insertAfter(Node * n) { |
| JIT_ASSERT(n->inGraphList()); |
| insertAt(n->next()); |
| } |
| |
| // Move 'this' (already in the graph) after 'n' in the topological order. |
| // |
| // Given: %2 = f(%1) |
| // %3 = g(%1) |
| // Execute: %2.moveAfter(%3) |
| // Result: %3 = g(%1) |
| // %2 = f(%1) |
| // |
| void moveAfter(Node * n) { |
| removeFromList(); |
| insertAfter(n); |
| } |
| |
| // Move a node 'n' (already in the graph) before 'this' in the topological order. |
| // |
| // Given: %2 = f(%1) |
| // %3 = g(%1) |
| // Execute: %3.moveBefore(%2) |
| // Result: %3 = g(%1) |
| // %2 = f(%1) |
| void moveBefore(Node * n) { |
| removeFromList(); |
| insertBefore(n); |
| } |
| |
| // Remove the input at 'i' from this node. |
| // |
| // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling |
| // removeInput. |
| // |
| // Given: %3 = f(%1, %2) |
| // Execute: %3.removeInput(1) |
| // Result: %3 = f(%1) |
| void removeInput(size_t i) { |
| dropInput(i); |
| // everything after this input shifts left, |
| // so we need to update their use offsets to match |
| for(size_t j = i+1; j < inputs_.size(); j++) { |
| auto it = findUseForInput(j); |
| it->offset--; |
| } |
| inputs_.erase(inputs_.begin() + i); |
| } |
| |
| // Remove all inputs from a node. |
| // |
| // Given: %3 = f(%1, %2) |
| // Execute: %3.removeAllInputs() |
| // Result: %3 = f() |
| void removeAllInputs() { |
| for(size_t i = 0; i < inputs().size(); ++i) |
| dropInput(i); |
| inputs_.clear(); |
| } |
| |
| // iterators of the node list starting at this node |
| // useful for resuming a search starting at this node |
| graph_node_list::iterator iterator() { |
| JIT_ASSERT(inGraphList()); |
| return nodes_iter_; |
| } |
| graph_node_list::reverse_iterator reverseIterator() { |
| JIT_ASSERT(inGraphList()); |
| // newly created reverse_iterator points to an element preceding |
| // (in forward order) the one pointed to by forward iter used to create it |
| return graph_node_list::reverse_iterator(std::next(nodes_iter_)); |
| } |
| |
| // Remove 'this' from the instruction list and deallocate it. |
| // |
| // Invariant: 'this' must not have any uses. |
| // |
| // Given: %2 = f(%1) |
| // %3 = g(%1) |
| // Execute: %2.destroy() |
| // Result: %3 = g(%1) |
| void destroy(); |
| |
| // Dynamically cast this node to the subclass indicated by the |
| // template variable, returning nullptr if the cast is invalid.. |
| // |
| // Example usage: if(auto s = n.cast<Select>()) { ... } |
| template<typename T> |
| T* cast() { |
| if(T::Kind == kind()) |
| return static_cast<T*>(this); |
| return nullptr; |
| } |
| |
| virtual ~Node() {} |
| private: |
| // Lookup iterator in use list of _input i_ that corresponds to its use of _this_ |
| use_list::iterator findUseForInput(size_t i) { |
| auto & input_uses = inputs_[i]->uses_; |
| // O(N) on the use list, but unless we get nodes with +100 uses |
| // vector traversal still is probably faster than linked list |
| auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i)); |
| JIT_ASSERT(use_it != input_uses.end()); |
| return use_it; |
| } |
| |
| void insertAt(graph_node_list::iterator it); |
| |
| // remove the use of input i, this sets input i to nullptr, but |
| // is only used internally to Node before setting it to a new value |
| // or erasing the entry from the list. |
| Node* dropInput(size_t i) { |
| JIT_ASSERT(i < inputs_.size()); |
| auto input_node = inputs_[i]; |
| auto use_it = findUseForInput(i); |
| input_node->uses_.erase(use_it); |
| inputs_[i] = nullptr; |
| return input_node; |
| } |
| |
| bool inGraphList(); |
| void removeFromList(); |
| protected: |
| virtual Node * allocClone(Graph * in_graph) = 0; |
| }; |
| |
| /******************* Nodes required inside a Graph ****************************/ |
| |
| // NodeWithKind handles the mapping between concrete node type and |
| // its NodeKind tag. |
| // |
| // It also sets up the clone infrastructure |
| // using CRTP so that we can alloc new clones and dispatch to custom clone code |
| // without significant boilerplate |
| |
| template<typename Self, NodeKind K, TypeKind T> |
| struct NodeWithKind : public Node { |
| friend struct Graph; /* so it can access allocClone() */ |
| static const NodeKind Kind = K; |
| NodeWithKind() |
| : Node(K, T) {} |
| // virtual so we can easily define a default here |
| // defined using CRTP so cloneFrom doesn't need casts. |
| // called from allocClone |
| virtual void cloneFrom(Self * s) {} |
| protected: |
| // allocate a new Node with the same type as this node, and |
| // get it initialized in in_graph |
| // in_graph may not be the same graph as this->graph_ because we might be |
| // cloning the node into a new graph |
| // defined here because we need to know Self to allocate a new node. |
| // user-defined cloneFrom is called. |
| virtual Node * allocClone(Graph * in_graph); |
| }; |
| |
| // helper to define simple primitive Ops. |
| template<typename Self, NodeKind K> |
| struct Primitive : public NodeWithKind<Self, K, TypeKind::Single> { |
| void init() {} |
| void init(ArrayRef<Node*> inputs) { |
| for(auto i : inputs) |
| this->addInput(i); |
| } |
| }; |
| |
| // the outputs of the Graph are represented as an Node so that its inputs |
| // can be tracked as Uses. |
| struct Return : public Primitive<Return, NodeKind::Return> {}; |
| |
| // an input tensor to the graph |
| struct Param : public NodeWithKind<Param, NodeKind::Param, TypeKind::Single> { |
| void init() {} |
| }; |
| |
| struct Graph { |
| TH_DISALLOW_COPY_AND_ASSIGN(Graph); |
| friend struct Node; |
| template<typename Self, NodeKind K, TypeKind T> |
| friend struct NodeWithKind; |
| private: |
| param_list inputs_; |
| |
| // holds outputs in a way that can be reflected |
| // as a Use object |
| // also used as the beginning/end of the circular node list to avoid |
| // having corner cases where the list is empty. |
| Return * output_; |
| |
| // only used to keep track of allocated nodes |
| // actual representation of Graph is done with |
| // inputs, outputs, nodes |
| |
| graph_node_list nodes_; |
| std::unordered_set<Node*> all_nodes; |
| size_t next_unique_; |
| |
| public: |
| Graph() |
| : next_unique_(0) { |
| output_ = create<Return>(); |
| } |
| |
| const param_list & inputs() { |
| return inputs_; |
| } |
| const node_list & outputs() { |
| return output_->inputs(); |
| } |
| const graph_node_list & nodes() { |
| return nodes_; |
| } |
| Node * return_node() { |
| return output_; |
| } |
| |
| Param * addInput() { |
| Param* p = create<Param>(); |
| inputs_.push_back(p); |
| return p; |
| } |
| |
| void eraseInput(size_t i) { |
| JIT_ASSERT(i < inputs_.size()); |
| JIT_ASSERT(inputs_[i]->uses().size() == 0); |
| Node * n = inputs_[i]; |
| inputs_.erase(inputs_.begin() + i); |
| freeNode(n); |
| } |
| |
| size_t registerOutput(Node * n) { |
| output_->addInput(n); |
| return outputs().size() - 1; |
| } |
| |
| // like make_shared, forward arguments to node initializers |
| // while also correctly allocating the node to live in this graph |
| // e.g. g.create<Select>(another,0); |
| template<typename T, typename... Args > |
| T * create(Args&&... args) { |
| // default construction of all nodes |
| T* r = new T(); |
| // common initialization for all nodes when they live in this graph |
| initNewNodeForGraph(r); |
| // custom per-node initialization |
| r->init(std::forward<Args>(args)...); |
| return r; |
| } |
| |
| // clone n, making a new node in _this_ graph. |
| // use node_map to translate inputs of n to inputs of the cloned node |
| Node * createClone(Node * n, std::function<Node*(Node*)> node_map) { |
| //n can be from a different graph |
| Node * r = n->allocClone(this); |
| for(auto i : n->inputs()) { |
| r->addInput(node_map(i)); |
| } |
| return r; |
| } |
| |
| Node * appendNode(Node * n) { |
| n->insertAt(nodes_.end()); |
| return n; |
| } |
| |
| Node * prependNode(Node * n) { |
| n->insertAt(nodes_.begin()); |
| return n; |
| } |
| |
| template<typename T, typename... Args > |
| Node * appendNewNode(Args&&... args) { |
| T* n = create<T>(std::forward<Args>(args)...); |
| return appendNode(n); |
| } |
| |
| template<typename T, typename... Args > |
| Node * prependNewNode(Args&&... args) { |
| T* n = create<T>(std::forward<Args>(args)...); |
| return prependNode(n); |
| } |
| |
| ~Graph() { |
| for (Node * n : all_nodes) |
| delete n; |
| } |
| |
| private: |
| // per graph initialization for any node |
| // called from NodeWithKind::allocClone and Graph::create |
| void initNewNodeForGraph(Node * r) { |
| r->graph_ = this; |
| r->unique_ = next_unique_++; |
| r->nodes_iter_ = nodes_.end(); |
| all_nodes.emplace(r); |
| } |
| |
| void freeNode(Node * n) { |
| auto it = all_nodes.find(n); |
| JIT_ASSERT(it != all_nodes.end()); |
| delete *it; |
| all_nodes.erase(it); |
| } |
| }; |
| |
| inline void Node::insertAt(graph_node_list::iterator it) { |
| JIT_ASSERT(!inGraphList()) |
| nodes_iter_ = graph_->nodes_.insert(it, this); |
| } |
| |
| inline bool Node::inGraphList() { |
| return nodes_iter_ != graph_->nodes_.end(); |
| } |
| |
| inline void Node::removeFromList() { |
| JIT_ASSERT(inGraphList()); |
| graph_->nodes_.erase(nodes_iter_); |
| nodes_iter_ = graph_->nodes_.end(); |
| } |
| |
| inline void Node::destroy() { |
| JIT_ASSERT(inGraphList()); |
| JIT_ASSERTM(uses().size() == 0, "attempting to erase a Node that still has uses."); |
| removeAllInputs(); |
| removeFromList(); |
| graph_->freeNode(this); |
| } |
| |
| template<typename Self, NodeKind K, TypeKind T> |
| Node * NodeWithKind<Self,K,T>::allocClone(Graph * in_graph) { |
| auto s = new Self(); |
| s->type_ = this->type_->clone(); |
| in_graph->initNewNodeForGraph(s); |
| // static cast is needed because the compiler doesn't know NodeWithKind is a CRTP. |
| s->cloneFrom(static_cast<Self*>(this)); |
| return s; |
| } |
| |
| // Helper macros for constructing switch statements over Node types |
| // instead of heavy-weight visitors |
| // read 'between' these defines to see how they turn into a big switch |
| // statement |
| |
| #define IR_IF(x,Kind) \ |
| auto && __match_key = x; \ |
| switch(__match_key->kind()) { \ |
| case NodeKind::Kind: { \ |
| auto * value = static_cast<::torch::jit::Kind*>(__match_key); (void) value; |
| #define IR_ELSEIF(Kind) \ |
| } break; \ |
| case NodeKind::Kind: { \ |
| auto * value = static_cast<::torch::jit::Kind*>(__match_key); (void) value; |
| #define IR_ELSE() \ |
| } break; \ |
| default: { |
| #define IR_END() \ |
| } break; \ |
| }; |
| |
| /* example: |
| Node * n = ...; |
| IR_IF(n,Select) |
| cout << "Select of" << value->base() << "\n"; |
| IR_ELSEIF(PythonOp) |
| cout << value->pyobj << "\n"; |
| IR_ELSEIF(Add) |
| cout << "Add" << \n"; |
| IR_ELSE() // optional |
| cout << "something else\n"; |
| IR_END() |
| */ |
| |
| std::ostream& operator<<(std::ostream & out, Graph & g); |
| |
| static inline const char * toString(NodeKind kind) { |
| switch(kind) { |
| #define DEFINE_CASE(Kind) \ |
| case NodeKind::Kind: return #Kind; |
| TH_FORALL_NODES(DEFINE_CASE) |
| #undef DEFINE_CASE |
| default: |
| __builtin_unreachable(); |
| } |
| } |
| |
| |
| /************* All nodes not required to be defined before Graph **************/ |
| |
| // execute a Python function, used for Ops we can't optimize but that we want to optimize around |
| struct PythonOp : public NodeWithKind<PythonOp,NodeKind::PythonOp,TypeKind::Multi> { |
| //TODO: make this non-autograd specific |
| //remove is_legacy, avoid THPObjectPtr to avoid big PyTorch dependency |
| |
| // The Python object which contains the implementation of this function. |
| // This is either a class (non-legacy) or an object (legacy). See |
| // TraceInterpreter for execution semantics. |
| THPObjectPtr pyobj; |
| // The calling convention for the Python function. |
| // 's' -- python scalar argument |
| // 't' -- tensor argument |
| std::string cconv; |
| bool is_legacy; |
| // Scalar arguments to the Python function. Not necessarily passed to |
| // the function in this order; see cconv for the correct order. |
| std::vector<THPObjectPtr> scalar_args; |
| std::string name(); |
| void init(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, pyobj_list&& scalar_args) { |
| this->pyobj = std::move(pyobj); |
| this->scalar_args = std::move(scalar_args); |
| this->cconv = cconv; |
| this->is_legacy = is_legacy; |
| } |
| virtual void cloneFrom(PythonOp * other) override { |
| this->cconv = other->cconv; |
| this->is_legacy = other->is_legacy; |
| Py_INCREF(other->pyobj.get()); |
| this->pyobj = THPObjectPtr(other->pyobj.get()); |
| for(auto & sa : other->scalar_args) { |
| Py_INCREF(sa.get()); |
| this->scalar_args.emplace_back(sa.get()); |
| } |
| } |
| }; |
| |
| // Select nodes are used to handle multiple returns for the ops that actually return |
| // multiple values like PythonOp |
| // By convension, there is a unique select node for each output of an op |
| // so you can iterate over uses of a multi-return op to get all the select nodes. |
| // in this case |
| // number_of_outputs = op.uses().size() |
| // this will change if Tuples ever become first class. |
| struct Select : public NodeWithKind<Select,NodeKind::Select,TypeKind::Single> { |
| void init(Node * node, size_t offset) { |
| addInput(node); |
| this->offset_ = offset; |
| } |
| // which multi-return op is it? |
| Node * base() { |
| return inputs()[0]; |
| } |
| // which output is it? |
| size_t offset() { |
| return offset_; |
| } |
| virtual void cloneFrom(Select * other) override { |
| this->offset_ = other->offset_; |
| } |
| private: |
| size_t offset_; |
| }; |
| |
| // NB: non-nullary constructors don't get forwarded to the |
| // parents, so you have to spell out the constructors you want explicitly. |
| |
| struct Add : public Primitive<Add,NodeKind::Add> {}; |
| struct Mul : public Primitive<Mul,NodeKind::Mul> {}; |
| struct Negate : public Primitive<Negate,NodeKind::Negate> {}; |
| struct Sigmoid : public Primitive<Sigmoid,NodeKind::Sigmoid> {}; |
| struct Tanh : public Primitive<Tanh,NodeKind::Tanh> {}; |
| |
| struct FusionGroup : public NodeWithKind<FusionGroup,NodeKind::FusionGroup,TypeKind::Multi> { |
| void init() { |
| subgraph_ = std::make_shared<Graph>(); |
| } |
| virtual void cloneFrom(FusionGroup * other) { |
| subgraph_ = other->subgraph_; |
| } |
| Graph & subgraph() { |
| return *subgraph_; |
| } |
| private: |
| std::shared_ptr<Graph> subgraph_; |
| }; |
| |
| }} |
| |
| namespace std { |
| |
| template<> |
| struct hash<torch::jit::NodeKind> { |
| std::size_t operator()(const torch::jit::NodeKind& k) const { |
| return hash<int>()(static_cast<int>(k)); |
| } |
| }; |
| |
| } // namespace std |