expandable_segments <-> other allocator options (#134338)

Previously setting  garbage_collection_threshold or max_split_size_mb along with expandable_segments:True could cause the allocator to hit assert failures when running nearly out of memory. This PR ensures garbage_collection and max_split freeing do not accidentally try to release expandable segments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134338
Approved by: https://github.com/ezyang
diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp
index bfab450..fa26a9f 100644
--- a/c10/cuda/CUDACachingAllocator.cpp
+++ b/c10/cuda/CUDACachingAllocator.cpp
@@ -2611,7 +2611,7 @@
       while (it != large_blocks.blocks.end()) {
         Block* block = *it;
         ++it;
-        if (!block->is_split() &&
+        if (!block->is_split() && !block->expandable_segment_ &&
             static_cast<double>(block->gc_count()) >= age_threshold) {
           block_freed = true;
           gc_reclaimed += block->size;
@@ -2754,7 +2754,8 @@
         ? CUDAAllocatorConfig::max_split_size()
         : key.size;
     auto it = pool.blocks.lower_bound(&key);
-    if (it == pool.blocks.end() || (*it)->stream != p.stream()) {
+    if (it == pool.blocks.end() || (*it)->stream != p.stream() ||
+        (*it)->expandable_segment_) {
       // No single block is large enough; free multiple oversize blocks,
       // starting with the largest
       if (it == pool.blocks.begin())
@@ -2766,12 +2767,15 @@
              ((*it)->size >= CUDAAllocatorConfig::max_split_size()) &&
              ((*it)->stream == p.stream())) {
         auto cur = it;
-        totalReleased += (*it)->size;
-        if (it != pool.blocks.begin()) {
+        bool is_first = cur == pool.blocks.begin();
+        if (!is_first) {
           --it;
+        }
+        if (!(*cur)->expandable_segment_) {
           release_block(*cur, context);
-        } else {
-          release_block(*cur, context);
+          totalReleased += (*cur)->size;
+        }
+        if (is_first) {
           break;
         }
       }
diff --git a/test/test_cuda.py b/test/test_cuda.py
index b2f1ef3..f53ae5f 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -4098,6 +4098,62 @@
         finally:
             torch.cuda.memory._record_memory_history(None)
 
+    def test_max_split_expandable(self):
+        torch.cuda.memory.empty_cache()
+        mb = 1024 * 1024
+        _, all_memory = torch.cuda.memory.mem_get_info()
+        total_allowed = 120 * mb
+        fraction_allowed = total_allowed / all_memory
+        assert int(fraction_allowed * all_memory) == total_allowed
+        torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
+
+        def alloc(n):
+            return torch.ones(n * mb, dtype=torch.int8, device="cuda")
+
+        torch.cuda.memory._set_allocator_settings(
+            "expandable_segments:False,max_split_size_mb:40"
+        )
+        a = alloc(40)
+        torch.cuda.memory._set_allocator_settings(
+            "expandable_segments:True,max_split_size_mb:40"
+        )
+        b = alloc(40)
+        torch.cuda.memory._set_allocator_settings(
+            "expandable_segments:False,max_split_size_mb:40"
+        )
+        c = alloc(40)
+        with self.assertRaises(torch.OutOfMemoryError):
+            alloc(40)
+        del a, b, c
+        # force release_cached_blocks to run with some expandable segments in the free list
+        alloc(120)
+
+    def test_garbage_collect_expandable(self):
+        torch.cuda.memory.empty_cache()
+        mb = 1024 * 1024
+        _, all_memory = torch.cuda.memory.mem_get_info()
+        total_allowed = 120 * mb
+        fraction_allowed = total_allowed / all_memory
+        assert int(fraction_allowed * all_memory) == total_allowed
+        torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
+
+        def alloc(n):
+            return torch.ones(n * mb, dtype=torch.int8, device="cuda")
+
+        torch.cuda.memory._set_allocator_settings(
+            "expandable_segments:False,garbage_collection_threshold:0.5"
+        )
+        a = alloc(40)
+        torch.cuda.memory._set_allocator_settings(
+            "expandable_segments:True,garbage_collection_threshold:0.5"
+        )
+        b = alloc(40)
+        del a, b
+        # causes GC to run. The expandable segment block will be split
+        # so GC would not attempt to free it anyway, but this at least makes sure
+        # expandable_segment blocks can be in the free list when this is called.
+        alloc(80)
+
     def test_allocator_settings(self):
         def power2_div(size, div_factor):
             pow2 = 1