blob: 6fb4de7610c927708a3ed6a9b07eee6456a41b30 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains functions to use mixed precision with the graph rewrite."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import config
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import optimizer
from tensorflow.python.training.experimental import loss_scale_optimizer as loss_scale_optimizer_v1
from tensorflow.python.training.experimental import mixed_precision_global_state
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
def _wrap_optimizer(opt, loss_scale, use_v1_behavior):
"""Wraps an optimizer with a LossScaleOptimizer."""
if isinstance(opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer):
raise ValueError('"opt" must not already be an instance of a '
'MixedPrecisionLossScaleOptimizer. '
'`enable_mixed_precision_graph_rewrite` will '
'automatically wrap the optimizer with a '
'MixedPrecisionLossScaleOptimizer.')
# To avoid a circular dependency, we cannot depend on tf.keras. Because
# LossScaleOptimizer is in Keras, we cannot use isinstance, so instead check
# the class name.
if opt.__class__.__name__ == 'LossScaleOptimizer':
raise ValueError('"opt" must not already be an instance of a '
'LossScaleOptimizer. '
'`enable_mixed_precision_graph_rewrite` will '
'automatically wrap the optimizer with a '
'LossScaleOptimizer.')
if isinstance(opt, optimizer.Optimizer):
# For convenience, we allow the V2 version of this function to wrap the V1
# optimizer, even though we do not document this.
return loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer(opt,
loss_scale)
# Because we cannot depend on tf.keras, we see if `opt` is an instance of the
# Keras OptimizerV2 class by checking the subclass names.
base_classes = tf_inspect.getmro(opt.__class__)
base_class_names = [cls.__name__ for cls in base_classes]
is_loss_scale_optimizer_v2 = 'OptimizerV2' in base_class_names
if is_loss_scale_optimizer_v2:
# Because we cannot depend on tf.keras, we cannot unconditionally do this
# import. But since `opt` is a Keras OptimizerV2, we know keras is
# importable, so it is safe to do this import. (Technically, it's possible
# to have a dependency on OptimizerV2 and not LossScaleOptimizer, but this
# is not done in practice).
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2 # pylint: disable=g-import-not-at-top
return loss_scale_optimizer_v2.LossScaleOptimizer(opt, loss_scale)
if use_v1_behavior:
raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '
'tf.keras.optimizers.Optimizer, but got: %s' % opt)
else:
raise ValueError('"opt" must be an instance of a '
'tf.keras.optimizers.Optimizer, but got: %s' % opt)
@tf_export('train.experimental.enable_mixed_precision_graph_rewrite', v1=[])
def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
"""Enable mixed precision in `tf.function`s via a graph rewrite.
Mixed precision is the use of both float16 and float32 when training a model,
and is used to make the model run faster. This function will use mixed
precision to speed up the execution time of `tf.function`s when run on a GPU.
It does this by changing the dtype of certain operations in the function's
graph from float32 to float16.
This function additionally wraps an Optimizer with a LossScaleOptimizer, which
is required to prevent underflow in the float16 tensors during the backwards
pass. An optimizer must be passed to this function, which will then be wrapped
to use loss scaling.
When this function is used, gradients should only be computed and applied with
the returned optimizer through `opt.minimize()`, and not with a
`tf.GradientTape`. This is because the returned optimizer will apply loss
scaling, and `tf.GradientTape` will not. If you do use a `tf.GradientTape`,
your model may train to a worse quality.
Currently, mixed precision is only enabled on Volta GPUs and above. TPU
support is coming soon. CPUs are not supported, as CPUs do not run float16
operations faster than float32 operations.
WARNING: This rewrite silently affects the entire model and can have
unintended consequences. One example: If a NaN occurs during dynamic loss
scaling, the data for the batch is silently dropped while the
LossScaleOptimizer attempts to find the appropriate scaling value on the next
batch.
Args:
opt: An instance of a `tf.keras.optimizers.Optimizer`.
loss_scale: Either an int/float, the string "dynamic", or an instance of a
`tf.train.experimental.LossScale`. The loss scale to use. It is
recommended to keep this as its default value of "dynamic".
Returns:
A version of `opt` that will use loss scaling to prevent underflow.
"""
return _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
use_v1_behavior=False)
@tf_export(v1=['train.experimental.enable_mixed_precision_graph_rewrite'])
def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
"""Enable mixed precision via a graph rewrite.
Mixed precision is the use of both float16 and float32 when training a model,
and is used to make the model run faster. This function will use mixed
precision to speed up the execution time of your model when run on a GPU. It
does this by changing the dtype of certain operations in the graph from
float32 to float16.
This function additionally wraps an Optimizer with a LossScaleOptimizer, which
is required to prevent underflow in the float16 tensors during the backwards
pass. An optimizer must be passed to this function, which will then be wrapped
to use loss scaling.
When this function is used, gradients should only be computed and applied with
the returned optimizer, either by calling `opt.minimize()` or
`opt.compute_gradients()` followed by `opt.apply_gradients()`. Gradients
should not be computed with `tf.gradients` or `tf.GradientTape`. This is
because the returned optimizer will apply loss scaling, and
`tf.gradients`/`tf.GradientTape` will not. If you do directly use
`tf.gradients` or `tf.GradientTape`, your model may train to a worse quality.
Currently, mixed precision is only enabled on Volta GPUs and above. TPU
support is coming soon. CPUs are not supported, as CPUs do not run float16
operations faster than float32 operations.
Args:
opt: An instance of a `tf.keras.optimizers.Optimizer` or a
`tf.train.Optimizer`.
loss_scale: Either an int/float, the string "dynamic", or an instance of a
`tf.train.experimental.LossScale`. The loss scale to use. It is
recommended to keep this as its default value of "dynamic".
Returns:
A version of `opt` that will use loss scaling to prevent underflow.
"""
# TODO(reedwm): If a ConfigProto is passed to Session, either assert that
# auto_mixed_precision is on or turn it on for the user.
return _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
use_v1_behavior=True)
def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
use_v1_behavior):
"""Enables mixed precision. See `enable_mixed_precision_graph_rewrite`."""
if not mixed_precision_global_state.using_default_mixed_precision_policy:
raise ValueError(
'The mixed precision graph rewrite cannot be enabled, because a keras '
'mixed precision Policy has been set. At most, one of the following '
'functions can be called:\n\n'
' 1. tf.keras.mixed_precision.experimental.set_policy() (You called '
'this first)\n'
' 2. tf.train.experimental.enable_mixed_precision_graph_rewrite() '
'(You called this second)\n\n'
'You called both functions, which is an error, because both functions '
'enable you to use mixed precision. If in doubt which function to use, '
'use the second, as it is currently more complete and easy to use. The '
'second function enables mixed precision in the graph with a graph '
'rewrite. However it is currently not very customizable, and does not '
'support eager.')
if mixed_precision_global_state.non_mixed_precision_session_created:
# TODO(reedwm): Give the stacktrace of the existing Sessions. And if the
# Sessions have already been closed, do not raise this error message.
tf_logging.warn('You already have existing Sessions that do not use mixed '
'precision. enable_mixed_precision_graph_rewrite() will '
'not affect these Sessions.')
opt = _wrap_optimizer(opt, loss_scale, use_v1_behavior=use_v1_behavior)
config.set_optimizer_experimental_options({'auto_mixed_precision': True})
mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = True
return opt
@tf_export('train.experimental.disable_mixed_precision_graph_rewrite', v1=[])
def disable_mixed_precision_graph_rewrite():
"""Disables the mixed precision graph rewrite.
After this is called, the mixed precision graph rewrite will no longer run for
tf.functions, and so float32 operations will no longer be converted to
float16.
This does not undo the effects of loss scaling. Any optimizers wrapped with a
LossScaleOptimizer will continue to do loss scaling, although this loss
scaling will no longer be useful, as the graph rewrite no longer converts
tf.functions to use float16.
This function is useful for unit testing. A unit test can test using the mixed
precision graph rewrite, then disable it so future unit tests continue using
float32.
"""
if not mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled:
tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed '
'precision is already disabled.')
config.set_optimizer_experimental_options({'auto_mixed_precision': False})
mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = False
@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite'])
def disable_mixed_precision_graph_rewrite_v1():
"""Disables the mixed precision graph rewrite.
After this is called, the mixed precision graph rewrite will no longer run for
new Sessions, and so float32 operations will no longer be converted to float16
in such Sessions. However, any existing Sessions will continue to have the
graph rewrite enabled if they were created after
`enable_mixed_precision_graph_rewrite` was called but before
`disable_mixed_precision_graph_rewrite` was called.
This does not undo the effects of loss scaling. Any optimizers wrapped with a
LossScaleOptimizer will continue to do loss scaling, although this loss
scaling will no longer be useful if the optimizer is used in new Sessions, as
the graph rewrite no longer converts the graph to use float16.
This function is useful for unit testing. A unit tests can test using the
mixed precision graph rewrite, then disable it so future unit tests continue
using float32. If this is done, unit tests should not share a single session,
as `enable_mixed_precision_graph_rewrite` and
`disable_mixed_precision_graph_rewrite` have no effect on existing sessions.
"""
# We only have a separate V1 version of this function, because the V1
# docstring mentions sessions.
disable_mixed_precision_graph_rewrite()