| #include <array> |
| |
| #include <ATen/Functions.h> |
| #include <ATen/Utils.h> |
| #include <c10/core/Allocator.h> |
| |
| namespace at { |
| |
| Tensor TensorMaker::make_tensor() { |
| AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove. |
| tracer::impl::NoTracerDispatchMode tracer_guard{}; |
| |
| check_size_nonnegative(sizes_); |
| |
| TORCH_CHECK_VALUE( |
| !deleter_ || !ctx_, |
| "The deleter and context arguments are mutually exclusive."); |
| |
| if (device_ == nullopt) { |
| device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type()); |
| } |
| |
| if (opts_.device().has_index()) { |
| // clang-format off |
| TORCH_CHECK_VALUE( |
| opts_.device() == *device_, |
| "Specified device ", opts_.device(), " does not match device of data ", *device_); |
| // clang-format on |
| } |
| |
| std::size_t size_bytes = computeStorageSize(); |
| |
| DataPtr data_ptr{}; |
| if (deleter_) { |
| data_ptr = makeDataPtrFromDeleter(); |
| } else { |
| data_ptr = makeDataPtrFromContext(); |
| } |
| |
| TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()"); |
| Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizeable=*/resizeable_}; |
| |
| Tensor tensor = detail::make_tensor<TensorImpl>( |
| std::move(storage), opts_.computeDispatchKey(), opts_.dtype()); |
| |
| TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); |
| if (strides_) { |
| tensor_impl->set_sizes_and_strides(sizes_, *strides_); |
| } else { |
| tensor_impl->set_sizes_contiguous(sizes_); |
| } |
| if (storage_offset_) { |
| tensor_impl->set_storage_offset(*storage_offset_); |
| } |
| |
| return tensor; |
| } |
| |
| std::size_t TensorMaker::computeStorageSize() const noexcept { |
| std::size_t itemsize = opts_.dtype().itemsize(); |
| |
| if (strides_) { |
| auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize); |
| if (storage_offset_) { |
| storage_size += storage_offset_.value(); |
| } |
| return storage_size; |
| } |
| |
| std::size_t size = 1; |
| for (std::int64_t s : sizes_) { |
| size *= static_cast<std::size_t>(s); |
| } |
| auto storage_size = size * itemsize; |
| if (storage_offset_) { |
| storage_size += storage_offset_.value(); |
| } |
| return storage_size; |
| } |
| |
| inline DataPtr TensorMaker::makeDataPtrFromDeleter() const { |
| return InefficientStdFunctionContext::makeDataPtr(data_, deleter_, *device_); |
| } |
| |
| inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept { |
| return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_}; |
| } |
| |
| IntArrayRef TensorMaker::makeTempSizes() const noexcept { |
| static std::int64_t zeros[5] = {0, 0, 0, 0, 0}; |
| if (opts_.has_memory_format()) { |
| MemoryFormat format = *opts_.memory_format_opt(); |
| if (format == MemoryFormat::ChannelsLast) { |
| return IntArrayRef(zeros, 4); |
| } |
| if (format == MemoryFormat::ChannelsLast3d) { |
| return IntArrayRef(zeros, 5); |
| } |
| } |
| return IntArrayRef(zeros, 1); |
| } |
| |
| } // namespace at |