| #include "ir.h" |
| |
| #include "torch/csrc/utils/auto_gil.h" |
| #include "torch/csrc/utils/python_strings.h" |
| |
| #include <iostream> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <set> |
| #include <stack> |
| #include <sstream> |
| |
| namespace torch { namespace jit { |
| |
| std::string getPythonName(const PyObject* obj, bool is_legacy) { |
| AutoGIL gil; |
| if (is_legacy) { |
| return std::string(obj->ob_type->tp_name); |
| } else { |
| // NB: hypothetically __name__ could mutate the Python |
| // object in a externally visible way. Please don't! |
| auto wobj = const_cast<PyObject*>(obj); |
| THPObjectPtr name{PyObject_GetAttrString(wobj, "__name__")}; |
| return THPUtils_unpackString(name.get()); |
| } |
| } |
| std::ostream& operator<<(std::ostream & out, Node & n) { |
| if(auto s = n.cast<Select>()) |
| out << "%" << s->base()->unique() << "." << s->offset(); |
| else |
| out << "%" << n.unique(); |
| return out; |
| } |
| std::ostream& operator<<(std::ostream & out, const node_list & nodes) { |
| size_t i = 0; |
| for(auto n : nodes) { |
| if(i++ > 0) |
| out << ", "; |
| out << *n; |
| } |
| return out; |
| } |
| |
| static std::ostream& operator<<(std::ostream & out, THPObjectPtr& obj) { |
| THPObjectPtr repr { PyObject_Repr(obj.get()) }; |
| return out << THPUtils_unpackString(repr.get()); |
| } |
| |
| std::string PythonOp::name() { |
| return getPythonName(pyobj.get(),is_legacy); |
| } |
| |
| static void emitUses(std::ostream & out, Node * n) { |
| size_t i = 0; |
| for(auto u : n->uses()) { |
| if(i++ > 0) |
| out << ", "; |
| out << *u.user << ".i" << u.offset; |
| } |
| } |
| |
| std::ostream& operator<<(std::ostream & out, Graph & g) { |
| out << "graph(" << g.inputs() << ") {\n"; |
| std::vector<FusionGroup*> groups; |
| for(auto n : g.nodes()) { |
| if(!n->cast<Select>()) { //improve readibility by printing selects inline |
| out << " "; |
| if(n->type()->kind() == TypeKind::Multi) { |
| node_list selects; |
| for(auto u : n->uses()) |
| selects.push_back(u.user); |
| out << selects; |
| } else { |
| out << "%" << n->unique(); |
| } |
| out << " = "; |
| IR_IF(n,PythonOp) |
| out << "^" << value->name(); |
| out << "("; |
| int i = 0; |
| for (auto& scalar : value->scalar_args) { |
| if (i++ > 0) |
| out << ", "; |
| out << scalar; |
| } |
| out << ")"; |
| IR_ELSEIF(FusionGroup) |
| out << "fusion_group_" << groups.size(); |
| groups.push_back(value); |
| IR_ELSE() |
| out << toString(n->kind()); |
| IR_END() |
| out << "(" << n->inputs() << "), uses = ["; |
| if(n->type()->kind() == TypeKind::Multi) { |
| size_t i = 0; |
| for(auto u : n->uses()) { |
| if(i++ > 0) |
| out << ", "; |
| out << "["; |
| emitUses(out,u.user); |
| out << "]"; |
| } |
| } else { |
| emitUses(out,n); |
| } |
| out << "];\n"; |
| } |
| } |
| out << " return (" << g.outputs() << ");\n}\n"; |
| size_t i = 0; |
| for(auto fg : groups) { |
| out << "with fusion_group_" <<i++ << " = " << fg->subgraph(); |
| } |
| return out; |
| } |
| |
| using node_set = std::set<Node*>; |
| #define ALL_OF(container) container.begin(), container.end() |
| |
| // These functions purposely operate on the internal members directly, to force |
| // you to think about how the invariants change if you change the data |
| // representation (even if the external API does not change.) |
| |
| // NB: This assert is written to assume you don't have any unattached |
| // nodes. Unattached nodes can occur while manipulations to the |
| // graph are occurring. |
| void Node::lint() { |
| // Node invariants |
| // - if node should live in list, nodes_iter is consistent |
| // - Inputs are all marked as a use by the nodes they refer to |
| // - Stage is consistent (stage is >= all input stages) |
| // - Owning graph is non-null and consistent |
| // - The "Select" invariant, when the node is MultiReturn |
| |
| if (kind_ != NodeKind::Param && kind_ != NodeKind::Return) { |
| JIT_ASSERT(*nodes_iter_ == this); |
| } |
| |
| { |
| size_t i = 0; |
| for (auto input : inputs_) { |
| // WARNING: O(n^2) |
| JIT_ASSERT(std::find(ALL_OF(input->uses_), Use(this, i)) != input->uses_.end()); |
| JIT_ASSERT(stage_ >= input->stage_); |
| JIT_ASSERT(graph_->all_nodes.count(this) == 1); |
| i++; |
| } |
| } |
| |
| { |
| size_t i = 0; |
| for (auto use : uses_) { |
| // Use invariants |
| // - Use is consistent with inputs |
| // - Every user node is live (checked in Graph) |
| JIT_ASSERT(use.user->inputs_[use.offset] == this); |
| // Select invariant |
| // - Multi-return nodes only have select uses |
| // - uses = [Select 0, Select 1, Select 2, ...] |
| if (type_->kind() == TypeKind::Multi) { |
| JIT_ASSERT(use.offset == 0); |
| IR_IF(use.user, Select) |
| JIT_ASSERT(value->offset() == i); |
| IR_ELSE() |
| JIT_ASSERT(0); |
| IR_END() |
| } |
| i++; |
| } |
| } |
| |
| // Node subclass invariants |
| // - Return uses is zero |
| // - Param inputs is zero |
| // - Select inputs is one |
| // - Python operator cconv is correct |
| |
| IR_IF(this, Return) |
| JIT_ASSERT(uses_.size() == 0); |
| IR_ELSEIF(Constant) |
| JIT_ASSERT(inputs_.size() == 0); |
| IR_ELSEIF(Param) |
| JIT_ASSERT(inputs_.size() == 0); |
| IR_ELSEIF(Select) |
| JIT_ASSERT(inputs_.size() == 1); |
| IR_ELSEIF(PythonOp) |
| std::size_t n_scalars = 0, n_tensors = 0; |
| for (auto c : value->cconv) { |
| if (c == 's') { |
| n_scalars++; |
| } else if (c == 't') { |
| n_tensors++; |
| } else { |
| JIT_ASSERT(0); |
| } |
| JIT_ASSERT(value->pyobj != nullptr); |
| } |
| JIT_ASSERT(n_scalars == value->scalar_args.size()); |
| JIT_ASSERT(n_tensors == inputs_.size()); |
| // TODO: It's not good for these ops to be top-level, it makes cases longer. |
| IR_ELSEIF(Add) |
| JIT_ASSERT(inputs_.size() == 2); |
| IR_ELSEIF(Mul) |
| JIT_ASSERT(inputs_.size() == 2); |
| IR_ELSEIF(Negate) |
| JIT_ASSERT(inputs_.size() == 1); |
| IR_ELSEIF(Sigmoid) |
| JIT_ASSERT(inputs_.size() == 1); |
| IR_ELSEIF(Tanh) |
| JIT_ASSERT(inputs_.size() == 1); |
| IR_ELSEIF(FusionGroup) |
| // TODO: Typecheck the parameters |
| value->subgraph().lint(); |
| IR_ELSEIF(Chunk) |
| JIT_ASSERT(inputs_.size() == 1); |
| IR_END() |
| |
| } |
| |
| void Graph::lint() { |
| // Graph invariants |
| |
| // std::cout << *this << std::endl; |
| |
| // nodes |
| // - nodes_ is a valid topological ordering for inputs |
| // - No repeated nodes |
| // - Params and return do NOT occur in nodes |
| // - next_unique_ is greater than all uniques in graph |
| // - uniques in all_nodes are unique |
| |
| std::unordered_set<Node*> in_scope; |
| std::unordered_set<size_t> seen_uniques; |
| auto check_node = [&](Node* n) { |
| auto b = in_scope.insert(n); |
| JIT_ASSERT(b.second); |
| auto b2 = seen_uniques.insert(n->unique_); |
| JIT_ASSERT(b2.second); |
| JIT_ASSERT(n->unique_ < next_unique_); |
| }; |
| |
| for (auto input : inputs_) { |
| JIT_ASSERT(input->kind_ == NodeKind::Param); |
| input->lint(); |
| check_node(input); |
| } |
| for (auto n : nodes_) { |
| n->lint(); |
| JIT_ASSERT(n->kind_ != NodeKind::Param); |
| JIT_ASSERT(n->kind_ != NodeKind::Return); |
| for (auto input : n->inputs_) { |
| JIT_ASSERT(in_scope.count(input) == 1); |
| } |
| for (auto use : n->uses_) { |
| JIT_ASSERT(in_scope.count(use.user) == 0); |
| JIT_ASSERT(all_nodes.count(use.user) == 1); |
| } |
| check_node(n); |
| } |
| JIT_ASSERT(output_->kind() == NodeKind::Return); |
| output_->lint(); |
| for (auto output : output_->inputs_) { |
| JIT_ASSERT(in_scope.count(output) == 1); |
| } |
| check_node(output_); |
| |
| // all_nodes |
| // - inputs_, output_ and nodes_ are all included in all_nodes |
| // - all_nodes does not contain dead nodes??? (likely to be temporarily |
| // suspended). Weaker: all_nodes contains all inputs and returns |
| // - only one return node??? |
| |
| node_set all_nodes_set(ALL_OF(all_nodes)); // NB: all_nodes is *unordered* |
| node_set nodes_set(ALL_OF(nodes_)); |
| node_set inputs_set(ALL_OF(inputs_)); |
| node_set output_set{output_}; |
| // TODO: Make a more type safe std::includes wrapper which disallows use on |
| // non-ordered containers |
| JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set))); |
| JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set))); |
| JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set))); |
| |
| node_set sum_set; |
| sum_set.insert(ALL_OF(nodes_set)); |
| sum_set.insert(ALL_OF(inputs_set)); |
| sum_set.insert(ALL_OF(output_set)); |
| JIT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set))); |
| |
| } |
| |
| void LintGraph(std::unique_ptr<Graph>& graph) { |
| graph->lint(); |
| } |
| |
| }} |