|  | 
 | #include <c10/cuda/CUDACachingAllocator.h> | 
 |  | 
 | #include <c10/cuda/CUDAException.h> | 
 | #include <c10/cuda/CUDAFunctions.h> | 
 | #include <c10/cuda/CUDAGuard.h> | 
 | #include <c10/util/UniqueVoidPtr.h> | 
 | #include <c10/util/flat_hash_map.h> | 
 | #include <c10/util/irange.h> | 
 | #include <c10/util/llvmMathExtras.h> | 
 |  | 
 | #include <cuda_runtime_api.h> | 
 | #include <algorithm> | 
 | #include <bitset> | 
 | #include <deque> | 
 | #include <iterator> | 
 | #include <map> | 
 | #include <memory> | 
 | #include <mutex> | 
 | #include <regex> | 
 | #include <set> | 
 | #include <vector> | 
 |  | 
 | namespace c10 { | 
 |  | 
 | C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); | 
 |  | 
 | namespace cuda { | 
 | namespace CUDACachingAllocator { | 
 |  | 
 | // | 
 | // Yet another caching allocator for CUDA device allocations. | 
 | // | 
 | // - Allocations are associated with a stream. Once freed, blocks can be | 
 | //   re-allocated on the same stream, but not on any other stream. | 
 | // - The allocator attempts to find the smallest cached block that will fit the | 
 | //   requested size. If the block is larger than the requested size, it may be | 
 | //   split. If no block is found, the allocator will delegate to cudaMalloc. | 
 | // - If the cudaMalloc fails, the allocator will attempt to free one cached | 
 | //   block of sufficient size that is not split and retry the allocation. | 
 | //   If this also fails, the allocator will attempt to free all cached blocks | 
 | //   that are not split and retry the allocation. | 
 | // - Large (>1MB) and small allocations are stored in separate pools. | 
 | //   Small requests are packed into 2MB buffers. Large requests will use the | 
 | //   smallest available free block or allocate a new block using cudaMalloc. | 
 | // - To reduce fragmentation, requests between 1MB and 10MB will allocate and | 
 | //   split a 20MB block, if no free block of sufficient size is available. | 
 | // - To further reduce fragmentation, blocks >= 200MB are not allowed to be | 
 | //   split. These oversize cached blocks will still satisfy requests within | 
 | //   20MB of the oversize cached block size. | 
 | // | 
 | // With this allocator, allocations and frees should logically be considered | 
 | // "usages" of the memory segment associated with streams, just like kernel | 
 | // launches. The programmer must insert the proper synchronization if memory | 
 | // segments are used from multiple streams. | 
 | // | 
 | // The library provides a recordStream() function to help insert the correct | 
 | // synchronization when allocations are used on multiple streams. This will | 
 | // ensure that the block is not reused before each recorded stream completes | 
 | // work. | 
 | // | 
 |  | 
 | /** | 
 |  * Note [Interaction with CUDA graph capture] | 
 |  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | 
 |  * Graph capture performs a dry run of a region of execution, freezing all CUDA | 
 |  * work (and virtual addresses used during that work) into a "graph." The graph | 
 |  * may be "replayed" like a single giant kernel, with greatly reduced CPU | 
 |  * overhead as well as modestly improved GPU performance. | 
 |  * | 
 |  * Because capture bakes in memory addresses, the memory used during capture | 
 |  * must be available for the graph to use during replay. DeviceCachingAllocator | 
 |  * assigns and frees memory eagerly and dynamically, so if we're not careful | 
 |  * about managing graphs' memory, at replay time those memory addresses could be | 
 |  * use by other tensors. | 
 |  * | 
 |  * To guarantee a graph's baked in addresses are safe to reuse in replay, | 
 |  * DeviceAllocator satisfies allocations from a graph-private memory pool during | 
 |  * capture, and doesn't begin cudaFreeing those addresses until the graph is | 
 |  * destroyed. | 
 |  * | 
 |  * Within the private pool, allocations are freed and reassigned as usual during | 
 |  * capture. Memory regions will be used in a consistent order during replay. So | 
 |  * a private pool doesn't use memory more wastefully than the default pools | 
 |  * during capture, but it does reserve its high-water mark of used memory away | 
 |  * from the default pools as long as the capture(s) it served survive | 
 |  * (regardless whether those captures are idle or replaying). | 
 |  * | 
 |  * CUDAGraph's requests for private pools are mediated by | 
 |  * DeviceAllocator::notifyCaptureBegin, notifyCaptureEnd, and | 
 |  * notifyCaptureDestroy. | 
 |  */ | 
 |  | 
 | namespace { | 
 |  | 
 | using stream_set = ska::flat_hash_set<cuda::CUDAStream>; | 
 |  | 
 | constexpr size_t kMinBlockSize = | 
 |     512; // all sizes are rounded to at least 512 bytes | 
 | constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB | 
 | constexpr size_t kSmallBuffer = | 
 |     2097152; // "small" allocations are packed in 2 MiB blocks | 
 | constexpr size_t kLargeBuffer = | 
 |     20971520; // "large" allocations may be packed in 20 MiB blocks | 
 | constexpr size_t kMinLargeAlloc = | 
 |     10485760; // allocations between 1 and 10 MiB may use kLargeBuffer | 
 | constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB | 
 |  | 
 | using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>; | 
 |  | 
 | void update_stat(Stat& stat, int64_t amount) { | 
 |   stat.current += amount; | 
 |  | 
 |   TORCH_INTERNAL_ASSERT_DEBUG_ONLY( | 
 |       stat.current >= 0, | 
 |       "Negative tracked stat in CUDA allocator (likely logic error)."); | 
 |  | 
 |   stat.peak = std::max(stat.current, stat.peak); | 
 |   if (amount > 0) { | 
 |     stat.allocated += amount; | 
 |   } | 
 |   if (amount < 0) { | 
 |     stat.freed += -amount; | 
 |   } | 
 | } | 
 |  | 
 | void reset_accumulated_stat(Stat& stat) { | 
 |   stat.allocated = 0; | 
 |   stat.freed = 0; | 
 | } | 
 |  | 
 | void reset_peak_stat(Stat& stat) { | 
 |   stat.peak = stat.current; | 
 | } | 
 |  | 
 | template <typename Func> | 
 | void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { | 
 |   for (const auto stat_type : c10::irange(stat_types.size())) { | 
 |     if (stat_types[stat_type]) { | 
 |       f(stat_type); | 
 |     } | 
 |   } | 
 | } | 
 |  | 
 | void update_stat_array( | 
 |     StatArray& stat_array, | 
 |     int64_t amount, | 
 |     const StatTypes& stat_types) { | 
 |   for_each_selected_stat_type( | 
 |       stat_types, [&stat_array, amount](size_t stat_type) { | 
 |         update_stat(stat_array[stat_type], amount); | 
 |       }); | 
 | } | 
 |  | 
 | struct Block; | 
 | struct PrivatePool; | 
 | typedef bool (*Comparison)(const Block*, const Block*); | 
 |  | 
 | struct BlockPool { | 
 |   BlockPool( | 
 |       Comparison comparator, | 
 |       bool small, | 
 |       PrivatePool* private_pool = nullptr) | 
 |       : blocks(comparator), is_small(small), owner_PrivatePool(private_pool) {} | 
 |   std::set<Block*, Comparison> blocks; | 
 |   const bool is_small; | 
 |   PrivatePool* owner_PrivatePool; | 
 | }; | 
 |  | 
 | struct Block { | 
 |   int device; // gpu | 
 |   cudaStream_t stream; // allocation stream | 
 |   stream_set stream_uses; // streams on which the block was used | 
 |   size_t size; // block size in bytes | 
 |   BlockPool* pool; // owning memory pool | 
 |   void* ptr; // memory address | 
 |   bool allocated; // in-use flag | 
 |   Block* prev; // prev block if split from a larger allocation | 
 |   Block* next; // next block if split from a larger allocation | 
 |   int event_count; // number of outstanding CUDA events | 
 |   int gc_count; // counter for prioritizing older / less useful blocks for | 
 |                 // garbage collection | 
 |  | 
 |   Block( | 
 |       int device, | 
 |       cudaStream_t stream, | 
 |       size_t size, | 
 |       BlockPool* pool, | 
 |       void* ptr) | 
 |       : device(device), | 
 |         stream(stream), | 
 |         stream_uses(), | 
 |         size(size), | 
 |         pool(pool), | 
 |         ptr(ptr), | 
 |         allocated(0), | 
 |         prev(nullptr), | 
 |         next(nullptr), | 
 |         event_count(0), | 
 |         gc_count(0) {} | 
 |  | 
 |   // constructor for search key | 
 |   Block(int device, cudaStream_t stream, size_t size) | 
 |       : device(device), | 
 |         stream(stream), | 
 |         stream_uses(), | 
 |         size(size), | 
 |         pool(nullptr), | 
 |         ptr(nullptr), | 
 |         allocated(0), | 
 |         prev(nullptr), | 
 |         next(nullptr), | 
 |         event_count(0), | 
 |         gc_count(0) {} | 
 |  | 
 |   bool is_split() const { | 
 |     return (prev != nullptr) || (next != nullptr); | 
 |   } | 
 | }; | 
 |  | 
 | static bool BlockComparator(const Block* a, const Block* b) { | 
 |   if (a->stream != b->stream) { | 
 |     return (uintptr_t)a->stream < (uintptr_t)b->stream; | 
 |   } | 
 |   if (a->size != b->size) { | 
 |     return a->size < b->size; | 
 |   } | 
 |   return (uintptr_t)a->ptr < (uintptr_t)b->ptr; | 
 | } | 
 |  | 
 | static std::string format_size(uint64_t size) { | 
 |   std::ostringstream os; | 
 |   os.precision(2); | 
 |   os << std::fixed; | 
 |   if (size <= 1024) { | 
 |     os << size << " bytes"; | 
 |   } else if (size <= 1048576) { | 
 |     os << (size / 1024.0); | 
 |     os << " KiB"; | 
 |   } else if (size <= 1073741824ULL) { | 
 |     os << size / 1048576.0; | 
 |     os << " MiB"; | 
 |   } else { | 
 |     os << size / 1073741824.0; | 
 |     os << " GiB"; | 
 |   } | 
 |   return os.str(); | 
 | } | 
 |  | 
 | struct AllocParams { | 
 |   AllocParams( | 
 |       int device, | 
 |       size_t size, | 
 |       cudaStream_t stream, | 
 |       BlockPool* pool, | 
 |       size_t alloc_size, | 
 |       DeviceStats& stats) | 
 |       : search_key(device, stream, size), | 
 |         pool(pool), | 
 |         alloc_size(alloc_size), | 
 |         block(nullptr), | 
 |         err(cudaSuccess) {} | 
 |  | 
 |   int device() const { | 
 |     return search_key.device; | 
 |   } | 
 |   cudaStream_t stream() const { | 
 |     return search_key.stream; | 
 |   } | 
 |   size_t size() const { | 
 |     return search_key.size; | 
 |   } | 
 |  | 
 |   Block search_key; | 
 |   BlockPool* pool; | 
 |   size_t alloc_size; | 
 |   Block* block; | 
 |   StatTypes stat_types = {false}; | 
 |   cudaError_t err; | 
 | }; | 
 |  | 
 | // CUDA graphs helper | 
 | struct PrivatePool { | 
 |   PrivatePool() | 
 |       : use_count(1), | 
 |         cudaMalloc_count(0), | 
 |         large_blocks(BlockComparator, /*is_small=*/false, this), | 
 |         small_blocks(BlockComparator, /*is_small=*/true, this) {} | 
 |   PrivatePool(const PrivatePool&) = delete; | 
 |   PrivatePool(PrivatePool&&) = delete; | 
 |   PrivatePool& operator=(const PrivatePool&) = delete; | 
 |   // Number of live graphs using this pool | 
 |   int use_count; | 
 |   // Number of unfreed cudaMallocs made for this pool. When use_count and | 
 |   // cudaMalloc_count drop to zero, we can delete this PrivatePool from | 
 |   // graph_pools. | 
 |   int cudaMalloc_count; | 
 |   // Instead of maintaining private BlockPools here, I could stuff all blocks | 
 |   // (private or no) into the top-level large_blocks and small_blocks, and | 
 |   // distinguish private blocks by adding a "pool id" check above the stream | 
 |   // check in BlockComparator. BlockComparator is performance- critial though, | 
 |   // I'd rather not add more logic to it. | 
 |   BlockPool large_blocks; | 
 |   BlockPool small_blocks; | 
 | }; | 
 |  | 
 | struct MempoolIdHash { | 
 |   std::size_t operator()(const MempoolId_t& mempool_id) const noexcept { | 
 |     return mempool_id.first != 0 ? mempool_id.first : mempool_id.second; | 
 |   } | 
 | }; | 
 |  | 
 | cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) { | 
 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 | 
 |   if (at::cuda::currentStreamCaptureStatusMayInitCtx() == | 
 |       at::cuda::CaptureStatus::None) { | 
 | #endif | 
 |     return cudaMalloc(p, size); | 
 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 | 
 |   } else { | 
 |     // It's ok to capture cudaMallocs, as long as we never cudaFree those | 
 |     // addresses before replay. | 
 |     // Capturing cudaMalloc behaves nicely: it gives the graph new VA, | 
 |     // but is ignored (won't leakily allocate new memory) in replays. | 
 |     at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed}; | 
 |     return cudaMalloc(p, size); | 
 |   } | 
 | #endif | 
 | } | 
 |  | 
 | } // namespace | 
 |  | 
 | class CachingAllocatorConfig { | 
 |  public: | 
 |   static size_t max_split_size() { | 
 |     return instance().m_max_split_size; | 
 |   } | 
 |   static double garbage_collection_threshold() { | 
 |     return instance().m_garbage_collection_threshold; | 
 |   } | 
 |  | 
 |   // This is used to round-up allocation size to nearest power of 2 divisions. | 
 |   // More description below in function roundup_power2_next_division | 
 |   // As ane example, if we want 4 divisions between 2's power, this can be done | 
 |   // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 | 
 |   static size_t roundup_power2_divisions() { | 
 |     return instance().m_roundup_power2_divisions; | 
 |   } | 
 |  | 
 |  private: | 
 |   static CachingAllocatorConfig& instance() { | 
 |     static CachingAllocatorConfig* s_instance = ([]() { | 
 |       auto inst = new CachingAllocatorConfig(); | 
 |       inst->parseArgs(); | 
 |       return inst; | 
 |     })(); | 
 |     return *s_instance; | 
 |   } | 
 |  | 
 |   CachingAllocatorConfig() | 
 |       : m_max_split_size(std::numeric_limits<size_t>::max()), | 
 |         m_roundup_power2_divisions(0), | 
 |         m_garbage_collection_threshold(0) {} | 
 |   size_t m_max_split_size; | 
 |   size_t m_roundup_power2_divisions; | 
 |   double m_garbage_collection_threshold; | 
 |  | 
 |   void parseArgs() { | 
 |     const char* val = getenv("PYTORCH_CUDA_ALLOC_CONF"); | 
 |     if (val != NULL) { | 
 |       const std::string config(val); | 
 |  | 
 |       std::regex exp("[\\s,]+"); | 
 |       std::sregex_token_iterator it(config.begin(), config.end(), exp, -1); | 
 |       std::sregex_token_iterator end; | 
 |       std::vector<std::string> options(it, end); | 
 |  | 
 |       for (auto option : options) { | 
 |         std::regex exp2("[:]+"); | 
 |         std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1); | 
 |         std::sregex_token_iterator end2; | 
 |         std::vector<std::string> kv(it2, end2); | 
 |         if (kv.size() >= 2) { | 
 |           /* Maximum split size in MB.  Limited to large size blocks */ | 
 |           if (kv[0].compare("max_split_size_mb") == 0) { | 
 |             size_t val2 = stoi(kv[1]); | 
 |             TORCH_CHECK( | 
 |                 val2 > kLargeBuffer / (1024 * 1024), | 
 |                 "CachingAllocator option max_split_size_mb too small, must be > ", | 
 |                 kLargeBuffer / (1024 * 1024), | 
 |                 ""); | 
 |             val2 = std::max(val2, kLargeBuffer / (1024 * 1024)); | 
 |             val2 = std::min( | 
 |                 val2, (std::numeric_limits<size_t>::max() / (1024 * 1024))); | 
 |             m_max_split_size = val2 * 1024 * 1024; | 
 |           } else if (kv[0].compare("roundup_power2_divisions") == 0) { | 
 |             size_t val2 = stoi(kv[1]); | 
 |             TORCH_CHECK( | 
 |                 llvm::isPowerOf2_64(val2), | 
 |                 "For roundups, the divisons has to be power of 2 ", | 
 |                 ""); | 
 |             m_roundup_power2_divisions = val2; | 
 |           } else if (kv[0].compare("garbage_collection_threshold") == 0) { | 
 |             /* | 
 |              * Perform garbage collection of GPU memory blocks to avoid | 
 |              * triggering expensive sync-and-reclaim-all operation. Upon setting | 
 |              * the threshold (e.g., 0.8), the allocator will start reclaiming | 
 |              * blocks if GPU memory capacity usage exceeds the threshold (i.e., | 
 |              * 80% of total memory). | 
 |              * Values 0.0 and 1.0 are not allowed as they are less meaningful. | 
 |              */ | 
 |             double val2 = stod(kv[1]); | 
 |             TORCH_CHECK( | 
 |                 val2 > 0, | 
 |                 "garbage_collect_threshold too small, set it 0.0~1.0", | 
 |                 ""); | 
 |             TORCH_CHECK( | 
 |                 val2 < 1.0, | 
 |                 "garbage_collect_threshold too big, set it 0.0~1.0", | 
 |                 ""); | 
 |             m_garbage_collection_threshold = val2; | 
 |           } else { | 
 |             TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]); | 
 |           } | 
 |         } | 
 |       } | 
 |     } | 
 |   } | 
 | }; | 
 |  | 
 | class DeviceCachingAllocator { | 
 |  private: | 
 |   // lock around all operations | 
 |   mutable std::recursive_mutex mutex; | 
 |  | 
 |   // device statistics | 
 |   DeviceStats stats; | 
 |  | 
 |   // unallocated cached blocks larger than 1 MB | 
 |   BlockPool large_blocks; | 
 |  | 
 |   // unallocated cached blocks 1 MB or smaller | 
 |   BlockPool small_blocks; | 
 |  | 
 |   // allocated or in use by a stream. Holds all active allocations, | 
 |   // whether they came from graph_pools or one of the BlockPools above. | 
 |   ska::flat_hash_set<Block*> active_blocks; | 
 |  | 
 |   // captures_underway tracks if a capture might be underway on any stream. | 
 |   // Most of the time it's zero, in which case malloc can avoid calling | 
 |   // cudaStreamGetCaptureInfo in the hot path. | 
 |   int captures_underway = 0; | 
 |   // See free() for this thing's purpose | 
 |   std::vector<Block*> needs_events_deferred_until_no_capture; | 
 |   // outstanding cuda events | 
 |   ska::flat_hash_map< | 
 |       cuda::CUDAStream, | 
 |       std::deque<std::pair<cudaEvent_t, Block*>>> | 
 |       cuda_events; | 
 |  | 
 |   // record used memory. | 
 |   size_t total_allocated_memory = 0; | 
 |  | 
 |   size_t allowed_memory_maximum = 0; | 
 |  | 
 |   bool set_fraction = false; | 
 |  | 
 |   // Members specific to CUDA graphs | 
 |  | 
 |   // Private pools for CUDA graphs | 
 |   ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash> | 
 |       graph_pools; | 
 |   // Pools no longer referenced by any graph. Their BlockPools are eligible for | 
 |   // free_blocks. Can't be a vector or deque because we might erase entries in | 
 |   // any order. Could be an std::list, but we don't care much, access and | 
 |   // insert/erase are rare. | 
 |   ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash> | 
 |       graph_pools_freeable; | 
 |  | 
 |   // Maps a capturing stream to its assigned private pool, | 
 |   // in case we want multiple captures to share the same pool | 
 |   ska::flat_hash_map<CaptureId_t, MempoolId_t> capture_to_pool_map; | 
 |  | 
 |  public: | 
 |   DeviceCachingAllocator() | 
 |       : large_blocks(BlockComparator, /*is_small=*/false), | 
 |         small_blocks(BlockComparator, /*is_small=*/true) { | 
 |     stats.max_split_size = CachingAllocatorConfig::max_split_size(); | 
 |   } | 
 |  | 
 |   // All public methods (except the above) acquire the allocator mutex. | 
 |   // Thus, do not call a public method from another public method. | 
 |  | 
 |   Block* malloc(int device, size_t size, cudaStream_t stream) { | 
 |     std::unique_lock<std::recursive_mutex> lock(mutex); | 
 |  | 
 |     if (C10_LIKELY(captures_underway == 0)) { | 
 |       // Processes end-of-life events for outstanding allocations used on | 
 |       // multiple streams (checks if their GPU-side uses are complete and | 
 |       // recycles their memory if so) | 
 |       // | 
 |       // Q. Why skip process_events if a capture might be underway? | 
 |       // A. process_events involves cudaEventQueries, illegal during CUDA graph | 
 |       // capture. | 
 |       //    Dumb simple solution: defer reclaiming these allocations until after | 
 |       //    capture. Cross-stream memory use is uncommon, so the deferral's | 
 |       //    effect on memory use during capture should be small. | 
 |       process_events(); | 
 |     } | 
 |  | 
 |     size = round_size(size); | 
 |     auto& pool = get_pool(size, stream); | 
 |     const size_t alloc_size = get_allocation_size(size); | 
 |     AllocParams params(device, size, stream, &pool, alloc_size, stats); | 
 |     params.stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; | 
 |     params.stat_types[static_cast<size_t>(get_stat_type_for_pool(pool))] = true; | 
 |  | 
 |     // First, try to get a block from the existing pool. | 
 |     bool block_found = | 
 |         // Search pool | 
 |         get_free_block(params) | 
 |         // Trigger callbacks and retry search | 
 |         || (trigger_free_memory_callbacks(params) && get_free_block(params)); | 
 |  | 
 |     // Can't reuse an existing block; try to get a new one. | 
 |     if (!block_found) { | 
 |       // Do garbage collection if the flag is set. | 
 |       if (C10_UNLIKELY( | 
 |               CachingAllocatorConfig::garbage_collection_threshold() > 0.0)) { | 
 |         garbage_collect_cached_blocks(); | 
 |       } | 
 |       // Attempt allocate | 
 |       block_found = alloc_block(params, false) | 
 |           // Free enough available cached blocks to satisfy alloc and retry | 
 |           // alloc. | 
 |           || (release_available_cached_blocks(params) && | 
 |               alloc_block(params, false)) | 
 |           // Free all non-split cached blocks and retry alloc. | 
 |           || (release_cached_blocks() && alloc_block(params, true)); | 
 |     } | 
 |  | 
 |     if (!block_found) { | 
 |       // For any error code other than cudaErrorMemoryAllocation, | 
 |       // alloc_block should have thrown an exception already. | 
 |       TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation); | 
 |  | 
 |       size_t device_free; | 
 |       size_t device_total; | 
 |       C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); | 
 |       std::string allowed_info; | 
 |  | 
 |       if (set_fraction) { | 
 |         allowed_info = format_size(allowed_memory_maximum) + " allowed; "; | 
 |       } | 
 |  | 
 |       stats.num_ooms += 1; | 
 |  | 
 |       // "total capacity": total global memory on GPU | 
 |       // "allowed": memory is allowed to use, which set by fraction. | 
 |       // "already allocated": memory allocated by the program using the | 
 |       //                      caching allocator | 
 |       // "free": free memory as reported by the CUDA API | 
 |       // "cached": memory held by the allocator but not used by the program | 
 |       // | 
 |       // The "allocated" amount  does not include memory allocated outside | 
 |       // of the caching allocator, such as memory allocated by other programs | 
 |       // or memory held by the driver. | 
 |       // | 
 |       // The sum of "allocated" + "free" + "cached" may be less than the | 
 |       // total capacity due to memory held by the driver and usage by other | 
 |       // programs. | 
 |       // | 
 |       // Note that at this point free_cached_blocks has already returned all | 
 |       // possible "cached" memory to the driver. The only remaining "cached" | 
 |       // memory is split from a larger block that is partially in-use. | 
 |       TORCH_CHECK_WITH( | 
 |           CUDAOutOfMemoryError, | 
 |           false, | 
 |           "CUDA out of memory. Tried to allocate ", | 
 |           format_size(alloc_size), | 
 |           " (GPU ", | 
 |           device, | 
 |           "; ", | 
 |           format_size(device_total), | 
 |           " total capacity; ", | 
 |           format_size( | 
 |               stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)] | 
 |                   .current), | 
 |           " already allocated; ", | 
 |           format_size(device_free), | 
 |           " free; ", | 
 |           allowed_info, | 
 |           format_size( | 
 |               stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)] | 
 |                   .current), | 
 |           " reserved in total by PyTorch)", | 
 |           " If reserved memory is >> allocated memory try setting max_split_size_mb to avoid" | 
 |           " fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF", | 
 |           ""); | 
 |     } | 
 |  | 
 |     TORCH_INTERNAL_ASSERT( | 
 |         params.err == cudaSuccess && params.block != nullptr && | 
 |         params.block->ptr != nullptr); | 
 |     Block* block = params.block; | 
 |     Block* remaining = nullptr; | 
 |  | 
 |     const bool already_split = block->is_split(); | 
 |     if (should_split(block, size)) { | 
 |       remaining = block; | 
 |  | 
 |       block = new Block(device, stream, size, &pool, block->ptr); | 
 |       block->prev = remaining->prev; | 
 |       if (block->prev) { | 
 |         block->prev->next = block; | 
 |       } | 
 |       block->next = remaining; | 
 |  | 
 |       remaining->prev = block; | 
 |       remaining->ptr = static_cast<char*>(remaining->ptr) + size; | 
 |       remaining->size -= size; | 
 |       bool inserted = pool.blocks.insert(remaining).second; | 
 |       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); | 
 |  | 
 |       if (already_split) { | 
 |         // An already-split inactive block is being shrunk by size bytes. | 
 |         update_stat_array( | 
 |             stats.inactive_split_bytes, -block->size, params.stat_types); | 
 |       } else { | 
 |         // A new split inactive block is being created from a previously unsplit | 
 |         // block, size remaining->size bytes. | 
 |         for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { | 
 |           update_stat(stats.inactive_split_bytes[stat_type], remaining->size); | 
 |           update_stat(stats.inactive_split[stat_type], 1); | 
 |         }); | 
 |       } | 
 |     } else if (already_split) { | 
 |       // An already-split block is becoming active | 
 |       for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { | 
 |         update_stat(stats.inactive_split_bytes[stat_type], -block->size); | 
 |         update_stat(stats.inactive_split[stat_type], -1); | 
 |       }); | 
 |     } | 
 |  | 
 |     block->allocated = true; | 
 |     bool inserted = active_blocks.insert(block).second; | 
 |     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); | 
 |  | 
 |     for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { | 
 |       update_stat(stats.allocation[stat_type], 1); | 
 |       update_stat(stats.allocated_bytes[stat_type], block->size); | 
 |       update_stat(stats.active[stat_type], 1); | 
 |       update_stat(stats.active_bytes[stat_type], block->size); | 
 |     }); | 
 |     if (block->size >= CachingAllocatorConfig::max_split_size()) | 
 |       update_stat(stats.oversize_allocations, 1); | 
 |  | 
 |     c10::reportMemoryUsageToProfiler( | 
 |         block->ptr, | 
 |         block->size, | 
 |         stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, | 
 |         stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, | 
 |         c10::Device(c10::DeviceType::CUDA, device)); | 
 |  | 
 |     return block; | 
 |   } | 
 |  | 
 |   void free(Block* block) { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |  | 
 |     block->allocated = false; | 
 |  | 
 |     // following logic might modifying underlaying Block, causing the size | 
 |     // changed. We store ahead for reporting | 
 |     auto orig_block_ptr = block->ptr; | 
 |     auto orig_block_size = block->size; | 
 |  | 
 |     StatTypes stat_types = {false}; | 
 |     stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; | 
 |     stat_types[static_cast<size_t>(get_stat_type_for_pool(*(block->pool)))] = | 
 |         true; | 
 |     for_each_selected_stat_type(stat_types, [&](size_t stat_type) { | 
 |       update_stat(stats.allocation[stat_type], -1); | 
 |       update_stat(stats.allocated_bytes[stat_type], -block->size); | 
 |     }); | 
 |     if (block->size >= CachingAllocatorConfig::max_split_size()) | 
 |       update_stat(stats.oversize_allocations, -1); | 
 |  | 
 |     if (!block->stream_uses.empty()) { | 
 |       if (C10_UNLIKELY(captures_underway)) { | 
 |         // It's forbidden to cudaEventQuery an event recorded during CUDA graph | 
 |         // capture. We conservatively defer recording end-of-life events until | 
 |         // the next call to process_events() (which won't happen until no | 
 |         // captures are underway) | 
 |         needs_events_deferred_until_no_capture.push_back(block); | 
 |       } else { | 
 |         insert_events(block); | 
 |       } | 
 |     } else { | 
 |       free_block(block); | 
 |     } | 
 |  | 
 |     c10::reportMemoryUsageToProfiler( | 
 |         orig_block_ptr, | 
 |         -orig_block_size, | 
 |         stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, | 
 |         stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, | 
 |         c10::Device(c10::DeviceType::CUDA, block->device)); | 
 |   } | 
 |  | 
 |   void* getBaseAllocation(Block* block, size_t* outSize) { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |     while (block->prev) { | 
 |       block = block->prev; | 
 |     } | 
 |     void* basePtr = block->ptr; | 
 |     if (outSize) { | 
 |       size_t size = 0; | 
 |       while (block) { | 
 |         size += block->size; | 
 |         block = block->next; | 
 |       } | 
 |       *outSize = size; | 
 |     } | 
 |     return basePtr; | 
 |   } | 
 |  | 
 |   void recordStream(Block* block, cuda::CUDAStream stream) { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |     if (stream.stream() == block->stream) { | 
 |       // ignore uses on the allocation stream, since those don't require any | 
 |       // special synchronization | 
 |       return; | 
 |     } | 
 |     block->stream_uses.insert(stream); | 
 |   } | 
 |  | 
 |   /** set memory fraction to limit maximum allocated memory **/ | 
 |   void setMemoryFraction(double fraction) { | 
 |     size_t device_free; | 
 |     size_t device_total; | 
 |     C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); | 
 |     allowed_memory_maximum = static_cast<size_t>(fraction * device_total); | 
 |     set_fraction = true; | 
 |   } | 
 |  | 
 |   /** returns cached blocks to the system allocator **/ | 
 |   void emptyCache() { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |     release_cached_blocks(); | 
 |   } | 
 |  | 
 |   /** Retrieves info (total size + largest block) of the memory cache **/ | 
 |   void cacheInfo(size_t* total, size_t* largest) { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |     if (*largest == | 
 |         0) { // make an initial guess if a zero *largest is passed in | 
 |       size_t tmp_bytes; | 
 |       cudaMemGetInfo( | 
 |           largest, // Use free memory as an optimistic initial guess of *largest | 
 |           &tmp_bytes); | 
 |     } | 
 |     cache_info_aux(large_blocks, total, largest); | 
 |     cache_info_aux(small_blocks, total, largest); | 
 |     for (const auto& gp : graph_pools) { | 
 |       cache_info_aux(gp.second->large_blocks, total, largest); | 
 |       cache_info_aux(gp.second->small_blocks, total, largest); | 
 |     } | 
 |   } | 
 |  | 
 |   /** Returns a copy of the memory allocator stats **/ | 
 |   DeviceStats getStats() { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |     return stats; | 
 |   } | 
 |  | 
 |   /** Resets the historical accumulation stats for the device **/ | 
 |   void resetAccumulatedStats() { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |  | 
 |     for (const auto statType : | 
 |          c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) { | 
 |       reset_accumulated_stat(stats.allocation[statType]); | 
 |       reset_accumulated_stat(stats.segment[statType]); | 
 |       reset_accumulated_stat(stats.active[statType]); | 
 |       reset_accumulated_stat(stats.inactive_split[statType]); | 
 |       reset_accumulated_stat(stats.allocated_bytes[statType]); | 
 |       reset_accumulated_stat(stats.reserved_bytes[statType]); | 
 |       reset_accumulated_stat(stats.active_bytes[statType]); | 
 |       reset_accumulated_stat(stats.inactive_split_bytes[statType]); | 
 |     } | 
 |  | 
 |     stats.num_alloc_retries = 0; | 
 |     stats.num_ooms = 0; | 
 |     reset_accumulated_stat(stats.oversize_allocations); | 
 |     reset_accumulated_stat(stats.oversize_segments); | 
 |   } | 
 |  | 
 |   /** Resets the historical peak stats for the device **/ | 
 |   void resetPeakStats() { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |  | 
 |     for (const auto statType : | 
 |          c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) { | 
 |       reset_peak_stat(stats.allocation[statType]); | 
 |       reset_peak_stat(stats.segment[statType]); | 
 |       reset_peak_stat(stats.active[statType]); | 
 |       reset_peak_stat(stats.inactive_split[statType]); | 
 |       reset_peak_stat(stats.allocated_bytes[statType]); | 
 |       reset_peak_stat(stats.reserved_bytes[statType]); | 
 |       reset_peak_stat(stats.active_bytes[statType]); | 
 |       reset_peak_stat(stats.inactive_split_bytes[statType]); | 
 |     } | 
 |     reset_peak_stat(stats.oversize_allocations); | 
 |     reset_peak_stat(stats.oversize_segments); | 
 |   } | 
 |  | 
 |   /** Dump a complete snapshot of the memory held by the allocator. Potentially | 
 |    * VERY expensive. **/ | 
 |   std::vector<SegmentInfo> snapshot() const { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |  | 
 |     std::vector<SegmentInfo> result; | 
 |     const auto all_blocks = get_all_blocks(); | 
 |  | 
 |     for (const Block* const head_block : all_blocks) { | 
 |       if (head_block->prev != nullptr) { | 
 |         continue; | 
 |       } | 
 |       result.emplace_back(); | 
 |       SegmentInfo& segment_info = result.back(); | 
 |       segment_info.device = head_block->device; | 
 |       segment_info.address = reinterpret_cast<int64_t>(head_block->ptr); | 
 |       segment_info.is_large = (!head_block->pool->is_small); | 
 |  | 
 |       const Block* block = head_block; | 
 |       while (block != nullptr) { | 
 |         segment_info.blocks.emplace_back(); | 
 |         BlockInfo& block_info = segment_info.blocks.back(); | 
 |  | 
 |         block_info.size = block->size; | 
 |         block_info.allocated = block->allocated; | 
 |         block_info.active = block->allocated || (block->event_count > 0) || | 
 |             !block->stream_uses.empty(); | 
 |  | 
 |         segment_info.total_size += block_info.size; | 
 |         if (block_info.allocated) { | 
 |           segment_info.allocated_size += block_info.size; | 
 |         } | 
 |         if (block_info.active) { | 
 |           segment_info.active_size += block_info.size; | 
 |         } | 
 |  | 
 |         block = block->next; | 
 |       } | 
 |     } | 
 |  | 
 |     std::sort( | 
 |         result.begin(), | 
 |         result.end(), | 
 |         [](const SegmentInfo& a, const SegmentInfo& b) { | 
 |           return a.address < b.address; | 
 |         }); | 
 |  | 
 |     return result; | 
 |   } | 
 |  | 
 |   // This function takes the size and number of divisions argument and rounds | 
 |   // up the size argument for the nearest power-of-2 division. | 
 |   // For example, if we need to round-up 1200 and number of divisions is 4, | 
 |   // the size 1200 lies between 1024 and 2048 and if we do 4 divisions between | 
 |   // them, the values are 1024, 1280, 1536, and 1792. So the function will | 
 |   // return 1280 as the nearest ceiling of power-2 divison. | 
 |   static size_t roundup_power2_next_division(size_t size, size_t divisions) { | 
 |     if (C10_UNLIKELY(size <= 4 || divisions <= 1)) { | 
 |       return size; | 
 |     } | 
 |     if (llvm::isPowerOf2_64(size)) { | 
 |       return size; | 
 |     } | 
 |  | 
 |     // divide the space between these 2's power into equal divisions | 
 |     // If division is zero, return the power-of-2 ceiling. | 
 |     size_t power2_floor = llvm::PowerOf2Floor(size); | 
 |     size_t power2_divison = | 
 |         power2_floor >> (63 - llvm::countLeadingZeros(divisions)); | 
 |     if (C10_UNLIKELY(power2_divison == 0)) { | 
 |       return (power2_floor << 1); | 
 |     } | 
 |     size_t round_size_floor = size & (~(power2_divison - 1)); | 
 |     return (round_size_floor == size) ? size | 
 |                                       : round_size_floor + power2_divison; | 
 |   } | 
 |  | 
 |   static size_t round_size(size_t size) { | 
 |     if (size < kMinBlockSize) { | 
 |       return kMinBlockSize; | 
 |     } else { | 
 |       auto divisions = CachingAllocatorConfig::roundup_power2_divisions(); | 
 |       if (divisions > 0 && size > (kMinBlockSize * divisions)) { | 
 |         return roundup_power2_next_division(size, divisions); | 
 |       } else { | 
 |         return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize); | 
 |       } | 
 |     } | 
 |   } | 
 |  | 
 |   // See Note [Interaction with CUDA graph capture] | 
 |  | 
 |   // Called by CUDAGraph::capture_begin | 
 |   void notifyCaptureBegin(CaptureId_t graph_id, MempoolId_t mempool_id) { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |     captures_underway++; | 
 |     auto it = graph_pools.find(mempool_id); | 
 |     if (it == graph_pools.end()) { | 
 |       // mempool_id does not reference an existing pool. Make a new pool for | 
 |       // this capture. | 
 |       graph_pools.emplace(mempool_id, std::make_unique<PrivatePool>()); | 
 |     } else { | 
 |       // mempool_id references an existing pool, which the current capture will | 
 |       // share. Check this pool is live (at least one other capture already | 
 |       // references it). | 
 |       TORCH_INTERNAL_ASSERT(it->second->use_count > 0); | 
 |       it->second->use_count++; | 
 |     } | 
 |     // Maps this graph_id to mempool_id and makes sure this graph_id wasn't | 
 |     // somehow assigned a mempool_id already. Keeps essential effect (insert) | 
 |     // out of macro. | 
 |     bool inserted = capture_to_pool_map.insert({graph_id, mempool_id}).second; | 
 |     TORCH_INTERNAL_ASSERT(inserted); | 
 |   } | 
 |  | 
 |   // Called by CUDAGraph::capture_end | 
 |   void notifyCaptureEnd(CaptureId_t graph_id) { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |     captures_underway--; | 
 |     auto it = capture_to_pool_map.find(graph_id); | 
 |     TORCH_INTERNAL_ASSERT(it != capture_to_pool_map.end()); | 
 |     capture_to_pool_map.erase(it); | 
 |   } | 
 |  | 
 |   // Called by CUDAGraph::reset | 
 |   void notifyCaptureDestroy(MempoolId_t mempool_id) { | 
 |     std::lock_guard<std::recursive_mutex> lock(mutex); | 
 |     // The instantiated cudaGraphExec_t has been destroyed. We can't blindly | 
 |     // delete and cudaFree the mempool its capture used, because | 
 |     //  1. other graph(s) might share the same pool | 
 |     //  2. the user might still hold references to output tensors allocated | 
 |     //  during capture. | 
 |     // To handle 1 and 2, we track the number of graphs using this particular | 
 |     // mempool. When the count reaches 0, we tell free_cached_blocks it may now | 
 |     // cudaFree blocks from this graph's pool when it discovers they're unused | 
 |     // (unsplit). | 
 |     auto it = graph_pools.find(mempool_id); | 
 |     TORCH_INTERNAL_ASSERT(it != graph_pools.end()); | 
 |     auto uc = --(it->second->use_count); | 
 |     TORCH_INTERNAL_ASSERT(uc >= 0); | 
 |     if (uc == 0) { | 
 |       // Allows free_cached_blocks to begin cudaFreeing this pool's memory, | 
 |       // and makes sure this pool wasn't somehow made freeable already. | 
 |       bool inserted = | 
 |           graph_pools_freeable.insert({mempool_id, it->second.get()}).second; | 
 |       TORCH_INTERNAL_ASSERT(inserted); | 
 |     } | 
 |   } | 
 |  | 
 |  private: | 
 |   // All private methods do not acquire the allocator mutex. | 
 |  | 
 |   std::vector<const Block*> get_all_blocks() const { | 
 |     std::vector<const Block*> blocks; | 
 |     blocks.insert( | 
 |         blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end()); | 
 |     blocks.insert( | 
 |         blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end()); | 
 |     for (const auto& gp : graph_pools) { | 
 |       blocks.insert( | 
 |           blocks.end(), | 
 |           gp.second->small_blocks.blocks.begin(), | 
 |           gp.second->small_blocks.blocks.end()); | 
 |       blocks.insert( | 
 |           blocks.end(), | 
 |           gp.second->large_blocks.blocks.begin(), | 
 |           gp.second->large_blocks.blocks.end()); | 
 |     } | 
 |     blocks.insert(blocks.end(), active_blocks.begin(), active_blocks.end()); | 
 |     return blocks; | 
 |   } | 
 |  | 
 |   /** moves a block into a pool of cached free blocks */ | 
 |   void free_block(Block* block) { | 
 |     TORCH_INTERNAL_ASSERT( | 
 |         !block->allocated && block->event_count == 0 && | 
 |         block->stream_uses.empty()); | 
 |  | 
 |     size_t original_block_size = block->size; | 
 |  | 
 |     auto& pool = *block->pool; | 
 |     int64_t net_change_inactive_split_blocks = 0; | 
 |     int64_t net_change_inactive_split_size = 0; | 
 |  | 
 |     const std::array<Block*, 2> merge_candidates = {block->prev, block->next}; | 
 |     for (Block* merge_candidate : merge_candidates) { | 
 |       const int64_t subsumed_size = | 
 |           try_merge_blocks(block, merge_candidate, pool); | 
 |       if (subsumed_size > 0) { | 
 |         net_change_inactive_split_blocks -= 1; | 
 |         net_change_inactive_split_size -= subsumed_size; | 
 |       } | 
 |     } | 
 |  | 
 |     active_blocks.erase(block); | 
 |     // Makes sure the Block* isn't already present in the pool we're freeing it | 
 |     // back into. | 
 |     bool inserted = pool.blocks.insert(block).second; | 
 |     TORCH_INTERNAL_ASSERT(inserted); | 
 |  | 
 |     if (block->is_split()) { | 
 |       net_change_inactive_split_blocks += 1; | 
 |       net_change_inactive_split_size += block->size; | 
 |     } | 
 |  | 
 |     StatTypes stat_types = {false}; | 
 |     stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; | 
 |     stat_types[static_cast<size_t>(get_stat_type_for_pool(pool))] = true; | 
 |     for_each_selected_stat_type(stat_types, [&](size_t stat_type) { | 
 |       update_stat( | 
 |           stats.inactive_split[stat_type], net_change_inactive_split_blocks); | 
 |       update_stat( | 
 |           stats.inactive_split_bytes[stat_type], | 
 |           net_change_inactive_split_size); | 
 |       update_stat(stats.active[stat_type], -1); | 
 |       update_stat(stats.active_bytes[stat_type], -original_block_size); | 
 |     }); | 
 |   } | 
 |  | 
 |   /** combine previously split blocks. returns the size of the subsumed block, | 
 |    * or 0 on failure. */ | 
 |   size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) { | 
 |     if (!src || src->allocated || src->event_count > 0 || | 
 |         !src->stream_uses.empty()) { | 
 |       return 0; | 
 |     } | 
 |  | 
 |     AT_ASSERT(dst->is_split() && src->is_split()); | 
 |  | 
 |     if (dst->prev == src) { | 
 |       dst->ptr = src->ptr; | 
 |       dst->prev = src->prev; | 
 |       if (dst->prev) { | 
 |         dst->prev->next = dst; | 
 |       } | 
 |     } else { | 
 |       dst->next = src->next; | 
 |       if (dst->next) { | 
 |         dst->next->prev = dst; | 
 |       } | 
 |     } | 
 |  | 
 |     const size_t subsumed_size = src->size; | 
 |     dst->size += subsumed_size; | 
 |     auto erased = pool.blocks.erase(src); | 
 |     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1); | 
 |     delete src; | 
 |  | 
 |     return subsumed_size; | 
 |   } | 
 |  | 
 |   BlockPool& get_pool(size_t size, cudaStream_t stream) { | 
 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 | 
 |     // captures_underway is a conservative guess that the current stream may be | 
 |     // capturing. It's only > 0 if some thread has begun and not yet ended a | 
 |     // capture, so it's usually 0, and we can short-circuit | 
 |     // cudaStreamCaptureStatus (which does a TLS lookup). | 
 |     if (C10_UNLIKELY(captures_underway)) { | 
 |       CaptureId_t id; | 
 |       cudaStreamCaptureStatus status; | 
 |       C10_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &id)); | 
 |       if (status != cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) { | 
 |         TORCH_INTERNAL_ASSERT( | 
 |             status != | 
 |             cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated); | 
 |         // Retrieves the private pool assigned to this capture. | 
 |         auto it0 = capture_to_pool_map.find(id); | 
 |         TORCH_INTERNAL_ASSERT(it0 != capture_to_pool_map.end()); | 
 |         auto it1 = graph_pools.find(it0->second); | 
 |         TORCH_INTERNAL_ASSERT(it1 != graph_pools.end()); | 
 |         if (size <= kSmallSize) { | 
 |           return it1->second->small_blocks; | 
 |         } else { | 
 |           return it1->second->large_blocks; | 
 |         } | 
 |       } | 
 |     } | 
 | #endif | 
 |     if (size <= kSmallSize) { | 
 |       return small_blocks; | 
 |     } else { | 
 |       return large_blocks; | 
 |     } | 
 |   } | 
 |  | 
 |   StatType get_stat_type_for_pool(const BlockPool& pool) { | 
 |     return pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL; | 
 |   } | 
 |  | 
 |   bool should_split(const Block* block, size_t size) { | 
 |     size_t remaining = block->size - size; | 
 |     if (block->pool->is_small) { | 
 |       return remaining >= kMinBlockSize; | 
 |     } else { | 
 |       return (size < CachingAllocatorConfig::max_split_size()) && | 
 |           (remaining > kSmallSize); | 
 |     } | 
 |   } | 
 |  | 
 |   static size_t get_allocation_size(size_t size) { | 
 |     if (size <= kSmallSize) { | 
 |       return kSmallBuffer; | 
 |     } else if (size < kMinLargeAlloc) { | 
 |       return kLargeBuffer; | 
 |     } else { | 
 |       return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge); | 
 |     } | 
 |   } | 
 |  | 
 |   bool get_free_block(AllocParams& p) { | 
 |     BlockPool& pool = *p.pool; | 
 |  | 
 |     if (C10_UNLIKELY( | 
 |             CachingAllocatorConfig::garbage_collection_threshold() > 0.0)) { | 
 |       // Track block reuse interval only when garbage collection is enabled. | 
 |       for (auto& b : pool.blocks) { | 
 |         ++b->gc_count; | 
 |       } | 
 |     } | 
 |     auto it = pool.blocks.lower_bound(&p.search_key); | 
 |     if (it == pool.blocks.end() || (*it)->stream != p.stream()) | 
 |       return false; | 
 |     // Do not return an oversized block for a large request | 
 |     if ((p.size() < CachingAllocatorConfig::max_split_size()) && | 
 |         ((*it)->size >= CachingAllocatorConfig::max_split_size())) | 
 |       return false; | 
 |     // Allow oversized block size to be rounded up but within a limit | 
 |     if ((p.size() >= CachingAllocatorConfig::max_split_size()) && | 
 |         ((*it)->size >= p.size() + kLargeBuffer)) | 
 |       return false; | 
 |     p.block = *it; | 
 |     (*it)->gc_count = 0; // Denote this block has been used | 
 |     pool.blocks.erase(it); | 
 |     return true; | 
 |   } | 
 |  | 
 |   bool trigger_free_memory_callbacks(AllocParams& p) { | 
 |     bool freed_memory = false; | 
 |     for (const auto& name : FreeCudaMemoryCallbacksRegistry()->Keys()) { | 
 |       freed_memory |= | 
 |           FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute(); | 
 |     } | 
 |     return freed_memory; | 
 |   } | 
 |  | 
 |   void garbage_collect_cached_blocks() { | 
 |     // Free unused cached blocks to reclaim GPU memory. | 
 |     // Unlike release_cached_blocks(), this does not enforce synchronization and | 
 |     // therefore should be of less overheads. | 
 |  | 
 |     size_t gc_threshold = static_cast<size_t>( | 
 |         CachingAllocatorConfig::garbage_collection_threshold() * | 
 |         allowed_memory_maximum); | 
 |     // No need to trigger GC yet | 
 |     if (total_allocated_memory <= gc_threshold) { | 
 |       return; | 
 |     } | 
 |     const auto target_size = total_allocated_memory - gc_threshold; | 
 |     size_t gc_reclaimed = 0; | 
 |  | 
 |     // Calculate the total age of the free-able blocks. We'll use it later to | 
 |     // get "avg age" threshold. | 
 |     double total_age = 0.0; | 
 |     int freeable_block_count = 0; | 
 |     for (auto& b : large_blocks.blocks) { | 
 |       if (!b->is_split()) { | 
 |         total_age += b->gc_count; | 
 |         ++freeable_block_count; | 
 |       } | 
 |     } | 
 |     // No free-able blocks? | 
 |     if (freeable_block_count == 0) { | 
 |       return; | 
 |     } | 
 |  | 
 |     // Repeat GC until we reach reclaim > target size. | 
 |     bool block_freed = true; | 
 |     while (gc_reclaimed < target_size && block_freed == true && | 
 |            freeable_block_count > 0) { | 
 |       // Free blocks exceeding this age threshold first. | 
 |       double age_threshold = total_age / freeable_block_count; | 
 |       // Stop iteration if we can no longer free a block. | 
 |       block_freed = false; | 
 |  | 
 |       // Free blocks of > avg age. Don't stop upon reaching the target_size, | 
 |       // we don't want this GC to be triggered frequently. | 
 |       auto it = large_blocks.blocks.begin(); | 
 |       while (it != large_blocks.blocks.end()) { | 
 |         Block* block = *it; | 
 |         ++it; | 
 |         if (!block->is_split() && block->gc_count >= age_threshold) { | 
 |           block_freed = true; | 
 |           gc_reclaimed += block->size; | 
 |           total_age -= block->gc_count; // Decrement the age | 
 |           freeable_block_count--; // One less block that can be freed | 
 |           release_block(block); | 
 |         } | 
 |       } | 
 |     } | 
 |   } | 
 |  | 
 |   bool alloc_block(AllocParams& p, bool isRetry) { | 
 |     // Defensively checks for preexisting CUDA error state. | 
 |     C10_CUDA_CHECK(cudaGetLastError()); | 
 |  | 
 |     size_t size = p.alloc_size; | 
 |     void* ptr; | 
 |  | 
 |     if (isRetry) { | 
 |       stats.num_alloc_retries += 1; | 
 |     } | 
 |  | 
 |     if (set_fraction && | 
 |         total_allocated_memory + size > allowed_memory_maximum) { | 
 |       p.err = cudaErrorMemoryAllocation; | 
 |       return false; | 
 |     } else { | 
 |       p.err = cudaMallocMaybeCapturing(&ptr, size); | 
 |       if (p.err != cudaSuccess) { | 
 |         if (p.err == cudaErrorMemoryAllocation) { | 
 |           // If this is the first attempt (!isRetry), we can forgive and clear | 
 |           // CUDA's | 
 |           //   internal error state. | 
 |           // If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH | 
 |           // will take | 
 |           //   over to throw a helpful exception. The user can choose to catch | 
 |           //   the exception, free some stuff in their script, and attempt their | 
 |           //   allocation again. In this case, we can also forgive and clear | 
 |           //   CUDA's internal error state. | 
 |           cudaGetLastError(); | 
 |         } else { | 
 |           // If the error's unrelated to memory allocation, we should throw | 
 |           // immediately. | 
 |           C10_CUDA_CHECK(p.err); | 
 |         } | 
 |         return false; | 
 |       } | 
 |     } | 
 |  | 
 |     if (p.pool->owner_PrivatePool) { | 
 |       // The block is for a CUDA graph's PrivatePool. | 
 |       p.pool->owner_PrivatePool->cudaMalloc_count++; | 
 |     } | 
 |  | 
 |     total_allocated_memory += size; | 
 |     p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); | 
 |     for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { | 
 |       update_stat(stats.segment[stat_type], 1); | 
 |       update_stat(stats.reserved_bytes[stat_type], size); | 
 |     }); | 
 |     if (size >= CachingAllocatorConfig::max_split_size()) | 
 |       update_stat(stats.oversize_segments, 1); | 
 |  | 
 |     // p.block came from new, not cudaMalloc. It should not be nullptr here. | 
 |     TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); | 
 |     return true; | 
 |   } | 
 |  | 
 |   /** Free one or more oversize blocks to the system allocator.  But only enough | 
 |    * **/ | 
 |   /** to satisfy the target size **/ | 
 |   bool release_available_cached_blocks(const AllocParams& p) { | 
 |     if (CachingAllocatorConfig::max_split_size() == | 
 |         std::numeric_limits<size_t>::max()) | 
 |       return false; | 
 |     BlockPool& pool = *p.pool; | 
 |     Block key = p.search_key; | 
 |     key.size = (key.size < CachingAllocatorConfig::max_split_size()) | 
 |         ? CachingAllocatorConfig::max_split_size() | 
 |         : key.size; | 
 |     auto it = pool.blocks.lower_bound(&key); | 
 |     if (it == pool.blocks.end() || (*it)->stream != p.stream()) { | 
 |       // No single block is large enough; free multiple oversize blocks, | 
 |       // starting with the largest | 
 |       if (it == pool.blocks.begin()) | 
 |         return false; | 
 |       size_t totalReleased = 0; | 
 |       --it; // Back up one item.  Now on the largest block for the correct | 
 |             // stream | 
 |       while ((totalReleased < key.size) && | 
 |              ((*it)->size >= CachingAllocatorConfig::max_split_size()) && | 
 |              ((*it)->stream == p.stream())) { | 
 |         auto cur = it; | 
 |         totalReleased += (*it)->size; | 
 |         if (it != pool.blocks.begin()) { | 
 |           --it; | 
 |           release_block(*cur); | 
 |         } else { | 
 |           release_block(*cur); | 
 |           break; | 
 |         } | 
 |       } | 
 |       if (totalReleased < key.size) | 
 |         return false; | 
 |     } else { | 
 |       release_block(*it); | 
 |     } | 
 |     return true; | 
 |   } | 
 |  | 
 |   bool release_cached_blocks() { | 
 |     // First ensure that all blocks that can't currently be allocated due to | 
 |     // outstanding events are returned to the pool. | 
 |     synchronize_and_free_events(); | 
 |  | 
 |     // Free all non-split cached blocks to system allocator | 
 |     release_blocks(large_blocks); | 
 |     release_blocks(small_blocks); | 
 |  | 
 |     for (auto it = graph_pools_freeable.begin(); | 
 |          it != graph_pools_freeable.end();) { | 
 |       // See notifyCaptureDestroy for the strategy here. | 
 |       TORCH_INTERNAL_ASSERT(it->second->use_count == 0); | 
 |       release_blocks(it->second->small_blocks); | 
 |       release_blocks(it->second->large_blocks); | 
 |       if (it->second->cudaMalloc_count == 0) { | 
 |         auto erase_count = graph_pools.erase(it->first); | 
 |         TORCH_INTERNAL_ASSERT(erase_count == 1); | 
 |         it = graph_pools_freeable.erase(it); | 
 |       } else { | 
 |         ++it; | 
 |       } | 
 |     } | 
 |  | 
 |     return true; | 
 |   } | 
 |  | 
 |   void release_block(Block* block) { | 
 |     C10_CUDA_CHECK(cudaFree((void*)block->ptr)); | 
 |     total_allocated_memory -= block->size; | 
 |  | 
 |     auto* pool = block->pool; | 
 |     if (pool->owner_PrivatePool) { | 
 |       // The cudaFreed block belonged to a CUDA graph's PrivatePool. | 
 |       TORCH_INTERNAL_ASSERT(pool->owner_PrivatePool->cudaMalloc_count > 0); | 
 |       pool->owner_PrivatePool->cudaMalloc_count--; | 
 |     } | 
 |  | 
 |     StatTypes stat_types = {false}; | 
 |     stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; | 
 |     stat_types[static_cast<size_t>(get_stat_type_for_pool(*pool))] = true; | 
 |     for_each_selected_stat_type(stat_types, [&](size_t stat_type) { | 
 |       update_stat(stats.segment[stat_type], -1); | 
 |       update_stat(stats.reserved_bytes[stat_type], -block->size); | 
 |     }); | 
 |     if (block->size >= CachingAllocatorConfig::max_split_size()) | 
 |       update_stat(stats.oversize_segments, -1); | 
 |  | 
 |     pool->blocks.erase(block); | 
 |     delete block; | 
 |   } | 
 |  | 
 |   void release_blocks(BlockPool& pool) { | 
 |     // Frees all non-split blocks | 
 |     auto it = pool.blocks.begin(); | 
 |     while (it != pool.blocks.end()) { | 
 |       Block* block = *it; | 
 |       ++it; | 
 |       if (!block->prev && !block->next) { | 
 |         release_block(block); | 
 |       } | 
 |     } | 
 |   } | 
 |  | 
 |   cudaEvent_t create_event_internal() { | 
 |     cudaEvent_t event; | 
 |     C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); | 
 |     return event; | 
 |   } | 
 |  | 
 |   void free_event_internal(cudaEvent_t event) { | 
 |     C10_CUDA_CHECK(cudaEventDestroy(event)); | 
 |   } | 
 |  | 
 |   void synchronize_and_free_events() { | 
 |     // Synchronize on outstanding events and then free associated blocks. | 
 |  | 
 |     // This function syncs, so capture should not be underway. Might as well | 
 |     // make sure capture-deferred end of life events get processed too. | 
 |     TORCH_INTERNAL_ASSERT(captures_underway == 0); | 
 |     insert_events_deferred_until_no_capture(); | 
 |  | 
 |     for (auto& st : cuda_events) { | 
 |       for (auto& e : st.second) { | 
 |         cudaEvent_t event = e.first; | 
 |         Block* block = e.second; | 
 |  | 
 |         C10_CUDA_CHECK(cudaEventSynchronize(event)); | 
 |         free_event_internal(event); | 
 |  | 
 |         block->event_count--; | 
 |         if (block->event_count == 0) { | 
 |           free_block(block); | 
 |         } | 
 |       } | 
 |     } | 
 |  | 
 |     cuda_events.clear(); | 
 |   } | 
 |  | 
 |   void insert_events(Block* block) { | 
 |     int prev_device; | 
 |     C10_CUDA_CHECK(cudaGetDevice(&prev_device)); | 
 |  | 
 |     stream_set streams(std::move(block->stream_uses)); | 
 |     AT_ASSERT(block->stream_uses.empty()); | 
 |     for (auto& stream : streams) { | 
 |       C10_CUDA_CHECK(cudaSetDevice(stream.device_index())); | 
 |  | 
 |       cudaEvent_t event = create_event_internal(); | 
 |       C10_CUDA_CHECK(cudaEventRecord(event, stream.stream())); | 
 |  | 
 |       block->event_count++; | 
 |       cuda_events[stream].emplace_back(event, block); | 
 |     } | 
 |  | 
 |     C10_CUDA_CHECK(cudaSetDevice(prev_device)); | 
 |   } | 
 |  | 
 |   void insert_events_deferred_until_no_capture() { | 
 |     if (C10_UNLIKELY(needs_events_deferred_until_no_capture.size() > 0)) { | 
 |       for (auto* block : needs_events_deferred_until_no_capture) { | 
 |         TORCH_INTERNAL_ASSERT(!block->stream_uses.empty()); | 
 |         insert_events(block); | 
 |       } | 
 |       needs_events_deferred_until_no_capture.clear(); | 
 |     } | 
 |   } | 
 |  | 
 |   void process_events() { | 
 |     insert_events_deferred_until_no_capture(); | 
 |  | 
 |     // Process outstanding cudaEvents. Events that are completed are | 
 |     // removed from the queue, and the 'event_count' for the | 
 |     // corresponding allocation is decremented. We maintain a separate | 
 |     // list of events per stream to avoid head-of-line delays if one | 
 |     // or more streams has long-running operations. | 
 |     for (auto it = cuda_events.begin(); it != cuda_events.end();) { | 
 |       while (!it->second.empty()) { | 
 |         auto& e = it->second.front(); | 
 |         cudaEvent_t event = e.first; | 
 |         Block* block = e.second; | 
 |  | 
 |         cudaError_t err = cudaEventQuery(event); | 
 |         if (err == cudaErrorNotReady) { | 
 |           // ignore and clear the error if not ready | 
 |           cudaGetLastError(); | 
 |           break; | 
 |         } else if (err != cudaSuccess) { | 
 |           C10_CUDA_CHECK(err); | 
 |         } | 
 |  | 
 |         free_event_internal(event); | 
 |  | 
 |         block->event_count--; | 
 |         if (block->event_count == 0) { | 
 |           free_block(block); | 
 |         } | 
 |         it->second.pop_front(); | 
 |       } | 
 |  | 
 |       if (it->second.empty()) { | 
 |         it = cuda_events.erase(it); | 
 |       } else { | 
 |         it++; | 
 |       } | 
 |     } | 
 |   } | 
 |  | 
 |   // Accumulates sizes of all memory blocks for given device in given pool | 
 |   void cache_info_aux(const BlockPool& pool, size_t* total, size_t* largest) { | 
 |     for (const auto& block : pool.blocks) { | 
 |       const auto blocksize = block->size; | 
 |       *total += blocksize; | 
 |       if (blocksize > *largest) { | 
 |         *largest = blocksize; | 
 |       } | 
 |     } | 
 |   } | 
 | }; | 
 |  | 
 | class THCCachingAllocator { | 
 |  private: | 
 |   std::mutex mutex; | 
 |  | 
 |   // allocated blocks by device pointer | 
 |   ska::flat_hash_map<void*, Block*> allocated_blocks; | 
 |  | 
 |   // lock around calls to cudaFree (to prevent deadlocks with NCCL) | 
 |   mutable std::mutex cuda_free_mutex; | 
 |  | 
 |   void add_allocated_block(Block* block) { | 
 |     std::lock_guard<std::mutex> lock(mutex); | 
 |     allocated_blocks[block->ptr] = block; | 
 |   } | 
 |  | 
 |  public: | 
 |   std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocator; | 
 |  | 
 |   std::mutex* getCudaFreeMutex() const { | 
 |     return &cuda_free_mutex; | 
 |   } | 
 |  | 
 |   Block* get_allocated_block(void* ptr, bool remove = false) { | 
 |     std::lock_guard<std::mutex> lock(mutex); | 
 |     auto it = allocated_blocks.find(ptr); | 
 |     if (it == allocated_blocks.end()) { | 
 |       return nullptr; | 
 |     } | 
 |     Block* block = it->second; | 
 |     if (remove) { | 
 |       allocated_blocks.erase(it); | 
 |     } | 
 |     return block; | 
 |   } | 
 |  | 
 |   void init(int device_count) { | 
 |     const auto size = static_cast<int64_t>(device_allocator.size()); | 
 |     if (size < device_count) { | 
 |       device_allocator.resize(device_count); | 
 |       for (const auto i : c10::irange(size, device_count)) { | 
 |         device_allocator[i] = std::make_unique<DeviceCachingAllocator>(); | 
 |       } | 
 |     } | 
 |   } | 
 |  | 
 |   /** allocates a block which is safe to use from the provided stream */ | 
 |   void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) { | 
 |     TORCH_INTERNAL_ASSERT( | 
 |         0 <= device && static_cast<size_t>(device) < device_allocator.size(), | 
 |         "Allocator not initialized for device ", | 
 |         device, | 
 |         ": did you call init?"); | 
 |     Block* block = device_allocator[device]->malloc(device, size, stream); | 
 |     add_allocated_block(block); | 
 |     *devPtr = (void*)block->ptr; | 
 |   } | 
 |  | 
 |   void free(void* ptr) { | 
 |     if (!ptr) { | 
 |       return; | 
 |     } | 
 |     Block* block = get_allocated_block(ptr, true /* remove */); | 
 |     if (!block) { | 
 |       TORCH_CHECK(false, "invalid device pointer: ", ptr); | 
 |     } | 
 |     device_allocator[block->device]->free(block); | 
 |   } | 
 |  | 
 |   void setMemoryFraction(double fraction, int device) { | 
 |     TORCH_INTERNAL_ASSERT( | 
 |         0 <= device && static_cast<size_t>(device) < device_allocator.size(), | 
 |         "Allocator not initialized for device ", | 
 |         device, | 
 |         ": did you call init?"); | 
 |     TORCH_INTERNAL_ASSERT( | 
 |         0 <= fraction && fraction <= 1, | 
 |         "invalid fraction:", | 
 |         fraction, | 
 |         ". Please set within (0, 1)."); | 
 |     int activated_device; | 
 |     cudaGetDevice(&activated_device); | 
 |     if (activated_device != device) { | 
 |       cudaSetDevice(device); | 
 |     } | 
 |     device_allocator[device]->setMemoryFraction(fraction); | 
 |   } | 
 |  | 
 |   void emptyCache() { | 
 |     for (auto& da : device_allocator) | 
 |       da->emptyCache(); | 
 |   } | 
 |  | 
 |   void* getBaseAllocation(void* ptr, size_t* outSize) { | 
 |     Block* block = get_allocated_block(ptr); | 
 |     if (!block) { | 
 |       TORCH_CHECK(false, "invalid device pointer: ", ptr); | 
 |     } | 
 |     return device_allocator[block->device]->getBaseAllocation(block, outSize); | 
 |   } | 
 |  | 
 |   void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) { | 
 |     // Empty tensor's storage().data() might be a null ptr. As there is no | 
 |     // blocks associated with those tensors, it is fine to do nothing here. | 
 |     if (!ptr.get()) { | 
 |       return; | 
 |     } | 
 |  | 
 |     // If a tensor is not allocated by this instance, simply skip | 
 |     // This usually happens when CUDA tensors are shared across processes, | 
 |     // we have implemented reference counting based sharing mechanism to | 
 |     // guarantee tensors won't be accidentally freed by one process while | 
 |     // they are still being used in another | 
 |     if (ptr.get_deleter() != &raw_delete) | 
 |       return; | 
 |  | 
 |     Block* block = get_allocated_block(ptr.get()); | 
 |     // block must not be null reaching here | 
 |     TORCH_INTERNAL_ASSERT(block != nullptr, "No allocated block can be found"); | 
 |     device_allocator[block->device]->recordStream(block, stream); | 
 |   } | 
 |  | 
 |   std::vector<SegmentInfo> snapshot() { | 
 |     std::vector<SegmentInfo> result; | 
 |     for (auto& da : device_allocator) { | 
 |       auto snap = da->snapshot(); | 
 |       result.insert(result.end(), snap.begin(), snap.end()); | 
 |     } | 
 |  | 
 |     return result; | 
 |   } | 
 | }; | 
 |  | 
 | THCCachingAllocator caching_allocator; | 
 |  | 
 | // Returns whether to force all allocations to bypass the caching allocator and | 
 | // go straight to cudaMalloc.  This setting is useful when debugging GPU memory | 
 | // errors, since the caching allocator foils cuda-memcheck. | 
 | bool forceUncachedAllocator() { | 
 |   static bool force_uncached = | 
 |       getenv("PYTORCH_NO_CUDA_MEMORY_CACHING") != nullptr; | 
 |   return force_uncached; | 
 | } | 
 |  | 
 | static void uncached_delete(void* ptr) { | 
 |   C10_CUDA_CHECK(cudaFree(ptr)); | 
 | } | 
 |  | 
 | // NB: I decided not to fold this into THCCachingAllocator, because the latter | 
 | // has a lot more methods and it wasn't altogether clear that they should | 
 | // actually be publicly exposed | 
 | struct CudaCachingAllocator : public Allocator { | 
 |   DataPtr allocate(size_t size) const override { | 
 |     constexpr size_t one_exa_bytes = 1152921504606846976ULL; | 
 |     TORCH_CHECK_WITH( | 
 |         CUDAOutOfMemoryError, | 
 |         size < one_exa_bytes, | 
 |         "CUDA out of memory. Tried to allocate more than 1EB memory."); | 
 |     int device; | 
 |     C10_CUDA_CHECK(cudaGetDevice(&device)); | 
 |     void* r = nullptr; | 
 |     if (forceUncachedAllocator()) { | 
 |       // Deliberately don't use cudaMallocMaybeCapturing here, to force an error | 
 |       // if someone tries to use forceUncachedAllocator while capturing. | 
 |       C10_CUDA_CHECK(cudaMalloc(&r, size)); | 
 |       return {r, r, &uncached_delete, Device(DeviceType::CUDA, device)}; | 
 |     } | 
 |     if (size != 0) { | 
 |       caching_allocator.malloc( | 
 |           &r, device, size, cuda::getCurrentCUDAStream(device)); | 
 |     } | 
 |     return {r, r, &raw_delete, Device(DeviceType::CUDA, device)}; | 
 |   } | 
 |   DeleterFnPtr raw_deleter() const override { | 
 |     if (forceUncachedAllocator()) { | 
 |       return &uncached_delete; | 
 |     } else { | 
 |       return &raw_delete; | 
 |     } | 
 |   } | 
 | }; | 
 |  | 
 | CudaCachingAllocator device_allocator; | 
 |  | 
 | Allocator* get(void) { | 
 |   return &device_allocator; | 
 | } | 
 |  | 
 | void init(int device_count) { | 
 |   caching_allocator.init(device_count); | 
 | } | 
 |  | 
 | void setMemoryFraction(double fraction, int device) { | 
 |   caching_allocator.setMemoryFraction(fraction, device); | 
 | } | 
 |  | 
 | void emptyCache(void) { | 
 |   caching_allocator.emptyCache(); | 
 | } | 
 |  | 
 | void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) { | 
 |   caching_allocator.device_allocator[dev_id]->cacheInfo( | 
 |       cachedAndFree, largestBlock); | 
 | } | 
 |  | 
 | void* getBaseAllocation(void* ptr, size_t* size) { | 
 |   return caching_allocator.getBaseAllocation(ptr, size); | 
 | } | 
 |  | 
 | void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) { | 
 |   caching_allocator.recordStream(ptr, stream); | 
 | } | 
 |  | 
 | std::mutex* getFreeMutex() { | 
 |   return caching_allocator.getCudaFreeMutex(); | 
 | } | 
 |  | 
 | static inline void assertValidDevice(int device) { | 
 |   const auto device_num = caching_allocator.device_allocator.size(); | 
 |   TORCH_CHECK( | 
 |       0 <= device && device < static_cast<int64_t>(device_num), | 
 |       "Invalid device argument ", | 
 |       device, | 
 |       ": did you call init?"); | 
 | } | 
 |  | 
 | DeviceStats getDeviceStats(int device) { | 
 |   assertValidDevice(device); | 
 |   return caching_allocator.device_allocator[device]->getStats(); | 
 | } | 
 |  | 
 | void resetAccumulatedStats(int device) { | 
 |   assertValidDevice(device); | 
 |   caching_allocator.device_allocator[device]->resetAccumulatedStats(); | 
 | } | 
 |  | 
 | void resetPeakStats(int device) { | 
 |   assertValidDevice(device); | 
 |   caching_allocator.device_allocator[device]->resetPeakStats(); | 
 | } | 
 |  | 
 | std::vector<SegmentInfo> snapshot() { | 
 |   return caching_allocator.snapshot(); | 
 | } | 
 |  | 
 | // CUDAGraph interactions | 
 | void notifyCaptureBegin( | 
 |     int device, | 
 |     CaptureId_t graph_id, | 
 |     MempoolId_t mempool_id) { | 
 |   assertValidDevice(device); | 
 |   caching_allocator.device_allocator[device]->notifyCaptureBegin( | 
 |       graph_id, mempool_id); | 
 | } | 
 |  | 
 | void notifyCaptureEnd(int device, CaptureId_t graph_id) { | 
 |   assertValidDevice(device); | 
 |   caching_allocator.device_allocator[device]->notifyCaptureEnd(graph_id); | 
 | } | 
 |  | 
 | void notifyCaptureDestroy(int device, MempoolId_t mempool_id) { | 
 |   assertValidDevice(device); | 
 |   caching_allocator.device_allocator[device]->notifyCaptureDestroy(mempool_id); | 
 | } | 
 |  | 
 | // | 
 | // In CUDA IPC, sender sends a tensor to receiver, getIpcDevPtr | 
 | // is called by the receiving process to map the CUDA memory from the sending | 
 | // process into its own address space. | 
 | // | 
 | // CUDA IPC only allows sharing a big memory block associated with a | 
 | // cudaIpcMemHandle_t and it can be opened only **once** per context per | 
 | // process. There can be multiple types of storage in the same IPC mem block, so | 
 | // we must cache the device ptr to construct typed storage as it comes. | 
 | // | 
 | // ipcMemHandle_to_devptr maps a cudaIpcMemHandle_t to a device pointer in the | 
 | // process that can be used to access the memory block in the sender process. It | 
 | // only saves a weak_ptr of the device pointer in the map, the shared_ptr will | 
 | // be used to reconstruct all storages in this CudaMalloc allocation. And it | 
 | // will deleted in cudaIpcCloseMemHandle when its reference count is 0. | 
 | // | 
 | namespace { | 
 | std::mutex IpcMutex; | 
 | ska::flat_hash_map<std::string, std::weak_ptr<void>> ipcMemHandle_to_devptr; | 
 | } // namespace | 
 |  | 
 | std::shared_ptr<void> getIpcDevPtr(std::string handle) { | 
 |   std::lock_guard<std::mutex> lock(IpcMutex); | 
 |  | 
 |   auto iter = ipcMemHandle_to_devptr.find(handle); | 
 |   if (iter != ipcMemHandle_to_devptr.end()) { | 
 |     auto devptr = iter->second.lock(); | 
 |     if (devptr) | 
 |       return devptr; | 
 |   } | 
 |   // This ipcMemHandle hasn't been opened, or already expired, open it to | 
 |   // enable IPC access to that mem block. | 
 |   void* dev = nullptr; | 
 |   auto ipc_handle = reinterpret_cast<const cudaIpcMemHandle_t*>(handle.c_str()); | 
 |   C10_CUDA_CHECK( | 
 |       cudaIpcOpenMemHandle(&dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess)); | 
 |   // devPtr has to be deleted in same device when created. | 
 |   int curr_device; | 
 |   C10_CUDA_CHECK(cudaGetDevice(&curr_device)); | 
 |   auto sp = std::shared_ptr<void>(dev, [handle, curr_device](void* ptr) { | 
 |     cuda::CUDAGuard device_guard(curr_device); | 
 |     std::lock_guard<std::mutex> deleter_lock(IpcMutex); | 
 |     C10_CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); | 
 |     ipcMemHandle_to_devptr.erase(handle); | 
 |   }); | 
 |   std::weak_ptr<void> wp = sp; | 
 |   // To eliminate an additional search, we can use insert(). | 
 |   // It doesn't overwrite when key already exists(ptr expired). | 
 |   // But in the deleter for sp we erased the entry, | 
 |   // this should be safe to do now. | 
 |   ipcMemHandle_to_devptr.insert(iter, {handle, wp}); | 
 |  | 
 |   return sp; | 
 | } | 
 |  | 
 | void* raw_alloc(size_t nbytes) { | 
 |   if (nbytes == 0) { | 
 |     return nullptr; | 
 |   } | 
 |   int device; | 
 |   C10_CUDA_CHECK(cudaGetDevice(&device)); | 
 |   void* r = nullptr; | 
 |   caching_allocator.malloc( | 
 |       &r, device, nbytes, cuda::getCurrentCUDAStream(device)); | 
 |   return r; | 
 | } | 
 |  | 
 | void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) { | 
 |   if (nbytes == 0) { | 
 |     return nullptr; | 
 |   } | 
 |   int device; | 
 |   C10_CUDA_CHECK(cudaGetDevice(&device)); | 
 |   void* r = nullptr; | 
 |   caching_allocator.malloc(&r, device, nbytes, stream); | 
 |   return r; | 
 | } | 
 |  | 
 | void raw_delete(void* ptr) { | 
 |   caching_allocator.free(ptr); | 
 | } | 
 |  | 
 | } // namespace CUDACachingAllocator | 
 |  | 
 | } // namespace cuda | 
 | } // namespace c10 |