Introduce an allocation test for TPU.

PiperOrigin-RevId: 396026895
Change-Id: I9268b87f449edbe97ebe774a7ed66c5003cdf3e5
diff --git a/tensorflow/python/distribute/integration_test/tpu_memory_test.py b/tensorflow/python/distribute/integration_test/tpu_memory_test.py
index 43dd56f..83933f8 100644
--- a/tensorflow/python/distribute/integration_test/tpu_memory_test.py
+++ b/tensorflow/python/distribute/integration_test/tpu_memory_test.py
@@ -185,6 +185,32 @@
       with self.assertRaises(tf.errors.ResourceExhaustedError):
         _ = train_step(iterator)
 
+  def testAutoDefragInBufferAllocation(self):
+    if not FLAGS.tpu_use_tfrt:
+      self.skipTest(
+          "TPU StreamExecutor does not support auto-defrag in allocation.")
+    with tf.device("TPU:0"):
+      # DF has ~15G HBM. Following 7 buffers will consume most HBM.
+      # pylint: disable=unused-variable
+      buffer_2g_1 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
+      buffer_2g_2 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
+      buffer_2g_3 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
+      buffer_2g_4 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
+      buffer_2g_5 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
+      buffer_2g_6 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
+      buffer_2g_7 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
+      #  pylint: enable=unused-variable
+
+      # Deallocate two buffers.
+      del buffer_2g_1, buffer_2g_3
+      gc.collect()
+
+      # The buffer we just deallocated doesn't provide enough contiguous region
+      # for allocating 4G. This allocation will trigger auto-defrag.
+      buffer_4g = tf.random.uniform((4, 256, 1024, 1024), dtype=tf.float32)
+
+    self.assertEndsWith(buffer_4g.device, "device:TPU:0")
+
 
 if __name__ == "__main__":
   tf.test.main()