blob: 4b523a37c9932cd75fe7d7d3f048ef6a9158cc3b [file] [log] [blame]
#pragma once
#include <ATen/core/builtin_function.h>
#include <ATen/core/stack.h>
#include <torch/csrc/jit/backends/backend_interface.h>
#include <torch/csrc/jit/backends/backend_resolver.h>
#include <torch/csrc/jit/frontend/code_template.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/frontend/sugared_value.h>
#include <torch/csrc/jit/python/pybind.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/custom_class.h>
namespace torch {
namespace jit {
c10::FunctionSchema getPreprocessSchema() {
c10::Argument self("self", c10::AnyType::get());
c10::Argument mod("mod", c10::AnyType::get());
c10::Argument method_compile_spec(
"method_compile_spec",
c10::DictType::create(c10::StringType::get(), c10::AnyType::get()));
c10::FunctionSchema preprocessor_schema(
"preprocess",
/*overload_name=*/"",
/*arguments=*/{self, mod, method_compile_spec},
/*returns=*/{mod});
return preprocessor_schema;
}
template <typename TBackendInterface>
std::function<void(Stack&)> getPreprocessFunc() {
return [](Stack& stack) {
auto method_compile_spec = pop(stack).toGenericDict();
auto mod = pop(stack);
auto self = pop(stack).toCustomClass<TBackendInterface>();
auto ret = self->preprocess(mod, method_compile_spec);
push(stack, ret);
};
}
c10::FunctionSchema getCompileSchema() {
c10::Argument self("self", c10::AnyType::get());
c10::Argument mod("processed", c10::AnyType::get());
auto any_dict_ty =
c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
c10::Argument method_compile_spec("method_compile_spec", any_dict_ty);
c10::Argument handles("handles", any_dict_ty);
c10::FunctionSchema compile_schema(
"compile",
/*overload_name=*/"",
/*arguments=*/{self, mod, method_compile_spec},
/*returns=*/{handles});
return compile_schema;
}
template <typename TBackendInterface>
std::function<void(Stack&)> getCompileFunc() {
return [](Stack& stack) {
auto method_compile_spec = pop(stack).toGenericDict();
auto processed = pop(stack);
auto self = pop(stack).toCustomClass<TBackendInterface>();
auto ret = self->compile(processed, method_compile_spec);
push(stack, ret);
};
}
c10::FunctionSchema getExecuteSchema() {
auto any_list_ty = c10::ListType::create(c10::AnyType::get());
c10::Argument self("self", c10::AnyType::get());
c10::Argument handle("handle", c10::AnyType::get());
c10::Argument input("input", any_list_ty);
c10::Argument output("output", any_list_ty);
return c10::FunctionSchema(
"execute",
/*overload_name=*/"",
/*arguments=*/{self, handle, input},
/*returns=*/{output});
}
template <typename TBackendInterface>
std::function<void(Stack&)> getExecuteFunc() {
return [](Stack& stack) {
auto args = pop(stack);
auto handle = pop(stack);
auto self = pop(stack);
auto backend = self.toCustomClass<TBackendInterface>();
auto res = backend->execute(handle, args.toList());
push(stack, res);
};
}
// Static registration API for backends.
template <class TBackendInterface>
class backend {
static_assert(
std::is_base_of<PyTorchBackendInterface, TBackendInterface>::value,
"torch::jit::backend_<T> requires T to inherit from PyTorchBackendInterface");
constexpr static auto kBackendsNamespace = "__backends__";
std::string backend_name_;
public:
explicit backend(const std::string& name) : backend_name_(name) {
static auto cls = torch::class_<TBackendInterface>(kBackendsNamespace, name)
.def(torch::init<>())
._def_unboxed(
"preprocess",
getPreprocessFunc<TBackendInterface>(),
getPreprocessSchema())
._def_unboxed(
"compile",
getCompileFunc<TBackendInterface>(),
getCompileSchema())
._def_unboxed(
"execute",
getExecuteFunc<TBackendInterface>(),
getExecuteSchema());
}
// Generates and returns a function that takes a Module and a lowering
// specification in the form of a dictionary. The caller is responsible for
// binding this into a CPython module.
std::function<Module(Module, py::dict)> generateToBackendFn() {
const c10::QualifiedName qual_backend_name(
{"__torch__", "torch", "classes", kBackendsNamespace, backend_name_});
const std::string backend_name = qual_backend_name.name();
return [=](Module orig_module, py::dict method_compile_spec) {
// TODO: Validate method_compile_spec.
// Clone orig_module to make sure backend transformation is
// functional.
auto cloned_module = orig_module.clone();
// Represents of a Type of Dict[str, Any].
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// Generate LoweredModule.
Module loweredModule("torch.jit." + backend_name + "LoweredModule");
// Generate attributes.
// This is the original cloned and preprocessed module.
loweredModule.register_attribute(
"__processed_module",
AnyType::get(),
cloned_module._ivalue(),
/*is_param=*/false,
/*allow_any=*/true);
// This is for the method_compile_spec passed in to to_<backend> or
// loaded from an exported model.
loweredModule.register_attribute(
"__method_compile_spec",
any_dict_ty,
toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
/*is_param=*/false,
/*allow_any=*/true);
// This is a pointer to a backend instance that is used to access
// compile and execute functions.
auto cls = getCustomClass(qual_backend_name.qualifiedName());
TORCH_INTERNAL_ASSERT(cls);
c10::intrusive_ptr<torch::CustomClassHolder> backend;
loweredModule.register_attribute(
"__backend", cls, IValue::make_capsule(backend));
// This is the list of opaque backend handles returned by
// backend.compile.
loweredModule.register_attribute(
"__handles",
any_dict_ty,
c10::impl::GenericDict(
any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
/*is_param=*/false,
/*allow_any=*/true);
// Methods.
// This is a helper function for creating a new instance of the
// backend class.
static const auto create_backend_ct = CodeTemplate(R"(
def __create_backend(self):
self.__backend = $name()
)");
TemplateEnv create_backend_te;
create_backend_te.s("name", qual_backend_name.qualifiedName());
loweredModule.define(
create_backend_ct.format(create_backend_te), loweredModuleResolver());
// getstate and setstate are for serialization/deserialization of the
// LoweredModule.
loweredModule.define(
R"(
def __getstate__(self):
return self.__method_compile_spec, self.__processed_module
)",
loweredModuleResolver());
loweredModule.define(
R"(
def __setstate__(self, state):
self.__method_compile_spec = state[0]
self.__processed_module = state[1]
self.__create_backend()
self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
)",
loweredModuleResolver());
// This is never called during compilation or execution, but is needed
// to generate the LoweredModule because we don't have access to an
// instance of the backend as a C++ object with which to call
// preprocess.
loweredModule.define(
R"(
def __preprocess(self, mod: Any, method_compile_spec: Dict[str, Any]):
self.__create_backend()
self.__processed_module = self.__backend.preprocess(mod, method_compile_spec)
)",
loweredModuleResolver());
// This loop generates one method on the LoweredModule for every key
// in method_compile_spec.
for (auto& e : method_compile_spec) {
std::string method_name = py::cast<std::string>(e.first);
static const auto method_ct = CodeTemplate(R"(
def $method(self${,def_inputs}):
typed_inputs: List[Any] = [${fwd_inputs,}]
$ret, = self.__backend.execute(self.__handles["$method"], typed_inputs)
${refine,}
return $ret
)");
TemplateEnv method_te;
method_te.s("method", method_name);
auto method = orig_module.get_method(method_name);
auto& function = method.function();
auto schema = function.getSchema();
// Generate the inputs for the function signature (def_inputs) and
// for passing to backend.execute (fwd_inputs).
std::vector<std::string> def_inputs, fwd_inputs;
for (const auto& arg : schema.arguments()) {
auto name = arg.name();
// Skip self since that is only and always present in the
// signature.
if (name == "self") {
continue;
}
auto default_value = arg.default_value();
if (arg.kwarg_only()) {
// If this is a kwarg, it needs to be emitted as keyword=value
// in the definition and keyword=keyword in the call to
// backend_execute.
TORCH_INTERNAL_ASSERT(default_value.has_value());
std::stringstream def_ss, fwd_ss;
def_ss << name << "=";
fwd_ss << name << "=" << name;
default_value->repr(
def_ss,
[](std::ostream&, const IValue&) -> bool { return false; });
def_inputs.emplace_back(def_ss.str());
fwd_inputs.emplace_back(fwd_ss.str());
} else {
// If this is not a kwarg, it should be emitted as is in the
// signature and the call to backend_execute.
def_inputs.emplace_back(name);
fwd_inputs.emplace_back(name);
}
}
// Generate a comma-delimited list of identifiers to unpack outputs, as
// well as a list of isinstance checks to make sure the backend returned
// the types it was supposed to.
std::stringstream out_ss, type_check_ss;
std::vector<std::string> type_checks;
TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
auto out_ty = schema.returns().at(0).type();
out_ss << "_0";
type_check_ss << "assert isinstance(_0, ";
if (auto out_tuple_ty = out_ty->cast<TupleType>()) {
auto tuple_elements = out_tuple_ty->elements();
type_check_ss << tuple_elements[0]->str() << ")";
type_checks.emplace_back(type_check_ss.str());
for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
type_check_ss.str(std::string());
type_check_ss.clear();
out_ss << ", _" << i;
type_check_ss << "assert isinstance(_" << i << ", "
<< tuple_elements[i]->str() << ")";
type_checks.emplace_back(type_check_ss.str());
}
} else {
type_check_ss << out_ty->str() << ")";
type_checks.emplace_back(type_check_ss.str());
}
method_te.v("def_inputs", def_inputs);
method_te.v("fwd_inputs", fwd_inputs);
method_te.v("refine", type_checks);
method_te.s("ret", out_ss.str());
loweredModule.define(
method_ct.format(method_te), loweredModuleResolver());
}
// Run preprocess so that __processed_module is set correctly before
// compilation.
loweredModule.run_method(
"__preprocess",
cloned_module._ivalue(),
toIValue(method_compile_spec, any_dict_ty).toGenericDict());
// Call __setstate__ to ensure that the returned Module is ready to
// run.
auto state = at::ivalue::Tuple::create(
toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
loweredModule.attr("__processed_module"));
loweredModule.run_method("__setstate__", state);
return loweredModule;
};
}
};
} // namespace jit
} // namespace torch