blob: 49a7a23aa510856e69c8e4465c0cf9f1904daa22 [file] [log] [blame]
#pragma once
#include <Python.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "torch/csrc/autograd/python_cpp_function.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/THP.h"
namespace py = pybind11;
namespace pybind11 { namespace detail {
// handle Python <-> torch::autograd::Function conversions
template <> struct type_caster<std::shared_ptr<torch::autograd::Function>> {
public:
PYBIND11_TYPE_CASTER(std::shared_ptr<torch::autograd::Function>, _("std::shared_ptr<torch::autograd::Function>"));
bool load(handle src, bool) {
if (!THPFunction_Check(src.ptr())) return false;
value = THPFunction_asFunction((THPFunction*)src.ptr());
return true;
}
static handle cast(std::shared_ptr<torch::autograd::Function> src, return_value_policy /* policy */, handle /* parent */) {
auto fn = functionToPyObject(src);
return handle(fn);
}
};
}} // namespace pybind11::detail