| #pragma once |
| #include <c10/util/Exception.h> |
| #include <torch/csrc/autograd/variable.h> |
| #include <torch/csrc/jit/argument_spec.h> |
| #include <torch/csrc/jit/graph_executor.h> |
| #include <torch/csrc/jit/ir.h> |
| #include <torch/csrc/jit/named_value.h> |
| #include <torch/csrc/jit/passes/shape_analysis.h> |
| #include <torch/csrc/jit/script/slot.h> |
| #include <torch/csrc/jit/source_range.h> |
| |
| #include <torch/csrc/WindowsTorchApiMacro.h> |
| #include <torch/csrc/api/include/torch/ordered_dict.h> |
| #include <torch/csrc/jit/script/compilation_unit.h> |
| #include <torch/csrc/utils/memory.h> |
| |
| #include <ATen/core/function_schema.h> |
| #include <ATen/core/qualified_name.h> |
| #include <c10/util/ArrayRef.h> |
| #include <c10/util/Optional.h> |
| |
| #include <functional> |
| #include <memory> |
| #include <mutex> |
| #include <ostream> |
| #include <string> |
| #include <unordered_map> |
| #include <vector> |
| |
| // This file contains classes which assist in desugaring Python style |
| // modules and their methods into flattened graphs which don't have any |
| // function calls. |
| |
| namespace torch { |
| namespace jit { |
| namespace script { |
| |
| using ::c10::Argument; |
| using ::c10::FunctionSchema; |
| using ::c10::QualifiedName; |
| // Map which stores filename to content. |
| using ExtraFilesMap = std::unordered_map<std::string, std::string>; |
| |
| using ModulePtr = c10::intrusive_ptr<c10::ivalue::Object>; |
| // A method in a module, e.g. f in: |
| // |
| // class M(ScriptModule): |
| // @script_method |
| // def f(self, x): |
| // ... |
| // Note: because Method/Module are exposed to python these |
| // classes use python method naming conventions |
| |
| struct Module; |
| |
| using ModuleLookup = |
| std::function<std::shared_ptr<Module>(const std::vector<std::string>&)>; |
| |
| struct TORCH_API Method { |
| Method(Module* owner, Function* function); |
| |
| // the module that contains this method. |
| Module& owner() const { |
| return *owner_; |
| } |
| |
| void run(Stack& stack) { |
| for (auto input : initial_ivalues_) { |
| push(stack, input.value()); |
| } |
| function_->run(stack); |
| } |
| void run(Stack&& stack) { |
| run(stack); |
| } |
| |
| IValue operator()( |
| std::vector<IValue> stack, |
| const Kwargs& kwargs = Kwargs()) { |
| getSchema().checkAndNormalizeInputs(stack, kwargs); |
| for (auto input : initial_ivalues_) { |
| push(stack, input.value()); |
| } |
| // use run rather than operator() to skip the second schema check. |
| function_->run(std::move(stack)); |
| return stack.front(); |
| } |
| |
| const std::vector<Slot>& initial_ivalues() const { |
| return initial_ivalues_; |
| } |
| |
| std::shared_ptr<Graph> graph() const { |
| return function_->graph(); |
| } |
| |
| const std::string& name() const { |
| return function_->name(); |
| } |
| |
| size_t num_inputs() const { |
| return function_->num_inputs() - initial_ivalues_.size(); |
| } |
| |
| const FunctionSchema& getSchema() const { |
| return schema_; |
| } |
| |
| GraphExecutor& get_executor() { |
| return function_->get_executor(); |
| } |
| |
| Function& function() const { |
| return *function_; |
| } |
| |
| private: |
| // Methods are uniqued onwed by a single module. This raw pointer allows |
| // looking up the module. |
| Module* owner_; |
| |
| // Underlying unbound function |
| // This is the _lowered_ function and is different than the |
| // first-class function in class_compilation_unit() |
| std::shared_ptr<Function> function_; |
| |
| // parameters and attributes loaded from the Module and appending |
| // before calling function_ |
| std::vector<Slot> initial_ivalues_; |
| FunctionSchema schema_; |
| }; |
| |
| struct Module; |
| |
| struct TORCH_API Module { |
| TH_DISALLOW_COPY_AND_ASSIGN(Module); |
| Module() |
| : name_("__main__"), |
| module_value_(c10::ivalue::Object::create( |
| ClassType::createModuleType(std::make_shared<CompilationUnit>()), |
| 0)) {} |
| |
| ~Module() { |
| // ClassType own the compilation unit of their Functions, but each |
| // Function has a self argument which owns the ClassType, created a |
| // reference cycle. By dropping all the methods of the module's class |
| // here we break the cycle. |
| class_compilation_unit().drop_all_functions(); |
| } |
| const std::string& name() const { |
| return name_; |
| } |
| |
| // note this doesn't change the flags of existing methods just ones |
| // added afterward. |
| void set_optimized(bool o) { |
| class_compilation_unit().set_optimized(o); |
| } |
| |
| bool is_optimized() const { |
| return class_compilation_unit().is_optimized(); |
| } |
| |
| IValue forward(std::vector<IValue> inputs) { |
| return get_method("forward")(std::move(inputs)); |
| } |
| |
| void register_buffer(const std::string& name, autograd::Variable v) { |
| if (auto b = find_attribute(name)) { |
| AT_ASSERT(b->type()->isSubtypeOf(TensorType::get())); |
| b->setValue(v); |
| return; |
| } |
| insert( |
| name, |
| attributes_, |
| EntityType::ATTRIBUTE, |
| appendSlot(name, TensorType::get(), std::move(v))); |
| } |
| |
| void register_parameter( |
| const std::string& name, |
| autograd::Variable v, |
| bool is_buffer) { |
| if (is_buffer) { |
| register_buffer(name, std::move(v)); |
| return; |
| } |
| if (auto p = find_parameter(name)) { |
| p->setValue(v); |
| return; |
| } |
| insert( |
| name, |
| parameters_, |
| EntityType::PARAMETER, |
| appendSlot(name, TensorType::get(), std::move(v))); |
| } |
| void register_attribute( |
| const std::string& name, |
| const TypePtr type, |
| IValue ivalue) { |
| insert( |
| name, |
| attributes_, |
| EntityType::ATTRIBUTE, |
| appendSlot(name, type, ivalue)); |
| } |
| void register_module( |
| const std::string& name, |
| std::shared_ptr<Module> module) { |
| // We would like to enable more stringent error checking at this point, |
| // but because script functions are considered modules, it is possible |
| // to hit this situation without knowing it. For now this is disabled |
| // until a later PR that distinguishes script functions from script modules. |
| // See TestScript.test_submodule_twice for example failure |
| // if (module->parent_) { |
| // AT_WARN( |
| // "Attempting to assign submodule '", |
| // name, |
| // "' but it is already a submodule of another ScriptModule '", |
| // module->parent_->name(), "'", " Modules of this form do not import |
| // and export correctly. This use is deprecated and may be" " removed |
| // in a future version."); |
| // } |
| module->parent_ = this; |
| module->name_ = name; |
| appendSlot(name, module->module_value_->type(), module->module_value_); |
| insert(name, modules_, EntityType::MODULE, std::move(module)); |
| } |
| |
| Slot parameter_slot(const std::string& name) const { |
| return parameters_[get_offset(name, EntityType::PARAMETER)]; |
| } |
| |
| void set_parameter(const std::string& name, at::Tensor v) { |
| parameter_slot(name).setValue(std::move(v)); |
| } |
| |
| autograd::Variable get_parameter(const std::string& name) const { |
| return autograd::as_variable_ref(parameter_slot(name).value().toTensor()); |
| } |
| |
| IValue get_attribute(const std::string& name) const { |
| return attributes_[get_offset(name, EntityType::ATTRIBUTE)].value(); |
| } |
| |
| autograd::Variable get_buffer(const std::string& name) const { |
| return autograd::as_variable_ref(get_attribute(name).toTensor()); |
| } |
| |
| // each module owns its method. The reference returned here |
| // is guarenteed to stay valid until this module has been destroyed |
| Method& get_method(const std::string& name) const { |
| if (Method* method = find_method(name)) { |
| return *method; |
| } |
| // temporary: force the error message |
| // once the on-demand creation of Method is removed, this code |
| // can be removed as well |
| get_offset(name, EntityType::METHOD); |
| AT_ERROR("unreachable"); |
| } |
| |
| std::shared_ptr<Module> get_module(const std::string& name) const { |
| return modules_[get_offset(name, EntityType::MODULE)]; |
| } |
| |
| c10::ArrayRef<std::shared_ptr<Module>> get_modules() const { |
| return modules_; |
| } |
| c10::ArrayRef<Slot> get_parameters() const { |
| return parameters_; |
| } |
| c10::ArrayRef<Slot> get_attributes() const { |
| return attributes_; |
| } |
| const std::vector<std::unique_ptr<Method>>& get_methods() const { |
| // force methods_ to be up to date by querying all |
| // methods. |
| for (const auto& m : class_compilation_unit().get_functions()) { |
| get_method(m->name()); |
| } |
| return methods_; |
| } |
| |
| Slot* find_parameter(const std::string& name) { |
| auto offset = find_offset(name, EntityType::PARAMETER); |
| return offset ? ¶meters_[*offset] : nullptr; |
| } |
| Slot* find_attribute(const std::string& name) { |
| auto offset = find_offset(name, EntityType::ATTRIBUTE); |
| return offset ? &attributes_[*offset] : nullptr; |
| } |
| Slot* find_buffer(const std::string& name) { |
| auto iv = find_attribute(name); |
| if (iv && iv->type()->isSubtypeOf(TensorType::get())) { |
| return iv; |
| } |
| return nullptr; |
| } |
| std::shared_ptr<Module> find_module(const std::string& name) { |
| auto offset = find_offset(name, EntityType::MODULE); |
| return offset ? modules_[*offset] : nullptr; |
| } |
| Method* find_method(const std::string& name) const { |
| // find_offset() method reads "dict_" object. |
| // Lock because another thread can modify "dict_" object at the same time |
| // calling insert() method. |
| // Ideally recursive_mutex should be replaced with shared_mutex (C++ 17) |
| // for the performance reasons. |
| std::unique_lock<std::recursive_mutex> keeper(create_method_guard_); |
| auto offset = find_offset(name, EntityType::METHOD); |
| if (offset) { |
| return methods_[*offset].get(); |
| } |
| |
| if (Function* fn = class_compilation_unit().find_function(name).get()) { |
| // lock because technically this is marked const, |
| // but we have to update the internal Method cache. |
| // This can be removed when class_compilation_unit() is the source of |
| // truth for methods. |
| Module* mutable_this = const_cast<Module*>(this); |
| std::unique_ptr<Method> m(new Method(mutable_this, fn)); |
| return mutable_this |
| ->insert( |
| fn->name(), |
| mutable_this->methods_, |
| EntityType::METHOD, |
| std::move(m)) |
| .get(); |
| } |
| |
| return nullptr; |
| } |
| void apply(std::function<void(Module&)> fn) { |
| for (auto& submod : get_modules()) { |
| submod->apply(fn); |
| } |
| fn(*this); |
| } |
| /// Enables "training" mode. |
| void train(bool on = true); |
| /// Calls train(false) to enable "eval" mode. |
| /// Do not override this method, override `train()` instead. |
| void eval() { |
| train(/*on=*/false); |
| } |
| /// True if the module is in training mode. |
| bool is_training() { |
| if (auto p = find_attribute("training")) { |
| return p->value().toBool(); |
| } |
| // We are in training mode by default |
| return true; |
| } |
| |
| /// Recursively casts all parameters to the given `dtype` and `device`. |
| /// |
| /// If `non_blocking` is true and the source is in pinned memory and |
| /// destination is on the GPU or vice versa, the copy is performed |
| /// asynchronously with respect to the host. Otherwise, the argument has no |
| /// effect. |
| void to(at::Device device, at::ScalarType dtype, bool non_blocking = false); |
| |
| /// Recursively casts all parameters to the given dtype. |
| /// |
| /// If `non_blocking` is true and the source is in pinned memory and |
| /// destination is on the GPU or vice versa, the copy is performed |
| /// asynchronously with respect to the host. Otherwise, the argument has no |
| /// effect. |
| void to(at::ScalarType dtype, bool non_blocking = false); |
| |
| /// Recursively moves all parameters to the given device. |
| /// |
| /// If `non_blocking` is true and the source is in pinned memory and |
| /// destination is on the GPU or vice versa, the copy is performed |
| /// asynchronously with respect to the host. Otherwise, the argument has no |
| /// effect. |
| void to(at::Device device, bool non_blocking = false); |
| |
| /// Run a method from this module. |
| /// |
| /// For example: |
| /// @code |
| /// IValue output = module->run("relu_script", a, b); |
| /// @endcode |
| /// |
| /// To get a compile a module from a source string, see torch::jit::compile |
| /// |
| /// @param method_name The name of the method to run |
| /// @param args Arguments to be passed to the method |
| /// @return An IValue containing the return value (or values if it is a tuple) |
| /// from the method |
| template <typename... Types> |
| IValue run_method(const std::string& method_name, Types&&... args) { |
| return get_method(method_name)({IValue(std::forward<Types>(args))...}); |
| } |
| |
| void save( |
| std::ostream& out, |
| const ExtraFilesMap& extra_files = ExtraFilesMap()); |
| |
| void save( |
| const std::string& filename, |
| const ExtraFilesMap& extra_files = ExtraFilesMap()); |
| |
| void copy_into( |
| const ModuleLookup& module_lookup, |
| // translate current module singleton type to new module |
| // singleton type. |
| std::unordered_map<TypePtr, TypePtr>& type_remap, |
| std::vector<std::string> names = {}) const; |
| |
| void clone_method( |
| const Module& orig, |
| const std::string& name, |
| const std::unordered_map<TypePtr, TypePtr>& type_remap); |
| |
| void clone_method(const Module& orig, const std::string& name); |
| |
| enum class EntityType { MODULE, PARAMETER, ATTRIBUTE, METHOD }; |
| |
| at::optional<EntityType> kind_of(const std::string& name) const { |
| auto it = dict_.find(name); |
| if (it == dict_.end()) { |
| // methods are lazily created, see if this is, in face, |
| // a method that has not been created yet. |
| if (auto fn = class_compilation_unit().find_function(name)) { |
| return EntityType::METHOD; |
| } |
| return at::nullopt; |
| } |
| return it->second.type; |
| } |
| |
| ModulePtr module_object() const { |
| return module_value_; |
| } |
| CompilationUnit& class_compilation_unit() { |
| return module_object()->type()->compilation_unit(); |
| } |
| const CompilationUnit& class_compilation_unit() const { |
| return module_object()->type()->compilation_unit(); |
| } |
| |
| // so that C++ users can easily add methods |
| void define(const std::string& src, const ResolverPtr& resolver = nullptr); |
| |
| template <typename... Types> |
| IValue create_class(const c10::QualifiedName& name, Types&&... args) |
| const { |
| return create_class(name, {IValue(std::forward<Types>(args))...}); |
| } |
| |
| IValue create_class(const c10::QualifiedName& name, Stack stack) const; |
| |
| private: |
| std::pair<std::shared_ptr<Function>, std::vector<Slot>> |
| lower_first_class_method(Function* fn); |
| |
| void to_impl( |
| const c10::optional<at::Device>& device, |
| const c10::optional<at::ScalarType>& dtype, |
| bool non_blocking); |
| |
| static const char* toString(EntityType t) { |
| switch (t) { |
| case EntityType::MODULE: |
| return "module"; |
| case EntityType::PARAMETER: |
| return "parameter"; |
| case EntityType::ATTRIBUTE: |
| return "attribute"; |
| case EntityType::METHOD: |
| return "method"; |
| } |
| return nullptr; |
| } |
| |
| struct Entry { |
| EntityType type; |
| size_t offset; |
| }; |
| |
| size_t get_offset(const std::string& name, EntityType expected_type) const { |
| auto it = dict_.find(name); |
| if (it == dict_.end()) { |
| AT_ERROR(toString(expected_type), " '", name, "' is not defined."); |
| } |
| if (it->second.type != expected_type) { |
| AT_ERROR( |
| "The field '", |
| name, |
| "' is a ", |
| toString(it->second.type), |
| " but this call is" |
| " trying to use it as a ", |
| toString(expected_type)); |
| } |
| return it->second.offset; |
| } |
| at::optional<size_t> find_offset( |
| const std::string& name, |
| EntityType expected_type) const { |
| auto it = dict_.find(name); |
| if (it == dict_.end() || it->second.type != expected_type) { |
| return at::nullopt; |
| } |
| return it->second.offset; |
| } |
| |
| template <typename T> |
| T& insert( |
| const std::string& name, |
| std::vector<T>& list, |
| EntityType type, |
| T value) { |
| auto it = dict_.find(name); |
| if (it != dict_.end()) { |
| if (type != it->second.type) { |
| AT_ERROR( |
| "attempting to add ", |
| toString(type), |
| " '", |
| name, |
| "' but it already exists as a ", |
| toString(it->second.type)); |
| } else { |
| AT_ERROR(toString(type), " '", name, "' already defined."); |
| } |
| } |
| dict_[name] = Entry{type, list.size()}; |
| list.emplace_back(std::move(value)); |
| return list.back(); |
| } |
| |
| // add a new entry to the singleton object that represents this |
| // Module as a first-class value in code, and update the corresponding |
| // ClassType to match. |
| Slot appendSlot(const std::string& name, TypePtr typ, IValue value) { |
| const ClassTypePtr& type = module_value_->type(); |
| type->addAttribute(name, std::move(typ)); |
| auto slot_index = type->getAttributeSlot(name); |
| module_value_->setSlot(slot_index, std::move(value)); |
| return Slot(module_value_, slot_index); |
| } |
| |
| // modules have a single namespace, but spread over 4 different concepts: |
| // parameters, attributes, methods, and sub-modules |
| // we store individual lists of each concept, and a single map to |
| // unify the namespace and ensure fast lookup |
| |
| // invariant: to ensure initial_ivalues of Methods stay valid, |
| // it is only legal to _add_ new modules and parameters. |
| // removing them will allow initial_ivalues to point to invalid parameters |
| // no such restriction exists for methods |
| std::vector<std::shared_ptr<Module>> modules_; |
| std::vector<Slot> parameters_; |
| std::vector<Slot> attributes_; |
| std::vector<std::unique_ptr<Method>> methods_; |
| |
| std::unordered_map<std::string, Entry> dict_; |
| std::string name_; |
| |
| ModulePtr module_value_; |
| |
| // back reference to parent of this Module if present |
| Module* parent_ = nullptr; |
| |
| // Currently we are in a transitionary state |
| // where we construct such first class functions but we lower them |
| // to a form where the modules does not exist before execution. |
| |
| // So each Method is actually stored twice once in first-class Module |
| // form and once in lowered form. |
| |
| // first-class: module_value_->type().compilation_unit() holds Functions that |
| // treat modules as first class. |
| |
| mutable std::recursive_mutex create_method_guard_; |
| friend struct Method; |
| }; |
| |
| } // namespace script |
| } // namespace jit |
| } // namespace torch |