| #pragma once |
| |
| #include <ATen/core/builtin_function.h> |
| #include <ATen/core/stack.h> |
| #include <torch/csrc/jit/backends/backend_detail.h> |
| #include <torch/csrc/jit/backends/backend_interface.h> |
| #include <torch/custom_class.h> |
| |
| namespace torch { |
| namespace jit { |
| |
| // Static registration API for backends. |
| template <class TBackendInterface> |
| class backend { |
| static_assert( |
| std::is_base_of<PyTorchBackendInterface, TBackendInterface>::value, |
| "torch::jit::backend_<T> requires T to inherit from PyTorchBackendInterface"); |
| std::string backend_name_; |
| |
| public: |
| explicit backend(const std::string& name) : backend_name_(name) { |
| static auto cls = |
| torch::class_<TBackendInterface>(detail::kBackendsNamespace, name) |
| .def(torch::init<>()) |
| ._def_unboxed( |
| "preprocess", |
| detail::getPreprocessFunc<TBackendInterface>(), |
| detail::getPreprocessSchema()) |
| ._def_unboxed( |
| "compile", |
| detail::getCompileFunc<TBackendInterface>(), |
| detail::getCompileSchema()) |
| ._def_unboxed( |
| "execute", |
| detail::getExecuteFunc<TBackendInterface>(), |
| detail::getExecuteSchema()); |
| |
| // Register the backend in the registry so that a Python binding is created |
| // for it later. |
| detail::registerBackend(backend_name_); |
| } |
| }; |
| |
| } // namespace jit |
| } // namespace torch |