|  | #pragma once | 
|  |  | 
|  | #include "caffe2/core/context.h" | 
|  | #include "caffe2/core/tensor.h" | 
|  | #include "caffe2/core/types.h" | 
|  | #include "caffe2/proto/caffe2_pb.h" | 
|  | #include "caffe2/python/dlpack.h" | 
|  |  | 
|  | #include <pybind11/pybind11.h> | 
|  | #include <pybind11/stl.h> | 
|  |  | 
|  | namespace caffe2 { | 
|  | namespace python { | 
|  |  | 
|  | namespace py = pybind11; | 
|  |  | 
|  | const DLDeviceType* CaffeToDLDeviceType(int device_type); | 
|  |  | 
|  | const DLDataType* CaffeToDLType(const TypeMeta meta); | 
|  |  | 
|  | const TypeMeta DLTypeToCaffe(const DLDataType& dl_type); | 
|  |  | 
|  | // TODO: remove context | 
|  | template <class Context> | 
|  | class DLPackWrapper { | 
|  | public: | 
|  | DLPackWrapper(Tensor* tensor, DeviceOption device_option) | 
|  | : tensor(tensor), device_option(device_option) {} | 
|  |  | 
|  | py::object data() { | 
|  | DLDevice tensor_context; | 
|  | auto device_type_ptr = CaffeToDLDeviceType(device_option.device_type()); | 
|  | CAFFE_ENFORCE( | 
|  | device_type_ptr, | 
|  | "Unsupported device type: ", | 
|  | device_option.device_type()); | 
|  | tensor_context.device_type = *device_type_ptr; | 
|  | tensor_context.device_id = device_option.device_id(); | 
|  |  | 
|  | if (tensor->numel() <= 0) { | 
|  | tensor->Resize(0); | 
|  | } | 
|  | if (tensor->dtype() == ScalarType::Undefined) { | 
|  | // treat uninitialized tensor as float tensor | 
|  | tensor->template mutable_data<float>(); | 
|  | } | 
|  | CAFFE_ENFORCE_GT(tensor->dim(), 0); | 
|  |  | 
|  | auto type_ptr = CaffeToDLType(tensor->dtype()); | 
|  | CAFFE_ENFORCE( | 
|  | type_ptr, | 
|  | "Tensor type is not supported in DLPack: ", | 
|  | tensor->dtype().name()); | 
|  | DLDataType tensor_type = *type_ptr; | 
|  |  | 
|  | DLTensor dlTensor; | 
|  | dlTensor.data = const_cast<void*>(tensor->raw_data()); | 
|  | dlTensor.device = tensor_context; | 
|  | dlTensor.ndim = tensor->dim(); | 
|  | dlTensor.dtype = tensor_type; | 
|  | dlTensor.shape = const_cast<int64_t*>(&(tensor->sizes()[0])); | 
|  | dlTensor.strides = nullptr; | 
|  | dlTensor.byte_offset = 0; | 
|  |  | 
|  | managed_tensor.dl_tensor = dlTensor; | 
|  | // C2 Tensor memory is managed by C2 | 
|  | managed_tensor.manager_ctx = nullptr; | 
|  | managed_tensor.deleter = [](DLManagedTensor*) {}; | 
|  |  | 
|  | return py::reinterpret_steal<py::object>( | 
|  | PyCapsule_New(&managed_tensor, "dltensor", nullptr)); | 
|  | } | 
|  |  | 
|  | void feed(py::object obj) { | 
|  | CAFFE_ENFORCE(PyCapsule_CheckExact(obj.ptr()), "Expected DLPack capsule"); | 
|  | DLManagedTensor* dlMTensor = | 
|  | (DLManagedTensor*)PyCapsule_GetPointer(obj.ptr(), "dltensor"); | 
|  | CAFFE_ENFORCE(dlMTensor, "Invalid DLPack capsule"); | 
|  | DLTensor* dlTensor = &dlMTensor->dl_tensor; | 
|  | auto device_type_ptr = CaffeToDLDeviceType(device_option.device_type()); | 
|  | CAFFE_ENFORCE( | 
|  | device_type_ptr, | 
|  | "Unsupported device type: ", | 
|  | device_option.device_type()); | 
|  | CAFFE_ENFORCE( | 
|  | dlTensor->device.device_type == *device_type_ptr, | 
|  | "DLPack tensor device type mismatch"); | 
|  | int dlpack_device_id = dlTensor->device.device_id; | 
|  | CAFFE_ENFORCE_EQ( | 
|  | dlpack_device_id, | 
|  | device_option.device_id(), | 
|  | "Expected same device id for DLPack and C2 tensors"); | 
|  |  | 
|  | std::vector<int64_t> dims; | 
|  | dims.reserve(dlTensor->ndim); | 
|  | for (int idx = 0; idx < dlTensor->ndim; ++idx) { | 
|  | dims.push_back(dlTensor->shape[idx]); | 
|  | } | 
|  |  | 
|  | if (dlTensor->strides) { | 
|  | int64_t stride = 1; | 
|  | for (int idx = dims.size() - 1; idx >= 0; --idx) { | 
|  | CAFFE_ENFORCE_EQ( | 
|  | stride, | 
|  | dlTensor->strides[idx], | 
|  | "Tensors with non-standard strides are not supported"); | 
|  | stride *= dims[idx]; | 
|  | } | 
|  | } | 
|  |  | 
|  | tensor->Resize(dims); | 
|  | caffe2::TypeMeta meta = DLTypeToCaffe(dlTensor->dtype); | 
|  | at::Device device = at::Device(tensor->GetDeviceType()); | 
|  | tensor->ShareExternalPointer( | 
|  | at::DataPtr( | 
|  | (void*)(((int8_t*)dlTensor->data) + dlTensor->byte_offset), | 
|  | static_cast<void*>(dlMTensor), | 
|  | [](void* t_ptr) -> void { | 
|  | DLManagedTensor* mt_ptr = static_cast<DLManagedTensor*>(t_ptr); | 
|  | if (mt_ptr->deleter) { | 
|  | mt_ptr->deleter(mt_ptr); | 
|  | } | 
|  | }, | 
|  | device), | 
|  | meta, | 
|  | 0); | 
|  | } | 
|  |  | 
|  | Tensor* tensor; | 
|  | DeviceOption device_option; | 
|  | DLManagedTensor managed_tensor; | 
|  | }; | 
|  |  | 
|  | } // namespace python | 
|  | } // namespace caffe2 |