| |
| #include <torch/csrc/jit/argument_spec.h> |
| |
| namespace torch { |
| namespace jit { |
| |
| void ArgumentSpecCreator::scan( |
| const TypePtr& typ, |
| size_t depth, |
| const WrittenSlots& written_slots) { |
| auto finishAggregate = [&](size_t pos) { |
| // it is possible after all the work we did to scan this aggregate, |
| // we found no tensors to specialize. In this case, just generate |
| // a skip for the whole aggregate. |
| bool any_spec = std::any_of( |
| instructions_.begin() + pos, instructions_.end(), [](Inst i) { |
| return i == SPECIALIZE_TENSOR; |
| }); |
| if (!any_spec) { |
| instructions_[pos] = SKIP; |
| instructions_.resize(pos + 1); |
| } else { |
| instructions_.emplace_back(LEAVE); |
| } |
| }; |
| // the simple vm that scans instructions_ has a limited stack depth, |
| // this prevents going deeper than that. |
| if (depth >= DEPTH_LIMIT) { |
| instructions_.emplace_back(SKIP); |
| } |
| if (typ->isSubtypeOf(TensorType::get())) { |
| num_tensors_++; |
| instructions_.emplace_back(SPECIALIZE_TENSOR); |
| } else if (auto tup = typ->cast<TupleType>()) { |
| size_t pos = instructions_.size(); |
| instructions_.emplace_back(ENTER_TUPLE); |
| for (const auto& elem : tup->containedTypes()) { |
| scan(elem, depth + 1, written_slots); |
| } |
| finishAggregate(pos); |
| } else if (auto cls = typ->cast<ClassType>()) { |
| size_t pos = instructions_.size(); |
| instructions_.emplace_back(ENTER_OBJECT); |
| for (size_t i = 0; i < cls->numAttributes(); ++i) { |
| auto key = cls->name() + cls->attributeNames().at(i); |
| // it is only safe to specialize because someone might have written to it |
| if (!written_slots.count(key)) { |
| scan(cls->containedTypes().at(i), depth + 1, written_slots); |
| } else { |
| instructions_.emplace_back(SKIP); |
| } |
| } |
| finishAggregate(pos); |
| } else { |
| instructions_.emplace_back(SKIP); |
| } |
| }; |
| |
| // this is a coarse-grained guarentee that the slots of a class will not be |
| // modified by the function. It works fine for things that used be read-only |
| // modules, but will be overly conservative when some classes are written to. |
| // Doing alias analysis and looking for writes to the class would be more |
| // accurate. |
| static void scanWrittenSlots( |
| Block* block, |
| ArgumentSpecCreator::WrittenSlots& written_slots) { |
| for (Node* n : block->nodes()) { |
| if (n->kind() == prim::SetAttr) { |
| if (auto cls = n->inputs().at(0)->type()->cast<ClassType>()) { |
| written_slots.insert(cls->name() + n->s(attr::name)); |
| } |
| } |
| for (Block* subblock : n->blocks()) { |
| scanWrittenSlots(subblock, written_slots); |
| } |
| if (n->hasAttribute(attr::Subgraph)) { |
| scanWrittenSlots(n->g(attr::Subgraph)->block(), written_slots); |
| } |
| } |
| } |
| |
| ArgumentSpecCreator::ArgumentSpecCreator(Graph& graph) |
| : num_inputs_(graph.inputs().size()) { |
| WrittenSlots written_slots; |
| scanWrittenSlots(graph.block(), written_slots); |
| for (Value* input : graph.inputs()) { |
| scan(input->type(), 0, written_slots); |
| } |
| } |
| |
| void ArgumentSpecCreator::dump() const { |
| for (Inst inst : instructions_) { |
| switch (inst) { |
| case LEAVE: |
| std::cout << "] "; |
| break; |
| case ENTER_TUPLE: |
| std::cout << "Tuple["; |
| break; |
| case ENTER_OBJECT: |
| std::cout << "Object["; |
| break; |
| case SKIP: |
| std::cout << "Skip "; |
| break; |
| case SPECIALIZE_TENSOR: |
| std::cout << "SpecializeTensor "; |
| break; |
| } |
| } |
| std::cout << "\n"; |
| } |
| |
| ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input) |
| const { |
| ArgumentSpec spec(num_tensors_); |
| const IValue* stack[DEPTH_LIMIT]; // The stack of IValue lists |
| // The stack gets initialized with the input list |
| stack[0] = last(input, num_inputs_).begin(); |
| size_t stack_top = 0; // offset to the top of the stack |
| for (Inst inst : instructions_) { |
| switch (inst) { |
| case SPECIALIZE_TENSOR: |
| // consume a tensor and add to the argspec |
| spec.addTensor(*stack[stack_top]++, with_grad); |
| break; |
| case ENTER_TUPLE: { |
| // consume tuple |
| const IValue* iv = stack[stack_top]++; |
| AT_ASSERT(iv->isTuple()); |
| // see [argspec refcounting] |
| auto p = *reinterpret_cast<const at::ivalue::Tuple* const*>(iv); |
| auto tup_ptr = &p->elements()[0]; |
| // push list of tuple elements to the stack |
| stack[++stack_top] = tup_ptr; |
| } break; |
| case ENTER_OBJECT: { |
| // consume object |
| const IValue* iv = stack[stack_top]++; |
| AT_ASSERT(iv->isObject()); |
| iv->toObject(); |
| // see [argspec refcounting] |
| auto p = *reinterpret_cast<const at::ivalue::Object* const*>(iv); |
| auto obj_ptr = &p->slots()[0]; |
| // push list of object elements to the stack |
| stack[++stack_top] = obj_ptr; |
| } break; |
| case SKIP: |
| // consume and skip an element |
| stack[stack_top]++; |
| break; |
| case LEAVE: |
| --stack_top; |
| break; |
| } |
| } |
| return spec; |
| } |
| |
| // For every input of a given graph, returns a most detailed type that can be |
| // inferred for it based on this ArgumentSpec. |
| std::vector<TypePtr> ArgumentSpecCreator::getSpecializedTypes( |
| Graph& graph, |
| const ArgumentSpec& spec) const { |
| auto input_types = |
| fmap(graph.inputs(), [](Value* input) { return input->type(); }); |
| std::vector<std::vector<TypePtr>> result_stack; |
| result_stack.emplace_back(); |
| std::vector<const TypePtr*> input_stack = {input_types.data()}; |
| std::vector<std::function<TypePtr()>> aggregate_creators; |
| |
| size_t arg_spec_offset = 0; // number of specialized tensors seen so far |
| |
| for (Inst inst : instructions_) { |
| switch (inst) { |
| case SPECIALIZE_TENSOR: { |
| input_stack.back()++; |
| auto& arg = spec.at(arg_spec_offset++); |
| if (!arg.defined()) { |
| result_stack.back().emplace_back(AutogradZeroTensorType::get()); |
| } else { |
| result_stack.back().emplace_back(DimensionedTensorType::create( |
| arg.type(), |
| ConvertIntToCPUOrCUDA(arg.device()), |
| arg.dim(), |
| arg.requires_grad())); |
| } |
| } break; |
| case ENTER_TUPLE: { |
| auto tup = (*input_stack.back()++)->expect<TupleType>(); |
| input_stack.emplace_back(tup->elements().data()); |
| result_stack.emplace_back(); |
| aggregate_creators.emplace_back( |
| [&] { return TupleType::create(result_stack.back()); }); |
| } break; |
| case ENTER_OBJECT: { |
| auto cls = (*input_stack.back()++)->expect<ClassType>(); |
| input_stack.emplace_back(cls->containedTypes().data()); |
| result_stack.emplace_back(); |
| aggregate_creators.emplace_back( |
| [&result_stack, cls] { return cls->refine(result_stack.back()); }); |
| } break; |
| case SKIP: |
| result_stack.back().emplace_back(*input_stack.back()++); |
| break; |
| case LEAVE: |
| TypePtr result = aggregate_creators.back()(); |
| result_stack.pop_back(); |
| aggregate_creators.pop_back(); |
| input_stack.pop_back(); |
| result_stack.back().emplace_back(std::move(result)); |
| break; |
| } |
| } |
| AT_ASSERT(result_stack.size() == 1); |
| return result_stack.back(); |
| } |
| |
| void ArgumentSpecCreator::setInputTypes(Graph& g, const ArgumentSpec& spec) |
| const { |
| auto input_types = getSpecializedTypes(g, spec); |
| auto inputs = g.inputs(); |
| for (size_t i = 0; i < inputs.size(); ++i) { |
| inputs[i]->setType(input_types[i]); |
| } |
| } |
| |
| } // namespace jit |
| } // namespace torch |