| #pragma once |
| |
| #include <c10/core/Allocator.h> |
| #include <c10/cuda/CUDAGraphsC10Utils.h> |
| #include <c10/cuda/CUDAMacros.h> |
| #include <c10/cuda/CUDAStream.h> |
| |
| #include <c10/cuda/CUDACachingAllocator.h> |
| |
| #include <mutex> |
| |
| namespace torch::cuda::CUDAPluggableAllocator { |
| |
| using MallocFuncType = void*(size_t, int, cudaStream_t); |
| using FreeFuncType = void(void*, size_t, int, cudaStream_t); |
| |
| // A CUDAPluggableAllocatorDeleterContext object is used as the `ctx` |
| // argument for DataPtr. We need context because a user can use |
| // multiple allocators in the same PyTorch program, and |
| // the allocators can have different free functions, such as: |
| // free, cudaFree, cudaFreeAsync, ncclMemFree etc. |
| struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext { |
| explicit CUDAPluggableAllocatorDeleterContext( |
| std::function<FreeFuncType> free_fn, |
| void* data, |
| size_t size, |
| int device, |
| cudaStream_t stream); |
| |
| void free(); |
| |
| private: |
| std::function<FreeFuncType> free_fn_; |
| void* data_; |
| size_t size_; |
| int device_; |
| cudaStream_t stream_; |
| }; |
| |
| #if defined(TORCH_HIP_VERSION) |
| using streamType = c10::hip::HIPStream; |
| #else |
| using streamType = c10::cuda::CUDAStream; |
| #endif |
| |
| TORCH_CUDA_CPP_API std::shared_ptr< |
| c10::cuda::CUDACachingAllocator::CUDAAllocator> |
| getCurrentAllocator(); |
| TORCH_CUDA_CPP_API std::shared_ptr< |
| c10::cuda::CUDACachingAllocator::CUDAAllocator> |
| createCustomAllocator( |
| std::function<MallocFuncType> alloc_fn, |
| std::function<FreeFuncType> free_fn); |
| TORCH_CUDA_CPP_API void changeCurrentAllocator( |
| const std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>& |
| allocator); |
| |
| struct _AllocationMetadata { |
| _AllocationMetadata(); |
| _AllocationMetadata( |
| size_t size, |
| c10::DeviceIndex device_idx, |
| cudaStream_t stream); |
| size_t size; |
| c10::DeviceIndex device_idx; |
| cudaStream_t stream; |
| }; |
| |
| struct TORCH_CUDA_CPP_API CUDAPluggableAllocator |
| : public c10::cuda::CUDACachingAllocator::CUDAAllocator { |
| CUDAPluggableAllocator( |
| std::function<MallocFuncType> alloc_fn, |
| std::function<FreeFuncType> free_fn); |
| |
| CUDAPluggableAllocator(CUDAPluggableAllocator& other); |
| CUDAPluggableAllocator& operator=(CUDAPluggableAllocator& other) = delete; |
| |
| void set_init_fn(std::function<void(int)> init_fn); |
| |
| void set_reset_fn(std::function<void()> reset_fn); |
| |
| void set_memory_fraction_fn( |
| std::function<void(double, int)> memory_fraction_fn); |
| |
| void set_base_alloc_fn(std::function<void*(void*, size_t*)> base_alloc_fn); |
| |
| void set_record_stream_fn( |
| std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn); |
| |
| void set_begin_allocate_to_pool( |
| std::function< |
| void(int, c10::cuda::MempoolId_t, std::function<bool(cudaStream_t)>)> |
| capture_begin_fn); |
| |
| void set_end_allocate_to_pool_fn( |
| std::function<void(int, c10::cuda::MempoolId_t)> capture_about_to_end_fn); |
| |
| void set_release_pool( |
| std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn); |
| |
| void* malloc(size_t size, c10::DeviceIndex device, cudaStream_t stream); |
| |
| c10::DataPtr allocate(size_t size) override; |
| c10::DeleterFnPtr raw_deleter() const override; |
| |
| void* raw_alloc(size_t nbytes) override; |
| void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override; |
| void raw_delete(void* ptr) override; |
| void init(int device_count) override; |
| bool initialized() override; |
| void setMemoryFraction(double fraction, c10::DeviceIndex device) override; |
| void emptyCache() override; |
| void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override; |
| void* getBaseAllocation(void* ptr, size_t* size) override; |
| |
| void recordStream(const c10::DataPtr&, streamType stream) override; |
| |
| c10::cuda::CUDACachingAllocator::DeviceStats getDeviceStats( |
| c10::DeviceIndex device) override; |
| void resetAccumulatedStats(c10::DeviceIndex device) override; |
| void resetPeakStats(c10::DeviceIndex device) override; |
| c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override; |
| void beginAllocateToPool( |
| c10::DeviceIndex device, |
| c10::cuda::MempoolId_t mempool_id, |
| std::function<bool(cudaStream_t)>) override; |
| void endAllocateToPool( |
| c10::DeviceIndex device, |
| c10::cuda::MempoolId_t mempool_id) override; |
| void releasePool(c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id) |
| override; |
| std::shared_ptr<void> getIpcDevPtr(std::string handle) override; |
| c10::cuda::CUDACachingAllocator::ShareableHandle shareIpcHandle( |
| void*) override; |
| void recordHistory( |
| bool enabled, |
| c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder, |
| size_t alloc_trace_max_entries, |
| c10::cuda::CUDACachingAllocator::RecordContext when) override; |
| void attachOutOfMemoryObserver( |
| c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) override; |
| void attachAllocatorTraceTracker( |
| c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker) override; |
| std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> |
| getCheckpointState(c10::DeviceIndex device, at::cuda::MempoolId_t id) |
| override; |
| c10::cuda::CUDACachingAllocator::CheckpointDelta setCheckpointPoolState( |
| c10::DeviceIndex device, |
| std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps) |
| override; |
| void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) |
| override; |
| cudaError_t memcpyAsync( |
| void* dst, |
| int dstDevice, |
| const void* src, |
| int srcDevice, |
| size_t count, |
| cudaStream_t stream, |
| bool p2p_enabled) override; |
| std::string name() override; |
| void copy_data(void* dest, const void* src, std::size_t count) const final; |
| |
| protected: |
| std::function<MallocFuncType> alloc_fn_; |
| std::function<FreeFuncType> free_fn_; |
| std::function<void(int)> init_fn_; |
| std::function<void()> reset_fn_; |
| std::function<void(double, int)> memory_fraction_fn_; |
| std::function<void*(void*, size_t*)> base_alloc_fn_; |
| std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn_; |
| std::function< |
| void(int, c10::cuda::MempoolId_t, std::function<bool(cudaStream_t)>)> |
| begin_allocate_to_pool_fn_; |
| std::function<void(int, c10::cuda::MempoolId_t)> end_allocate_to_pool_fn_; |
| std::function<void(int, c10::cuda::MempoolId_t)> relase_pool_fn_; |
| std::mutex allocator_mutex_; |
| // We do the bookeeping here in order to simplify custom allocators |
| std::unordered_map<void*, _AllocationMetadata> allocation_metadata_; |
| |
| bool initialized_ = false; |
| }; |
| } // namespace torch::cuda::CUDAPluggableAllocator |