| #pragma once | 
 |  | 
 | #include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h> | 
 | #include <ATen/core/function.h> | 
 | #include <c10/util/Metaprogramming.h> | 
 | #include <c10/util/TypeTraits.h> | 
 | #include <c10/util/irange.h> | 
 |  | 
 | namespace torch { | 
 |  | 
 | namespace detail { | 
 | /** | 
 |  * In the Facebook internal build (using BUCK), this macro is enabled by | 
 |  * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer | 
 |  * binary. | 
 |  */ | 
 | #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE | 
 | TORCH_API void record_custom_class(std::string name); | 
 |  | 
 | /** | 
 |  * Record an instance of a custom class being loaded | 
 |  * grab portion of string after final '.' from qualified name | 
 |  * as this seemingly aligns with how users name their custom classes | 
 |  * example: __torch__.torch.classes.xnnpack.Conv2dOpContext | 
 |  */ | 
 | #define RECORD_CUSTOM_CLASS(NAME) \ | 
 |   auto name = std::string(NAME);  \ | 
 |   detail::record_custom_class(name.substr(name.find_last_of(".") + 1)); | 
 | #else | 
 | #define RECORD_CUSTOM_CLASS(NAME) | 
 | #endif | 
 | } // namespace detail | 
 |  | 
 | /// This struct is used to represent default values for arguments | 
 | /// when registering methods for custom classes. | 
 | ///     static auto register_foo = torch::class_<Foo>("myclasses", "Foo") | 
 | ///       .def("myMethod", &Foo::myMethod, {torch::arg("name") = name}); | 
 | struct arg { | 
 |   // Static method for representing a default value of None. This is meant to | 
 |   // be used like so: | 
 |   //     torch::arg("name") = torch::arg::none | 
 |   // and is identical to: | 
 |   //     torch::arg("name") = IValue() | 
 |   static c10::IValue none() { | 
 |     return c10::IValue(); | 
 |   } | 
 |  | 
 |   // Explicit constructor. | 
 |   explicit arg(std::string name) | 
 |       : name_(std::move(name)), value_(c10::nullopt) {} | 
 |   // Assignment operator. This enables the pybind-like syntax of | 
 |   // torch::arg("name") = value. | 
 |   arg& operator=(const c10::IValue& rhs) { | 
 |     value_ = rhs; | 
 |     return *this; | 
 |   } | 
 |  | 
 |   // The name of the argument. This is copied to the schema; argument | 
 |   // names cannot be extracted from the C++ declaration. | 
 |   std::string name_; | 
 |   // IValue's default constructor makes it None, which is not distinguishable | 
 |   // from an actual, user-provided default value that is None. This boolean | 
 |   // helps distinguish between the two cases. | 
 |   c10::optional<c10::IValue> value_; | 
 | }; | 
 |  | 
 | namespace detail { | 
 |  | 
 | // Argument type utilities | 
 | template <class R, class...> | 
 | struct types { | 
 |   using type = types; | 
 | }; | 
 |  | 
 | template <typename Method> | 
 | struct WrapMethod; | 
 |  | 
 | template <typename R, typename CurrClass, typename... Args> | 
 | struct WrapMethod<R (CurrClass::*)(Args...)> { | 
 |   WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {} | 
 |  | 
 |   R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) { | 
 |     return c10::guts::invoke(m, *cur, args...); | 
 |   } | 
 |  | 
 |   R (CurrClass::*m)(Args...); | 
 | }; | 
 |  | 
 | template <typename R, typename CurrClass, typename... Args> | 
 | struct WrapMethod<R (CurrClass::*)(Args...) const> { | 
 |   WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {} | 
 |  | 
 |   R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) { | 
 |     return c10::guts::invoke(m, *cur, args...); | 
 |   } | 
 |  | 
 |   R (CurrClass::*m)(Args...) const; | 
 | }; | 
 |  | 
 | // Adapter for different callable types | 
 | template < | 
 |     typename CurClass, | 
 |     typename Func, | 
 |     std::enable_if_t< | 
 |         std::is_member_function_pointer<std::decay_t<Func>>::value, | 
 |         bool> = false> | 
 | WrapMethod<Func> wrap_func(Func f) { | 
 |   return WrapMethod<Func>(std::move(f)); | 
 | } | 
 |  | 
 | template < | 
 |     typename CurClass, | 
 |     typename Func, | 
 |     std::enable_if_t< | 
 |         !std::is_member_function_pointer<std::decay_t<Func>>::value, | 
 |         bool> = false> | 
 | Func wrap_func(Func f) { | 
 |   return f; | 
 | } | 
 |  | 
 | template < | 
 |     class Functor, | 
 |     bool AllowDeprecatedTypes, | 
 |     size_t... ivalue_arg_indices> | 
 | typename c10::guts::infer_function_traits_t<Functor>::return_type | 
 | call_torchbind_method_from_stack( | 
 |     Functor& functor, | 
 |     jit::Stack& stack, | 
 |     std::index_sequence<ivalue_arg_indices...>) { | 
 |   (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would | 
 |                  // be unused and we have to silence the compiler warning. | 
 |  | 
 |   constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices); | 
 |  | 
 |   using IValueArgTypes = | 
 |       typename c10::guts::infer_function_traits_t<Functor>::parameter_types; | 
 |   // TODO We shouldn't use c10::impl stuff directly here. We should use the | 
 |   // KernelFunction API instead. | 
 |   return (functor)(c10::impl::ivalue_to_arg< | 
 |                    typename c10::impl::decay_if_not_tensor< | 
 |                        c10::guts::typelist:: | 
 |                            element_t<ivalue_arg_indices, IValueArgTypes>>::type, | 
 |                    AllowDeprecatedTypes>:: | 
 |                        call(torch::jit::peek( | 
 |                            stack, ivalue_arg_indices, num_ivalue_args))...); | 
 | } | 
 |  | 
 | template <class Functor, bool AllowDeprecatedTypes> | 
 | typename c10::guts::infer_function_traits_t<Functor>::return_type | 
 | call_torchbind_method_from_stack(Functor& functor, jit::Stack& stack) { | 
 |   constexpr size_t num_ivalue_args = | 
 |       c10::guts::infer_function_traits_t<Functor>::number_of_parameters; | 
 |   return call_torchbind_method_from_stack<Functor, AllowDeprecatedTypes>( | 
 |       functor, stack, std::make_index_sequence<num_ivalue_args>()); | 
 | } | 
 |  | 
 | template <class RetType, class Func> | 
 | struct BoxedProxy; | 
 |  | 
 | template <class RetType, class Func> | 
 | struct BoxedProxy { | 
 |   void operator()(jit::Stack& stack, Func& func) { | 
 |     auto retval = call_torchbind_method_from_stack<Func, false>(func, stack); | 
 |     constexpr size_t num_ivalue_args = | 
 |         c10::guts::infer_function_traits_t<Func>::number_of_parameters; | 
 |     torch::jit::drop(stack, num_ivalue_args); | 
 |     stack.emplace_back(c10::ivalue::from(std::move(retval))); | 
 |   } | 
 | }; | 
 |  | 
 | template <class Func> | 
 | struct BoxedProxy<void, Func> { | 
 |   void operator()(jit::Stack& stack, Func& func) { | 
 |     call_torchbind_method_from_stack<Func, false>(func, stack); | 
 |     constexpr size_t num_ivalue_args = | 
 |         c10::guts::infer_function_traits_t<Func>::number_of_parameters; | 
 |     torch::jit::drop(stack, num_ivalue_args); | 
 |     stack.emplace_back(c10::IValue()); | 
 |   } | 
 | }; | 
 |  | 
 | inline bool validIdent(size_t i, char n) { | 
 |   return isalpha(n) || n == '_' || (i > 0 && isdigit(n)); | 
 | } | 
 |  | 
 | inline void checkValidIdent(const std::string& str, const char* type) { | 
 |   for (const auto i : c10::irange(str.size())) { | 
 |     TORCH_CHECK( | 
 |         validIdent(i, str[i]), | 
 |         type, | 
 |         " must be a valid Python/C++ identifier." | 
 |         " Character '", | 
 |         str[i], | 
 |         "' at index ", | 
 |         i, | 
 |         " is illegal."); | 
 |   } | 
 | } | 
 |  | 
 | class TORCH_API class_base { | 
 |  protected: | 
 |   explicit class_base( | 
 |       const std::string& namespaceName, | 
 |       const std::string& className, | 
 |       std::string doc_string, | 
 |       const std::type_info& intrusivePtrClassTypeid, | 
 |       const std::type_info& taggedCapsuleClass); | 
 |  | 
 |   static c10::FunctionSchema withNewArguments( | 
 |       const c10::FunctionSchema& schema, | 
 |       std::initializer_list<arg> default_args); | 
 |   std::string qualClassName; | 
 |   at::ClassTypePtr classTypePtr; | 
 | }; | 
 |  | 
 | } // namespace detail | 
 |  | 
 | TORCH_API void registerCustomClass(at::ClassTypePtr class_type); | 
 | TORCH_API void registerCustomClassMethod(std::unique_ptr<jit::Function> method); | 
 |  | 
 | // Given a qualified name (e.g. __torch__.torch.classes.Foo), return | 
 | // the ClassType pointer to the Type that describes that custom class, | 
 | // or nullptr if no class by that name was found. | 
 | TORCH_API at::ClassTypePtr getCustomClass(const std::string& name); | 
 |  | 
 | // Given an IValue, return true if the object contained in that IValue | 
 | // is a custom C++ class, otherwise return false. | 
 | TORCH_API bool isCustomClass(const c10::IValue& v); | 
 |  | 
 | // This API is for testing purposes ONLY. It should not be used in | 
 | // any load-bearing code. | 
 | TORCH_API std::vector<c10::FunctionSchema> customClassSchemasForBCCheck(); | 
 |  | 
 | namespace jit { | 
 | using ::torch::registerCustomClass; | 
 | using ::torch::registerCustomClassMethod; | 
 | } // namespace jit | 
 |  | 
 | } // namespace torch |