Fix the test order dependency issue for gradient_checkpointing.
The root cause is probably related to keras global graph and some potential memory leak. Adding a teardown method to force clean any keras model created in the test, with a force GC.
Also cleanup some assertion methods which make the test code more readable.
PiperOrigin-RevId: 320630906
Change-Id: Ib72a39c184613e431d6eab03b18f3b217cb8e506
diff --git a/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py b/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py
index 9d9e0a0..cc0daa4 100644
--- a/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py
+++ b/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py
@@ -16,6 +16,8 @@
from __future__ import division
from __future__ import print_function
+import gc
+
import tensorflow as tf
layers = tf.keras.layers
optimizers = tf.keras.optimizers
@@ -42,7 +44,7 @@
def _get_split_cnn_model(img_dim, n_channels, num_partitions,
blocks_per_partition):
- """Creates a test model that is split into `num_partitions` smaller models"""
+ """Creates a test model that is split into `num_partitions` smaller models."""
models = [tf.keras.Sequential() for _ in range(num_partitions)]
models[0].add(layers.Input(shape=(img_dim, img_dim, n_channels)))
for i in range(num_partitions):
@@ -70,7 +72,7 @@
def _limit_gpu_memory():
- """Helper function to limit GPU memory for testing """
+ """Helper function to limit GPU memory for testing."""
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
tf.config.experimental.set_virtual_device_configuration(
@@ -143,15 +145,21 @@
self.skipTest('No virtual GPUs found')
with self.assertRaises(Exception) as context:
_train_no_recompute(1)
- self.assertTrue(
- context.exception.__class__.__name__ == 'ResourceExhaustedError')
+ self.assertIsInstance(context.exception, tf.errors.ResourceExhaustedError)
def test_does_not_raise_oom_exception(self):
if not _limit_gpu_memory():
self.skipTest('No virtual GPUs found')
n_step = 2
losses = _train_with_recompute(n_step)
- self.assertTrue(len(losses) == n_step)
+ self.assertLen(losses, n_step)
+
+ def tearDown(self):
+ super(GradientCheckpointTest, self).tearDown()
+ # Make sure all the models created in keras has been deleted and cleared
+ # from the global keras grpah, also do a force GC to recycle the GPU memory.
+ tf.keras.backend.clear_session()
+ gc.collect()
if __name__ == '__main__':