blob: 62ab4c16cf12d3b44fcfc82bedf6e9657be14318 [file] [log] [blame]
#pragma once
#include <c10/core/TensorImpl.h>
#include <c10/core/MemoryFormat.h>
#include <c10/util/Exception.h>
namespace at {
// An "Opaque" TensorImpl -- there are no strides and (for now)
// even data() is not supported (thus no pointer arithmetic).
// NOTE: We could allow data() in the future, but would have to ensure pointer
// arithmetic code is properly guarded.
//
// NOTE: This does not support resize_ (and other metadata-changing ops) because of
// `shallow_copy_and_detach`. We would need to define an interface to "shallow copy"
// in order to add support.
template <typename OpaqueHandle>
struct CAFFE2_API OpaqueTensorImpl : public TensorImpl {
// public constructor for now...
OpaqueTensorImpl(at::TensorTypeId type_id, const caffe2::TypeMeta& data_type, c10::Device device,
OpaqueHandle opaque_handle, c10::IntArrayRef sizes)
: TensorImpl(type_id, data_type, device),
opaque_handle_(std::move(opaque_handle))
{
sizes_ = sizes.vec();
refresh_numel();
}
void release_resources() override {
TensorImpl::release_resources();
opaque_handle_ = {};
}
IntArrayRef strides() const override {
AT_ERROR("opaque tensors do not have strides");
}
bool is_contiguous(c10::MemoryFormat memory_format=c10::MemoryFormat::Contiguous) const override {
AT_ERROR("opaque tensors do not have is_contiguous");
}
int64_t stride(int64_t d) const override {
AT_ERROR("opaque tensors do not have strides");
}
void resize_dim(int64_t ndim) override {
AT_ERROR("opaque tensors do not have resize_dim");
}
void set_size(int64_t dim, int64_t new_size) override {
AT_ERROR("opaque tensors do not have set_size");
}
void set_stride(int64_t dim, int64_t new_stride) override {
AT_ERROR("opaque tensors do not have set_stride");
}
void set_storage_offset(int64_t storage_offset) override {
AT_ERROR("opaque tensors do not have set_storage_offset");
}
TensorImpl* maybe_zero_dim(bool condition_when_zero_dim) override {
AT_ERROR("opaque tensors do not support maybe_zero_dim");
}
bool has_storage() const override {
return false;
}
const Storage& storage() const override{
AT_ERROR("opaque tensors do not have storage");
}
int64_t storage_offset() const override {
AT_ERROR("opaque tensors do not have storage");
}
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override {
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
type_id(), dtype(), device(), opaque_handle_, sizes_);
copy_tensor_metadata(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
/*version_counter=*/version_counter,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel();
return impl;
}
/**
* Shallow-copies data from another TensorImpl into this TensorImpl.
*
* For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
AT_ASSERT(has_compatible_shallow_copy_type(impl->type_id()));
auto opaque_impl = static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
copy_tensor_metadata(
/*src_impl=*/opaque_impl,
/*dest_impl=*/this,
/*version_counter=*/version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
refresh_numel();
}
OpaqueHandle& unsafe_opaque_handle() {
return opaque_handle_;
}
private:
OpaqueHandle opaque_handle_;
/**
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset)
* from one TensorImpl to another TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ].
*/
static void copy_tensor_metadata(
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) {
TensorImpl::copy_tensor_metadata(src_opaque_impl, dest_opaque_impl, version_counter, allow_tensor_metadata_change);
// OpaqueTensorImpl-specific fields.
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
}
};
} // namespace at