|  |  | 
|  | #pragma once | 
|  |  | 
|  | #include <ATen/core/function_schema.h> | 
|  | #include <ATen/core/ivalue.h> | 
|  | #include <ATen/core/jit_type.h> | 
|  | #include <ATen/core/op_registration/op_registration.h> | 
|  | #include <ATen/core/stack.h> | 
|  | #include <c10/util/C++17.h> | 
|  | #include <c10/util/Metaprogramming.h> | 
|  | #include <c10/util/TypeList.h> | 
|  | #include <pybind11/pybind11.h> | 
|  | #include <torch/csrc/jit/operator.h> | 
|  | #include <torch/csrc/jit/pybind_utils.h> | 
|  | #include <torch/csrc/jit/script/compilation_unit.h> | 
|  | #include <torch/csrc/jit/tracer.h> | 
|  | #include <torch/csrc/utils/variadic.h> | 
|  | #include <iostream> | 
|  | #include <sstream> | 
|  |  | 
|  |  | 
|  | namespace py = pybind11; | 
|  | namespace torch { | 
|  | namespace jit { | 
|  |  | 
|  | static std::vector<c10::RegisterOperators> registeredOps; | 
|  |  | 
|  | namespace detail { | 
|  | template <class R, class...> | 
|  | struct types { | 
|  | constexpr static bool hasRet = true; | 
|  | using type = types; | 
|  | }; | 
|  | template <class... args> | 
|  | struct types<void, args...> { | 
|  | constexpr static bool hasRet = false; | 
|  | using type = types; | 
|  | }; | 
|  | template <class Sig> | 
|  | struct args; | 
|  | template <class R, class CurClass, class... Args> | 
|  | struct args<R (CurClass::*)(Args...)> : types<R, Args...> {}; | 
|  | template <class Sig> | 
|  | using args_t = typename args<Sig>::type; | 
|  | } // namespace detail | 
|  | template <class... Types> | 
|  | detail::types<void, Types...> init() { return detail::types<void, Types...>{}; } | 
|  |  | 
|  | // To bind custom classes into Torchscript, use an API very similar to Pybind's. | 
|  | // Currently exposes one class `torch::jit::class_<T>` and 2 methods. | 
|  | // - Constructing `torch::jit::class_<Foo>` registers `Foo` in Python and | 
|  | // Torchscript, and puts it under `torch.classes.Foo` in Python. | 
|  | // - torch::jit::class_<Foo>.def("method1", &Foo::method1) does some template | 
|  | // metaprogramming to introspect the function types and register the operator | 
|  | // for use in Torchscript. | 
|  | // - torch::jit::class_<Foo>.def(torch::jit::init<int64_t, int64_t>()) registers | 
|  | // the Foo(int, int) constructor. | 
|  | // see test/custom_operator/classes.cpp and | 
|  | // test/custom_operator/test_custom_classes.py for example usages | 
|  |  | 
|  | template <class CurClass> | 
|  | class class_ { | 
|  | std::string className; | 
|  | std::string qualClassName; | 
|  | c10::optional<py::class_<CurClass>> pyClass = c10::nullopt; | 
|  | std::shared_ptr<script::CompilationUnit> classCu = nullptr; | 
|  | ClassTypePtr classTypePtr; | 
|  |  | 
|  | const std::string parentModule = "classes"; | 
|  | const std::string topModule = "__torch__.torch"; | 
|  |  | 
|  | public: | 
|  | class_(string className_) : className(std::move(className_)) { | 
|  | // Currently we register everything as a python class just for convenience. | 
|  | // We'll want to remove this at some point to get rid of the python | 
|  | // dependency. It would require significant changes to class registration, | 
|  | // (I think)? | 
|  | qualClassName = topModule + "." + parentModule + "." + className; | 
|  |  | 
|  | auto obj = py::module::import("torch").attr(parentModule.c_str()); | 
|  | pyClass = py::class_<CurClass>(obj, className.c_str()); | 
|  | pyClass->attr("qualified_name") = py::str(qualClassName); | 
|  | auto newClass = | 
|  | py::module::import("torch.jit") | 
|  | .attr("_add_script_class")(*pyClass, qualClassName.c_str()); | 
|  |  | 
|  | auto castToPython = [](void* objPtr) -> PyObject* { | 
|  | CurClass x = *static_cast<CurClass*>(objPtr); | 
|  | auto py_object = py::cast(x); | 
|  | PyObject* rawPyObj = py_object.release().ptr(); | 
|  | return rawPyObj; | 
|  | }; | 
|  | getClassConverter()[qualClassName] = castToPython; | 
|  |  | 
|  | // We currently represent custom classes as torchscript classes with a | 
|  | // capsule attribute | 
|  | classCu = torch::jit::get_python_cu(); | 
|  | classTypePtr = | 
|  | ClassType::create(c10::QualifiedName(qualClassName), classCu); | 
|  | classTypePtr->addAttribute("capsule", CapsuleType::get()); | 
|  |  | 
|  | c10::getCustomClassTypeMap().insert({typeid(c10::intrusive_ptr<CurClass>).name(), | 
|  | StrongTypePtr(classCu, classTypePtr)}); | 
|  | c10::getCustomClassTypeMap().insert({typeid(c10::tagged_capsule<CurClass>).name(), | 
|  | StrongTypePtr(classCu, classTypePtr)}); | 
|  |  | 
|  | classCu->register_type(classTypePtr); | 
|  | } | 
|  |  | 
|  | template <typename... Types> | 
|  | class_& def(detail::types<void, Types...>) { // Used in combination with | 
|  | // torch::jit::init<...>() | 
|  | pyClass->def(py::init<Types...>()); | 
|  |  | 
|  | auto func = [](c10::tagged_capsule<CurClass> self, Types... args) { | 
|  | auto classObj = c10::make_intrusive<CurClass>(args...); | 
|  | auto genericPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(classObj); | 
|  | auto capsule = IValue(genericPtr); | 
|  | auto object = self.ivalue.toObject(); | 
|  | object->setSlot(0, capsule); | 
|  | }; | 
|  |  | 
|  | defineMethod<void>("__init__", std::move(func), false); | 
|  | return *this; | 
|  | } | 
|  | template <typename Func> | 
|  | class_& def(string name, Func f) { | 
|  | auto res = def_(name, f, detail::args_t<decltype(f)>{}); | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | private: | 
|  | template <class T> | 
|  | struct addInput { | 
|  | static Value* call(std::shared_ptr<Graph> graph) { | 
|  | return graph->addInput()->setType(getTypePtr<T>()); | 
|  | } | 
|  | }; | 
|  | template <class Func, size_t... arg_indices> | 
|  | std::vector<Value*> addInputs_( | 
|  | Func f, | 
|  | std::shared_ptr<Graph> graph, | 
|  | guts::index_sequence<arg_indices...>) { | 
|  | using argTypes = | 
|  | typename guts::infer_function_traits_t<Func>::parameter_types; | 
|  | std::vector<Value*> res = { | 
|  | addInput<guts::typelist::element_t<arg_indices, argTypes>>::call( | 
|  | graph)...}; | 
|  | return res; | 
|  | } | 
|  | template <class Func> | 
|  | std::vector<Value*> addInputs(Func f, std::shared_ptr<Graph> graph) { | 
|  | constexpr auto numArgs = | 
|  | guts::infer_function_traits_t<Func>::number_of_parameters; | 
|  | return addInputs_(f, graph, guts::make_index_sequence<numArgs>()); | 
|  | } | 
|  |  | 
|  | template <typename Last> | 
|  | std::string type_name() { | 
|  | return std::string(typeid(Last).name()); | 
|  | } | 
|  | template <typename First, typename Second, typename... Rest> | 
|  | std::string type_name() { | 
|  | return type_name<First>() + "_" + type_name<Second, Rest...>(); | 
|  | } | 
|  |  | 
|  | template <class T> | 
|  | void addType(Value* v) { | 
|  | v->setType(getTypePtr<T>()); | 
|  | } | 
|  | template<typename R, typename Func> | 
|  | void defineMethod(std::string name, Func func, bool hasRet) { | 
|  | auto graph = std::make_shared<Graph>(); | 
|  | auto qualFuncName = className + "::" + name; | 
|  | registeredOps.push_back( | 
|  | torch::RegisterOperators().op(qualFuncName, std::move(func))); | 
|  |  | 
|  |  | 
|  | std::vector<Value*> inputs = addInputs(func, graph); | 
|  | auto methodCall = graph->insertNode(graph->create( | 
|  | Symbol::fromQualString(qualFuncName), inputs, hasRet)); | 
|  | Value* res; | 
|  | if (hasRet) { | 
|  | res = methodCall->output(); | 
|  | addType<R>(res); | 
|  | } else { | 
|  | res = graph->insertConstant(IValue())->setType(NoneType::get()); | 
|  | } | 
|  | graph->registerOutput(res); | 
|  |  | 
|  | auto method = classCu->create_function(qualClassName + "." + name, graph); | 
|  | classTypePtr->addMethod(method); | 
|  | } | 
|  | template <typename Func, typename R, typename... Types> | 
|  | class_& def_(string name, Func f, detail::types<R, Types...> funcInfo) { | 
|  | pyClass->def(name.c_str(), f); | 
|  |  | 
|  | auto func = [f](c10::intrusive_ptr<CurClass> cur, Types... args) { | 
|  | return guts::invoke(f, *cur, args...); | 
|  | }; | 
|  | defineMethod<R>(name, std::move(func), funcInfo.hasRet); | 
|  | return *this; | 
|  | } | 
|  | }; | 
|  |  | 
|  | } // namespace jit | 
|  |  | 
|  | } // namespace torch |