Fix the test order dependency issue for gradient_checkpointing.

The root cause is probably related to keras global graph and some potential memory leak. Adding a teardown method to force clean any keras model created in the test, with a force GC.

Also cleanup some assertion methods which make the test code more readable.

PiperOrigin-RevId: 320630906
Change-Id: Ib72a39c184613e431d6eab03b18f3b217cb8e506
diff --git a/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py b/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py
index 9d9e0a0..cc0daa4 100644
--- a/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py
+++ b/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py
@@ -16,6 +16,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import gc
+
 import tensorflow as tf
 layers = tf.keras.layers
 optimizers = tf.keras.optimizers
@@ -42,7 +44,7 @@
 
 def _get_split_cnn_model(img_dim, n_channels, num_partitions,
                          blocks_per_partition):
-  """Creates a test model that is split into `num_partitions` smaller models"""
+  """Creates a test model that is split into `num_partitions` smaller models."""
   models = [tf.keras.Sequential() for _ in range(num_partitions)]
   models[0].add(layers.Input(shape=(img_dim, img_dim, n_channels)))
   for i in range(num_partitions):
@@ -70,7 +72,7 @@
 
 
 def _limit_gpu_memory():
-  """Helper function to limit GPU memory for testing  """
+  """Helper function to limit GPU memory for testing."""
   gpus = tf.config.experimental.list_physical_devices('GPU')
   if gpus:
     tf.config.experimental.set_virtual_device_configuration(
@@ -143,15 +145,21 @@
       self.skipTest('No virtual GPUs found')
     with self.assertRaises(Exception) as context:
       _train_no_recompute(1)
-    self.assertTrue(
-        context.exception.__class__.__name__ == 'ResourceExhaustedError')
+    self.assertIsInstance(context.exception, tf.errors.ResourceExhaustedError)
 
   def test_does_not_raise_oom_exception(self):
     if not _limit_gpu_memory():
       self.skipTest('No virtual GPUs found')
     n_step = 2
     losses = _train_with_recompute(n_step)
-    self.assertTrue(len(losses) == n_step)
+    self.assertLen(losses, n_step)
+
+  def tearDown(self):
+    super(GradientCheckpointTest, self).tearDown()
+    # Make sure all the models created in keras has been deleted and cleared
+    # from the global keras grpah, also do a force GC to recycle the GPU memory.
+    tf.keras.backend.clear_session()
+    gc.collect()
 
 
 if __name__ == '__main__':