| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // Temporary memories are used to allocate scratch space required by an |
| // operation about to be enqueued onto a stream. |
| // |
| // std::unique_ptr<TemporaryDeviceMemory<float>> temporary_memory = |
| // stream.AllocateTemporaryArray<float>(1024).ConsumeValueOrDie(); |
| // // ... enqueue stuff onto the stream using the temporary memory ... |
| // // Note that the memory is accessible via |
| // // temporary_memory->device_memory() and similar. |
| // |
| // // Finalize the temporary memory. The underlying device memory may |
| // // be released any time after this program point, as another thread may |
| // // call Stream::BlockHostUntilDone, causing synchronization. This |
| // // finalization also happens automatically for the user if the unique_ptr |
| // // goes out of scope. |
| // temporary_memory.Finalize(); |
| // |
| // WARNING: do NOT hold onto the device memory associated with temporary_memory |
| // after finalization. If temporary_memory->device_memory() is used after the |
| // temporary memory is finalized, it will cause a DCHECK failure. |
| // |
| // Note that standard usage takes advantage of the type-safe wrapper, |
| // TemporaryDeviceMemory<T>, defined below. |
| // |
| // Also see tests for executable sample usage. |
| |
| #ifndef TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ |
| #define TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ |
| |
| #include "tensorflow/stream_executor/device_memory.h" |
| |
| namespace stream_executor { |
| |
| class Stream; |
| namespace internal { |
| class TemporaryMemoryManager; |
| } |
| |
| // Untyped base class (analogous to a void*) for temporary device memory |
| // allocations associated with a stream. |
| class TemporaryDeviceMemoryBase { |
| public: |
| // Marks the temporary memory as finalized if it is not already marked as |
| // such. |
| ~TemporaryDeviceMemoryBase(); |
| |
| // Precondition: !IsFinalized() |
| DeviceMemoryBase* mutable_device_memory(); |
| |
| // Precondition: !IsFinalized() |
| const DeviceMemoryBase& device_memory() const; |
| |
| // "Finalizes" this temporary memory, making it acceptable to release at the |
| // next stream synchronization point -- the device memory can be reclaimed at |
| // any time after the temporary memory is marked as finalized (e.g. if a |
| // separate thread is calls Stream::BlockHostUntilDone). This may only be |
| // called once -- see the precondition below. |
| // |
| // Precondition: !IsFinalized() |
| void Finalize(); |
| |
| // Returns true iff the temporary memory is finalized (that is, the user is |
| // done referring to the temporary device memory, and thus it can be released |
| // at the next stream synchronization point). |
| bool IsFinalized() const; |
| |
| // Returns true iff the temporary memory is still allocated. |
| // |
| // Note: this is a polling call, no guarantee is made that the temporary |
| // memory is still allocated after the call has completed. |
| bool IsAllocated() const; |
| |
| private: |
| friend class internal::TemporaryMemoryManager; |
| friend class TemporaryDeviceMemoryTest; |
| |
| // Note: construction DCHECKs that the memory is known-allocated in the |
| // stream's temporary-allocation-manager. |
| TemporaryDeviceMemoryBase(Stream* parent, DeviceMemoryBase device_memory, |
| uint64 allocation_generation); |
| |
| // The device memory region that has allocated. |
| DeviceMemoryBase device_memory_; |
| |
| // The generation counter value for the temporary memory record in the |
| // temporary memory manager. |
| uint64 allocation_generation_; |
| |
| // The stream that this temporary memory was allocated for. |
| Stream* parent_; |
| }; |
| |
| // Type-safe wrapper around the base type (which is analogous to a void*). |
| template <typename T> |
| class TemporaryDeviceMemory : public TemporaryDeviceMemoryBase { |
| public: |
| // Type-safe wrapper around TemporaryDeviceMemoryBase::mutable_device_memory. |
| DeviceMemory<T>* mutable_device_memory() { |
| StaticSlicingAssertionDummy(); |
| return reinterpret_cast<DeviceMemory<T>*>( |
| TemporaryDeviceMemoryBase::mutable_device_memory()); |
| } |
| |
| // Type-safe wrapper around TemporaryDeviceMemoryBase::device_memory. |
| const DeviceMemory<T>& device_memory() const { |
| StaticSlicingAssertionDummy(); |
| return reinterpret_cast<const DeviceMemory<T>&>( |
| TemporaryDeviceMemoryBase::device_memory()); |
| } |
| |
| private: |
| static void StaticSlicingAssertionDummy() { |
| static_assert( |
| sizeof(TemporaryDeviceMemory) == sizeof(TemporaryDeviceMemoryBase), |
| "derived class is simply a wrapper, no members may be added due to " |
| "slicing"); |
| } |
| }; |
| |
| } // namespace stream_executor |
| |
| #endif // TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ |