blob: cc0daa4cf706bdf428b851e215990a4ac95d7b2a [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
import tensorflow as tf
layers = tf.keras.layers
optimizers = tf.keras.optimizers
def _get_big_cnn_model(img_dim, n_channels, num_partitions,
blocks_per_partition):
"""Creates a test model whose activations are significantly larger than model size."""
model = tf.keras.Sequential()
model.add(layers.Input(shape=(img_dim, img_dim, n_channels)))
for _ in range(num_partitions):
for _ in range(blocks_per_partition):
model.add(layers.Conv2D(10, 5, padding='same', activation=tf.nn.relu))
model.add(layers.MaxPooling2D((1, 1), padding='same'))
model.add(layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu))
model.add(layers.MaxPooling2D((1, 1), padding='same'))
model.add(layers.Conv2D(20, 5, padding='same', activation=tf.nn.relu))
model.add(layers.MaxPooling2D((1, 1), padding='same'))
model.add(layers.Flatten())
model.add(layers.Dense(32, activation=tf.nn.relu))
model.add(layers.Dense(10))
return model
def _get_split_cnn_model(img_dim, n_channels, num_partitions,
blocks_per_partition):
"""Creates a test model that is split into `num_partitions` smaller models."""
models = [tf.keras.Sequential() for _ in range(num_partitions)]
models[0].add(layers.Input(shape=(img_dim, img_dim, n_channels)))
for i in range(num_partitions):
model = models[i]
if i > 0:
last_shape = models[i - 1].layers[-1].output_shape
model.add(layers.Input(shape=last_shape[1:]))
for _ in range(blocks_per_partition):
model.add(layers.Conv2D(10, 5, padding='same', activation=tf.nn.relu))
model.add(layers.MaxPooling2D((1, 1), padding='same'))
model.add(layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu))
model.add(layers.MaxPooling2D((1, 1), padding='same'))
model.add(layers.Conv2D(20, 5, padding='same', activation=tf.nn.relu))
model.add(layers.MaxPooling2D((1, 1), padding='same'))
models[-1].add(layers.Flatten())
models[-1].add(layers.Dense(32, activation=tf.nn.relu))
models[-1].add(layers.Dense(10))
return models
def _compute_loss(logits, labels):
return tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels))
def _limit_gpu_memory():
"""Helper function to limit GPU memory for testing."""
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
tf.config.experimental.set_virtual_device_configuration(
gpus[0],
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)])
return True
return False
def _get_dummy_data(img_dim, n_channels, batch_size):
inputs = tf.ones([batch_size, img_dim, img_dim, n_channels])
labels = tf.ones([batch_size], dtype=tf.int64)
return inputs, labels
def _train_no_recompute(n_steps):
"""Trains a single large model without gradient checkpointing."""
img_dim, n_channels, batch_size = 256, 1, 4
x, y = _get_dummy_data(img_dim, n_channels, batch_size)
model = _get_big_cnn_model(
img_dim, n_channels, num_partitions=3, blocks_per_partition=2)
optimizer = optimizers.SGD()
losses = []
tr_vars = model.trainable_variables
for _ in range(n_steps):
with tf.GradientTape() as tape:
logits = model(x)
loss = _compute_loss(logits, y)
losses.append(loss)
grads = tape.gradient(loss, tr_vars) # tr_vars
optimizer.apply_gradients(zip(grads, tr_vars))
del grads
return losses
def _train_with_recompute(n_steps):
"""Trains a single large model with gradient checkpointing using tf.recompute_grad."""
img_dim, n_channels, batch_size = 256, 1, 4
x, y = _get_dummy_data(img_dim, n_channels, batch_size)
# This model is the same model as _get_big_cnn_model but split into 3 parts.
models = _get_split_cnn_model(
img_dim, n_channels, num_partitions=3, blocks_per_partition=2)
model1, model2, model3 = models
# Apply gradient checkpointing to the submodels using tf.recompute_grad.
model1_re = tf.recompute_grad(model1)
model2_re = tf.recompute_grad(model2)
model3_re = tf.recompute_grad(model3)
optimizer = optimizers.SGD()
tr_vars = (
model1.trainable_variables + model2.trainable_variables +
model3.trainable_variables)
losses = []
for _ in range(n_steps):
with tf.GradientTape() as tape:
logits1 = model1_re(x)
logits2 = model2_re(logits1)
logits3 = model3_re(logits2)
loss = _compute_loss(logits3, y)
losses.append(loss)
grads = tape.gradient(loss, tr_vars) # tr_vars
optimizer.apply_gradients(zip(grads, tr_vars))
del grads
return losses
class GradientCheckpointTest(tf.test.TestCase):
def test_raises_oom_exception(self):
if not _limit_gpu_memory():
self.skipTest('No virtual GPUs found')
with self.assertRaises(Exception) as context:
_train_no_recompute(1)
self.assertIsInstance(context.exception, tf.errors.ResourceExhaustedError)
def test_does_not_raise_oom_exception(self):
if not _limit_gpu_memory():
self.skipTest('No virtual GPUs found')
n_step = 2
losses = _train_with_recompute(n_step)
self.assertLen(losses, n_step)
def tearDown(self):
super(GradientCheckpointTest, self).tearDown()
# Make sure all the models created in keras has been deleted and cleared
# from the global keras grpah, also do a force GC to recycle the GPU memory.
tf.keras.backend.clear_session()
gc.collect()
if __name__ == '__main__':
tf.test.main()