| #pragma once |
| |
| #include <ATen/core/builtin_function.h> |
| #include <ATen/core/stack.h> |
| #include <torch/csrc/jit/backends/backend_interface.h> |
| #include <torch/csrc/jit/backends/backend_resolver.h> |
| #include <torch/csrc/jit/frontend/code_template.h> |
| #include <torch/csrc/jit/frontend/resolver.h> |
| #include <torch/csrc/jit/frontend/sugared_value.h> |
| #include <torch/csrc/jit/python/pybind.h> |
| #include <torch/csrc/utils/pybind.h> |
| #include <torch/custom_class.h> |
| |
| namespace torch { |
| namespace jit { |
| |
| c10::FunctionSchema getPreprocessSchema() { |
| c10::Argument self("self", c10::AnyType::get()); |
| c10::Argument mod("mod", c10::AnyType::get()); |
| c10::Argument method_compile_spec( |
| "method_compile_spec", |
| c10::DictType::create(c10::StringType::get(), c10::AnyType::get())); |
| |
| c10::FunctionSchema preprocessor_schema( |
| "preprocess", |
| /*overload_name=*/"", |
| /*arguments=*/{self, mod, method_compile_spec}, |
| /*returns=*/{mod}); |
| return preprocessor_schema; |
| } |
| |
| template <typename TBackendInterface> |
| std::function<void(Stack&)> getPreprocessFunc() { |
| return [](Stack& stack) { |
| auto method_compile_spec = pop(stack).toGenericDict(); |
| auto mod = pop(stack); |
| auto self = pop(stack).toCustomClass<TBackendInterface>(); |
| auto ret = self->preprocess(mod, method_compile_spec); |
| push(stack, ret); |
| }; |
| } |
| |
| c10::FunctionSchema getCompileSchema() { |
| c10::Argument self("self", c10::AnyType::get()); |
| c10::Argument mod("processed", c10::AnyType::get()); |
| auto any_dict_ty = |
| c10::DictType::create(c10::StringType::get(), c10::AnyType::get()); |
| c10::Argument method_compile_spec("method_compile_spec", any_dict_ty); |
| c10::Argument handles("handles", any_dict_ty); |
| |
| c10::FunctionSchema compile_schema( |
| "compile", |
| /*overload_name=*/"", |
| /*arguments=*/{self, mod, method_compile_spec}, |
| /*returns=*/{handles}); |
| return compile_schema; |
| } |
| |
| template <typename TBackendInterface> |
| std::function<void(Stack&)> getCompileFunc() { |
| return [](Stack& stack) { |
| auto method_compile_spec = pop(stack).toGenericDict(); |
| auto processed = pop(stack); |
| auto self = pop(stack).toCustomClass<TBackendInterface>(); |
| auto ret = self->compile(processed, method_compile_spec); |
| push(stack, ret); |
| }; |
| } |
| |
| c10::FunctionSchema getExecuteSchema() { |
| auto any_list_ty = c10::ListType::create(c10::AnyType::get()); |
| c10::Argument self("self", c10::AnyType::get()); |
| c10::Argument handle("handle", c10::AnyType::get()); |
| c10::Argument input("input", any_list_ty); |
| c10::Argument output("output", any_list_ty); |
| return c10::FunctionSchema( |
| "execute", |
| /*overload_name=*/"", |
| /*arguments=*/{self, handle, input}, |
| /*returns=*/{output}); |
| } |
| |
| template <typename TBackendInterface> |
| std::function<void(Stack&)> getExecuteFunc() { |
| return [](Stack& stack) { |
| auto args = pop(stack); |
| auto handle = pop(stack); |
| auto self = pop(stack); |
| auto backend = self.toCustomClass<TBackendInterface>(); |
| auto res = backend->execute(handle, args.toList()); |
| push(stack, res); |
| }; |
| } |
| |
| // Static registration API for backends. |
| template <class TBackendInterface> |
| class backend { |
| static_assert( |
| std::is_base_of<PyTorchBackendInterface, TBackendInterface>::value, |
| "torch::jit::backend_<T> requires T to inherit from PyTorchBackendInterface"); |
| constexpr static auto kBackendsNamespace = "__backends__"; |
| std::string backend_name_; |
| |
| public: |
| explicit backend(const std::string& name) : backend_name_(name) { |
| static auto cls = torch::class_<TBackendInterface>(kBackendsNamespace, name) |
| .def(torch::init<>()) |
| ._def_unboxed( |
| "preprocess", |
| getPreprocessFunc<TBackendInterface>(), |
| getPreprocessSchema()) |
| ._def_unboxed( |
| "compile", |
| getCompileFunc<TBackendInterface>(), |
| getCompileSchema()) |
| ._def_unboxed( |
| "execute", |
| getExecuteFunc<TBackendInterface>(), |
| getExecuteSchema()); |
| } |
| |
| // Generates and returns a function that takes a Module and a lowering |
| // specification in the form of a dictionary. The caller is responsible for |
| // binding this into a CPython module. |
| std::function<Module(Module, py::dict)> generateToBackendFn() { |
| const c10::QualifiedName qual_backend_name( |
| {"__torch__", "torch", "classes", kBackendsNamespace, backend_name_}); |
| const std::string backend_name = qual_backend_name.name(); |
| |
| return [=](Module orig_module, py::dict method_compile_spec) { |
| // TODO: Validate method_compile_spec. |
| |
| // Clone orig_module to make sure backend transformation is |
| // functional. |
| auto cloned_module = orig_module.clone(); |
| |
| // Represents of a Type of Dict[str, Any]. |
| auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
| |
| // Generate LoweredModule. |
| Module loweredModule("torch.jit." + backend_name + "LoweredModule"); |
| |
| // Generate attributes. |
| // This is the original cloned and preprocessed module. |
| loweredModule.register_attribute( |
| "__processed_module", |
| AnyType::get(), |
| cloned_module._ivalue(), |
| /*is_param=*/false, |
| /*allow_any=*/true); |
| |
| // This is for the method_compile_spec passed in to to_<backend> or |
| // loaded from an exported model. |
| loweredModule.register_attribute( |
| "__method_compile_spec", |
| any_dict_ty, |
| toIValue(method_compile_spec, any_dict_ty).toGenericDict(), |
| /*is_param=*/false, |
| /*allow_any=*/true); |
| |
| // This is a pointer to a backend instance that is used to access |
| // compile and execute functions. |
| auto cls = getCustomClass(qual_backend_name.qualifiedName()); |
| TORCH_INTERNAL_ASSERT(cls); |
| c10::intrusive_ptr<torch::CustomClassHolder> backend; |
| loweredModule.register_attribute( |
| "__backend", cls, IValue::make_capsule(backend)); |
| |
| // This is the list of opaque backend handles returned by |
| // backend.compile. |
| loweredModule.register_attribute( |
| "__handles", |
| any_dict_ty, |
| c10::impl::GenericDict( |
| any_dict_ty->getKeyType(), any_dict_ty->getValueType()), |
| /*is_param=*/false, |
| /*allow_any=*/true); |
| |
| // Methods. |
| |
| // This is a helper function for creating a new instance of the |
| // backend class. |
| static const auto create_backend_ct = CodeTemplate(R"( |
| def __create_backend(self): |
| self.__backend = $name() |
| )"); |
| TemplateEnv create_backend_te; |
| create_backend_te.s("name", qual_backend_name.qualifiedName()); |
| loweredModule.define( |
| create_backend_ct.format(create_backend_te), loweredModuleResolver()); |
| |
| // getstate and setstate are for serialization/deserialization of the |
| // LoweredModule. |
| loweredModule.define( |
| R"( |
| def __getstate__(self): |
| return self.__method_compile_spec, self.__processed_module |
| )", |
| loweredModuleResolver()); |
| |
| loweredModule.define( |
| R"( |
| def __setstate__(self, state): |
| self.__method_compile_spec = state[0] |
| self.__processed_module = state[1] |
| self.__create_backend() |
| self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec) |
| )", |
| loweredModuleResolver()); |
| |
| // This is never called during compilation or execution, but is needed |
| // to generate the LoweredModule because we don't have access to an |
| // instance of the backend as a C++ object with which to call |
| // preprocess. |
| loweredModule.define( |
| R"( |
| def __preprocess(self, mod: Any, method_compile_spec: Dict[str, Any]): |
| self.__create_backend() |
| self.__processed_module = self.__backend.preprocess(mod, method_compile_spec) |
| )", |
| loweredModuleResolver()); |
| |
| // This loop generates one method on the LoweredModule for every key |
| // in method_compile_spec. |
| for (auto& e : method_compile_spec) { |
| std::string method_name = py::cast<std::string>(e.first); |
| static const auto method_ct = CodeTemplate(R"( |
| def $method(self${,def_inputs}): |
| typed_inputs: List[Any] = [${fwd_inputs,}] |
| $ret, = self.__backend.execute(self.__handles["$method"], typed_inputs) |
| ${refine,} |
| return $ret |
| )"); |
| |
| TemplateEnv method_te; |
| method_te.s("method", method_name); |
| auto method = orig_module.get_method(method_name); |
| auto& function = method.function(); |
| auto schema = function.getSchema(); |
| |
| // Generate the inputs for the function signature (def_inputs) and |
| // for passing to backend.execute (fwd_inputs). |
| std::vector<std::string> def_inputs, fwd_inputs; |
| for (const auto& arg : schema.arguments()) { |
| auto name = arg.name(); |
| |
| // Skip self since that is only and always present in the |
| // signature. |
| if (name == "self") { |
| continue; |
| } |
| |
| auto default_value = arg.default_value(); |
| |
| if (arg.kwarg_only()) { |
| // If this is a kwarg, it needs to be emitted as keyword=value |
| // in the definition and keyword=keyword in the call to |
| // backend_execute. |
| TORCH_INTERNAL_ASSERT(default_value.has_value()); |
| std::stringstream def_ss, fwd_ss; |
| def_ss << name << "="; |
| fwd_ss << name << "=" << name; |
| default_value->repr( |
| def_ss, |
| [](std::ostream&, const IValue&) -> bool { return false; }); |
| def_inputs.emplace_back(def_ss.str()); |
| fwd_inputs.emplace_back(fwd_ss.str()); |
| } else { |
| // If this is not a kwarg, it should be emitted as is in the |
| // signature and the call to backend_execute. |
| def_inputs.emplace_back(name); |
| fwd_inputs.emplace_back(name); |
| } |
| } |
| |
| // Generate a comma-delimited list of identifiers to unpack outputs, as |
| // well as a list of isinstance checks to make sure the backend returned |
| // the types it was supposed to. |
| std::stringstream out_ss, type_check_ss; |
| std::vector<std::string> type_checks; |
| TORCH_INTERNAL_ASSERT(schema.returns().size() == 1); |
| auto out_ty = schema.returns().at(0).type(); |
| |
| out_ss << "_0"; |
| type_check_ss << "assert isinstance(_0, "; |
| |
| if (auto out_tuple_ty = out_ty->cast<TupleType>()) { |
| auto tuple_elements = out_tuple_ty->elements(); |
| type_check_ss << tuple_elements[0]->str() << ")"; |
| type_checks.emplace_back(type_check_ss.str()); |
| for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) { |
| type_check_ss.str(std::string()); |
| type_check_ss.clear(); |
| out_ss << ", _" << i; |
| type_check_ss << "assert isinstance(_" << i << ", " |
| << tuple_elements[i]->str() << ")"; |
| type_checks.emplace_back(type_check_ss.str()); |
| } |
| } else { |
| type_check_ss << out_ty->str() << ")"; |
| type_checks.emplace_back(type_check_ss.str()); |
| } |
| |
| method_te.v("def_inputs", def_inputs); |
| method_te.v("fwd_inputs", fwd_inputs); |
| method_te.v("refine", type_checks); |
| method_te.s("ret", out_ss.str()); |
| |
| loweredModule.define( |
| method_ct.format(method_te), loweredModuleResolver()); |
| } |
| |
| // Run preprocess so that __processed_module is set correctly before |
| // compilation. |
| loweredModule.run_method( |
| "__preprocess", |
| cloned_module._ivalue(), |
| toIValue(method_compile_spec, any_dict_ty).toGenericDict()); |
| |
| // Call __setstate__ to ensure that the returned Module is ready to |
| // run. |
| auto state = at::ivalue::Tuple::create( |
| toIValue(method_compile_spec, any_dict_ty).toGenericDict(), |
| loweredModule.attr("__processed_module")); |
| loweredModule.run_method("__setstate__", state); |
| return loweredModule; |
| }; |
| } |
| }; |
| |
| } // namespace jit |
| } // namespace torch |