| #pragma once |
| #include "torch/csrc/jit/ir.h" |
| #include "torch/csrc/jit/graph_executor.h" |
| #include "torch/csrc/autograd/variable.h" |
| #include "torch/csrc/jit/passes/shape_analysis.h" |
| #include "torch/csrc/jit/argument_spec.h" |
| #include <ATen/optional.h> |
| #include <functional> |
| |
| // This file contains classes which assist in desugaring Python style |
| // modules and their methods into flattened graphs which don't have any |
| // function calls. |
| |
| namespace torch { namespace jit { namespace script { |
| |
| // A method in a module, e.g. f in: |
| // |
| // class M(ScriptModule): |
| // @script_method |
| // def f(self, x): |
| // ... |
| // Note: because Method/Module are exposed to python these |
| // classes use python method naming conventions |
| |
| struct SourceRange; |
| |
| struct Method { |
| Method(std::string name, bool optimize, |
| std::shared_ptr<Graph> graph, |
| std::vector<at::Tensor*> initial_members, |
| std::function<void(Method&)> method_creator) |
| : name_(std::move(name)) |
| , graph_(std::move(graph)) |
| , optimize(optimize) |
| , member_inputs(std::move(initial_members)) |
| , method_creator(method_creator) { |
| JIT_ASSERT(graph_->inputs().size() >= member_inputs.size()); |
| int i = graph_->inputs().size() - member_inputs.size(); |
| for(at::Tensor* member : member_inputs) { |
| member_input_index[member] = i++; |
| } |
| } |
| |
| variable_tensor_list run(variable_tensor_list && inputs) { |
| std::call_once(executor_init, [&]{ |
| executor = GraphExecutor(graph(), optimize); |
| }); |
| for(auto tp : member_inputs) { |
| inputs.push_back(*tp); |
| } |
| return executor.run(std::move(inputs)); |
| } |
| std::shared_ptr<Graph> graph() const { |
| return graph_; |
| } |
| |
| const std::string & name() const { |
| return name_; |
| } |
| // emit a function call by inlining the callees Graph into this one |
| // adding any extra parameters necessary to do this call |
| |
| // defined here to keep details of member_input handling confined to this class |
| std::vector<Value*> emit_call_to(SourceRange loc, Method & callee, ArrayRef<Value*> inputs); |
| // if this isn't yet defined, run its method_creator function |
| void ensure_defined(); |
| |
| |
| size_t num_inputs() const { |
| return graph()->inputs().size() - member_inputs.size(); |
| } |
| Value * get_or_add_parameter(at::Tensor* slot) { |
| auto it = member_input_index.find(slot); |
| if(it != member_input_index.end()) { |
| return graph()->inputs().at(it->second); |
| } |
| // add it as a new parameter |
| member_inputs.push_back(slot); |
| member_input_index[slot] = graph()->inputs().size(); |
| return graph()->addInput(); |
| } |
| |
| std::shared_ptr<Graph> propagate_shapes(std::vector<at::Tensor> inputs, bool with_grad=false) { |
| auto retval = graph_->copy(); |
| for (auto inp : member_inputs) { |
| inputs.push_back(*inp); |
| } |
| PropagateInputShapes(*retval, ArgumentSpec(with_grad, variable_tensor_list(std::move(inputs)))); |
| return retval; |
| } |
| |
| std::vector<at::Tensor*> params() { |
| return member_inputs; |
| } |
| |
| private: |
| std::string name_; |
| std::shared_ptr<Graph> graph_; // for debugging and for inlining |
| bool optimize; |
| GraphExecutor executor; // for execution |
| // member_inputs are a list of additional arguments appended to graph that are |
| // inputs that come from the members of the Module or its submodules. |
| // each is a pointer to a slot in the module that owns this parameter |
| // parameters and submodules can only be _added_ to script Modules to ensure |
| // these pointers always stay valid |
| std::vector<at::Tensor*> member_inputs; |
| |
| // map from a at::Tensor* in member_inputs to the offset it appears at |
| // in graph. used to accelerate get_or_add_parameter |
| std::unordered_map<at::Tensor*, size_t> member_input_index; |
| |
| // TODO: support that case where we allow _writes_ to parameters from |
| // compiled functions. |
| // This requires more sophisticated tracking of ssa values in Graphs so that |
| // stores to all modules can be lifted to the end of a graph execution. |
| // It also adds more complexity to adding actual module invocations |
| // to the executor, so currently it is not done. |
| // std::vector<at::Tensor*> member_outputs; |
| |
| std::once_flag executor_init; |
| |
| // an optional function that actually creates the method when emit_call_to(this,...) |
| // is first called. |
| // this is used by the compiler so that it can construct methods out of order |
| std::function<void(Method&)> method_creator; |
| }; |
| |
| struct Module; |
| |
| struct NamedModule { |
| std::string name; |
| std::shared_ptr<Module> module; |
| }; |
| |
| struct NamedParameter { |
| NamedParameter(std::string name, at::Tensor tensor, bool is_buffer) |
| : name(std::move(name)) |
| , is_buffer(is_buffer) |
| , parameter(new at::Tensor(std::move(tensor))) {} |
| |
| const std::string name; |
| bool is_buffer; // buffers are part of the module state but |
| // are not modified by optimizers during SGD |
| at::Tensor* slot() const { |
| return parameter.get(); |
| } |
| private: |
| // the extra level of indirection allows Methods to safely store pointers |
| // to the slots where parameters are kept while also allow parameters |
| // to be reassigned |
| std::unique_ptr<at::Tensor> parameter; |
| }; |
| |
| // simple ordered dict used only in Module |
| // contains only the minimum necessary functionality for Module |
| template<typename T> |
| struct OrderedDict { |
| OrderedDict(const char * what) |
| : what(what) {} |
| // note: slight difference from python here. |
| // we do not allow for insertion of an already existing value, |
| // because we not allow allow methods or submodules to be updated |
| // once created |
| T& insert(const std::string& name, T&& value) { |
| if(index_.count(name) != 0) { |
| std::stringstream ss; |
| ss << "module " << what << "'" << name << "' already defined."; |
| throw std::runtime_error(ss.str()); |
| } |
| values_.push_back(std::move(value)); |
| index_[name] = values_.size() - 1; |
| return values_.back(); |
| } |
| at::optional<T&> find(const std::string& str) { |
| auto it = index_.find(str); |
| if(it == index_.end()) |
| return at::nullopt; |
| return at::optional<T&>(values_.at(it->second)); |
| } |
| at::optional<const T&> find(const std::string& str) const { |
| auto it = index_.find(str); |
| if(it == index_.end()) |
| return at::nullopt; |
| return at::optional<const T&>(values_.at(it->second)); |
| } |
| T& get(const std::string& name) { |
| if(auto v = find(name)) { |
| return *v; |
| } |
| std::stringstream ss; |
| ss << "module " << what << "'" << name << "' is not defined."; |
| throw std::runtime_error(ss.str()); |
| } |
| const T& get(const std::string& name) const { |
| if(auto v = find(name)) { |
| return *v; |
| } |
| std::stringstream ss; |
| ss << "module " << what << "'" << name << "' is not defined."; |
| throw std::runtime_error(ss.str()); |
| } |
| const std::vector<T>& values() const { |
| return values_; |
| } |
| private: |
| std::unordered_map<std::string, size_t> index_; |
| std::vector<T> values_; |
| const char * what; |
| }; |
| |
| struct Module : public std::enable_shared_from_this<Module> { |
| TH_DISALLOW_COPY_AND_ASSIGN(Module); |
| Module() |
| : modules("modules") |
| , parameters("parameters") |
| , methods("methods") |
| , optimize(true) {} |
| |
| // note this doesn't change the flags of existing methods just ones |
| // added afterward. |
| void set_optimized(bool o) { |
| optimize = o; |
| } |
| |
| void register_parameter(const std::string & name, autograd::Variable v, bool is_buffer) { |
| if(auto p = parameters.find(name)){ |
| *p->slot() = v; |
| p->is_buffer = is_buffer; |
| return; |
| } |
| parameters.insert(name, NamedParameter(name, std::move(v), is_buffer)); |
| } |
| void register_module(const std::string& name, std::shared_ptr<Module> module) { |
| modules.insert(name, {name, std::move(module)}); |
| } |
| |
| Method& create_method(const std::string & name, std::shared_ptr<Graph> graph, std::vector<at::Tensor*> member_inputs) { |
| JIT_ASSERT(graph); |
| std::unique_ptr<Method> method(new Method(name, optimize, std::move(graph), std::move(member_inputs), nullptr)); |
| return *methods.insert(name, std::move(method)); |
| } |
| |
| Method& create_method(const std::string & name, std::function<void(Method&)> creator) { |
| std::unique_ptr<Method> method(new Method(name, optimize, std::make_shared<Graph>(), {}, creator)); |
| return *methods.insert(name, std::move(method)); |
| } |
| |
| at::Tensor* parameter_slot(const std::string & name) const { |
| return parameters.get(name).slot(); |
| } |
| |
| void set_parameter(const std::string & name, at::Tensor v) { |
| *parameter_slot(name) = std::move(v); |
| } |
| |
| autograd::Variable get_parameter(const std::string& name) const { |
| return static_cast<autograd::Variable&>(*parameter_slot(name)); |
| } |
| |
| // each module owns its method. The reference returned here |
| // is guarenteed to stay valid until this module has been destoryed |
| Method& get_method(const std::string& name) const { |
| return *methods.get(name); |
| } |
| |
| std::shared_ptr<Module> get_module(const std::string& name) const { |
| return modules.get(name).module; |
| } |
| |
| const std::vector<NamedModule>& get_modules() const { |
| return modules.values(); |
| } |
| const std::vector<NamedParameter>& get_parameters() const { |
| return parameters.values(); |
| } |
| const std::vector<std::unique_ptr<Method>>& get_methods() const { |
| return methods.values(); |
| } |
| |
| |
| at::optional<NamedParameter&> find_parameter(const std::string& name) { |
| return parameters.find(name); |
| } |
| at::optional<NamedModule&> find_module(const std::string& name) { |
| return modules.find(name); |
| } |
| at::optional<Method&> find_method(const std::string& name) { |
| if(auto pm = methods.find(name)) |
| return at::optional<Method&>(**pm); |
| return at::nullopt; |
| } |
| |
| |
| private: |
| |
| // invariant: to ensure member_inputs of Methods stay valid, |
| // it is only legal to _add_ new modules and parameters. |
| // removing them will allow member_inputs to point to invalid parameters |
| // no such restriction exists for methods |
| OrderedDict<NamedModule> modules; |
| OrderedDict<NamedParameter> parameters; |
| OrderedDict<std::unique_ptr<Method>> methods; |
| bool optimize; |
| }; |
| |
| }}} |