| #include <torch/csrc/jit/ir/ir.h> |
| |
| #include <ATen/core/builtin_function.h> |
| #include <ATen/core/function.h> |
| #include <c10/util/Exception.h> |
| #include <c10/util/StringUtil.h> |
| #include <c10/util/irange.h> |
| #include <torch/csrc/jit/api/function_impl.h> |
| #include <torch/csrc/jit/frontend/error_report.h> |
| #include <torch/csrc/jit/frontend/schema_matching.h> |
| #include <torch/csrc/jit/ir/constants.h> |
| #include <torch/csrc/jit/runtime/operator.h> |
| #include <torch/csrc/jit/serialization/python_print.h> |
| |
| #include <algorithm> |
| #include <iostream> |
| #include <locale> |
| #include <memory> |
| #include <set> |
| #include <sstream> |
| #include <string> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <utility> |
| |
| namespace torch { |
| namespace jit { |
| |
| namespace utils { |
| std::string getNodesModuleHierarchy(const Node& n) { |
| if (!n.callstack().has_value()) { |
| return std::string(); |
| } |
| InlinedCallStackPtr callstack_ptr = n.callstack().value(); |
| std::string module_hierarchy; |
| for (auto& entry : callstack_ptr->vec()) { |
| const auto& opt_module_info = std::get<kModuleInstanceInfo>(entry); |
| if (opt_module_info.has_value()) { |
| const auto& module_instance_info = opt_module_info.value(); |
| if (!module_hierarchy.empty()) { |
| module_hierarchy.append("."); |
| } |
| module_hierarchy.append(utils::get_module_info(module_instance_info)); |
| } else { |
| module_hierarchy += ".UNKNOWN_INSTANCE(UNKNOWN_TYPE)"; |
| } |
| } |
| return module_hierarchy; |
| } |
| } // namespace utils |
| |
| namespace { |
| |
| // Constants relating to maintaining the topological index of nodes. |
| // |
| // Lower and upper bounds of the index. Inclusive range. |
| constexpr topo_position_t kLowerBound = INT64_MIN; |
| constexpr topo_position_t kUpperBound = INT64_MAX; |
| constexpr topo_position_t kMidPoint = 0; |
| |
| // How far away to space nodes that are appended to the graph. |
| // should be 2^n, where: |
| // - n is the maximum number of repeated insertions without a re-index |
| // - 2^(64-n) is the maximum number of appends to the end without reindex |
| constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */; |
| |
| void printValueRef(std::ostream& out, const Value* n) { |
| out << "%" << n->debugName(); |
| } |
| |
| bool isNumber(c10::string_view str) { |
| return str.find_first_not_of("0123456789") == std::string::npos; |
| } |
| |
| std::string normalizeAttrName(c10::string_view field) { |
| if (isNumber(field)) { |
| return "_" + std::string{field}; |
| } |
| return std::string{field}; |
| } |
| |
| void findAllNodes( |
| Block& block, |
| Symbol kind, |
| bool recurse, |
| std::vector<Node*>& ret) { |
| for (Node* n : block.nodes()) { |
| if (n->kind() == kind) { |
| ret.push_back(n); |
| } |
| if (recurse) { |
| for (auto b : n->blocks()) { |
| findAllNodes(*b, kind, recurse, ret); |
| } |
| } |
| } |
| } |
| |
| } // namespace |
| |
| // NB: This overload will become ambiguous with the one Caffe2 provides in its |
| // logging, if they ever intersect. |
| template <typename T> |
| std::ostream& operator<<(std::ostream& out, const std::vector<T>& nodes) { |
| out << at::ArrayRef<T>{nodes}; |
| return out; |
| } |
| |
| template <typename T> |
| static std::ostream& printValueRefs( |
| std::ostream& out, |
| const at::ArrayRef<T> nodes) { |
| size_t i = 0; |
| for (auto n : nodes) { |
| if (i++ > 0) { |
| out << ", "; |
| } |
| printValueRef(out, n); |
| } |
| return out; |
| } |
| |
| // Can't make these two overloads directly a template, it'll be ambiguous with |
| // the global printer for operator<<. |
| |
| std::ostream& operator<<( |
| std::ostream& out, |
| const at::ArrayRef<const Value*> nodes) { |
| return printValueRefs(out, nodes); |
| } |
| |
| std::ostream& operator<<(std::ostream& out, const at::ArrayRef<Value*> nodes) { |
| return printValueRefs(out, nodes); |
| } |
| |
| struct const_value_list_with_types { |
| const ArrayRef<const Value*> values; |
| std::string delim; |
| const_value_list_with_types( |
| ArrayRef<const Value*> values, |
| std::string delim_ = ", ") |
| : values(values), delim(std::move(delim_)) {} |
| }; |
| |
| std::ostream& operator<<( |
| std::ostream& out, |
| const const_value_list_with_types& l) { |
| size_t i = 0; |
| for (auto n : l.values) { |
| if (i++ > 0) { |
| out << l.delim; |
| } |
| printValueRef(out, n); |
| if (c10::type_verbosity() >= c10::TypeVerbosity::Type) { |
| out << " : "; |
| out << *n->type(); |
| } |
| } |
| return out; |
| } |
| |
| static void printAttribute(std::ostream& out, const at::Tensor& tensor) { |
| // 1-elem tensors are usually boxed scalars, so print them like it |
| if (tensor.numel() == 1) { |
| auto scalar_tensor = tensor.view(std::vector<int64_t>{}).item(); |
| out << "{"; |
| if (scalar_tensor.isFloatingPoint()) { |
| out << scalar_tensor.toDouble(); |
| } else if (scalar_tensor.isComplex()) { |
| out << scalar_tensor.toComplexDouble(); |
| } else { |
| out << scalar_tensor.toLong(); |
| } |
| out << "}"; |
| } else if (tensor.numel() <= max_tensor_display_size) { |
| // TODO: This is awful code. Also it doesn't work on Windows. |
| std::ostringstream tensor_ss; |
| tensor_ss << tensor; |
| std::string tensor_s{tensor_ss.str()}; |
| // Remove newlines |
| std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' '); |
| out << tensor_s; |
| } else { |
| out << "<Tensor>"; |
| } |
| } |
| |
| static void printAttribute(std::ostream& out, const IValue& ival) { |
| const auto customFormatter = [](std::ostream& ss, const IValue& input) { |
| if (input.isTensor()) { |
| printAttribute(ss, input.toTensor()); |
| return true; |
| } else if (input.isTensorList()) { |
| ss << "[<Tensors>]"; |
| return true; |
| } else if (input.isObject() && !input.type()->is_module()) { |
| ss << "object(" << &input.toObjectRef() << ")"; |
| return true; |
| } |
| return false; |
| }; |
| ival.repr(out, customFormatter); |
| } |
| |
| static void printTypeList( |
| std::ostream& out, |
| const std::vector<TypePtr>& items) { |
| out << "["; |
| int i = 0; |
| for (auto& item : items) { |
| if (i++ > 0) |
| out << ", "; |
| out << *item; |
| } |
| out << "]"; |
| } |
| |
| void Node::printAttrValue(std::ostream& out, const Symbol& name) const { |
| switch (kindOf(name)) { |
| case AttributeKind::c: |
| printAttribute(out, c(name)); |
| break; |
| case AttributeKind::cs: |
| // TODO(@anjali411): fix this |
| AT_ASSERT(false); |
| break; |
| case AttributeKind::f: |
| printAttribute(out, f(name)); |
| break; |
| case AttributeKind::fs: |
| printAttribute(out, fs(name)); |
| break; |
| case AttributeKind::i: |
| printAttribute(out, i(name)); |
| break; |
| case AttributeKind::is: |
| printAttribute(out, is(name)); |
| break; |
| case AttributeKind::s: |
| printAttribute(out, s(name)); |
| break; |
| case AttributeKind::ss: |
| printAttribute(out, ss(name)); |
| break; |
| case AttributeKind::t: |
| printAttribute(out, t(name)); |
| break; |
| case AttributeKind::ts: |
| out << "[<Tensors>]"; |
| break; |
| case AttributeKind::ival: |
| printAttribute(out, ival(name)); |
| break; |
| case AttributeKind::g: |
| out << "<Graph>"; |
| break; |
| case AttributeKind::gs: |
| out << "[<Graphs>]"; |
| break; |
| case AttributeKind::ty: |
| out << *ty(name); |
| break; |
| case AttributeKind::tys: |
| printTypeList(out, tys(name)); |
| break; |
| } |
| } |
| |
| void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false) |
| const { |
| out << "["; |
| auto names = attributeNames(); |
| int i = 0; |
| for (auto name : names) { |
| if (ignore_subgraph && name == attr::Subgraph) { |
| continue; |
| } |
| if (i++ > 0) { |
| out << ", "; |
| } |
| // TODO: debugging mode to see the qualifier. We definitely |
| // don't want to print the qualifier since it should always |
| // be attribute, but you might be able to track down a weird |
| // bug by printing it out. |
| out << name.toUnqualString() << "="; |
| |
| printAttrValue(out, name); |
| } |
| out << "]"; |
| } |
| |
| SourceRange Node::sourceRange() const { |
| if (source_range_) { |
| return *source_range_; |
| } |
| return SourceRange(); |
| } |
| |
| static std::ostream& indent(std::ostream& out, size_t level) { |
| for (const auto i : c10::irange(level)) { |
| (void)i; // Suppress unused variable warning |
| out << " "; |
| } |
| return out; |
| } |
| |
| std::ostream& Node::print( |
| std::ostream& out, |
| size_t level, |
| std::vector<const Node*>* groups, |
| bool print_source_locations, |
| bool print_attributes, |
| bool print_scopes, |
| bool print_body) const { |
| auto outs = outputs(); |
| indent(out, level) << const_value_list_with_types(outs); |
| out << " = "; |
| if (kind() == prim::PythonOp) { |
| auto* pyOp = static_cast<const ::torch::jit::PythonOp*>(this); |
| out << "^" << pyOp->name(); |
| pyOp->writeScalars(out); |
| } else if (hasAttribute(attr::Subgraph) && groups) { |
| out << kind().toQualString() << "_" << groups->size(); |
| if (print_attributes && numAttributes() > 1 && |
| kind() != prim::DifferentiableGraph) { |
| printAttributes(out, /*ignore_subgraph=*/true); |
| } |
| |
| groups->push_back(this); |
| } else { |
| out << kind().toQualString(); |
| if (print_attributes && hasAttributes()) { |
| printAttributes(out); |
| } |
| } |
| out << "(" << inputs() << ")"; |
| |
| if (print_scopes) { |
| std::string scName = scopeName(); |
| if (!scName.empty()) { |
| out << ", "; |
| out << "scope: " << scName; |
| } |
| } |
| |
| // In debug print, append file:line:col as a comment after each node |
| if (print_source_locations) { |
| SourceRange r = sourceRange(); |
| if (sourceRange().source()) { |
| if (auto orig = sourceRange().source()->findSourceRangeThatGenerated(r)) { |
| r = *orig; |
| } |
| } |
| if (auto file_line_col = r.file_line_col()) { |
| std::string filename; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| size_t line, col; |
| std::tie(filename, line, col) = *file_line_col; |
| out << " # " << filename << ":" << line << ":" << col; |
| } |
| } |
| |
| if (!print_body) { |
| return out; |
| } |
| |
| out << "\n"; |
| |
| for (const auto i : c10::irange(blocks().size())) { |
| auto b = blocks()[i]; |
| indent(out, level + 1) << "block" << i << "(" |
| << const_value_list_with_types(b->inputs()) |
| << "):\n"; |
| for (auto nested : b->nodes()) { |
| nested->print(out, level + 2, groups); |
| } |
| indent(out, level + 2) << "-> (" << b->outputs() << ")\n"; |
| } |
| |
| return out; |
| } |
| |
| std::ostream& operator<<(std::ostream& out, const Node& n) { |
| return n.print(out, 0, nullptr); |
| } |
| |
| std::ostream& Graph::print(std::ostream& out, bool print_source_locations) |
| const { |
| out << "graph(" << const_value_list_with_types(inputs(), ",\n ") |
| << "):\n"; |
| std::vector<const Node*> groups; |
| for (auto n : nodes()) { |
| n->print(out, 1, &groups, print_source_locations); |
| } |
| out << " return (" << outputs() << ")\n"; |
| size_t i = 0; |
| for (auto fg : groups) { |
| out << "with " << fg->kind().toQualString() << "_" << i++ << " = " |
| << *fg->g(attr::Subgraph); |
| } |
| out.flush(); |
| |
| /* |
| // Uncomment this to debug all_nodes issues |
| { |
| out << "\n"; |
| out << "all_nodes:\n"; |
| for (auto& n : all_nodes) { |
| printNode(out, const_cast<Node*>(n), nullptr); |
| } |
| } |
| */ |
| return out; |
| } |
| |
| std::ostream& operator<<(std::ostream& out, const Graph& g) { |
| return g.print(out, true); |
| } |
| |
| static void checkSameDevice(const Node* node) { |
| bool has_device = false; |
| c10::optional<at::Device> device = c10::nullopt; |
| auto checkValue = [&](const Value* v) { |
| if (TensorTypePtr type = v->type()->cast<TensorType>()) { |
| if (type->device() && !has_device) { |
| has_device = true; |
| device = *type->device(); |
| } else { |
| AT_ASSERT(device == type->device()); |
| } |
| } |
| }; |
| for (auto input : node->inputs()) { |
| checkValue(input); |
| } |
| for (auto output : node->outputs()) { |
| checkValue(output); |
| } |
| } |
| |
| using node_set = std::set<const Node*>; |
| #define ALL_OF(container) container.begin(), container.end() |
| |
| // These functions purposely operate on the internal members directly, to force |
| // you to think about how the invariants change if you change the data |
| // representation (even if the external API does not change.) |
| |
| // NB: This assert is written to assume you don't have any unattached |
| // nodes. Unattached nodes can occur while manipulations to the |
| // graph are occurring. |
| void Node::lint() const { |
| // Node invariants |
| // - if node should live in list, nodes_iter is consistent |
| // - Inputs are all marked as a use by the nodes they refer to |
| // - Owning graph is non-null and consistent |
| // - The "Select" invariant, when the node is MultiReturn |
| // |
| // The handle invariant: |
| // If a node takes a handle as an input, it is always the |
| // LAST input of the node. There is at most one handle input. |
| |
| { |
| size_t i = 0; |
| for (auto input : inputs_) { |
| // WARNING: O(n^2) |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
| AT_ASSERT( |
| std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) != |
| input->uses_.end()); |
| AT_ASSERT(graph_->all_nodes.count(this) == 1); |
| i++; |
| } |
| } |
| |
| for (auto o : outputs()) { |
| size_t i = 0; |
| for (auto use : o->uses()) { |
| // Use invariants |
| // - Use is consistent with inputs |
| // - Every user node is live (checked in Graph) |
| AT_ASSERT(use.user->inputs_[use.offset] == o); |
| i++; |
| } |
| } |
| |
| // Node subclass invariants |
| switch (kind()) { |
| case prim::Constant: |
| AT_ASSERT(inputs_.size() == 0); |
| break; |
| case prim::Return: |
| // Return uses is zero |
| AT_ASSERT(outputs().size() == 0); |
| break; |
| case prim::Param: |
| // Param inputs is zero |
| AT_ASSERT(inputs_.size() == 0); |
| break; |
| case prim::PythonOp: { |
| // Python operator cconv is correct |
| auto* value = static_cast<const PythonOp*>(this); |
| value->lint_python(); |
| break; |
| } |
| case prim::Eval: |
| // TODO: add invariants |
| // TODO: It's not good for these ops to be top-level, it makes cases |
| // longer. |
| break; |
| case prim::FusionGroup: |
| case prim::CudaFusionGroup: |
| case prim::oneDNNFusionGroup: |
| checkSameDevice(this); |
| // TODO: Typecheck the parameters |
| g(attr::Subgraph)->lint(); |
| break; |
| } |
| } |
| |
| // TODO: When lint fails, give better indication about which |
| // instruction triggered the failure. |
| void Graph::lint() const { |
| // Graph invariants |
| |
| // Uncomment the following to see the graph |
| // std::cout << *const_cast<Graph*>(this); |
| |
| // nodes |
| // - nodes_ is a valid topological ordering for inputs |
| // - No repeated nodes |
| // - Params and return do NOT occur in nodes |
| // - next_unique_ is greater than all uniques in graph |
| // - uniques in all_nodes are unique |
| // - every use will occur later in the toposort |
| |
| struct LintScope { |
| LintScope() = default; |
| LintScope(std::unique_ptr<LintScope> parent) : parent(std::move(parent)) {} |
| bool contains(const Value* v) { |
| return values.count(v) > 0 || (parent && parent->contains(v)); |
| } |
| bool contains(const Node* n) { |
| return nodes.count(n) > 0 || (parent && parent->contains(n)); |
| } |
| void insert(const Value* v) { |
| AT_ASSERT(!contains(v)); |
| values.insert(v); |
| } |
| void insert(const Node* n) { |
| AT_ASSERT(!contains(n)); |
| nodes.insert(n); |
| } |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| std::unique_ptr<LintScope> parent; |
| |
| private: |
| std::unordered_set<const Value*> values; |
| std::unordered_set<const Node*> nodes; |
| }; |
| // Struct enables mutual recursion in linting methods. |
| // Putting it inside Graph::lint enables access to private Graph members |
| struct LintImpl { |
| LintImpl(const Graph& g) |
| : g(g), |
| scope(new LintScope()), |
| all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered* |
| const Graph& g; |
| std::unique_ptr<LintScope> scope; |
| std::unordered_set<size_t> seen_uniques; |
| std::unordered_map<const Node*, int64_t> anticipated_uses; |
| node_set all_nodes_set; |
| node_set sum_set; |
| |
| void check_value(const Value* v) { |
| scope->insert(v); |
| auto b2 = seen_uniques.insert(v->unique()); |
| AT_ASSERT(b2.second); // insertion took place |
| AT_ASSERT(v->unique() < g.next_unique_); |
| |
| for (auto use : v->uses()) { |
| AT_ASSERT(!scope->contains(use.user)); |
| AT_ASSERT(g.all_nodes.count(use.user) == 1); |
| anticipated_uses[use.user]++; // int default constructs to 0 |
| } |
| } |
| void check_node(const Node* n) { |
| for (auto input : n->inputs_) { |
| if (!scope->contains(input)) { |
| AT_ASSERTM(0, input->unique(), " not in scope"); |
| } |
| } |
| AT_ASSERT(anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size())); |
| anticipated_uses[n] = -1; // we saw the anticipated user! |
| scope->insert(n); |
| for (auto block : n->blocks()) { |
| std::unique_ptr<LintScope> new_scope(new LintScope(std::move(scope))); |
| scope = std::move(new_scope); |
| check_block(block); |
| scope = std::move(scope->parent); |
| } |
| size_t i = 0; |
| for (auto o : n->outputs()) { |
| AT_ASSERT(o->node() == n); |
| AT_ASSERT(i++ == o->offset_); |
| check_value(o); |
| } |
| n->lint(); |
| } |
| void check_block(const Block* b) { |
| // Check topological ordering |
| AT_ASSERT(b->param_node()->isBefore(*b->nodes().begin())); |
| auto curNode = *b->nodes().begin(); |
| while (curNode != b->return_node()) { |
| AT_ASSERT(curNode->isBefore(curNode->next())); |
| curNode = curNode->next(); |
| } |
| |
| for (auto input : b->inputs()) { |
| check_value(input); |
| AT_ASSERT(input->node()->kind_ == prim::Param); |
| } |
| |
| for (auto n : b->nodes()) { |
| AT_ASSERT(n->kind_ != prim::Param); |
| AT_ASSERT(n->kind_ != prim::Return); |
| check_node(n); |
| } |
| |
| AT_ASSERT(b->output_->kind() == prim::Return); |
| check_node(b->output_); |
| |
| // all_nodes |
| // - inputs_, output_ and nodes_ are all included in all_nodes |
| // - all_nodes does not contain dead nodes??? (likely to be temporarily |
| // suspended). Weaker: all_nodes contains all inputs and returns |
| // - only one return node??? |
| |
| node_set nodes_set(ALL_OF(b->nodes())); |
| node_set inputs_set{b->input_}; |
| node_set output_set{b->output_}; |
| // TODO: Make a more type safe std::includes wrapper which disallows use |
| // on non-ordered containers |
| AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set))); |
| AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set))); |
| AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set))); |
| |
| sum_set.insert(ALL_OF(nodes_set)); |
| sum_set.insert(ALL_OF(inputs_set)); |
| sum_set.insert(ALL_OF(output_set)); |
| } |
| void check_graph() { |
| node_set all_nodes_set( |
| ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered* |
| |
| check_block(g.block_); |
| for (auto kv : anticipated_uses) { |
| AT_ASSERT(kv.second == -1); |
| } |
| AT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set))); |
| } |
| }; |
| LintImpl(*this).check_graph(); |
| } |
| |
| void Graph::dump() const { |
| std::cout << *this << "\n"; |
| } |
| |
| void Graph::push_scope(const std::string& scope_name) { |
| current_scope_ = current_scope_->push(Symbol::scope(scope_name)); |
| Node* block_node = insertNode(create(prim::TracedModuleForward, 0)); |
| block_node->s_(attr::scope, scope_name); |
| Block* b = block_node->addBlock(); |
| setInsertPoint(b); |
| } |
| void Graph::pop_scope() { |
| current_scope_ = current_scope_->parent(); |
| if (insertPoint()->owningBlock()->owningNode()->kind() == |
| prim::TracedModuleForward) { |
| setInsertPoint(insertPoint()->owningBlock()->owningNode()->next()); |
| } |
| } |
| |
| void LintGraph(const std::shared_ptr<Graph>& graph) { |
| graph->lint(); |
| } |
| |
| Block::Block(Graph* graph_, Node* node_) |
| : graph_(graph_), |
| output_(graph_->create(prim::Return, 0)), |
| input_(graph_->create(prim::Param, 0)), |
| owning_node_(node_) { |
| input_->next() = output_; |
| input_->prev() = output_; |
| output_->next() = input_; |
| output_->prev() = input_; |
| |
| graph_->all_blocks.emplace(this); |
| output_->owning_block_ = this; |
| output_->topo_position_ = kUpperBound; |
| input_->owning_block_ = this; |
| input_->topo_position_ = kLowerBound; |
| } |
| |
| void Block::reIndexTopology() { |
| auto curPos = kLowerBound; |
| for (auto node : nodes()) { |
| AT_ASSERT(curPos <= (kUpperBound - kAppendInterval)); |
| curPos += kAppendInterval; |
| node->topo_position_ = curPos; |
| } |
| } |
| |
| void Block::cloneFrom(Block* src, std::function<Value*(Value*)> value_map) { |
| std::unordered_map<Value*, Value*> local_map; |
| auto env = [&](Value* v) { |
| auto it = local_map.find(v); |
| if (it != local_map.end()) { |
| return it->second; |
| } |
| return value_map(v); |
| }; |
| |
| auto graph = owningGraph(); |
| for (auto input : src->inputs()) { |
| local_map[input] = this->addInput()->copyMetadata(input); |
| } |
| |
| for (auto node : src->nodes()) { |
| auto new_node = this->appendNode(graph->createClone(node, env)); |
| for (size_t i = 0; i < node->outputs().size(); ++i) { |
| auto oo = node->outputs()[i]; |
| auto no = new_node->outputs()[i]; |
| local_map[oo] = no; |
| no->copyMetadata(oo); |
| } |
| } |
| for (auto output : src->outputs()) { |
| this->registerOutput(env(output)); |
| } |
| } |
| |
| void Block::destroy() { |
| // we cannot destroy the output because it is used as the sentinel |
| // for the nodes() list and has to remain valid for the loop |
| output_->removeAllInputs(); |
| for (auto it = this->nodes().reverse().begin(), |
| end = this->nodes().reverse().end(); |
| it != end; |
| ++it) { |
| it.destroyCurrent(); |
| } |
| output_->destroy(); |
| input_->destroy(); |
| graph_->freeBlock(this); |
| } |
| |
| void Graph::cloneFrom(Graph& src) { |
| auto env = [](Value* v) -> Value* { |
| AT_ERROR( |
| "Graph::copy() encountered a use of a value " + v->debugName() + |
| " not in scope. Run lint!"); |
| }; |
| block()->cloneFrom(src.block(), env); |
| } |
| |
| std::shared_ptr<Graph> Graph::copy() { |
| auto new_g = std::make_shared<Graph>(); |
| new_g->cloneFrom(*this); |
| return new_g; |
| } |
| |
| std::unique_ptr<Graph> Graph::copyUnique() { |
| auto new_g = std::make_unique<Graph>(); |
| new_g->cloneFrom(*this); |
| return new_g; |
| } |
| |
| void Block::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) { |
| for (Value* input : inputs()) { |
| input->setType(type_map(input->type())); |
| } |
| for (Node* node : nodes()) { |
| for (Value* output : node->outputs()) { |
| output->setType(type_map(output->type())); |
| } |
| for (Block* sub_block : node->blocks()) { |
| sub_block->remapTypes(type_map); |
| } |
| for (Symbol name : node->attributeNames()) { |
| if (node->kindOf(name) == AttributeKind::g) { |
| node->g(name)->remapTypes(type_map); |
| } else if (node->kindOf(name) == AttributeKind::gs) { |
| for (const auto& g : node->gs(name)) { |
| g->remapTypes(type_map); |
| } |
| } |
| } |
| } |
| } |
| |
| void Graph::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) { |
| block()->remapTypes(type_map); |
| } |
| |
| void Value::inferTypeFrom(const at::Tensor& output) { |
| setType(TensorType::create(output)); |
| } |
| |
| void Value::inferTypeFrom( |
| const c10::intrusive_ptr<c10::ivalue::Object>& output) { |
| setType(output->type()); |
| } |
| |
| bool Value::mustBeNone() const { |
| return type()->cast<NoneType>() || node_->mustBeNone(); |
| } |
| bool Value::mustNotBeNone() const { |
| return node_->kind() != prim::AutogradAdd && type() != NoneType::get() && |
| !type()->cast<OptionalType>() && |
| !(type()->cast<UnionType>() && |
| type()->expect<UnionType>()->canHoldType(*NoneType::get())); |
| } |
| |
| std::string Value::debugNameBase() const { |
| std::string name = debugName(); |
| std::string name_base = name; |
| auto last_dot_pos = name.find_last_of('.'); |
| if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) { |
| if (name.find_first_not_of("0123456789", last_dot_pos + 1) == |
| std::string::npos) { |
| name_base = name.substr(0, last_dot_pos); |
| } |
| } |
| return name_base; |
| } |
| |
| bool Value::isValidName(const std::string& name) { |
| // Empty strings are legal |
| if (!name.size()) { |
| return true; |
| } |
| |
| // Numbers are not legal |
| if (isNumber(name)) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| Value* Value::setDebugName(const std::string& name) { |
| if (!isValidName(name)) { |
| throw std::runtime_error("Invalid name: '" + name + "'"); |
| } |
| |
| auto& names = node()->owningGraph()->unique_names_; |
| |
| // clear any old name from the map |
| if (hasDebugName()) { |
| names.erase(unique_name_); |
| unique_name_ = ""; |
| } |
| |
| // allow "" to clear the uniquename |
| if (name == "") { |
| return this; |
| } |
| |
| // if someone else has this name, then rename the other value |
| auto old_owner_of_name = names.find(name); |
| if (old_owner_of_name != names.end()) { |
| size_t suffix = 1; |
| std::string name_base = name; |
| auto last_dot_pos = name.find_last_of('.'); |
| if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) { |
| if (name.find_first_not_of("0123456789", last_dot_pos + 1) == |
| std::string::npos) { |
| suffix = c10::stoll(name.substr(last_dot_pos + 1)); |
| name_base = name.substr(0, last_dot_pos); |
| } |
| } |
| |
| auto& names_suffixes = node()->owningGraph()->name_base_suffix_; |
| auto it = names_suffixes.find(name_base); |
| if (it != names_suffixes.end()) { |
| suffix = std::max(suffix, it->second + 1); |
| } |
| |
| // Verify that new name is not used and find next usable name in case |
| // suffix is used. |
| std::string replacement_name; |
| do { |
| std::stringstream ss; |
| #ifndef _WIN32 |
| // Protect 12345 integer from becoming "1,2345" if some other process sets |
| // global locale For more details see |
| // https://github.com/pytorch/pytorch/issues/79583#issuecomment-1161260061 |
| static std::locale c_locale("C"); |
| ss.imbue(c_locale); |
| #endif |
| ss << name_base << "." << suffix++; |
| replacement_name = ss.str(); |
| } while (names.count(replacement_name) > 0); |
| |
| names_suffixes[name_base] = suffix; |
| |
| old_owner_of_name->second->setDebugName(replacement_name); |
| } |
| |
| names[name] = this; |
| unique_name_ = name; |
| return this; |
| } |
| |
| Value* Value::copyMetadata(Value* from) { |
| setType(from->type()); |
| if (from->hasDebugName()) { |
| setDebugName(from->debugName()); |
| } |
| return this; |
| } |
| |
| void Value::replaceFirstUseWith(Value* newValue) { |
| AT_ASSERT(owningGraph() == newValue->owningGraph()); |
| auto u = uses()[0]; |
| u.user->inputs_[u.offset] = newValue; |
| newValue->uses_.push_back(u); |
| uses_.erase(uses_.begin()); |
| } |
| |
| void Value::replaceAllUsesWith(Value* newValue) { |
| while (!uses().empty()) { |
| replaceFirstUseWith(newValue); |
| } |
| } |
| |
| void Value::replaceAllUsesAfterNodeWith(const Node* node, Value* newValue) { |
| std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) { |
| if (u.user->isAfter(node)) { |
| u.user->inputs_[u.offset] = newValue; |
| newValue->uses_.push_back(u); |
| } |
| }); |
| |
| uses_.erase( |
| std::remove_if( |
| uses_.begin(), |
| uses_.end(), |
| [&node](const Use& u) { return u.user->isAfter(node); }), |
| uses_.end()); |
| } |
| |
| void Value::replaceAllUsesDominatedByNodeWith( |
| const Node* node, |
| Value* newValue) { |
| std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) { |
| if (u.user->isDominatedBy(node)) { |
| u.user->inputs_[u.offset] = newValue; |
| newValue->uses_.push_back(u); |
| } |
| }); |
| |
| uses_.erase( |
| std::remove_if( |
| uses_.begin(), |
| uses_.end(), |
| [&node](const Use& u) { return u.user->isDominatedBy(node); }), |
| uses_.end()); |
| } |
| |
| size_t findArgument( |
| const FunctionSchema& the_schema, |
| const std::string& unqualName) { |
| for (const auto i : c10::irange(the_schema.arguments().size())) { |
| const Argument* arg = &the_schema.arguments()[i]; |
| if (arg->name() == unqualName) { |
| return i; |
| } |
| } |
| throw std::runtime_error( |
| std::string("Couldn't find an argument called ") + unqualName); |
| } |
| |
| size_t findArgument(const FunctionSchema& the_schema, Symbol name) { |
| const auto unqualName = name.toUnqualString(); |
| return findArgument(the_schema, unqualName); |
| } |
| |
| c10::optional<IValue> Node::get(Symbol name) const { |
| return toIValue(namedInput(name)); |
| } |
| |
| bool Node::hasNamedInput(const std::string& name) const { |
| for (const auto& argument : schema().arguments()) { |
| if (argument.name() == name) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| Value* Node::namedInput(const std::string& unqualName) const { |
| return input(findArgument(schema(), unqualName)); |
| } |
| Value* Node::namedInput(Symbol name) const { |
| return input(findArgument(schema(), name)); |
| } |
| |
| bool Node::matches(const FunctionSchema& schema) const { |
| // wrong name |
| if (kind().toQualString() != schema.name()) { |
| return false; |
| } |
| at::ArrayRef<const Value*> actuals = inputs(); |
| const auto& formals = schema.arguments(); |
| |
| // not enough inputs |
| if (actuals.size() < formals.size()) { |
| return false; |
| } |
| |
| TypeEnv type_env; |
| for (const auto i : c10::irange(formals.size())) { |
| auto formal = formals[i].type(); |
| const MatchTypeReturn matched_type = |
| matchTypeVariables(formal, actuals[i]->type(), type_env); |
| if (!matched_type.success()) { |
| return false; |
| } |
| |
| TypePtr resolved = tryEvalTypeVariables(formal, type_env); |
| if (resolved) { |
| formal = resolved; |
| } |
| // note: it is possible at this point that type variable matching has |
| // not resolved all type variables, e.g. if None was matched to Optional[T] |
| // we will not succeed at matching T. However None <: Optional[T] so this |
| // check can still succeed. |
| |
| if (!actuals[i]->type()->isSubtypeOf(*formal)) { |
| return false; |
| } |
| } |
| |
| // too many inputs |
| if (!schema.is_vararg() && actuals.size() != formals.size()) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| bool Node::matches( |
| const char* signature_literal, |
| at::ArrayRef<Symbol> const_inputs) const { |
| if (!matches(getOperatorForLiteral(signature_literal)->schema())) { |
| return false; |
| } |
| for (Symbol s : const_inputs) { |
| if (!is_constant(s)) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| bool Node::mustBeNone() const { |
| // We can statically deduce this Node has returning None if: |
| return |
| // It's an AutogradZero node, or ... |
| kind_ == prim::AutogradZero || |
| // It has only one output and that output is NoneType, or ... |
| (outputs().size() == 1 && output()->type() == NoneType::get()) || |
| // It's a constant optional with no value in the attributes. |
| (kind_ == prim::Constant && !this->hasAttributes() && |
| output()->type()->cast<OptionalType>()); |
| } |
| |
| void Node::dump() const { |
| std::cout << *this << "\n"; |
| } |
| |
| const FunctionSchema& Node::schema() const { |
| if (op_) { |
| return op_->schema(); |
| } |
| return getOperator().schema(); |
| } |
| |
| const FunctionSchema* Node::maybeSchema() const { |
| if (auto op = maybeOperator()) { |
| return &op->schema(); |
| } |
| return nullptr; |
| } |
| |
| const Operator* Node::maybeOperator() const { |
| if (!op_) { |
| const auto& candidates = getAllOperatorsFor(kind()); |
| for (const auto& candidate : candidates) { |
| if (matches(candidate->schema())) { |
| op_ = candidate.get(); |
| break; |
| } |
| } |
| } |
| return op_; |
| } |
| |
| const Operator& Node::getOperator() const { |
| const Operator* maybe = maybeOperator(); |
| if (maybe) |
| return *maybe; |
| |
| auto er = ErrorReport(sourceRange()); |
| er << "Schema not found for node. File a bug report.\n"; |
| er << "Node: " << *this << "\n"; |
| er << "Input types:"; |
| for (const auto i : c10::irange(inputs().size())) { |
| if (i > 0) |
| er << ", "; |
| er << *inputs()[i]->type(); |
| } |
| const auto& candidates = getAllOperatorsFor(kind()); |
| if (candidates.size() > 0) { |
| er << "\ncandidates were:\n"; |
| for (auto& candidate : candidates) { |
| er << " " << candidate->schema() << "\n"; |
| } |
| } else { |
| er << "\nno candidates found\n"; |
| } |
| er << "within the graph:\n"; |
| er << *owningGraph() << "\n"; |
| throw er; |
| } |
| |
| Operation Node::getOperation() const { |
| // note: some operators require the node to produce a runnable operation, |
| // which is why 'this' is passed here. getOperator() ensures that 'this' |
| // matches the schema of the returned operator. |
| return getOperator().getOperation(this); |
| } |
| |
| bool Node::isNondeterministic() const { |
| const auto schema = maybeSchema(); |
| if (!kind().is_aten()) { |
| return false; |
| } |
| // All aten ops are expecte to have a schema. However this is left as a |
| // warning instead of an assert to ensure that previous use cases do not |
| // break. |
| if (!schema) { |
| TORCH_WARN("aten Schema not found."); |
| return false; |
| } |
| torch::utils::SchemaInfo schema_info(*schema); |
| if (hasNamedInput("train")) { |
| auto value = constant_as<bool>(namedInput("train")); |
| if (value.has_value()) { |
| schema_info.addArgumentValue("train", *value); |
| } |
| } |
| return schema_info.is_nondeterministic(); |
| } |
| |
| bool Node::hasSideEffects() const { |
| switch (kind_) { |
| case prim::PythonOp: |
| case prim::IgnoredPythonOp: |
| case prim::Print: |
| case prim::RaiseException: |
| case aten::warn: |
| case aten::save: |
| case aten::manual_seed: |
| case prim::AddStatValue: |
| case prim::TimePoint: |
| case prim::CallFunction: |
| case prim::CallMethod: |
| case prim::BailoutTemplate: |
| case prim::BailOut: |
| case prim::rpc_async: // It represents RPC message sent. |
| case prim::rpc_sync: // It represents RPC message sent. |
| case prim::rpc_remote: // It represents RPC message sent. |
| case aten::wait: // It can represent RPC message received. |
| #if !defined(USE_ROCM) |
| case cuda::set_stream: |
| case cuda::_set_device: |
| case cuda::_current_device: |
| case cuda::synchronize: |
| #endif |
| case prim::Enter: |
| case prim::Exit: |
| return true; |
| } |
| |
| auto op = maybeOperator(); |
| if (!op) { |
| TORCH_INTERNAL_ASSERT( |
| kind_.is_prim(), |
| "Only prim ops are allowed to not have a registered operator but ", |
| kind_.toDisplayString(), |
| " doesn't have one either. We don't know if this op has side effects."); |
| return false; |
| } |
| |
| if (kind_.is_prim() || kind_.is_aten() || kind_.is_cuda()) { |
| // TODO There is nothing in the system that relies on aten:: and prim:: |
| // ops using AliasAnalysisKind::FROM_SCHEMA, |
| // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, or |
| // AliasAnalysisKind::CONSERVATIVE but this is the intended behavior for all |
| // current ops and a good error check. We can consider lifting this |
| // constraint later if we have a use case for it. |
| TORCH_INTERNAL_ASSERT( |
| op->aliasAnalysisKind() == AliasAnalysisKind::INTERNAL_SPECIAL_CASE || |
| op->aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA || |
| op->aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE, |
| "aten:: and prim:: ops should have AliasAnalysisKind::INTERNAL_SPECIAL_CASE" |
| ", AliasAnalysisKind::FROM_SCHEMA or AliasAnalysisKind::CONSERVATIVE but ", |
| kind_.toDisplayString(), |
| " has ", |
| toString(op->aliasAnalysisKind())); |
| } |
| |
| switch (op->aliasAnalysisKind()) { |
| case AliasAnalysisKind::PURE_FUNCTION: |
| case AliasAnalysisKind::FROM_SCHEMA: |
| case AliasAnalysisKind::INTERNAL_SPECIAL_CASE: |
| return false; |
| case AliasAnalysisKind::CONSERVATIVE: |
| return true; |
| } |
| TORCH_INTERNAL_ASSERT(false, "Unhandled AliasAnalysisKind case"); |
| return false; // silence compiler warning |
| } |
| |
| // Assign this node a topological position, to facilitate fast isBefore() and |
| // isAfter() queries. Must be called right after a node is inserted into the |
| // node list. |
| // |
| // The basic scheme is: assign every node a position (uint64_t). The common |
| // case (appending to the end of the graph) is made more efficient by advancing |
| // a fixed interval past the previous node and placing `this` there. Otherwise, |
| // assign `this` a position at the midpoint between its prev() and next() |
| // nodes. |
| // |
| // If we ever run out of space (by, e.g. inserting too much in place), we |
| // reindex by spreading out all the nodes again. |
| void Node::assignTopoPosition() { |
| bool is_first = prev() == owningBlock()->param_node(); |
| bool is_last = next() == owningBlock()->return_node(); |
| |
| const auto prevPos = prev()->topo_position_; |
| const auto nextPos = next()->topo_position_; |
| |
| // Append to the end of the graph |
| if (is_last) { |
| if (is_first) { |
| // the node list is empty, assign the first position |
| topo_position_ = kMidPoint; |
| return; |
| } |
| |
| if (prevPos >= (kUpperBound - kAppendInterval)) { |
| // we're running off the edge |
| owningBlock()->reIndexTopology(); |
| return; |
| } |
| |
| topo_position_ = prevPos + kAppendInterval; |
| |
| // Prepend to the graph |
| } else if (is_first) { |
| // next() is the first element in the block list |
| if (nextPos <= (kLowerBound + kAppendInterval)) { |
| // we're running off the edge |
| owningBlock()->reIndexTopology(); |
| return; |
| } |
| topo_position_ = nextPos - kAppendInterval; |
| |
| // insert between two existing nodes |
| } else { |
| const auto posBetween = prevPos + (nextPos - prevPos) / 2; |
| if (posBetween == prevPos) { |
| // There was no room |
| owningBlock()->reIndexTopology(); |
| return; |
| } |
| topo_position_ = posBetween; |
| } |
| } |
| |
| Node::Node(Graph* graph_, NodeKind kind_) |
| : kind_(kind_), |
| graph_(graph_), |
| owning_block_(nullptr), |
| scope_(graph_->current_scope_), |
| callstack_(c10::nullopt), |
| op_(nullptr), |
| topo_position_(0) { |
| graph_->all_nodes.emplace(this); |
| } |
| |
| void Node::eraseOutput(size_t i) { |
| AT_ASSERT(i < outputs_.size()); |
| AT_ASSERT(outputs_[i]->uses().empty()); |
| op_ = nullptr; |
| Value* n = outputs_[i]; |
| outputs_.erase(outputs_.begin() + i); |
| owningGraph()->freeValue(n); |
| for (const auto j : c10::irange(i, outputs_.size())) { |
| outputs_[j]->offset_--; |
| } |
| } |
| |
| Block* Node::addBlock() { |
| op_ = nullptr; |
| blocks_.push_back(new Block(owningGraph(), this)); |
| return blocks_.back(); |
| } |
| |
| void Node::eraseBlock(size_t i) { |
| AT_ASSERT(i < blocks_.size()); |
| op_ = nullptr; |
| Block* n = blocks_[i]; |
| blocks_.erase(blocks_.begin() + i); |
| n->destroy(); |
| } |
| |
| void Node::destroy() { |
| while (!outputs().empty()) { |
| eraseOutput(outputs().size() - 1); |
| } |
| while (!blocks().empty()) { |
| eraseBlock(blocks().size() - 1); |
| } |
| removeAllInputs(); |
| if (inBlockList()) { |
| removeFromList(); |
| } |
| graph_->freeNode(this); |
| } |
| |
| void Node::cloneFrom(Node* s) { |
| source_range_ = s->source_range_; |
| if (s->scope_ && !s->scope_->isBlank()) { |
| scope_ = s->scope_; |
| } |
| copyAttributes(*s); |
| callstack_ = s->callstack_; |
| } |
| |
| void Node::replaceAllUsesWith(Node* n) { |
| AT_ASSERT(outputs().size() == n->outputs().size()); |
| size_t nOutputs = outputs().size(); |
| for (const auto i : c10::irange(nOutputs)) { |
| outputs()[i]->replaceAllUsesWith(n->outputs()[i]); |
| } |
| } |
| |
| Node* Node::replaceWithNewSymbol(Symbol new_symbol) { |
| WithInsertPoint insert_guard{this}; |
| bool had_operator = maybeOperator() != nullptr; |
| auto graph = owningGraph(); |
| auto replace_node = graph->insertNode(graph->create(new_symbol, 0)); |
| for (Value* v : inputs()) { |
| replace_node->addInput(v); |
| } |
| for (Value* v : outputs()) { |
| auto new_out = replace_node->addOutput()->copyMetadata(v); |
| v->replaceAllUsesWith(new_out); |
| } |
| replace_node->copyMetadata(this); |
| replace_node->copyAttributes(*this); |
| TORCH_INTERNAL_ASSERT( |
| (replace_node->maybeOperator() != nullptr) == had_operator, |
| "invalid symbol replacement:", |
| new_symbol, |
| kind()); |
| return replace_node; |
| } |
| |
| bool Node::isDominatedBy(const Node* dominator) const { |
| const Node* node = this; |
| while (node) { |
| if (node->owningBlock() == dominator->owningBlock()) { |
| return dominator->isBefore(node); |
| } |
| node = node->owningBlock()->owningNode(); |
| } |
| return false; |
| } |
| |
| Value* Node::insertInput(size_t i, Value* value) { |
| AT_ASSERT(graph_ == value->owningGraph()); |
| op_ = nullptr; |
| // First we update the offsets for all existing inputs that will reside |
| // after the one we're inserting. Concretely, these are the inputs at |
| // indices [i, # input). Since we're inserting one input before all of |
| // these inputs, increment their use offsets for this value by 1 |
| for (const auto use_itr : c10::irange(i, inputs_.size())) { |
| // See Note [User node does not uniquely identify use] |
| auto use = findUseForInput(use_itr); |
| use->offset += 1; |
| } |
| // Insert the actual input at the specified index |
| inputs_.insert(inputs_.begin() + i, value); |
| // Register the new use of the value we're inserted as an input. |
| value->uses_.emplace_back(this, i); |
| return value; |
| } |
| |
| Value* Node::addInput(Value* value) { |
| AT_ASSERT(graph_ == value->owningGraph()); |
| op_ = nullptr; |
| value->uses_.emplace_back(this, inputs_.size()); |
| inputs_.push_back(value); |
| return value; |
| } |
| |
| Value* Node::replaceInput(size_t i, Value* newValue) { |
| AT_ASSERT(newValue->owningGraph() == graph_); |
| op_ = nullptr; |
| Value* old = dropInput(i); |
| inputs_[i] = newValue; |
| newValue->uses_.emplace_back(this, i); |
| return old; |
| } |
| |
| void Node::replaceInputWith(Value* from, Value* to) { |
| AT_ASSERT(from->owningGraph() == graph_); |
| AT_ASSERT(to->owningGraph() == graph_); |
| op_ = nullptr; |
| size_t i = 0; |
| for (auto input : inputs()) { |
| if (input == from) { |
| replaceInput(i, to); |
| } |
| i++; |
| } |
| } |
| |
| Value* Node::addOutput() { |
| outputs_.push_back(new Value(this, outputs_.size())); |
| op_ = nullptr; |
| return outputs_.back(); |
| } |
| |
| Value* Node::insertOutput(size_t i) { |
| op_ = nullptr; |
| outputs_.insert(outputs_.begin() + i, new Value(this, i)); |
| for (size_t itr = i + 1; itr < outputs_.size(); ++itr) { |
| outputs_[itr]->setOffset(outputs_[itr]->offset() + 1); |
| } |
| return outputs_.at(i); |
| } |
| |
| bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const { |
| if (this->owningBlock() == n->owningBlock()) { |
| if (moveSide == MoveSide::BEFORE) { |
| return this->topo_position_ < n->topo_position_; |
| } |
| |
| if (moveSide == MoveSide::AFTER) { |
| return this->topo_position_ > n->topo_position_; |
| } |
| |
| AT_ASSERT(this == n); |
| return false; |
| } |
| |
| // These nodes don't share a common block. Traverse the blockchains upward |
| // until we find the first common block. |
| auto lhs = this; |
| while (lhs) { |
| AT_ASSERT(lhs->owningBlock()); |
| |
| auto rhs = n; |
| while (rhs) { |
| if (!rhs->owningBlock()) { |
| break; |
| } |
| |
| if (lhs->owningBlock() == rhs->owningBlock()) { |
| return lhs->isBeforeOrAfter(rhs, moveSide); |
| } |
| rhs = rhs->owningBlock()->owningNode(); |
| } |
| |
| lhs = lhs->owningBlock()->owningNode(); |
| } |
| // should never reach here, since both nodes are ultimately in the same graph |
| AT_ASSERT(false); |
| } |
| |
| bool Node::isBefore(const Node* n) const { |
| return isBeforeOrAfter(n, MoveSide::BEFORE); |
| } |
| |
| bool Node::isAfter(const Node* n) const { |
| return isBeforeOrAfter(n, MoveSide::AFTER); |
| } |
| |
| Node* Node::insertBefore(Node* n) { |
| AT_ASSERT(n->inBlockList()); |
| insertAfter(n->prev()); |
| return this; |
| } |
| |
| Node* Node::insertAfter(Node* n) { |
| AT_ASSERT(!inBlockList() && n->inBlockList()); |
| AT_ASSERT(n->owningBlock()); |
| AT_ASSERTM( |
| n->kind() != prim::Return, |
| "Attempting to insert a Node after the Return node or before the Param node. Tried to insert", |
| *this, |
| " after ", |
| *n, |
| "."); |
| this->owning_block_ = n->owningBlock(); |
| Node* next = n->next(); |
| n->next() = this; |
| this->prev() = n; |
| this->next() = next; |
| next->prev() = this; |
| assignTopoPosition(); |
| return this; |
| } |
| |
| void Node::moveAfter(Node* n) { |
| removeFromList(); |
| insertAfter(n); |
| } |
| |
| void Node::moveBefore(Node* n) { |
| removeFromList(); |
| insertBefore(n); |
| } |
| |
| void Node::removeInput(size_t i) { |
| op_ = nullptr; |
| dropInput(i); |
| // everything after this input shifts left, |
| // so we need to update their use offsets to match |
| for (size_t j = i + 1; j < inputs_.size(); j++) { |
| auto it = findUseForInput(j); |
| it->offset--; |
| } |
| inputs_.erase(inputs_.begin() + i); |
| } |
| |
| void Node::removeAllInputs() { |
| op_ = nullptr; |
| for (const auto i : c10::irange(inputs().size())) { |
| dropInput(i); |
| } |
| inputs_.clear(); |
| } |
| |
| void Node::removeAllOutputs() { |
| op_ = nullptr; |
| size_t init_osize = outputs_.size(); |
| for (auto i : c10::irange(init_osize)) { |
| eraseOutput(init_osize - i - 1); |
| } |
| } |
| |
| void Node::permuteInputs(const std::vector<size_t>& new_order) { |
| op_ = nullptr; |
| AT_ASSERT(new_order.size() == inputs_.size()); |
| std::vector<Value*> new_inputs; |
| new_inputs.reserve(new_order.size()); |
| for (const auto i : c10::irange(new_order.size())) { |
| AT_ASSERTM(inputs_.at(new_order[i]) != nullptr, "Repeated index"); |
| new_inputs.push_back(inputs_.at(new_order[i])); |
| auto it = findUseForInput(new_order[i]); |
| it->offset = i; |
| inputs_.at(new_order[i]) = nullptr; |
| } |
| inputs_ = std::move(new_inputs); |
| } |
| |
| void Node::permuteOutputs(const std::vector<size_t>& new_order) { |
| op_ = nullptr; |
| AT_ASSERT(new_order.size() == outputs_.size()); |
| std::vector<Value*> new_outputs; |
| new_outputs.reserve(new_order.size()); |
| for (const auto i : c10::irange(new_order.size())) { |
| AT_ASSERTM(outputs_.at(new_order[i]) != nullptr, "Repeated index"); |
| new_outputs.push_back(outputs_.at(new_order[i])); |
| outputs_.at(new_order[i])->setOffset(i); |
| outputs_.at(new_order[i]) = nullptr; |
| } |
| outputs_ = std::move(new_outputs); |
| } |
| |
| use_list::iterator Node::findUseForInput(size_t i) { |
| auto& input_uses = inputs_[i]->uses_; |
| // O(N) on the use list, but unless we get nodes with +100 uses |
| // vector traversal still is probably faster than linked list |
| auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i)); |
| AT_ASSERT(use_it != input_uses.end()); |
| return use_it; |
| } |
| |
| Value* Node::dropInput(size_t i) { |
| AT_ASSERT(i < inputs_.size()); |
| auto input_node = inputs_[i]; |
| auto use_it = findUseForInput(i); |
| input_node->uses_.erase(use_it); |
| inputs_[i] = nullptr; |
| return input_node; |
| } |
| |
| void Node::removeFromList() { |
| AT_ASSERT(inBlockList()); |
| this->owning_block_ = nullptr; |
| Node* next = this->next(); |
| Node* prev = this->prev(); |
| prev->next() = next; |
| next->prev() = prev; |
| this->next() = nullptr; |
| this->prev() = nullptr; |
| } |
| |
| Block* Node::findCommonAncestorBlockWith(Node* n) { |
| if (n->owningBlock() == owningBlock()) { |
| return owningBlock(); |
| } |
| |
| Node* n1 = this; |
| Node* n2 = n; |
| |
| size_t d_1 = n1->blocksFromGraphBlock(); |
| size_t d_2 = n2->blocksFromGraphBlock(); |
| |
| for (; d_1 > d_2; --d_1) { |
| n1 = n1->owningBlock()->owningNode(); |
| // n2 contains n1 |
| } |
| |
| for (; d_2 > d_1; --d_2) { |
| n2 = n2->owningBlock()->owningNode(); |
| } |
| |
| // Now they are the same numer of blocks from the graph block, |
| // recurse upwards, checking if they are on the same block |
| while (true) { |
| if (n1->owningBlock() == n2->owningBlock()) { |
| return n1->owningBlock(); |
| } |
| |
| n1 = n1->owningBlock()->owningNode(); |
| n2 = n2->owningBlock()->owningNode(); |
| |
| AT_ASSERT(n1 != nullptr); |
| AT_ASSERT(n2 != nullptr); |
| } |
| } |
| |
| size_t Node::blocksFromGraphBlock() { |
| Node* n = this; |
| size_t dist = 0; |
| while (n->owningBlock()->owningNode()) { |
| n = n->owningBlock()->owningNode(); |
| ++dist; |
| } |
| return dist; |
| } |
| |
| inline const SourceRange& fakeRange() { |
| static SourceRange range(std::make_shared<Source>(std::string("")), 0, 1); |
| return range; |
| } |
| |
| Value* Graph::insert( |
| Symbol opname, |
| at::ArrayRef<NamedValue> args, |
| at::ArrayRef<NamedValue> kwargs, |
| const c10::optional<SourceRange>& range) { |
| return emitBuiltinCall( |
| range.value_or(fakeRange()), *this, opname, args, kwargs); |
| } |
| |
| Node* Graph::create(NodeKind kind, size_t num_outputs) { |
| // NB: Node constructor adds node to all_nodes |
| auto n = new Node(this, kind); |
| for (const auto i : c10::irange(num_outputs)) { |
| (void)i; |
| n->addOutput(); |
| } |
| return n; |
| } |
| |
| Node* Graph::create( |
| NodeKind kind, |
| ArrayRef<Value*> inputs, |
| size_t num_outputs) { |
| auto n = create(kind, num_outputs); |
| for (auto i : inputs) { |
| n->addInput(i); |
| } |
| return n; |
| } |
| |
| Node* Graph::createAutogradZero() { |
| return create(prim::AutogradZero); |
| } |
| |
| Node* Graph::createNone() { |
| Node* n = create(prim::Constant); |
| n->output()->setType(NoneType::get()); |
| return n; |
| } |
| |
| Node* Graph::createUninitialized(TypePtr typ) { |
| Node* n = create(prim::Uninitialized); |
| n->output()->setType(std::move(typ)); |
| return n; |
| } |
| |
| Node* Graph::createWithSubgraph(Symbol kind) { |
| auto n = create(kind, 0); |
| n->g_(attr::Subgraph, std::make_shared<Graph>(current_scope())); |
| return n; |
| } |
| |
| Node* Graph::createTuple(at::ArrayRef<Value*> values, TupleTypePtr tuple_type) { |
| TORCH_INTERNAL_ASSERT( |
| !tuple_type || tuple_type->schema(), |
| "only pass tuple_type when creating a named tuple"); |
| if (!tuple_type) { |
| auto types = fmap(values, [](Value* v) { return v->type(); }); |
| tuple_type = TupleType::create(std::move(types)); |
| } |
| auto n = create(prim::TupleConstruct, values); |
| |
| n->output()->setType(tuple_type); |
| return n; |
| } |
| |
| Node* Graph::createTupleUnpack(Value* v) { |
| TupleTypePtr tt = v->type()->expect<TupleType>(); |
| auto n = create(prim::TupleUnpack, {v}, 0); |
| for (auto& element : tt->elements()) { |
| n->addOutput()->setType(element); |
| } |
| return n; |
| } |
| |
| Node* Graph::createTupleIndex( |
| Value* tup, |
| Value* idx, |
| const TypePtr& output_type) { |
| auto n = create(prim::TupleIndex, {tup, idx}); |
| n->output()->setType(output_type); |
| return n; |
| } |
| |
| Node* Graph::createTupleSlice( |
| Value* tup, |
| int64_t beg, |
| int64_t step_size, |
| int64_t num_values) { |
| std::vector<Value*> new_vals; |
| TupleTypePtr tt = tup->type()->expect<TupleType>(); |
| new_vals.reserve(num_values); |
| |
| int64_t i = beg; |
| for (const auto j : c10::irange(num_values)) { |
| (void)j; // Suppress unused variable warning |
| auto idx = insertConstant(IValue(static_cast<int64_t>(i))); |
| auto tupleIndex = insertNode(createTupleIndex(tup, idx, tt->elements()[i])); |
| |
| new_vals.push_back(tupleIndex->output()); |
| i += step_size; |
| } |
| |
| auto n = createTuple(new_vals); |
| return n; |
| } |
| |
| Node* Graph::createEnumName(Value* e) { |
| e->type()->expect<EnumType>(); |
| assert(e->type()->cast<EnumType>()); |
| auto n = create(prim::EnumName, {e}); |
| n->output()->setType(StringType::get()); |
| return n; |
| } |
| |
| Node* Graph::createEnumValue(Value* e) { |
| auto enum_type = e->type()->expect<EnumType>(); |
| auto n = create(prim::EnumValue, {e}); |
| n->output()->setType(enum_type->getValueType()); |
| return n; |
| } |
| |
| Node* Graph::createList( |
| const TypePtr& contained_type, |
| at::ArrayRef<Value*> values) { |
| auto n = create(prim::ListConstruct, values); |
| for (const auto& v : values) { |
| TORCH_CHECK( |
| v->type()->isSubtypeOf(*contained_type), |
| "Expected a list element that subtypes '", |
| contained_type->repr_str(), |
| "' but got an element of type '", |
| v->type()->repr_str(), |
| "'"); |
| } |
| n->output()->setType(ListType::create(contained_type)); |
| return n; |
| } |
| |
| Node* Graph::createListUnpack(Value* v, size_t size) { |
| ListTypePtr list_type = v->type()->expect<ListType>(); |
| TypePtr elem_type = list_type->getElementType(); |
| auto n = create(prim::ListUnpack, {v}, 0); |
| for (const auto i : c10::irange(size)) { |
| (void)i; // Suppress unused variable warning |
| n->addOutput()->setType(elem_type); |
| } |
| return n; |
| } |
| |
| Node* Graph::createDict( |
| const TypePtr& key_type, |
| const TypePtr& value_type, |
| at::ArrayRef<Value*> keys, |
| at::ArrayRef<Value*> values) { |
| AT_ASSERT(keys.size() == values.size()); |
| auto n = create(prim::DictConstruct, 1); |
| for (const auto i : c10::irange(keys.size())) { |
| AT_ASSERT(keys[i]->type()->isSubtypeOf(*key_type)); |
| AT_ASSERT(values[i]->type()->isSubtypeOf(*value_type)); |
| |
| n->addInput(keys[i]); |
| n->addInput(values[i]); |
| } |
| n->output()->setType(DictType::create(key_type, value_type)); |
| return n; |
| } |
| |
| Node* Graph::createNumToTensor(Value* value) { |
| Node* result = create(prim::NumToTensor, {value}); |
| result->output()->setType(TensorType::fromNumberType(*value->type())); |
| return result; |
| } |
| |
| Node* Graph::createObject(const ClassTypePtr& type) { |
| auto result = create(prim::CreateObject); |
| result->output()->setType(type); |
| return result; |
| } |
| |
| Node* Graph::createSetAttr( |
| Value* obj, |
| const std::string& field, |
| Value* newValue) { |
| auto n = create(prim::SetAttr, {obj, newValue}, /*num_outputs=*/0); |
| n->s_(attr::name, field); |
| return n; |
| } |
| |
| Node* Graph::createGetAttr(Value* obj, const std::string& field) { |
| const auto classType = obj->type()->expect<ClassType>(); |
| |
| auto n = create(prim::GetAttr, {obj}, /*num_outputs=*/1); |
| n->s_(attr::name, field); |
| |
| const auto outputType = classType->getAttribute(field); |
| n->output()->setType(outputType); |
| n->output()->setDebugName(normalizeAttrName(field)); |
| return n; |
| } |
| |
| Node* Graph::createStore(const std::string& name, Value* v) { |
| auto n = create(prim::Store, {v}, /*num_outputs*/ 0); |
| n->s_(attr::name, name); |
| return n; |
| } |
| |
| Node* Graph::createLoad(const std::string& name, const TypePtr& type) { |
| auto n = create(prim::Load, {}, /*num_outputs*/ 1); |
| n->s_(attr::name, name); |
| n->output()->setType(type); |
| return n; |
| } |
| |
| Node* Graph::createIsInstance(Value* v, at::ArrayRef<TypePtr> types) { |
| auto n = create(prim::isinstance, {v}, /*num_outputs*/ 1); |
| n->tys_(attr::types, types.vec()); |
| n->output()->setType(BoolType::get()); |
| return n; |
| } |
| Value* Graph::insertUncheckedCast(Value* v, TypePtr type) { |
| Node* n = insertNode(create(prim::unchecked_cast, {v})); |
| n->output()->setType(std::move(type)); |
| return n->output(); |
| } |
| |
| Value* Graph::insertToList(Value* v, TypePtr type) { |
| int dim = 0; |
| TypePtr ptr = type; |
| |
| // Unwrap the type to determine the number of dimensions. |
| while (auto list_type = ptr->cast<ListType>()) { |
| ptr = list_type->getElementType(); |
| ++dim; |
| } |
| |
| // Encode the base element type as an integer. |
| int elem_ty = 0; |
| if (ptr == IntType::get()) { |
| elem_ty = 0; |
| } else if (ptr == FloatType::get()) { |
| elem_ty = 1; |
| } else if (ptr == BoolType::get()) { |
| elem_ty = 2; |
| } else if (ptr == ComplexType::get()) { |
| elem_ty = 3; |
| } else { |
| TORCH_CHECK( |
| false, |
| ptr->repr_str(), |
| " is not one of the supported element types for tolist: int, float, complex, bool"); |
| } |
| |
| // Pass in the number of dimensions and base element type as arguments |
| // to the op. |
| Value* dim_val = insertConstant(IValue(dim)); |
| Value* elem_ty_val = insertConstant(IValue(elem_ty)); |
| Node* n = insertNode(create(prim::tolist, {v, dim_val, elem_ty_val})); |
| n->output()->setType(std::move(type)); |
| return n->output(); |
| } |
| |
| Value* Graph::insertFunctionCall( |
| Function* callee, |
| const MatchedSchema& matched) { |
| std::string func_name = callee->name(); |
| Value* fn_constant = insertNode(create(prim::Constant)) |
| ->s_(attr::name, func_name) |
| ->output() |
| ->setType(FunctionType::create(callee)); |
| std::vector<Value*> inputs = {fn_constant}; |
| inputs.insert(inputs.end(), matched.inputs.begin(), matched.inputs.end()); |
| Value* result = insertNode(create(prim::CallFunction, inputs)) |
| ->output() |
| ->setType(matched.return_types.at(0)); |
| return result; |
| } |
| |
| Value* Graph::insertMethodCall( |
| std::string method_name, |
| const MatchedSchema& matched) { |
| Value* result = insertNode(create(prim::CallMethod, matched.inputs)) |
| ->s_(attr::name, std::move(method_name)) |
| ->output() |
| ->setType(matched.return_types.at(0)); |
| return result; |
| } |
| |
| Node* Graph::createClone( |
| Node* n, |
| const std::function<Value*(Value*)>& value_map, |
| bool copy_blocks) { |
| // n can be from a different graph |
| Node* r = n->allocNewInstance(this); |
| for (auto o : n->outputs()) { |
| r->addOutput()->copyMetadata(o); |
| } |
| r->cloneFrom(n); |
| for (auto i : n->inputs()) { |
| r->addInput(value_map(i)); |
| } |
| if (copy_blocks) { |
| for (auto b : n->blocks()) { |
| r->addBlock()->cloneFrom(b, value_map); |
| } |
| } |
| return r; |
| } |
| |
| Value* Graph::insertConstant( |
| const IValue& val, |
| c10::optional<SourceRange> loc, |
| c10::optional<ScopePtr> scope) { |
| return jit::insertConstant(*this, val, std::move(loc), std::move(scope)); |
| } |
| |
| std::string Graph::toString(bool print_source_locations) const { |
| std::ostringstream oss; |
| print(oss, print_source_locations); |
| return oss.str(); |
| } |
| |
| Graph::~Graph() { |
| for (const Node* n : all_nodes) { |
| delete n; |
| } |
| for (const Value* v : all_values) { |
| delete v; |
| } |
| for (const Block* b : all_blocks) { |
| delete b; |
| } |
| } |
| |
| void Graph::freeNode(Node* n) { |
| auto it = all_nodes.find(n); |
| AT_ASSERT(it != all_nodes.end()); |
| delete *it; |
| all_nodes.erase(it); |
| } |
| void Graph::freeValue(Value* v) { |
| v->setDebugName(""); |
| auto it = all_values.find(v); |
| AT_ASSERT(it != all_values.end()); |
| delete *it; |
| all_values.erase(it); |
| } |
| void Graph::freeBlock(Block* b) { |
| auto it = all_blocks.find(b); |
| AT_ASSERT(it != all_blocks.end()); |
| delete *it; |
| all_blocks.erase(it); |
| } |
| |
| at::ArrayRef<Value*> createTupleUnpack(Value* v) { |
| // small peephole optimization to ensure IntArrayRef attributes can still turn |
| // into constants e.g. in x.expand([3, 4]) |
| if (v->node()->kind() == prim::TupleConstruct) { |
| return v->node()->inputs(); |
| } |
| auto& g = *v->owningGraph(); |
| return g.insertNode(g.createTupleUnpack(v))->outputs(); |
| } |
| |
| void inlineCallStackOfNode( |
| Node* n, |
| std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries, |
| Function* callee, |
| Node* to_replace, |
| c10::optional<ModuleInstanceInfo> m_info); |
| |
| void inlineCallStackOfBlock( |
| Block* b, |
| std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries, |
| Function* callee, |
| Node* to_replace, |
| c10::optional<ModuleInstanceInfo> m_info) { |
| for (auto n : b->nodes()) { |
| inlineCallStackOfNode(n, new_cs_entries, callee, to_replace, m_info); |
| } |
| } |
| |
| void inlineCallStackOfNode( |
| Node* new_node, |
| std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries, |
| Function* callee, |
| Node* to_replace, |
| c10::optional<ModuleInstanceInfo> m_info) { |
| auto new_node_cs = new_node->callstack(); |
| |
| InlinedCallStack* raw_callstack_ptr = |
| new_node_cs ? new_node_cs->get() : nullptr; |
| |
| if (!new_cs_entries.count(raw_callstack_ptr)) { |
| if (new_node_cs) { |
| new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>( |
| *new_node_cs, callee, to_replace->sourceRange(), m_info); |
| } else { |
| new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>( |
| callee, to_replace->sourceRange(), m_info); |
| } |
| } |
| new_node->setCallStack(new_cs_entries.at(raw_callstack_ptr)); |
| // We updated the inlined callstack of new_node. |
| // Same must be done for the nodes of the blocks of new_node. |
| // For example If node's block otherwise is not annotated appropriately. |
| for (auto block : new_node->blocks()) { |
| inlineCallStackOfBlock(block, new_cs_entries, callee, to_replace, m_info); |
| } |
| } |
| |
| std::vector<Value*> inlineCallTo( |
| Node* to_replace, |
| GraphFunction* callee, |
| Graph* callee_graph) { |
| WithInsertPoint guard(to_replace); |
| std::unordered_map<Value*, Value*> value_map; |
| std::vector<torch::jit::Value*> new_outputs = insertGraph( |
| *to_replace->owningGraph(), |
| *callee_graph, |
| to_replace->inputs(), |
| value_map); |
| |
| std::unordered_map<InlinedCallStack*, InlinedCallStackPtr> |
| new_callstack_entries; |
| |
| c10::optional<ModuleInstanceInfo> module_instance_info = c10::nullopt; |
| if (to_replace->kind() == prim::CallMethod) { |
| auto class_type_ptr = to_replace->input(0)->type()->cast<c10::ClassType>(); |
| if (to_replace->input(0)->node()->kind() == prim::GetAttr) { |
| module_instance_info = c10::make_optional(ModuleInstanceInfo( |
| class_type_ptr, to_replace->input(0)->node()->s(attr::name))); |
| } else if ( |
| to_replace->owningGraph()->inputs().size() > 0 && |
| to_replace->input(0) == to_replace->owningGraph()->inputs()[0]) { |
| // This CallMethod must correspond to method of the same object |
| // to which this graph belongs. |
| module_instance_info = |
| c10::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF")); |
| } else { |
| // Not sure if it is possible to come here ever. |
| // TODO: Remove this else. Or add assert |
| module_instance_info = c10::make_optional( |
| ModuleInstanceInfo(class_type_ptr, "INSTANCE_NAME_UNKNOWN")); |
| } |
| } |
| |
| // TODO: We might need to use nodes_map instead of value_map. Otherwise, we |
| // are missing nodes without outputs (e.g. prim::Print). |
| std::unordered_set<Node*> updated_nodes; |
| for (const auto& kv : value_map) { |
| /* Skip the old value if it is the graph input. |
| * The reason is that, value_map contains values not all for the nodes of |
| * the graph but primary inputs as well, and it will create duplicates when |
| * the first inlined graph is input to the next one. To avoid this issue, |
| * skip the old value when it is one of the |
| * callee->optimized_graph()->inputs() or callee->graph()->inputs(), depends |
| * on if it is inlined_optimized_graph |
| */ |
| auto is_graph_input = std::find( |
| callee_graph->inputs().begin(), callee_graph->inputs().end(), kv.first); |
| if (is_graph_input != callee_graph->inputs().end()) { |
| continue; |
| } |
| |
| Node* new_node = kv.second->node(); |
| if (!updated_nodes.insert(new_node).second) { |
| continue; |
| } |
| |
| inlineCallStackOfNode( |
| new_node, |
| new_callstack_entries, |
| callee, |
| to_replace, |
| module_instance_info); |
| } |
| const auto& old_outputs = to_replace->outputs(); |
| |
| AT_ASSERT(new_outputs.size() == old_outputs.size()); |
| for (const auto i : c10::irange(old_outputs.size())) { |
| if (old_outputs[i]->hasDebugName()) { |
| new_outputs[i]->setDebugName(old_outputs[i]->debugName()); |
| } |
| old_outputs[i]->replaceAllUsesWith(new_outputs[i]); |
| } |
| to_replace->destroy(); |
| |
| return new_outputs; |
| } |
| |
| // inline_optimized_graph argument is used in substitute function call for |
| // ONNX conversion |
| std::vector<Value*> inlineCallTo( |
| Node* to_replace, |
| GraphFunction* callee, |
| bool inline_optimized_graph /*=true*/) { |
| auto graph = |
| inline_optimized_graph ? callee->optimized_graph() : callee->graph(); |
| return inlineCallTo(to_replace, callee, graph.get()); |
| } |
| |
| std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs) { |
| std::vector<Value*> new_outputs; |
| if (outputs.size() != 1 || outputs.at(0)->type()->kind() != TupleType::Kind) { |
| return outputs; |
| } |
| |
| auto tup = outputs[0]; |
| for (Value* v : createTupleUnpack(tup)) { |
| new_outputs.emplace_back(v); |
| } |
| // if this was a peephole tuple unpack we can just get rid of |
| // the tuple construct here and prevent needing DCE |
| if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) { |
| tup->node()->destroy(); |
| } |
| return new_outputs; |
| } |
| |
| std::vector<Node*> findAllNodes( |
| at::ArrayRef<Block*> array, |
| Symbol kind, |
| bool recurse) { |
| std::vector<Node*> ret; |
| for (auto block : array) { |
| findAllNodes(*block, kind, recurse, ret); |
| } |
| return ret; |
| } |
| |
| std::vector<Node*> findAllNodes(Block& block, Symbol kind, bool recurse) { |
| return findAllNodes({&block}, kind, recurse); |
| } |
| |
| std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse) { |
| return findAllNodes(*g.block(), kind, recurse); |
| } |
| |
| std::vector<Value*> insertGraph( |
| Graph& g, |
| Graph& callee, |
| ArrayRef<Value*> inputs, |
| std::unordered_map<Value*, Value*>& value_map) { |
| auto value_map_func = [&](Value* v) { return value_map.at(v); }; |
| AT_ASSERT(callee.inputs().size() == inputs.size()); |
| for (const auto i : c10::irange(inputs.size())) { |
| value_map[callee.inputs()[i]] = inputs[i]; |
| } |
| for (auto* node : callee.nodes()) { |
| auto* new_node = g.insertNode(g.createClone(node, value_map_func)); |
| for (size_t i = 0; i < node->outputs().size(); ++i) { |
| value_map[node->outputs()[i]] = new_node->outputs()[i]; |
| } |
| } |
| |
| std::vector<Value*> outputs; |
| for (auto* output : callee.outputs()) { |
| outputs.push_back(value_map_func(output)); |
| } |
| |
| return outputs; |
| } |
| |
| std::vector<Value*> insertGraph( |
| Graph& g, |
| Graph& callee, |
| ArrayRef<Value*> inputs) { |
| std::unordered_map<Value*, Value*> value_map; |
| return insertGraph(g, callee, inputs, value_map); |
| } |
| |
| void ProfileOp::cloneFrom(Node* other_) { |
| Node::cloneFrom(other_); |
| auto other = other_->cast<ProfileOp>(); |
| this->callback_ = other->getCallback(); |
| } |
| |
| Node* ProfileOp::allocNewInstance(Graph* g) { |
| return new ProfileOp(g, {nullptr}); |
| } |
| |
| void ProfileIValueOp::cloneFrom(Node* other_) { |
| Node::cloneFrom(other_); |
| auto other = other_->cast<ProfileIValueOp>(); |
| this->callback_ = other->getCallback(); |
| } |
| |
| Node* ProfileIValueOp::allocNewInstance(Graph* g) { |
| return new ProfileIValueOp(g, {nullptr}); |
| } |
| |
| TypePtr NamedValue::type() const { |
| if (value_) { |
| return value_->type(); |
| } else { |
| return ivalue_.type(); |
| } |
| } |
| |
| const Symbol ProfileOp::Kind = ::c10::prim::profile; |
| const Symbol ProfileIValueOp::Kind = ::c10::prim::profile_ivalue; |
| |
| OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) { |
| insert(sig_literals); |
| } |
| |
| std::vector<std::shared_ptr<Operator>> OperatorSet::getOps() const { |
| std::vector<std::shared_ptr<Operator>> result; |
| for (const auto& kv : ops) { |
| auto ops_for_symbol = kv.second; |
| result.insert(result.end(), ops_for_symbol.begin(), ops_for_symbol.end()); |
| } |
| return result; |
| } |
| |
| void OperatorSet::insert(std::initializer_list<const char*> sig_literals) { |
| for (const char* sig : sig_literals) { |
| auto op = getOperatorForLiteral(sig); |
| ops[Symbol::fromQualString(op->schema().name())].push_back(op); |
| } |
| } |
| |
| bool Node::isMemberOf(const OperatorSet& os) const { |
| auto it = os.ops.find(kind()); |
| if (it == os.ops.end()) { |
| return false; |
| } |
| for (auto& op : it->second) { |
| if (matches(op->schema())) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| } // namespace jit |
| } // namespace torch |