| #pragma once |
| #include <ATen/MemoryOverlap.h> |
| #include <ATen/Tensor.h> |
| #include <c10/core/DispatchKey.h> |
| #include <c10/core/DispatchKeySet.h> |
| #include <c10/core/MemoryFormat.h> |
| #include <c10/core/TensorImpl.h> |
| #include <c10/util/ArrayRef.h> |
| #include <c10/util/Exception.h> |
| #include <c10/util/Metaprogramming.h> |
| #include <c10/util/irange.h> |
| |
| namespace at::native { |
| struct NestedTensorImpl; |
| inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt); |
| int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor); |
| at::Tensor construct_nested_strides(const at::Tensor& nested_size); |
| at::Tensor construct_offsets(const at::Tensor& nested_size); |
| |
| struct TORCH_API NestedTensorImpl : public c10::TensorImpl { |
| explicit NestedTensorImpl( |
| Storage storage, |
| c10::DispatchKeySet key_set, |
| const caffe2::TypeMeta data_type, |
| at::Tensor nested_sizes, |
| at::Tensor nested_strides, |
| at::Tensor storage_offsets); |
| |
| explicit NestedTensorImpl( |
| const at::Tensor& buffer, |
| at::Tensor nested_sizes, |
| at::Tensor nested_strides, |
| at::Tensor storage_offsets); |
| // assume contiguous, `nested_strides` and `offsets` |
| // can be infered from `nested_sizes` |
| explicit NestedTensorImpl( |
| const at::Tensor& buffer, |
| const at::Tensor& nested_sizes); |
| |
| // This constructor is used creating view tensors from nested tensors |
| explicit NestedTensorImpl( |
| c10::TensorImpl::ImplType impl_type, |
| const at::Tensor& base_tensor, |
| at::Tensor nested_sizes, |
| at::Tensor nested_strides, |
| at::Tensor storage_offsets); |
| |
| // TODO: don't expose private implementation details like this; in |
| // particular, resizing this tensor will mess up our dim() and |
| // callers cannot fix it. |
| const Tensor& get_nested_sizes() const { |
| return nested_sizes_; |
| } |
| // TODO: don't expose private implementation details like this |
| const Tensor& get_nested_strides() const { |
| return nested_strides_; |
| } |
| const Tensor& get_storage_offsets() const { |
| return storage_offsets_; |
| } |
| // Returns nullopt if the ith dimension is irregular. The ith dimension |
| // of a NestedTensor is regular if the unbound tensors match in |
| // size at the (i-1)th dimension. |
| std::optional<int64_t> opt_size(int64_t d) const; |
| |
| int64_t size(int64_t d) const { |
| std::optional<int64_t> optional_size = this->opt_size(d); |
| TORCH_CHECK( |
| optional_size.has_value(), |
| "Given dimension ", |
| d, |
| " is irregular and does not have a size."); |
| return *optional_size; |
| } |
| /** |
| * Return a view of the nested tensor as a 1 dimensional contiguous tensor. |
| * |
| * The buffer tensor created by this function shares the same storage_impl as |
| * the original nested tensor, and therefore can be seen as a view. |
| * |
| * @return A newly constructed view tensor |
| */ |
| at::Tensor get_buffer() const { |
| TORCH_CHECK( |
| nested_tensor_impl_is_contiguous(this), |
| "NestedTensor must be contiguous to get buffer."); |
| return get_unsafe_storage_as_tensor(); |
| } |
| /** |
| * If possible use get_buffer() instead. This function returns the storage |
| * as a tensor directly, which is not safe to use in general. If using this |
| * function, The caller must ensure to account for nested_sizes, |
| * nested_strides and storage_offsets. |
| * |
| * @return A newly constructed view tensor |
| */ |
| at::Tensor get_unsafe_storage_as_tensor() const { |
| auto buffer_key_set_ = generate_buffer_key_set(); |
| const auto buffer_size = get_buffer_size(); |
| auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>( |
| c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_); |
| buffer_tensor_impl->set_sizes_contiguous( |
| c10::makeArrayRef(static_cast<int64_t>(buffer_size))); |
| return Tensor(buffer_tensor_impl); |
| } |
| |
| size_t get_buffer_size() const { |
| return storage_.nbytes() / data_type_.itemsize(); |
| } |
| |
| protected: |
| const char* tensorimpl_type_name() const override; |
| |
| // TODO: numel_custom and is_contiguous_custom can be profitably overridden |
| // with real implementations |
| int64_t numel_custom() const override; |
| c10::SymInt sym_numel_custom() const override; |
| bool is_contiguous_custom(MemoryFormat) const override; |
| int64_t size_custom(int64_t d) const override { |
| return this->size(d); |
| } |
| c10::SymInt sym_size_custom(int64_t d) const override { |
| return c10::SymInt{this->size(d)}; |
| } |
| IntArrayRef sizes_custom() const override; |
| c10::SymIntArrayRef sym_sizes_custom() const override; |
| IntArrayRef strides_custom() const override; |
| c10::SymIntArrayRef sym_strides_custom() const override; |
| |
| // this one is real |
| int64_t dim_custom() const override; |
| |
| c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
| const c10::VariableVersion& version_counter, |
| bool allow_tensor_metadata_change) const override; |
| |
| c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
| c10::VariableVersion&& version_counter, |
| bool allow_tensor_metadata_change) const override; |
| |
| void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override { |
| copy_tensor_metadata( |
| /*src_impl=*/impl.get(), |
| /*dest_impl=*/this, |
| /*version_counter=*/version_counter(), |
| /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); |
| } |
| |
| private: |
| // Must be called after any changes to our dim() to sync the state |
| // to TensorImpl. |
| void refresh_dim(); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) |
| const at::Tensor nested_sizes_, nested_strides_; |
| // The starting positions of the underlying tensors in contiguous buffer |
| // i.e. the buffer memory offsets to get the underlying tensors |
| // The reason to keep this metadata is that, without strong enough constraint |
| // it cannot be derived from `nested_sizes_` |
| // and `nested_strides_`: |
| // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2] |
| // this can happen e.g. after slicing a nested tensor |
| // 2. when multiple tensors share a same memory |
| // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2] |
| // Some strong enough constraints are: |
| // 1. every underlying tensor is contiguous in memory |
| // && nesting in ascending order |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) |
| const at::Tensor storage_offsets_; |
| // NOTE: -1 here means the size is missing |
| // Optional to allow it to be computed lazily from nested. |
| // TODO: maybe we can remove this metadata since |
| // we can compute it from `nested_sizes_` |
| mutable std::optional<std::vector<int64_t>> opt_sizes_; |
| |
| template <typename VariableVersion> |
| c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core( |
| VariableVersion&& version_counter, |
| bool allow_tensor_metadata_change) const; |
| |
| /** |
| * Generates a non-nested key_set from a nested tensor. |
| * |
| * For many nested tensor kernel implementations a buffer tensor |
| * is generated and redispatched to a non-nested kernel this function |
| * generates the key set used by that buffer tensor |
| * |
| * @return Appropriate key set for non-nested tensor |
| */ |
| inline c10::DispatchKeySet generate_buffer_key_set() const { |
| auto buffer_key_set = this->key_set(); |
| const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset); |
| // Remove nested tensor specific keys |
| buffer_key_set = buffer_key_set - |
| c10::DispatchKeySet{ |
| c10::DispatchKey::NestedTensor, |
| c10::DispatchKey::AutogradNestedTensor}; |
| |
| // Add dense tensor specific keys |
| buffer_key_set = |
| buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense}; |
| buffer_key_set = Autograd |
| ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set |
| : buffer_key_set; |
| |
| return buffer_key_set; |
| } |
| }; |
| |
| inline NestedTensorImpl* get_nested_tensor_impl_or_null( |
| const at::Tensor& tensor) { |
| if (tensor.is_nested()) { |
| return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl()); |
| } |
| return nullptr; |
| } |
| |
| inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) { |
| TORCH_CHECK( |
| tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor."); |
| return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl()); |
| } |
| |
| inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) { |
| int64_t ntensors = nt->size(0); |
| if (ntensors == 0) { |
| return true; |
| } |
| const Tensor &sizemat = nt->get_nested_sizes(), |
| &stridemat = nt->get_nested_strides(); |
| const int64_t* offsets_ptr = |
| nt->get_storage_offsets().const_data_ptr<int64_t>(); |
| int64_t orig_dim = sizemat.size(1); |
| // nesting scalars |
| if (orig_dim == 0) { |
| // each scalar must be contiguous |
| // if there is blank memory between underlying scalars |
| for (int64_t i = 0; i < ntensors; i++) { |
| if (offsets_ptr[i] != i) { |
| return false; |
| } |
| } |
| } |
| // nesting tensors |
| else { |
| // if any underlying tensor is non-contiguous |
| const int64_t *sizemat_ptr = sizemat.const_data_ptr<int64_t>(), |
| *stridemat_ptr = stridemat.const_data_ptr<int64_t>(); |
| for (int64_t i = 0; i < ntensors; i++) { |
| if (stridemat_ptr[orig_dim - 1] != 1) { |
| return false; |
| } |
| int64_t product = sizemat_ptr[orig_dim - 1]; |
| for (int64_t j = orig_dim - 2; j >= 0; j--) { |
| if (stridemat_ptr[j] != product) { |
| return false; |
| } |
| product *= sizemat_ptr[j]; |
| } |
| sizemat_ptr += orig_dim; |
| stridemat_ptr += orig_dim; |
| } |
| // if there is blank memory between underlying tensors |
| if (offsets_ptr[0] != 0) { |
| return false; |
| } |
| sizemat_ptr = sizemat.const_data_ptr<int64_t>(); |
| stridemat_ptr = stridemat.const_data_ptr<int64_t>(); |
| for (int64_t i = 1; i < ntensors; i++) { |
| if (offsets_ptr[i] != |
| offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) { |
| return false; |
| } |
| sizemat_ptr += orig_dim; |
| stridemat_ptr += orig_dim; |
| } |
| } |
| // everything is fine |
| return true; |
| } |
| |
| inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) { |
| return get_nested_tensor_impl(tensor)->get_nested_sizes(); |
| } |
| |
| } // namespace at::native |