| #pragma once |
| #include <c10/util/Exception.h> |
| #include <torch/csrc/autograd/variable.h> |
| #include <torch/csrc/jit/argument_spec.h> |
| #include <torch/csrc/jit/graph_executor.h> |
| #include <torch/csrc/jit/ir.h> |
| #include <torch/csrc/jit/named_value.h> |
| #include <torch/csrc/jit/passes/shape_analysis.h> |
| #include <torch/csrc/jit/script/slot.h> |
| #include <torch/csrc/jit/source_range.h> |
| |
| #include <torch/csrc/WindowsTorchApiMacro.h> |
| #include <torch/csrc/api/include/torch/ordered_dict.h> |
| #include <torch/csrc/jit/script/compilation_unit.h> |
| #include <torch/csrc/utils/memory.h> |
| |
| #include <ATen/core/function_schema.h> |
| #include <ATen/core/qualified_name.h> |
| #include <c10/util/ArrayRef.h> |
| #include <c10/util/Optional.h> |
| |
| #include <functional> |
| #include <memory> |
| #include <mutex> |
| #include <ostream> |
| #include <string> |
| #include <unordered_map> |
| #include <vector> |
| |
| // This file contains classes which assist in desugaring Python style |
| // modules and their methods into flattened graphs which don't have any |
| // function calls. |
| |
| namespace torch { |
| namespace jit { |
| namespace script { |
| |
| using ::c10::Argument; |
| using ::c10::FunctionSchema; |
| using ::c10::QualifiedName; |
| // Map which stores filename to content. |
| using ExtraFilesMap = std::unordered_map<std::string, std::string>; |
| |
| using ModulePtr = c10::intrusive_ptr<c10::ivalue::Object>; |
| // A method in a module, e.g. f in: |
| // |
| // class M(ScriptModule): |
| // @script_method |
| // def f(self, x): |
| // ... |
| // Note: because Method/Module are exposed to python these |
| // classes use python method naming conventions |
| |
| struct Module; |
| |
| template <typename T> |
| struct slot_list_impl; |
| using slot_list = slot_list_impl<Slot>; |
| using module_list = slot_list_impl<Module>; |
| using ModuleLookup = std::function<Module(const std::vector<std::string>&)>; |
| |
| struct TORCH_API Method { |
| Method(ModulePtr owner, Function* function); |
| |
| // the module that contains this method. |
| Module owner() const; |
| void run(Stack& stack); |
| void run(Stack&& stack) { |
| run(stack); |
| } |
| |
| IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs()); |
| |
| std::shared_ptr<Graph> graph() const { |
| return function_->graph(); |
| } |
| |
| const std::string& name() const { |
| return function_->name(); |
| } |
| |
| size_t num_inputs() const { |
| return function_->num_inputs(); |
| } |
| |
| GraphExecutor& get_executor() { |
| return function_->get_executor(); |
| } |
| |
| Function& function() const { |
| return *function_; |
| } |
| |
| // Used for ONNX export. Return a tuple (graph, parameters) where |
| // the last parameters.size() inputs to the graph are the trainable parameters |
| // used in this method. The remaining inputs are the true inputs to the function. |
| std::pair<std::shared_ptr<Graph>, std::vector<at::Tensor>> _lowered_graph(); |
| |
| private: |
| // Methods are uniqued onwed by a single module. This raw pointer allows |
| // looking up the module. |
| ModulePtr owner_; |
| |
| // Underlying unbound function |
| // This is the _lowered_ function and is different than the |
| // first-class function in class_compilation_unit() |
| Function* function_; |
| }; |
| |
| struct TORCH_API Module { |
| explicit Module(c10::QualifiedName class_name); |
| Module(c10::QualifiedName, std::shared_ptr<CompilationUnit> cu); |
| // module_value_ null and will be lazily initialized if is needed |
| Module() {} |
| Module(ModulePtr module_value) : module_value_(std::move(module_value)) {} |
| ~Module() {} |
| |
| const c10::QualifiedName& name() const { |
| return *module_object()->type()->qualified_name_obj(); |
| } |
| |
| void set_optimized(bool o) { |
| AT_WARN( |
| "Module::set_optimized() is deprecated and has no effect. " |
| "Please use setGraphExecutorOptimize()"); |
| } |
| |
| bool is_optimized() const { |
| AT_WARN( |
| "Module::is_optimized() is deprecated and always returns true. " |
| "Please use getGraphExecutorOptimize()"); |
| return true; |
| } |
| |
| IValue forward(std::vector<IValue> inputs) { |
| return get_method("forward")(std::move(inputs)); |
| } |
| |
| // In script modules, buffers are Tensors attribute that are _not_ registered |
| // as parameters. This is different than in nn.Module where there is a special |
| // register_buffer method. With this simplification, we only need to track |
| // whether a slot is a parameter to be able to classify it. |
| void register_buffer(const std::string& name, autograd::Variable v) { |
| set_or_add_slot(name, TensorType::get(), v, EntityType::ATTRIBUTE); |
| } |
| |
| void register_parameter( |
| const std::string& name, |
| autograd::Variable v, |
| bool is_buffer) { |
| set_or_add_slot( |
| name, |
| TensorType::get(), |
| v, |
| is_buffer ? EntityType::ATTRIBUTE : EntityType::PARAMETER); |
| } |
| void register_attribute( |
| const std::string& name, |
| const TypePtr type, |
| IValue ivalue) { |
| set_or_add_slot(name, type, ivalue, EntityType::ATTRIBUTE); |
| } |
| void register_module(const std::string& name, const Module& module) { |
| set_or_add_slot( |
| name, module.type(), module.module_object(), EntityType::MODULE); |
| } |
| |
| void set_parameter(const std::string& name, at::Tensor v) { |
| get_slot(name, EntityType::PARAMETER).setValue(v); |
| } |
| |
| autograd::Variable get_parameter(const std::string& name) const { |
| return autograd::as_variable_ref( |
| get_slot(name, EntityType::PARAMETER).value().toTensor()); |
| } |
| |
| IValue get_attribute(const std::string& name) const { |
| return get_slot(name, EntityType::ATTRIBUTE).value(); |
| } |
| |
| autograd::Variable get_buffer(const std::string& name) const { |
| return autograd::as_variable_ref(get_attribute(name).toTensor()); |
| } |
| |
| // each module owns its method. The reference returned here |
| // is guarenteed to stay valid until this module has been destroyed |
| Method get_method(const std::string& name) const { |
| if (auto method = find_method(name)) { |
| return *method; |
| } |
| AT_ERROR("Method '", name, "' is not defined."); |
| } |
| |
| Module get_module(const std::string& name) const { |
| auto obj = get_slot(name, EntityType::MODULE).value().toObject(); |
| return Module(obj); |
| } |
| |
| module_list get_modules() const; |
| slot_list get_slots() const; |
| slot_list get_parameters() const; |
| slot_list get_attributes() const; |
| slot_list get_module_slots() const; |
| |
| const std::vector<Method> get_methods() const { |
| return fmap( |
| class_compilation_unit()->get_functions(), |
| [&](Function* func) { |
| return Method(module_object(), func); |
| }); |
| } |
| |
| c10::optional<Slot> find_parameter(const std::string& name) const { |
| return find_slot(name, EntityType::PARAMETER); |
| } |
| c10::optional<Slot> find_attribute(const std::string& name) { |
| return find_slot(name, EntityType::ATTRIBUTE); |
| } |
| c10::optional<Slot> find_buffer(const std::string& name) { |
| auto iv = find_attribute(name); |
| if (iv && iv->type()->isSubtypeOf(TensorType::get())) { |
| return iv; |
| } |
| return c10::nullopt; |
| } |
| c10::optional<Module> find_module(const std::string& name) const { |
| if (auto slot = find_slot(name, EntityType::MODULE)) { |
| return Module(slot->value().toObject()); |
| } |
| return c10::nullopt; |
| } |
| c10::optional<Method> find_method(const std::string& basename) const { |
| if (const auto fn = class_compilation_unit()->find_function( |
| getNameForMethod(basename))) { |
| return Method(module_object(), fn); |
| } |
| return c10::nullopt; |
| } |
| void apply(const std::function<void(Module&)>& fn); |
| |
| /// Enables "training" mode. |
| void train(bool on = true); |
| /// Calls train(false) to enable "eval" mode. |
| /// Do not override this method, override `train()` instead. |
| void eval() { |
| train(/*on=*/false); |
| } |
| /// True if the module is in training mode. |
| bool is_training() { |
| if (auto p = find_attribute("training")) { |
| return p->value().toBool(); |
| } |
| // We are in training mode by default |
| return true; |
| } |
| |
| /// Recursively casts all parameters to the given `dtype` and `device`. |
| /// |
| /// If `non_blocking` is true and the source is in pinned memory and |
| /// destination is on the GPU or vice versa, the copy is performed |
| /// asynchronously with respect to the host. Otherwise, the argument has no |
| /// effect. |
| void to(at::Device device, at::ScalarType dtype, bool non_blocking = false); |
| |
| /// Recursively casts all parameters to the given dtype. |
| /// |
| /// If `non_blocking` is true and the source is in pinned memory and |
| /// destination is on the GPU or vice versa, the copy is performed |
| /// asynchronously with respect to the host. Otherwise, the argument has no |
| /// effect. |
| void to(at::ScalarType dtype, bool non_blocking = false); |
| |
| /// Recursively moves all parameters to the given device. |
| /// |
| /// If `non_blocking` is true and the source is in pinned memory and |
| /// destination is on the GPU or vice versa, the copy is performed |
| /// asynchronously with respect to the host. Otherwise, the argument has no |
| /// effect. |
| void to(at::Device device, bool non_blocking = false); |
| |
| /// Run a method from this module. |
| /// |
| /// For example: |
| /// @code |
| /// IValue output = module->run("relu_script", a, b); |
| /// @endcode |
| /// |
| /// To get a compile a module from a source string, see torch::jit::compile |
| /// |
| /// @param method_name The name of the method to run |
| /// @param args Arguments to be passed to the method |
| /// @return An IValue containing the return value (or values if it is a tuple) |
| /// from the method |
| template <typename... Types> |
| IValue run_method(const std::string& method_name, Types&&... args) { |
| return get_method(method_name)({IValue(std::forward<Types>(args))...}); |
| } |
| |
| void save( |
| std::ostream& out, |
| const ExtraFilesMap& extra_files = ExtraFilesMap()) const; |
| |
| void save( |
| const std::string& filename, |
| const ExtraFilesMap& extra_files = ExtraFilesMap()) const; |
| |
| void copy_into( |
| const ModuleLookup& module_lookup, |
| // translate current module singleton type to new module |
| // singleton type. |
| std::unordered_map<TypePtr, TypePtr>& type_remap, |
| std::vector<std::string> names = {}) const; |
| |
| void clone_method(const Module& orig, const std::string& name); |
| |
| at::optional<EntityType> kind_of(const std::string& name) const { |
| if (class_compilation_unit()->find_function(getNameForMethod(name))) { |
| return EntityType::METHOD; |
| } |
| if (auto offset = type()->findAttributeSlot(name)) { |
| return get_slot(*offset).entity_type(); |
| } |
| return c10::nullopt; |
| } |
| |
| ModulePtr module_object() const; |
| |
| ClassTypePtr type() const { |
| return module_object()->type(); |
| } |
| std::shared_ptr<CompilationUnit> class_compilation_unit() const { |
| return module_object()->compilation_unit(); |
| } |
| |
| // so that C++ users can easily add methods |
| void define(const std::string& src, const ResolverPtr& resolver = nullptr); |
| |
| template <typename... Types> |
| IValue create_class(const c10::QualifiedName& name, Types&&... args) const { |
| return create_class(name, {IValue(std::forward<Types>(args))...}); |
| } |
| |
| IValue create_class(const c10::QualifiedName& name, Stack stack) const; |
| |
| Slot get_slot(size_t slot) const { |
| TORCH_CHECK( |
| slot < module_object()->slots().size(), "not a valid slot offset"); |
| return Slot(module_object(), slot); |
| } |
| |
| size_t num_slots() const { |
| return module_object()->slots().size(); |
| } |
| |
| private: |
| void clone_method( |
| const Module& orig, |
| const QualifiedName& orig_method_name, |
| const std::unordered_map<TypePtr, TypePtr>& type_remap); |
| |
| c10::QualifiedName getNameForMethod(std::string basename) const { |
| return QualifiedName(name(), basename); |
| } |
| static const char* toString(EntityType t) { |
| switch (t) { |
| case EntityType::MODULE: |
| return "module"; |
| case EntityType::PARAMETER: |
| return "parameter"; |
| case EntityType::ATTRIBUTE: |
| return "attribute"; |
| case EntityType::METHOD: |
| return "method"; |
| } |
| return nullptr; |
| } |
| void check_entity(EntityType expected, size_t slot) const { |
| EntityType actual = get_slot(slot).entity_type(); |
| TORCH_CHECK( |
| expected == actual, |
| "The field '", |
| type()->getAttributeName(slot), |
| "' is a ", |
| toString(actual), |
| " but this call is" |
| " trying to use it as a ", |
| toString(expected)); |
| } |
| |
| void set_or_add_slot( |
| const std::string& name, |
| const TypePtr& slot_type, |
| IValue v, |
| EntityType etype) { |
| auto slot = type()->findAttributeSlot(name); |
| if (!slot) { |
| slot = |
| type()->addAttribute(name, slot_type, etype == EntityType::PARAMETER); |
| } else { |
| check_entity(etype, *slot); |
| } |
| TypePtr atype = type()->getAttribute(*slot); |
| TORCH_CHECK(slot_type->isSubtypeOf(atype)); |
| module_object()->setSlot(*slot, std::move(v)); |
| } |
| |
| Slot get_slot(const std::string& name, EntityType etype) const { |
| size_t slot = type()->getAttributeSlot(name); |
| check_entity(etype, slot); |
| return get_slot(slot); |
| } |
| c10::optional<Slot> find_slot(const std::string& name, EntityType etype) |
| const { |
| auto slot = type()->findAttributeSlot(name); |
| if (!slot) { |
| return c10::nullopt; |
| } |
| Slot r = get_slot(*slot); |
| if (r.entity_type() != etype) { |
| return c10::nullopt; |
| } |
| return r; |
| } |
| |
| void to_impl( |
| const c10::optional<at::Device>& device, |
| const c10::optional<at::ScalarType>& dtype, |
| bool non_blocking); |
| |
| // mutable be we lazily initialize in module_object. |
| mutable ModulePtr module_value_; |
| }; |
| |
| // this iterator for the slot list defined below has a position in the list i_ |
| // and an optional field type_ that if present |
| // restricts iteration to only the slots of module_ that |
| // have EntityType *type_. This allows it to return, e.g. |
| // only the parameter slots. |
| // The template parameter allows us to use the same implementation for a list |
| // that returns Module via template specialization of the operator* method. |
| template <typename T> |
| struct TORCH_API slot_iterator_impl { |
| slot_iterator_impl(Module module, c10::optional<EntityType> type, size_t i) |
| : module_(module), type_(type), i_(i) { |
| advance_to_valid(); |
| } |
| T operator*() const; |
| T operator->() const { |
| return **this; |
| } |
| slot_iterator_impl& operator++() { |
| ++i_; |
| advance_to_valid(); |
| return *this; |
| } |
| slot_iterator_impl operator++(int) { |
| slot_iterator_impl old = *this; |
| ++(*this); |
| return old; |
| } |
| |
| private: |
| void advance_to_valid() { |
| while (i_ < module_.num_slots() && |
| (type_ && module_.get_slot(i_).entity_type() != *type_)) { |
| ++i_; |
| } |
| } |
| Module module_; |
| c10::optional<EntityType> type_; |
| size_t i_; |
| |
| template <typename TT> |
| friend inline bool operator!=( |
| const slot_iterator_impl<TT>& a, |
| const slot_iterator_impl<TT>& b); |
| }; |
| |
| template <> |
| inline Slot slot_iterator_impl<Slot>::operator*() const { |
| return module_.get_slot(i_); |
| } |
| |
| template <> |
| inline Module slot_iterator_impl<Module>::operator*() const { |
| return Module(module_.get_slot(i_).to_module()); |
| } |
| |
| template <typename T> |
| inline bool operator!=( |
| const slot_iterator_impl<T>& a, |
| const slot_iterator_impl<T>& b) { |
| return a.i_ != b.i_; |
| } |
| |
| // This type represents lists of parameters, attributes, and |
| // submodules contained in the module. It is abstract because |
| // they are not stored directly in std::vectors but inside the |
| // module's IValue object itself. |
| template <typename T> |
| struct TORCH_API slot_list_impl { |
| using iterator = slot_iterator_impl<T>; |
| using const_iterator = slot_iterator_impl<T>; |
| slot_iterator_impl<T> begin() const { |
| return slot_iterator_impl<T>(module_, type_, 0); |
| } |
| slot_iterator_impl<T> end() const { |
| return slot_iterator_impl<T>(module_, type_, module_.num_slots()); |
| } |
| size_t size() const { |
| if (!size_) { |
| size_ = size_t(0); |
| for (Slot s : *(this)) { |
| ++*size_; |
| } |
| } |
| return *size_; |
| } |
| |
| private: |
| slot_list_impl(Module module, c10::optional<EntityType> type) |
| : module_(std::move(module)), type_(type) { |
| if (!type_) { |
| size_ = module_.num_slots(); |
| } |
| } |
| Module module_; |
| // only include Slots of the following type |
| c10::optional<EntityType> type_; |
| // size of this list, cached on first request |
| // when we need to filter the slot list |
| mutable c10::optional<size_t> size_; |
| friend struct Module; |
| }; |
| |
| TORCH_API bool& getInlineEverythingMode(); |
| |
| } // namespace script |
| } // namespace jit |
| } // namespace torch |