| #include <torch/csrc/jit/ir.h> |
| |
| |
| #include <torch/csrc/jit/operator.h> |
| #include <torch/csrc/autograd/function.h> |
| #include <torch/csrc/jit/constants.h> |
| #include <torch/csrc/jit/assertions.h> |
| #include <torch/csrc/jit/script/schema_matching.h> |
| #include <torch/csrc/jit/passes/python_print.h> |
| #include <torch/csrc/jit/passes/alias_analysis.h> |
| |
| #include <algorithm> |
| #include <iostream> |
| #include <set> |
| #include <sstream> |
| #include <stack> |
| #include <string> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <utility> |
| |
| namespace torch { namespace jit { |
| // Constants relating to maintaining the topological index of nodes. |
| // |
| // Lower and upper bounds of the index. Inclusive range. |
| static constexpr topo_position_t kLowerBound = INT64_MIN; |
| static constexpr topo_position_t kUpperBound = INT64_MAX; |
| static constexpr topo_position_t kMidPoint = 0; |
| // How far away to space nodes that are appended to the graph. |
| // should be 2^n, where: |
| // - n is the maximum number of repeated insertions without a re-index |
| // - 2^(64-n) is the maximum number of appends to the end without reindex |
| static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */; |
| |
| // Sigh, see https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char |
| constexpr Symbol PythonOp::Kind; |
| |
| void printValueRef(std::ostream & out, const Value * n) { |
| out << "%" << n->uniqueName(); |
| } |
| |
| // NB: This overload will become ambiguous with the one Caffe2 provides in its |
| // logging, if they ever intersect. |
| template <typename T> |
| std::ostream& operator<<(std::ostream & out, const std::vector<T> & nodes) { |
| out << at::ArrayRef<T>{nodes}; |
| return out; |
| } |
| |
| template <typename T> |
| std::ostream& printValueRefs(std::ostream & out, const at::ArrayRef<T> & nodes) { |
| size_t i = 0; |
| for(auto n : nodes) { |
| if(i++ > 0) |
| out << ", "; |
| printValueRef(out, n); |
| } |
| return out; |
| } |
| |
| // Can't make these two overloads directly a template, it'll be ambiguous with |
| // the global printer for operator<<. |
| |
| std::ostream& operator<<(std::ostream & out, const at::ArrayRef<const Value*> & nodes) { |
| return printValueRefs(out, nodes); |
| } |
| |
| std::ostream& operator<<(std::ostream & out, const at::ArrayRef<Value*> & nodes) { |
| return printValueRefs(out, nodes); |
| } |
| |
| struct const_value_list_with_types { |
| const ArrayRef<const Value*> values; |
| bool use_newlines; |
| const_value_list_with_types(ArrayRef<const Value*> values, bool use_newlines = false) |
| : values(values), use_newlines(use_newlines) {} |
| }; |
| std::ostream& operator<<(std::ostream & out, const_value_list_with_types l) { |
| size_t i = 0; |
| for(auto n : l.values) { |
| if(i++ > 0) { |
| if (l.use_newlines) { |
| // TODO: Indent here is hard-coded for "graph(": un-hard-code it |
| out << "\n "; |
| } else { |
| out << ", "; |
| } |
| } |
| printValueRef(out, n); |
| out << " : "; |
| out << *n->type(); |
| } |
| return out; |
| } |
| |
| void printAttributes(std::ostream & out, const Node * n, bool ignore_subgraph=false) { |
| out << "["; |
| auto names = n->attributeNames(); |
| int i = 0; |
| for(auto name : names) { |
| if (ignore_subgraph && name == attr::Subgraph) |
| continue; |
| if(i++ > 0) |
| out << ", "; |
| // TODO: debugging mode to see the qualifier. We definitely |
| // don't want to print the qualifier since it should always |
| // be attribute, but you might be able to track down a weird |
| // bug by printing it out. |
| out << name.toUnqualString() << "="; |
| |
| n->printValue(out, name); |
| } |
| out << "]"; |
| } |
| |
| static std::ostream & indent(std::ostream & out, size_t level) { |
| for(size_t i = 0; i < level; ++i) |
| out << " "; |
| return out; |
| } |
| |
| std::ostream& printNode(std::ostream & out, size_t level, const Node * n, std::vector<const Node*> * groups) { |
| auto outputs = n->outputs(); |
| indent(out, level) << const_value_list_with_types(outputs); |
| out << " = "; |
| IR_IFM_CONST(n,PythonOp) |
| out << "^" << value->name(); |
| value->writeScalars(out); |
| IR_ELSE() |
| if(n->hasAttribute(attr::Subgraph) && groups) { |
| out << n->kind().toQualString() << "_" << groups->size(); |
| if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) { |
| printAttributes(out, n, /*ignore_subgraph=*/true); |
| } |
| groups->push_back(n); |
| } else { |
| out << n->kind().toQualString(); |
| if(n->hasAttributes()) { |
| printAttributes(out,n); |
| } |
| } |
| IR_END() |
| out << "(" << n->inputs() << ")"; |
| std::string scopeName = n->scopeName(); |
| if (scopeName.empty()) { |
| out << "\n"; |
| } |
| else { |
| out << ", "; |
| out << "scope: " << scopeName << "\n"; |
| } |
| for(size_t i = 0; i < n->blocks().size(); ++i) { |
| auto b = n->blocks()[i]; |
| indent(out, level + 1) << "block" << i << "(" << const_value_list_with_types(b->inputs(), false) << ") {\n"; |
| for(auto n : b->nodes()) { |
| printNode(out, level + 2, n, groups); |
| } |
| indent(out, level + 2) << "-> (" << b->outputs() << ")\n"; |
| indent(out, level + 1) << "}\n"; |
| } |
| return out; |
| } |
| |
| std::ostream& operator<<(std::ostream & out, const Node & n) { |
| return printNode(out, 0, &n, nullptr); |
| } |
| |
| std::ostream& operator<<(std::ostream & out, const Graph & g) { |
| out << "graph(" << const_value_list_with_types(g.inputs(), true) << ") {\n"; |
| std::vector<const Node*> groups; |
| for(auto n : g.nodes()) { |
| printNode(out, 1, n, &groups); |
| } |
| out << " return (" << g.outputs() << ");\n}\n"; |
| size_t i = 0; |
| for(auto fg : groups) { |
| out << "with " << fg->kind().toQualString() << "_" <<i++ << " = " << *fg->g(attr::Subgraph); |
| } |
| /* |
| // Uncomment this to debug all_nodes issues |
| { |
| out << "\n"; |
| out << "all_nodes:\n"; |
| for (auto& n : g.all_nodes) { |
| printNode(out, const_cast<Node*>(n), nullptr); |
| } |
| } |
| */ |
| return out; |
| } |
| |
| std::ostream& Graph::prettyPrint(std::ostream & out) { |
| std::vector<at::Tensor> tensor_table; |
| PythonPrint(out, *this, tensor_table); |
| return out; |
| } |
| |
| void Graph::dumpPretty() { |
| std::vector<at::Tensor> tensor_table; |
| PythonPrint(std::cout, *this, tensor_table); |
| } |
| |
| static void checkSameDevice(const Node* node) { |
| bool has_device = false; |
| c10::optional<at::Device> device = c10::nullopt; |
| auto checkValue = [&](const Value* v) { |
| if(CompleteTensorTypePtr type = v->type()->cast<CompleteTensorType>()) { |
| if(!has_device) { |
| has_device = true; |
| device = type->device(); |
| } else { |
| JIT_ASSERT(device == type->device()); |
| } |
| } |
| }; |
| for(auto input : node->inputs()) { |
| checkValue(input); |
| } |
| for(auto output : node->outputs()) { |
| checkValue(output); |
| } |
| } |
| |
| using node_set = std::set<const Node*>; |
| #define ALL_OF(container) container.begin(), container.end() |
| |
| // These functions purposely operate on the internal members directly, to force |
| // you to think about how the invariants change if you change the data |
| // representation (even if the external API does not change.) |
| |
| // NB: This assert is written to assume you don't have any unattached |
| // nodes. Unattached nodes can occur while manipulations to the |
| // graph are occurring. |
| void Node::lint() const { |
| // Node invariants |
| // - if node should live in list, nodes_iter is consistent |
| // - Inputs are all marked as a use by the nodes they refer to |
| // - Owning graph is non-null and consistent |
| // - The "Select" invariant, when the node is MultiReturn |
| // |
| // The handle invariant: |
| // If a node takes a handle as an input, it is always the |
| // LAST input of the node. There is at most one handle input. |
| |
| { |
| size_t i = 0; |
| for (auto input : inputs_) { |
| // WARNING: O(n^2) |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
| JIT_ASSERT(std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) != input->uses_.end()); |
| JIT_ASSERT(graph_->all_nodes.count(this) == 1); |
| i++; |
| } |
| } |
| |
| for(auto o : outputs()) { |
| size_t i = 0; |
| for (auto use : o->uses()) { |
| // Use invariants |
| // - Use is consistent with inputs |
| // - Every user node is live (checked in Graph) |
| JIT_ASSERT(use.user->inputs_[use.offset] == o); |
| i++; |
| } |
| } |
| |
| // Node subclass invariants |
| IR_IF(this,Constant) |
| JIT_ASSERT(inputs_.size() == 0); |
| IR_ELSEIF(Return) |
| // Return uses is zero |
| JIT_ASSERT(outputs().size() == 0); |
| IR_ELSEIF(Param) |
| // Param inputs is zero |
| JIT_ASSERT(inputs_.size() == 0); |
| IR_ELSEIFM_CONST(PythonOp) |
| // Python operator cconv is correct |
| size_t n_scalars = 0, n_tensors = 0; |
| for (auto c : value->cconv) { |
| if (c == 'c') { |
| n_scalars++; |
| } else if (c == 'd') { |
| n_tensors++; |
| } else { |
| JIT_ASSERT(0); |
| } |
| JIT_ASSERT(static_cast<bool>(value->pyobj)); |
| } |
| JIT_ASSERT(n_scalars == value->scalar_args.size()); |
| JIT_ASSERT(n_tensors == inputs_.size()); |
| IR_ELSEIF(Eval) |
| // TODO: add invariants |
| // TODO: It's not good for these ops to be top-level, it makes cases longer. |
| IR_ELSEIF(FusionGroup) |
| checkSameDevice(value); |
| // TODO: Typecheck the parameters |
| value->g(attr::Subgraph)->lint(); |
| IR_END() |
| |
| } |
| |
| // TODO: When lint fails, give better indication about which |
| // instruction triggered the failure. |
| void Graph::lint() const { |
| // Graph invariants |
| |
| // Uncomment the following to see the graph |
| // std::cout << *const_cast<Graph*>(this); |
| |
| // nodes |
| // - nodes_ is a valid topological ordering for inputs |
| // - No repeated nodes |
| // - Params and return do NOT occur in nodes |
| // - next_unique_ is greater than all uniques in graph |
| // - uniques in all_nodes are unique |
| // - every use will occur later in the topsort |
| |
| struct LintScope { |
| LintScope() = default; |
| LintScope(std::unique_ptr<LintScope> parent) |
| : parent(std::move(parent)) {} |
| bool contains(const Value * v) { |
| return values.count(v) > 0 || (parent && parent->contains(v)); |
| } |
| bool contains(const Node * n) { |
| return nodes.count(n) > 0 || (parent && parent->contains(n)); |
| } |
| void insert(const Value * v) { |
| JIT_ASSERT(!contains(v)); |
| values.insert(v); |
| } |
| void insert(const Node * n) { |
| JIT_ASSERT(!contains(n)); |
| nodes.insert(n); |
| } |
| std::unique_ptr<LintScope> parent; |
| private: |
| std::unordered_set<const Value*> values; |
| std::unordered_set<const Node*> nodes; |
| }; |
| // Struct enables mutual recursion in linting methods. |
| // Putting it inside Graph::lint enables access to private Graph members |
| struct LintImpl { |
| LintImpl(const Graph & g) |
| : g(g) |
| , scope(new LintScope()) |
| , all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered* |
| const Graph & g; |
| std::unique_ptr<LintScope> scope; |
| std::unordered_set<size_t> seen_uniques; |
| std::unordered_map<const Node*, int64_t> anticipated_uses; |
| node_set all_nodes_set; |
| node_set sum_set; |
| |
| void check_value(const Value* v) { |
| scope->insert(v); |
| auto b2 = seen_uniques.insert(v->unique()); |
| JIT_ASSERT(b2.second); // insertion took place |
| JIT_ASSERT(v->unique() < g.next_unique_); |
| |
| for (auto use : v->uses()) { |
| JIT_ASSERT(!scope->contains(use.user)); |
| JIT_ASSERT(g.all_nodes.count(use.user) == 1); |
| anticipated_uses[use.user]++; // int default constructs to 0 |
| } |
| } |
| void check_node(const Node* n) { |
| for (auto input : n->inputs_) { |
| if (!scope->contains(input)) { |
| JIT_ASSERTM(0, input->unique(), " not in scope"); |
| } |
| } |
| JIT_ASSERT(anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size())); |
| anticipated_uses[n] = -1; // we saw the anticipated user! |
| scope->insert(n); |
| for(auto block : n->blocks()) { |
| std::unique_ptr<LintScope> new_scope(new LintScope(std::move(scope))); |
| scope = std::move(new_scope); |
| check_block(block); |
| scope = std::move(scope->parent); |
| } |
| size_t i = 0; |
| for(auto o : n->outputs()) { |
| JIT_ASSERT(o->node() == n); |
| JIT_ASSERT(i++ == o->offset_); |
| check_value(o); |
| } |
| n->lint(); |
| } |
| void check_block(const Block *b) { |
| // Check topological ordering |
| JIT_ASSERT(b->param_node()->isBefore(*b->nodes().begin())); |
| auto curNode = *b->nodes().begin(); |
| while (curNode != b->return_node()) { |
| JIT_ASSERT(curNode->isBefore(curNode->next())); |
| curNode = curNode->next(); |
| } |
| |
| for (auto input : b->inputs()) { |
| check_value(input); |
| JIT_ASSERT(input->node()->kind_ == prim::Param); |
| } |
| |
| for (auto n : b->nodes()) { |
| JIT_ASSERT(n->kind_ != prim::Param); |
| JIT_ASSERT(n->kind_ != prim::Return); |
| check_node(n); |
| } |
| |
| JIT_ASSERT(b->output_->kind() == prim::Return); |
| check_node(b->output_); |
| |
| // all_nodes |
| // - inputs_, output_ and nodes_ are all included in all_nodes |
| // - all_nodes does not contain dead nodes??? (likely to be temporarily |
| // suspended). Weaker: all_nodes contains all inputs and returns |
| // - only one return node??? |
| |
| node_set nodes_set(ALL_OF(b->nodes())); |
| node_set inputs_set {b->input_}; |
| node_set output_set {b->output_}; |
| // TODO: Make a more type safe std::includes wrapper which disallows use on |
| // non-ordered containers |
| JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set))); |
| JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set))); |
| JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set))); |
| |
| sum_set.insert(ALL_OF(nodes_set)); |
| sum_set.insert(ALL_OF(inputs_set)); |
| sum_set.insert(ALL_OF(output_set)); |
| } |
| void check_graph() { |
| node_set all_nodes_set(ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered* |
| |
| check_block(g.block_); |
| for (auto kv : anticipated_uses) { |
| JIT_ASSERT(kv.second == -1); |
| } |
| JIT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set))); |
| } |
| }; |
| LintImpl(*this).check_graph(); |
| } |
| |
| void Graph::dump() const { |
| std::cout << *this << "\n"; |
| } |
| |
| void LintGraph(std::shared_ptr<Graph>& graph) { |
| graph->lint(); |
| } |
| |
| Block::Block(Graph* graph_, Node* node_) |
| : graph_(graph_), |
| output_(initOutput(graph_->create(prim::Return, 0))), |
| input_(graph_->create(prim::Param, 0)), |
| owning_node_(node_) { |
| graph_->all_blocks.emplace(this); |
| output_->owning_block_ = this; |
| output_->topo_position_ = kUpperBound; |
| input_->owning_block_ = this; |
| input_->topo_position_ = kLowerBound; |
| } |
| |
| void Block::reIndexTopology() { |
| auto curPos = kLowerBound; |
| for (auto node : nodes()) { |
| AT_ASSERT(curPos <= (kUpperBound - kAppendInterval)); |
| curPos += kAppendInterval; |
| node->topo_position_ = curPos; |
| } |
| } |
| |
| void Block::cloneFrom(Block * src, std::function<Value*(Value*)> value_map) { |
| std::unordered_map<Value*, Value*> local_map; |
| auto env = [&](Value * v) { |
| auto it = local_map.find(v); |
| if(it != local_map.end()) |
| return it->second; |
| return value_map(v); |
| }; |
| |
| auto graph = owningGraph(); |
| for(auto input : src->inputs()) { |
| local_map[input] = this->addInput()->copyMetadata(input); |
| } |
| |
| for(auto node : src->nodes()) { |
| auto new_node = this->appendNode(graph->createClone(node, env)); |
| for(size_t i = 0; i < node->outputs().size(); ++i) { |
| auto oo = node->outputs()[i]; |
| auto no = new_node->outputs()[i]; |
| local_map[oo] = no; |
| no->copyMetadata(oo); |
| } |
| } |
| for(auto output : src->outputs()) { |
| this->registerOutput(env(output)); |
| } |
| } |
| |
| void Block::destroy() { |
| // we cannot destroy the output because it is used as the sentinel |
| // for the nodes() list and has to remain valid for the loop |
| output_->removeAllInputs(); |
| for(auto it = this->nodes().reverse().begin(), |
| end = this->nodes().reverse().end(); |
| it != end; ++it) { |
| it.destroyCurrent(); |
| } |
| output_->destroy(); |
| input_->destroy(); |
| graph_->freeBlock(this); |
| } |
| |
| std::shared_ptr<Graph> Graph::copy() { |
| auto new_g = std::make_shared<Graph>(); |
| auto env = [](Value* v) -> Value* { |
| AT_ERROR( |
| "Graph::copy() encountered a use of a value not in scope. Run lint!"); |
| }; |
| new_g->block()->cloneFrom(this->block(), env); |
| return new_g; |
| } |
| |
| bool Value::mustBeNone() const { |
| return node_->kind() == prim::None; |
| } |
| |
| std::string Value::uniqueNameBase() const { |
| std::string name = uniqueName(); |
| std::string name_base = name; |
| auto last_dot_pos = name.find_last_of('.'); |
| if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) { |
| if (name.find_first_not_of("0123456789", last_dot_pos + 1) == std::string::npos) { |
| name_base = name.substr(0, last_dot_pos); |
| } |
| } |
| return name_base; |
| } |
| |
| Value* Value::setUniqueName(const std::string & name) { |
| if (name.size() > 0 && name.find_first_not_of("0123456789") == std::string::npos) { |
| throw std::runtime_error("names may not be integers: " + name); |
| } |
| |
| auto & names = node()->owningGraph()->unique_names_; |
| |
| // clear any old name from the map |
| if(hasUniqueName()) { |
| names.erase(unique_name_); |
| unique_name_ = ""; |
| } |
| |
| // allow "" to clear the uniquename |
| if(name == "") |
| return this; |
| |
| // if someone else has this name, then rename the other value |
| auto old_owner_of_name = names.find(name); |
| if(old_owner_of_name != names.end()) { |
| size_t suffix = 1; |
| std::string name_base = name; |
| auto last_dot_pos = name.find_last_of('.'); |
| if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) { |
| if (name.find_first_not_of("0123456789", last_dot_pos + 1) == std::string::npos) { |
| suffix = std::stoll(name.substr(last_dot_pos + 1)); |
| name_base = name.substr(0, last_dot_pos); |
| } |
| } |
| std::string replacement_name; |
| do { |
| std::stringstream ss; |
| ss << name_base << "." << suffix++; |
| replacement_name = ss.str(); |
| } while(names.count(replacement_name) > 0); |
| old_owner_of_name->second->setUniqueName(replacement_name); |
| } |
| |
| names[name] = this; |
| unique_name_ = name; |
| return this; |
| } |
| |
| Value* Value::copyMetadata(Value * from) { |
| setType(from->type()); |
| if (from->hasUniqueName()) |
| setUniqueName(from->uniqueName()); |
| return this; |
| } |
| |
| void Value::replaceFirstUseWith(Value * newValue) { |
| JIT_ASSERT(owningGraph() == newValue->owningGraph()); |
| auto u = uses()[0]; |
| u.user->inputs_[u.offset] = newValue; |
| newValue->uses_.push_back(u); |
| uses_.erase(uses_.begin()); |
| } |
| |
| void Value::replaceAllUsesWith(Value * newValue) { |
| while (!uses().empty()) { |
| replaceFirstUseWith(newValue); |
| } |
| } |
| |
| size_t findArgument(const FunctionSchema& the_schema, Symbol name) { |
| auto name_str = name.toUnqualString(); |
| for (size_t i = 0; i < the_schema.arguments().size(); ++i) { |
| const Argument* arg = &the_schema.arguments()[i]; |
| if (arg->name() == name_str) { |
| return i; |
| } |
| } |
| throw std::runtime_error(std::string("Couldn't find an argument called ") + name.toQualString()); |
| } |
| |
| c10::optional<IValue> Node::get(Symbol name) const { |
| return toIValue(namedInput(name)); |
| } |
| |
| Value* Node::namedInput(Symbol name) const { |
| return input(findArgument(schema(), name)); |
| } |
| |
| bool Node::matches(const char *signature_literal, at::ArrayRef<Symbol> const_inputs) const { |
| if (!sig(signature_literal).matches(this)) return false; |
| for (Symbol s : const_inputs) { |
| if (!is_constant(s)) return false; |
| } |
| return true; |
| } |
| |
| void Node::dump() const { |
| std::cout << *this << "\n"; |
| } |
| |
| void Node::findSchema() const { |
| schema_ = &getOperatorFor(this).schema(); |
| } |
| |
| const FunctionSchema* Node::maybeSchema() const { |
| if(!schema_) { |
| if(auto op = findOperatorFor(this)) { |
| schema_ = &op->schema(); |
| } |
| } |
| return schema_; |
| } |
| |
| bool Node::isNondeterministic() const { |
| static const OperatorSet nondeterministic_ops = { |
| "aten::dropout(Tensor input, float p, bool train) -> Tensor", |
| "aten::_fused_dropout(Tensor self, float p, Generator generator) -> (Tensor, Tensor)", |
| "aten::_standard_gamma(Tensor self, Generator generator) -> Tensor", |
| "aten::bernoulli(Tensor self, *, Generator generator) -> Tensor", |
| "aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor", |
| "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator generator) -> Tensor", |
| "aten::normal(Tensor mean, Tensor std, *, Generator generator) -> Tensor", |
| "aten::normal(float mean, Tensor std, *, Generator generator) -> Tensor", |
| "aten::normal(Tensor mean, float std, *, Generator generator) -> Tensor", |
| "aten::poisson(Tensor self, Generator generator) -> Tensor", |
| "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor", |
| "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor", |
| "aten::rand(int[] size, *, int dtype, int layout, Device device) -> Tensor", |
| "aten::rand_like(Tensor self) -> Tensor", |
| "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", |
| "aten::randint(int high, int[] size, *, int dtype, int layout, Device device) -> Tensor", |
| "aten::randint(int low, int high, int[] size, *, int dtype, int layout, Device device) -> Tensor", |
| "aten::randint_like(Tensor self, int high) -> Tensor", |
| "aten::randint_like(Tensor self, int low, int high) -> Tensor", |
| "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor", |
| "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor", |
| "aten::randn(int[] size, *, int dtype, int layout, Device device) -> Tensor", |
| "aten::randn_like(Tensor self) -> Tensor", |
| "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", |
| "aten::randperm(int n, *, int dtype, int layout, Device device) -> Tensor" |
| }; |
| |
| if (nondeterministic_ops.find(this) == nullptr) { |
| return false; |
| } |
| // Dropout with train = False is deterministic |
| if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") && is_constant(attr::train) && !get<bool>(attr::train).value()) { |
| return false; |
| } |
| return true; |
| } |
| |
| bool Node::hasSideEffects() const { |
| switch (kind_) { |
| case prim::PythonOp: |
| case prim::Print: |
| case prim::RaiseException: |
| case aten::warn: |
| return true; |
| } |
| return false; |
| } |
| |
| // Assign this node a topological position, to facilitate fast isBefore() and |
| // isAfter() queries. Must be called right after a node is inserted into the |
| // node list. |
| // |
| // The basic scheme is: assign every node a position (uint64_t). The common |
| // case (appending to the end of the graph) is made more efficient by advancing |
| // a fixed interval past the previous node and placing `this` there. Otherwise, |
| // assign `this` a position at the midpoint between its prev() and next() |
| // nodes. |
| // |
| // If we ever run out of space (by, e.g. inserting too much in place), we |
| // reindex by spreading out all the nodes again. |
| void Node::assignTopoPosition() { |
| auto returnNode = owningBlock()->return_node(); |
| const auto prevPos = prev()->topo_position_; |
| const auto nextPos = next()->topo_position_; |
| |
| // Append to the end of the graph |
| if (next() == returnNode) { |
| if (next() == prev()) { |
| // the node list is empty, assign the first position |
| topo_position_ = kMidPoint; |
| return; |
| } |
| |
| if (prevPos >= (kUpperBound - kAppendInterval)) { |
| // we're running off the edge |
| owningBlock()->reIndexTopology(); |
| return; |
| } |
| |
| topo_position_ = prevPos + kAppendInterval; |
| |
| // Prepend to the graph |
| } else if (prev() == returnNode) { |
| // next() is the first element in the block list |
| if (nextPos <= (kLowerBound + kAppendInterval)) { |
| // we're running off the edge |
| owningBlock()->reIndexTopology(); |
| return; |
| } |
| |
| topo_position_ = nextPos - kAppendInterval; |
| |
| // insert between two existing nodes |
| } else { |
| const auto posBetween = prevPos + (nextPos - prevPos) / 2; |
| if (posBetween == prevPos) { |
| // There was no room |
| owningBlock()->reIndexTopology(); |
| return; |
| } |
| topo_position_ = posBetween; |
| } |
| } |
| |
| Node::Node(Graph * graph_, NodeKind kind_) : |
| kind_(kind_), |
| graph_(graph_), |
| owning_block_(nullptr), |
| scope_(graph_->current_scope_), |
| schema_(nullptr), |
| topo_position_(0) { |
| graph_->all_nodes.emplace(this); |
| } |
| |
| void Node::eraseOutput(size_t i) { |
| JIT_ASSERT(i < outputs_.size()); |
| JIT_ASSERT(outputs_[i]->uses().empty()); |
| schema_ = nullptr; |
| Value * n = outputs_[i]; |
| outputs_.erase(outputs_.begin() + i); |
| owningGraph()->freeValue(n); |
| for(size_t j = i; j < outputs_.size(); j++) { |
| outputs_[j]->offset_--; |
| } |
| } |
| |
| Block * Node::addBlock() { |
| schema_ = nullptr; |
| blocks_.push_back(new Block(owningGraph(), this)); |
| return blocks_.back(); |
| } |
| |
| void Node::eraseBlock(size_t i) { |
| JIT_ASSERT(i < blocks_.size()); |
| schema_ = nullptr; |
| Block * n = blocks_[i]; |
| blocks_.erase(blocks_.begin() + i); |
| n->destroy(); |
| } |
| |
| void Node::destroy() { |
| while(!outputs().empty()) |
| eraseOutput(outputs().size() - 1); |
| while(!blocks().empty()) |
| eraseBlock(blocks().size() - 1); |
| removeAllInputs(); |
| if(inBlockList()) |
| removeFromList(); |
| graph_->freeNode(this); |
| } |
| |
| void Node::cloneFrom(Node * s) { |
| setSourceLocation(s->getSourceLocation()); |
| if(s->scope_ && !s->scope_->isBlank()) |
| scope_ = s->scope_; |
| copyAttributes(*s); |
| } |
| |
| void Node::replaceAllUsesWith(Node * n) { |
| JIT_ASSERT(outputs().size() == n->outputs().size()); |
| size_t nOutputs = outputs().size(); |
| for(size_t i = 0; i < nOutputs; i++) { |
| outputs()[i]->replaceAllUsesWith(n->outputs()[i]); |
| } |
| } |
| |
| Value* Node::insertInput(size_t i, Value* value) { |
| JIT_ASSERT(graph_ == value->owningGraph()); |
| schema_ = nullptr; |
| // First we update the offsets for all existing inputs that will reside |
| // after the one we're inserting. Concretely, these are the inputs at |
| // indices [i, # input). Since we're inserting one input before all of |
| // these inputs, increment their use offsets for this value by 1 |
| for (size_t use_itr = i; use_itr < inputs_.size(); ++use_itr) { |
| // See Note [User node does not uniquely identify use] |
| auto use = findUseForInput(use_itr); |
| use->offset += 1; |
| } |
| // Insert the actual input at the specified index |
| inputs_.insert(inputs_.begin() + i, value); |
| // Register the new use of the value we're inserted as an input. |
| value->uses_.emplace_back(this, i); |
| return value; |
| } |
| |
| Value* Node::addInput(Value * value) { |
| JIT_ASSERT(graph_ == value->owningGraph()); |
| schema_ = nullptr; |
| value->uses_.emplace_back(this, inputs_.size()); |
| inputs_.push_back(value); |
| return value; |
| } |
| |
| Value* Node::replaceInput(size_t i, Value * newValue) { |
| JIT_ASSERT(newValue->owningGraph() == graph_); |
| schema_ = nullptr; |
| Value * old = dropInput(i); |
| inputs_[i] = newValue; |
| newValue->uses_.emplace_back(this, i); |
| return old; |
| } |
| |
| void Node::replaceInputWith(Value * from, Value * to) { |
| JIT_ASSERT(from->owningGraph() == graph_); |
| JIT_ASSERT(to->owningGraph() == graph_); |
| schema_ = nullptr; |
| size_t i = 0; |
| for(auto input : inputs()) { |
| if(input == from) |
| replaceInput(i, to); |
| i++; |
| } |
| } |
| |
| Value* Node::addOutput() { |
| outputs_.push_back(new Value(this, outputs_.size())); |
| schema_ = nullptr; |
| return outputs_.back(); |
| } |
| |
| Value* Node::insertOutput(size_t i) { |
| schema_ = nullptr; |
| outputs_.insert(outputs_.begin() + i, new Value(this, i)); |
| for (size_t itr = i + 1; itr < outputs_.size(); ++itr) { |
| outputs_[itr]->setOffset(outputs_[itr]->offset() + 1); |
| } |
| return outputs_.at(i); |
| } |
| |
| bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const { |
| if (this->owningBlock() == n->owningBlock()) { |
| if (moveSide == MoveSide::BEFORE) { |
| return this->topo_position_ < n->topo_position_; |
| } |
| |
| if (moveSide == MoveSide::AFTER) { |
| return this->topo_position_ > n->topo_position_; |
| } |
| |
| JIT_ASSERT(this == n); |
| return false; |
| } |
| |
| // These nodes don't share a common block. Traverse the blockchains upward |
| // until we find the first common block. |
| auto lhs = this; |
| while (lhs) { |
| JIT_ASSERT(lhs->owningBlock()); |
| |
| auto rhs = n; |
| while (rhs) { |
| if (!rhs->owningBlock()) { |
| break; |
| } |
| |
| if (lhs->owningBlock() == rhs->owningBlock()) { |
| return lhs->isBeforeOrAfter(rhs, moveSide); |
| } |
| rhs = rhs->owningBlock()->owningNode(); |
| } |
| |
| lhs = lhs->owningBlock()->owningNode(); |
| } |
| // should never reach here, since both nodes are ultimately in the same graph |
| JIT_ASSERT(false); |
| |
| } |
| |
| bool Node::isBefore(const Node * n) const { |
| return isBeforeOrAfter(n, MoveSide::BEFORE); |
| } |
| |
| bool Node::isAfter(const Node * n) const { |
| return isBeforeOrAfter(n, MoveSide::AFTER); |
| } |
| |
| Node* Node::insertBefore(Node * n) { |
| JIT_ASSERT(n->inBlockList()); |
| insertAfter(n->prev()); |
| return this; |
| } |
| |
| Node* Node::insertAfter(Node * n) { |
| JIT_ASSERT(!inBlockList() && n->inBlockList()); |
| JIT_ASSERT(n->owningBlock()); |
| this->owning_block_ = n->owningBlock(); |
| Node * next = n->next(); |
| n->next() = this; |
| this->prev() = n; |
| this->next() = next; |
| next->prev() = this; |
| assignTopoPosition(); |
| return this; |
| } |
| |
| bool Node::moveAfterTopologicallyValid(Node* n, const AliasDb& aliasDb) { |
| return tryMove(n, MoveSide::AFTER, aliasDb, /*dryRun=*/false); |
| } |
| |
| bool Node::couldMoveAfterTopologically(Node* n, const AliasDb& aliasDb) { |
| return tryMove(n, MoveSide::AFTER, aliasDb, /*dryRun=*/true); |
| } |
| |
| bool Node::moveBeforeTopologicallyValid(Node* n, const AliasDb& aliasDb) { |
| // We have to distinguish the move side (instead of just moving after |
| // n->prev()). Consider the following example: |
| // If the dependency graph looks like this -> n -> o then moveBefore(o) will |
| // end up with [this, o, n], but moveAfter(n) will return false. |
| return tryMove(n, MoveSide::BEFORE, aliasDb, /*dryRun=*/false); |
| } |
| |
| bool Node::couldMoveBeforeTopologically(Node* n, const AliasDb& aliasDb) { |
| return tryMove(n, MoveSide::BEFORE, aliasDb, /*dryRun=*/true); |
| } |
| |
| // Helper for topologically-safe node moves. See `tryMove()` for details. |
| namespace { |
| struct WorkingSet { |
| public: |
| explicit WorkingSet(Node* mover, const AliasDb& aliasDb) |
| : aliasDb_(aliasDb) { |
| add(mover); |
| } |
| |
| // Add `n` to the working set |
| void add(Node* n) { |
| nodes_.push_back(n); |
| for (const auto user : getUsersSameBlock(n)) { |
| users_[user]++; |
| } |
| |
| for (const auto& writer : getWritersSameBlock(n)) { |
| writers_[writer]++; |
| } |
| if (aliasDb_.hasWildcard(n)) { |
| numWildcards_++; |
| } |
| if (aliasDb_.hasWrites(n)) { |
| numWriterNodes_++; |
| } |
| } |
| |
| void eraseMover() { |
| auto mover = nodes_.front(); |
| for (const auto user : getUsersSameBlock(mover)) { |
| // If this user node only uses the mover, we can remove it |
| if (users_[user] == 1) { |
| users_.erase(user); |
| } |
| } |
| |
| for (const auto& writer : getWritersSameBlock(mover)) { |
| if (writers_[writer] == 1) { |
| writers_.erase(writer); |
| } |
| } |
| if (aliasDb_.hasWildcard(mover)) { |
| numWildcards_--; |
| } |
| if (aliasDb_.hasWrites(mover)) { |
| numWriterNodes_--; |
| } |
| nodes_.pop_front(); |
| } |
| |
| const std::list<Node*>& nodes() { |
| return nodes_; |
| } |
| |
| // Does the working set depend on `n`? |
| bool dependsOn(Node* n) const { |
| if (nodes_.empty()) { |
| return false; |
| } |
| |
| return hasDataDependency(n) || hasMutabilityDependency(n); |
| } |
| |
| private: |
| bool hasDataDependency(Node* n) const { |
| if (n->isAfter(nodes_.front())) { |
| return producesFor(n); |
| } else { |
| return consumesFrom(n); |
| } |
| } |
| |
| bool hasMutabilityDependency(Node* n) const { |
| // 1. Handle wildcard dependencies: |
| // If the working set has a wildcard, `n` can't write to anything. |
| if (numWildcards_ > 0 && aliasDb_.hasWrites(n)) { |
| return true; |
| } |
| |
| // If `n` has a wildcard, the working set can't write to anything. |
| if (aliasDb_.hasWildcard(n) && numWriterNodes_ > 0) { |
| return true; |
| } |
| |
| // 2. Handle regular mutable dependencies |
| // Check that this node does not write to anything used by the working set |
| if (writers_.count(n) != 0) { |
| return true; |
| } |
| |
| // Check that the working set does not write to anything used by this node |
| const auto writersToNode = getWritersSameBlock(n); |
| return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) { |
| return writersToNode.count(node) != 0; |
| }); |
| } |
| |
| // Does the working set produce any values consumed by `n`? |
| bool producesFor(Node* n) const { |
| // This equivalent to asking: does the total use-set of all the nodes in the |
| // working set include `n`? |
| return users_.count(n) != 0; |
| } |
| |
| // Does the working set consume any values produced by `n`? |
| bool consumesFrom(Node* n) const { |
| const auto users = getUsersSameBlock(n); |
| return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) { |
| return users.count(node) != 0; |
| }); |
| } |
| |
| // Get all users of outputs of `n`, in the same block as `n`. |
| // This means if there is an `if` node that uses an output of `n` in some |
| // inner sub-block, we will consider the whole `if` node a user of `n`. |
| std::unordered_set<Node*> getUsersSameBlock(Node* n) const { |
| std::unordered_set<Node*> users; |
| for (const auto output : n->outputs()) { |
| for (const auto& use : output->uses()) { |
| if (auto sameBlock = findSameBlock(use.user, n)) { |
| users.insert(sameBlock); |
| } |
| } |
| } |
| return users; |
| } |
| |
| std::unordered_set<Node*> getWritersSameBlock(Node* n) const { |
| std::unordered_set<Node*> writers; |
| for (const auto writer : aliasDb_.getWriters(n)) { |
| if (auto sameBlock = findSameBlock(writer, n)) { |
| writers.insert(sameBlock); |
| } |
| } |
| return writers; |
| } |
| |
| // Traverse `target`'s blockchain upward until we find a node that shares a |
| // block with `n`. |
| // |
| // If one can't be found (say, because `n` is an inner block and target is |
| // outside), then return nullptr. Since we can only reorder nodes within a |
| // block, `target` would be irrelevant. |
| static Node* findSameBlock(Node* target, Node* n) { |
| if (target->owningBlock() == n->owningBlock()) { |
| return target; |
| } else { |
| // This user is in a sub-block. Traverse the blockchain upward until |
| // we arrive at a node that shares a block with `this` |
| auto curNode = target; |
| while (curNode->owningBlock() != n->owningBlock()) { |
| curNode = curNode->owningBlock()->owningNode(); |
| if (curNode == nullptr) { |
| return curNode; |
| } |
| } |
| return curNode; |
| } |
| } |
| |
| const AliasDb& aliasDb_; |
| std::list<Node*> nodes_; |
| // users => # of working set nodes it uses |
| std::unordered_map<Node*, size_t> users_; |
| std::unordered_map<Node*, size_t> writers_; |
| size_t numWildcards_ = 0; |
| size_t numWriterNodes_ = 0; |
| }; |
| } // namespace |
| |
| // Try to move `this` before/after `movePoint` while preserving value |
| // dependencies. Returns false iff such a move could not be made |
| // |
| // The basic approach is: have a "working set" that we are moving forward, one |
| // node at a time. When we can't move past a node (because it depends on the |
| // working set), then add it to the working set and keep moving until we hit |
| // `moveAfter`. |
| bool Node::tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb, bool dryRun) { |
| JIT_ASSERT(this->inBlockList() && movePoint->inBlockList()); |
| JIT_ASSERT(this->owningBlock() == movePoint->owningBlock()); |
| if (this == movePoint) { |
| return true; |
| } |
| |
| // 1. Move from `this` toward movePoint, building up the working set of |
| // dependencies |
| WorkingSet workingSet(this, aliasDb); |
| |
| int direction; |
| if (this->isAfter(movePoint)) { |
| direction = kPrevDirection; |
| } else { |
| direction = kNextDirection; |
| } |
| |
| auto curNode = this->next_in_graph[direction]; |
| // Move forward one node at a time |
| while (curNode != movePoint) { |
| if (workingSet.dependsOn(curNode)) { |
| // If we can't move past this node, add it to the working set |
| workingSet.add(curNode); |
| } |
| curNode = curNode->next_in_graph[direction]; |
| } |
| |
| // 2. Decide whether we can move it all to `movePoint`. |
| |
| // Say we are moving directly before movePoint and `this` starts before |
| // movePoint in the graph. The move looks like |
| // |
| // `this` `this` | |
| // <dependencies> -> `movePoint` | `this` and deps are split |
| // `movePoint` <dependencies> | |
| // |
| // Contrast with the case where `this` starts AFTER movePoint: |
| // |
| // `movePoint` <dependencies> | |
| // <dependencies> -> `this` | `this` and deps are together |
| // `this` `movePoint` | |
| // |
| // In the first case, we need to split `this` off from its dependencies, so we |
| // can move the dependencies below `movePoint` and keep `this` above. |
| const bool splitThisAndDeps = |
| (moveSide == MoveSide::BEFORE && this->isBefore(movePoint)) || |
| (moveSide == MoveSide::AFTER && this->isAfter(movePoint)); |
| |
| if (splitThisAndDeps) { |
| // remove `this` from dependencies to be moved past `movePoint` |
| workingSet.eraseMover(); |
| } |
| |
| // Check if we can move the working set past the move point |
| if (workingSet.dependsOn(movePoint)) { |
| // if we can't, then there are intermediate dependencies between the |
| // `this` and `movePoint`, so we can't do the move |
| return false; |
| } |
| |
| if (dryRun) { |
| return true; |
| } |
| |
| // 3. Execute the move |
| JIT_ASSERT(curNode == movePoint); |
| if (splitThisAndDeps) { |
| // Move `this` |
| this->move(movePoint, moveSide); |
| |
| // Then move all of its dependencies on the other side of `movePoint` |
| const auto reversed = |
| moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE; |
| for (auto toMove : workingSet.nodes()) { |
| toMove->move(curNode, reversed); |
| curNode = toMove; |
| } |
| } else { |
| // Just append/prepend everything to `movePoint` |
| for (auto toMove : workingSet.nodes()) { |
| toMove->move(curNode, moveSide); |
| curNode = toMove; |
| } |
| } |
| return true; |
| } |
| |
| // Helper function so we can generalize `tryMove` |
| void Node::move(Node* movePoint, MoveSide moveSide) { |
| switch (moveSide) { |
| case MoveSide::BEFORE: |
| this->moveBefore(movePoint); |
| break; |
| case MoveSide::AFTER: |
| this->moveAfter(movePoint); |
| break; |
| } |
| } |
| |
| void Node::moveAfter(Node * n) { |
| removeFromList(); |
| insertAfter(n); |
| } |
| |
| void Node::moveBefore(Node * n) { |
| removeFromList(); |
| insertBefore(n); |
| } |
| |
| void Node::removeInput(size_t i) { |
| schema_ = nullptr; |
| dropInput(i); |
| // everything after this input shifts left, |
| // so we need to update their use offsets to match |
| for(size_t j = i+1; j < inputs_.size(); j++) { |
| auto it = findUseForInput(j); |
| it->offset--; |
| } |
| inputs_.erase(inputs_.begin() + i); |
| } |
| |
| void Node::removeAllInputs() { |
| schema_ = nullptr; |
| for(size_t i = 0; i < inputs().size(); ++i) |
| dropInput(i); |
| inputs_.clear(); |
| } |
| |
| use_list::iterator Node::findUseForInput(size_t i) { |
| auto & input_uses = inputs_[i]->uses_; |
| // O(N) on the use list, but unless we get nodes with +100 uses |
| // vector traversal still is probably faster than linked list |
| auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i)); |
| JIT_ASSERT(use_it != input_uses.end()); |
| return use_it; |
| } |
| |
| Value* Node::dropInput(size_t i) { |
| JIT_ASSERT(i < inputs_.size()); |
| auto input_node = inputs_[i]; |
| auto use_it = findUseForInput(i); |
| input_node->uses_.erase(use_it); |
| inputs_[i] = nullptr; |
| return input_node; |
| } |
| |
| void Node::removeFromList() { |
| JIT_ASSERT(inBlockList()); |
| this->owning_block_ = nullptr; |
| Node * next = this->next(); |
| Node * prev = this->prev(); |
| prev->next() = next; |
| next->prev() = prev; |
| this->next() = nullptr; |
| this->prev() = nullptr; |
| } |
| |
| inline const SourceRange& fakeRange() { |
| static SourceRange range(std::make_shared<std::string>("<internally-created-node>"), 0, 1); |
| return range; |
| } |
| |
| Value* Graph::insert( |
| Symbol opname, |
| at::ArrayRef<NamedValue> args, |
| at::ArrayRef<NamedValue> kwargs, |
| const c10::optional<SourceRange>& range) { |
| return script::emitBuiltinCall( |
| range.value_or(fakeRange()), |
| *this, |
| opname, |
| c10::nullopt, |
| args, |
| kwargs, |
| /*required=*/true); |
| } |
| |
| Node* Graph::create(NodeKind kind, size_t num_outputs) { |
| // NB: Node constructor adds node to all_nodes |
| auto n = new Node(this, kind); |
| for(size_t i = 0; i < num_outputs; i++) |
| n->addOutput(); |
| return n; |
| } |
| |
| Node* Graph::create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs) { |
| auto n = create(kind, num_outputs); |
| for(auto i : inputs) |
| n->addInput(i); |
| return n; |
| } |
| |
| Node* Graph::createUndefined() { |
| return create(prim::Undefined); |
| } |
| |
| Node* Graph::createNone(TypePtr typ) { |
| Node * n = create(prim::None); |
| n->output()->setType(OptionalType::create(std::move(typ))); |
| return n; |
| } |
| |
| Node * Graph::createNoneGenerator() { |
| auto n = create(prim::NoneGenerator); |
| n->output()->setType(GeneratorType::get()); |
| return n; |
| } |
| |
| Node * Graph::createFusionGroup() { |
| auto n = create(prim::FusionGroup, 0); |
| n->g_(attr::Subgraph,std::make_shared<Graph>(current_scope())); |
| return n; |
| } |
| |
| Node* Graph::createTuple(at::ArrayRef<Value*> values) { |
| auto types = fmap(values, [](Value* v) { return v->type(); }); |
| auto tt = TupleType::create(std::move(types)); |
| auto n = create(prim::TupleConstruct, values); |
| n->output()->setType(tt); |
| return n; |
| } |
| |
| Node* Graph::createTupleUnpack(Value * v) { |
| TupleTypePtr tt = v->type()->expect<TupleType>(); |
| auto n = create(prim::TupleUnpack, {v}, 0); |
| for(auto & element : tt->elements()) { |
| n->addOutput()->setType(element); |
| } |
| return n; |
| } |
| |
| Node* Graph::createTupleIndex(Value * tup, int64_t index) { |
| auto n = create(prim::TupleIndex, {tup}); |
| n->i_(attr::index, index); |
| auto tuple_type = tup->type()->expect<TupleType>(); |
| n->output()->setType(tuple_type->elements().at(index)); |
| return n; |
| } |
| |
| Node* Graph::createTupleSlice(Value * tup, int64_t beg, int64_t end) { |
| auto n = create(prim::TupleSlice, {tup}); |
| auto tuple_type = tup->type()->expect<TupleType>(); |
| n->i_(attr::beg, beg); |
| n->i_(attr::end, end); |
| std::vector<TypePtr> output_types; |
| for (auto i = beg; i < end; ++i) { |
| output_types.push_back(tuple_type->elements().at(i)); |
| } |
| auto tt = TupleType::create(std::move(output_types)); |
| n->output()->setType(tt); |
| return n; |
| } |
| |
| Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef<Value*> values) { |
| auto n = create(prim::ListConstruct, values); |
| for(const auto & v : values) { |
| JIT_ASSERT(v->type()->isSubtypeOf(elem_type)); |
| } |
| n->output()->setType(ListType::create(elem_type)); |
| return n; |
| } |
| Node* Graph::createListUnpack(Value *v, size_t size) { |
| ListTypePtr list_type = v->type()->expect<ListType>(); |
| TypePtr elem_type = list_type->getElementType(); |
| auto n = create(prim::ListUnpack, {v}, 0); |
| for (size_t i = 0; i < size; ++i) { |
| n->addOutput()->setType(elem_type); |
| } |
| return n; |
| } |
| |
| Node* Graph::createNumToTensor(Value* value) { |
| auto typ = value->type(); |
| Node * result = create(prim::NumToTensor, {value}); |
| result->output()->setType(CompleteTensorType::fromNumberType(std::move(typ))); |
| return result; |
| } |
| |
| Node* Graph::createImplicitTensorToNum(const TypePtr& type, Value* value) { |
| auto* result = create(prim::ImplicitTensorToNum, {value}); |
| result->output()->setType(type); |
| return result; |
| } |
| |
| Node* Graph::createClone(Node * n, const std::function<Value*(Value*)>& value_map, bool copy_blocks) { |
| //n can be from a different graph |
| Node * r = n->allocNewInstance(this); |
| for(auto o : n->outputs()) { |
| r->addOutput()->copyMetadata(o); |
| } |
| r->cloneFrom(n); |
| for(auto i : n->inputs()) { |
| r->addInput(value_map(i)); |
| } |
| if(copy_blocks) { |
| for(auto b : n->blocks()) { |
| r->addBlock()->cloneFrom(b, value_map); |
| } |
| } |
| return r; |
| } |
| |
| Value* Graph::insertConstant( |
| IValue val, |
| c10::optional<SourceRange> loc, |
| c10::optional<ScopePtr> scope) { |
| return jit::insertConstant(*this, std::move(val), std::move(loc), std::move(scope)); |
| } |
| |
| std::string Graph::toString() const { |
| std::ostringstream oss; |
| oss << *this; |
| return oss.str(); |
| } |
| |
| Graph::~Graph() { |
| for (const Node * n : all_nodes) |
| delete n; |
| for (const Value * v : all_values) |
| delete v; |
| for (const Block * b : all_blocks) |
| delete b; |
| } |
| |
| void Graph::freeNode(Node * n) { |
| auto it = all_nodes.find(n); |
| JIT_ASSERT(it != all_nodes.end()); |
| delete *it; |
| all_nodes.erase(it); |
| } |
| void Graph::freeValue(Value * v) { |
| v->setUniqueName(""); |
| auto it = all_values.find(v); |
| JIT_ASSERT(it != all_values.end()); |
| delete *it; |
| all_values.erase(it); |
| } |
| void Graph::freeBlock(Block * b) { |
| auto it = all_blocks.find(b); |
| JIT_ASSERT(it != all_blocks.end()); |
| delete *it; |
| all_blocks.erase(it); |
| } |
| |
| at::ArrayRef<Value*> createTupleUnpack(Value* v) { |
| // small peephole optimization to ensure IntList attributes can still turn |
| // into constants e.g. in x.expand([3, 4]) |
| if(v->node()->kind() == prim::TupleConstruct) |
| return v->node()->inputs(); |
| auto & g = *v->owningGraph(); |
| return g.insertNode(g.createTupleUnpack(v))->outputs(); |
| } |
| |
| std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) { |
| std::unordered_map<Value*, Value*> value_map; |
| auto value_map_func = [&](Value* v) { return value_map.at(v); }; |
| JIT_ASSERT(callee.inputs().size() == inputs.size()); |
| for (size_t i = 0; i < inputs.size(); ++i) { |
| value_map[callee.inputs()[i]] = inputs[i]; |
| } |
| for (auto* node : callee.nodes()) { |
| auto* new_node = |
| g.insertNode(g.createClone(node, value_map_func)); |
| for (size_t i = 0; i < node->outputs().size(); ++i) { |
| value_map[node->outputs()[i]] = new_node->outputs()[i]; |
| } |
| } |
| |
| std::vector<Value*> outputs; |
| for (auto* output : callee.outputs()) { |
| outputs.push_back(value_map_func(output)); |
| } |
| return outputs; |
| } |
| |
| PythonOp* defaultAllocPythonOp(Graph*g) { |
| throw std::runtime_error("Trying to allocate a Python object without python bindings loaded"); |
| } |
| std::atomic<decltype(&defaultAllocPythonOp)> alloc_python_op; |
| |
| // patched in when python bindings are loaded |
| PythonOp* allocPythonOp(Graph* g) { |
| return alloc_python_op.load()(g); |
| } |
| void setAllocPythonOp(PythonOp* (*v)(Graph* g)) { |
| alloc_python_op.store(v); |
| } |
| |
| |
| }} // namespace torch::jit |