blob: fd3af2e86597d7ba5f7d21433879784805591d79 [file] [log] [blame]
#pragma once
#include <Python.h>
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "torch/csrc/DynamicTypes.h"
namespace py = pybind11;
namespace pybind11 { namespace detail {
// handle Tensor <-> at::Tensor conversions
template <> struct type_caster<at::Tensor> {
public:
PYBIND11_TYPE_CASTER(at::Tensor, _("at::Tensor"));
bool load(handle src, bool) {
try {
value = torch::createTensor(src.ptr());
} catch (std::exception& e) {
return false;
}
return true;
}
static handle cast(at::Tensor src, return_value_policy /* policy */, handle /* parent */) {
return handle(torch::createPyObject(src));
}
};
}} // namespace pybind11::detail