blob: d6a786aa4e4561b3739f8e29e75b34cc07c90ce0 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the loss scaling optimizer class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import one_device_strategy
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.keras import backend
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.mixed_precision.experimental import loss_scale as keras_loss_scale_module
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import mixed_precision
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util.tf_export import keras_export
class _UnwrapPreventer(object):
"""Wrapper that DistributionStrategy will not unwrap.
Typically, DistributionStrategy will unwrap values when going from a cross-
replica context to a replica context via `call_for_each_replica`. This class
is a wrapper that DistributionStrategy will not unwrap, so it can be used to
prevent it from unwrapping a value.
TODO(reedwm): Find/implement a better way of preventing values from being
unwrapped by DistributionStrategy
"""
def __init__(self, value):
self.value = value
class _DelegatingTrackableMixin(object):
"""A mixin that delegates all Trackable methods to another trackable object.
This class must be used with multiple inheritance. A class that subclasses
Trackable can also subclass this class, which causes all Trackable methods to
be delegated to the trackable object passed in the constructor.
A subclass can use this mixin to appear as if it were the trackable passed to
the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this
mixin, so that the checkpoint format for a LossScaleOptimizer is identical to
the checkpoint format for a normal optimizer. This allows a model to be saved
with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa.
The only difference in checkpoint format is that the loss scale is also saved
with a LossScaleOptimizer.
"""
def __init__(self, trackable_obj):
self._trackable = trackable_obj
# pylint: disable=protected-access
@property
def _setattr_tracking(self):
return self._trackable._setattr_tracking
@_setattr_tracking.setter
def _setattr_tracking(self, value):
self._trackable._setattr_tracking = value
@property
def _update_uid(self):
return self._trackable._update_uid
@_update_uid.setter
def _update_uid(self, value):
self._trackable._update_uid = value
@property
def _unconditional_checkpoint_dependencies(self):
return self._trackable._unconditional_checkpoint_dependencies
@property
def _unconditional_dependency_names(self):
return self._trackable._unconditional_dependency_names
@property
def _name_based_restores(self):
return self._trackable._name_based_restores
def _maybe_initialize_trackable(self):
return self._trackable._maybe_initialize_trackable()
@property
def _object_identifier(self):
return self._trackable._object_identifier
@property
def _tracking_metadata(self):
return self._trackable._tracking_metadata
def _no_dependency(self, value):
return self._trackable._no_dependency(value)
def _name_based_attribute_restore(self, checkpoint):
return self._trackable._name_based_attribute_restore(checkpoint)
@property
def _checkpoint_dependencies(self):
return self._trackable._checkpoint_dependencies
@property
def _deferred_dependencies(self):
return self._trackable._deferred_dependencies
def _lookup_dependency(self, name):
self._trackable._lookup_dependency(name)
def _add_variable_with_custom_getter(self,
name,
shape=None,
dtype=dtypes.float32,
initializer=None,
getter=None,
overwrite=False,
**kwargs_for_getter):
return self._trackable._add_variable_with_custom_getter(
name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
def _preload_simple_restoration(self, name, shape):
return self._trackable._preload_simple_restoration(name, shape)
def _track_trackable(self, trackable, name, overwrite=False): # pylint: disable=redefined-outer-name
return self._trackable._track_trackable(trackable, name, overwrite)
def _handle_deferred_dependencies(self, name, trackable): # pylint: disable=redefined-outer-name
return self._trackable._handle_deferred_dependencies(name, trackable)
def _restore_from_checkpoint_position(self, checkpoint_position):
return self._trackable._restore_from_checkpoint_position(
checkpoint_position)
def _single_restoration_from_checkpoint_position(self, checkpoint_position,
visit_queue):
return self._trackable._single_restoration_from_checkpoint_position(
checkpoint_position, visit_queue)
def _gather_saveables_for_checkpoint(self):
return self._trackable._gather_saveables_for_checkpoint()
def _list_extra_dependencies_for_serialization(self, serialization_cache):
return self._trackable._list_extra_dependencies_for_serialization(
serialization_cache)
def _list_functions_for_serialization(self, serialization_cache):
return self._trackable._list_functions_for_serialization(
serialization_cache)
# pylint: enable=protected-access
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
"""An optimizer that applies loss scaling.
Loss scaling is a process that multiplies the loss by a multiplier called the
loss scale, and divides each gradient by the same multiplier. The pseudocode
for this process is:
```
loss = ...
loss *= loss_scale
grads = gradients(loss, vars)
grads /= loss_scale
```
Mathematically, loss scaling has no effect, but can help avoid numerical
underflow in intermediate gradients when float16 tensors are used. By
multiplying the loss, each intermediate gradient will have the same multiplier
applied.
The loss scale can either be a fixed constant, chosen by the user, or be
dynamically determined. Dynamically determining the loss scale is convenient
as a loss scale does not have to be explicitly chosen. However it reduces
performance.
This optimizer wraps another optimizer and applies loss scaling to it via a
`LossScale`. Loss scaling is applied whenever gradients are
computed, either through `minimize()` or `get_gradients()`. The loss scale is
updated via `LossScale.update()` whenever gradients are applied, either
through `minimize()` or `apply_gradients()`. For example:
>>> opt = tf.keras.optimizers.SGD(0.25)
>>> opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt,
... "dynamic")
>>> var = tf.Variable(1.)
>>> loss_fn = lambda: var ** 2
>>> # 'minimize' applies loss scaling to the loss and updates the loss sale.
>>> opt.minimize(loss_fn, var_list=var)
>>> var.numpy()
0.5
If a `tf.GradientTape` is used to compute gradients instead of
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, the loss
and gradients must be scaled manually. This can be done by calling
`LossScaleOptimizer.get_scaled_loss` before passing the loss to
`tf.GradientTape`, and `LossScaleOptimizer.get_unscaled_gradients` after
computing the gradients with `tf.GradientTape`. For example:
>>> with tf.GradientTape() as tape:
... loss = loss_fn()
... scaled_loss = opt.get_scaled_loss(loss)
>>> scaled_grad = tape.gradient(scaled_loss, var)
>>> (grad,) = opt.get_unscaled_gradients([scaled_grad])
>>> opt.apply_gradients([(grad, var)]) # Loss scale is updated here
>>> var.numpy()
0.25
"""
_HAS_AGGREGATE_GRAD = True
def __init__(self, optimizer, loss_scale):
"""Initializes this loss scale optimizer.
Args:
optimizer: The Optimizer instance to wrap.
loss_scale: The loss scale to scale the loss and gradients. This can
either be an int/float to use a fixed loss scale, the string "dynamic"
to use dynamic loss scaling, or an instance of a LossScale. The string
"dynamic" equivalent to passing `DynamicLossScale()`, and passing an
int/float is equivalent to passing a FixedLossScale with the given loss
scale.
"""
if not isinstance(optimizer, optimizer_v2.OptimizerV2):
raise ValueError('"optimizer" must be an instance of OptimizerV2, but '
'got: %s' % optimizer)
if optimizer.clipnorm is not None:
raise ValueError('LossScaleOptimizer does not support wrapping '
'optimizers with a clipnorm. Optimizer %s has clipnorm '
'%s' % (optimizer, optimizer.clipnorm))
if optimizer.clipvalue is not None:
raise ValueError('LossScaleOptimizer does not support wrapping '
'optimizers with a clipvalue. Optimizer %s has '
'clipvalue %s' % (optimizer, optimizer.clipvalue))
self._raise_if_strategy_unsupported()
self.clipnorm = None
self.clipvalue = None
self._optimizer = optimizer
self._loss_scale = keras_loss_scale_module.get(loss_scale)
if self._loss_scale is None:
raise ValueError('loss_scale cannot be None.')
# We don't call super().__init__, since we do not want to call OptimizerV2's
# constructor.
_DelegatingTrackableMixin.__init__(self, self._optimizer)
for weight in loss_scale_module.get_loss_scale_weights(self._loss_scale):
# We cannot call `track_variable` in the LossScale class itself, because a
# file outside of Keras cannot depend on a Keras file. Calling it here
# instead is OK, because a variable only needs to be tracked if used with
# a Keras class, and the only way to use LossScale with a Keras class is
# through the LossScaleOptimizer.
backend.track_variable(weight)
self._track_trackable(self._loss_scale, 'loss_scale')
# Needed because the superclass's __getattribute__ checks this.
self._hyper = {}
# To support restoring TensorFlow 2.2 checkpoints.
self._track_trackable(FakeOptimizerForRestoration(self._optimizer),
'base_optimizer')
@property
def loss_scale(self):
"""The `LossScale` instance associated with this optimizer."""
return self._loss_scale
def get_scaled_loss(self, loss):
"""Scales the loss by the loss scale.
This method is only needed if you compute gradients manually, e.g. with
`tf.GradientTape`. In that case, call this method to scale the loss before
passing the loss to `tf.GradientTape`. If you use
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
scaling is automatically applied and this method is unneeded.
If this method is called, `get_unscaled_gradients` should also be called.
See the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for
an example.
Args:
loss: The loss, which will be multiplied by the loss scale. Can either be
a tensor or a callable returning a tensor.
Returns:
`loss` multiplied by `LossScaleOptimizer.loss_scale()`.
"""
loss_scale = self._loss_scale()
if callable(loss):
def new_loss():
loss_val = loss()
return loss_val * math_ops.cast(loss_scale, loss_val.dtype)
return new_loss
else:
return loss * math_ops.cast(loss_scale, loss.dtype)
def get_unscaled_gradients(self, grads):
"""Unscales the gradients by the loss scale.
This method is only needed if you compute gradients manually, e.g. with
`tf.GradientTape`. In that case, call this method to unscale the gradients
after computing them with `tf.GradientTape`. If you use
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
scaling is automatically applied and this method is unneeded.
If this method is called, `get_scaled_loss` should also be called. See
the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for an
example.
Args:
grads: A list of tensors, each which will be divided by the loss scale.
Can have None values, which are ignored.
Returns:
A new list the same size as `grads`, where every non-None value in `grads`
is divided by `LossScaleOptimizer.loss_scale()`.
"""
loss_scale = self._loss_scale()
loss_scale_reciprocal = 1. / loss_scale
return [
_multiply_gradient(g, loss_scale_reciprocal) if g is not None else None
for g in grads
]
def _compute_gradients(self, loss, var_list, grad_loss=None):
loss = self.get_scaled_loss(loss)
grads_and_vars = self._optimizer._compute_gradients(loss, var_list, # pylint: disable=protected-access
grad_loss)
grads = [g for g, _ in grads_and_vars]
variables = [v for _, v in grads_and_vars]
unscaled_grads = self.get_unscaled_gradients(grads)
return list(zip(unscaled_grads, variables))
def get_gradients(self, loss, params):
loss = self.get_scaled_loss(loss)
grads = self._optimizer.get_gradients(loss, params)
return self.get_unscaled_gradients(grads)
def _create_all_weights(self, var_list):
self._optimizer._create_all_weights(var_list) # pylint: disable=protected-access
def apply_gradients(self,
grads_and_vars,
name=None,
experimental_aggregate_gradients=True):
if distribution_strategy_context.in_cross_replica_context():
raise ValueError('apply_gradients() must be called in a replica context.')
# We check for the strategy here despite already checking in the constructor
# as frequently the optimizer is created outside the strategy's scope.
self._raise_if_strategy_unsupported()
grads_and_vars = tuple(grads_and_vars)
return distribution_strategy_context.get_replica_context().merge_call(
self._apply_gradients_cross_replica,
args=(grads_and_vars, name, experimental_aggregate_gradients))
def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name,
experimental_aggregate_gradients):
grads = [g for g, _ in grads_and_vars]
loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads)
def apply_fn():
# We do not want DistributionStrategy to unwrap any MirroredVariables in
# grads_and_vars, because even in a replica context, the wrapped optimizer
# expects mirrored variables. So we wrap the variables with an
# _UnwrapPreventer, preventing DistributionStrategy from unwrapping the
# MirroredVariables.
wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars])
return distribution.extended.call_for_each_replica(
self._apply_gradients,
args=(grads, wrapped_vars, name, experimental_aggregate_gradients))
# Note: We must call this cond() in a cross-replica context.
# DistributionStrategy does not support having a cond in a replica context
# with a branch that calls `merge_call`, and self._optimizer.apply_gradients
# calls `merge_call`.
maybe_apply_op = smart_cond.smart_cond(should_apply_grads,
apply_fn,
control_flow_ops.no_op)
return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
def _apply_gradients(self, grads, wrapped_vars, name,
experimental_aggregate_gradients):
# TODO(reedwm): This will raise a fairly cryptic error message if
# self._optimizer.apply_gradients does not take
# experimental_aggregate_gradients.
return self._optimizer.apply_gradients(
list(zip(grads, wrapped_vars.value)), name,
experimental_aggregate_gradients=experimental_aggregate_gradients)
def get_config(self):
serialized_optimizer = optimizers.serialize(self._optimizer)
serialized_loss_scale = keras_loss_scale_module.serialize(self._loss_scale)
return {
'optimizer': serialized_optimizer,
'loss_scale': serialized_loss_scale,
}
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy() # Make a copy, since we mutate config
config['optimizer'] = optimizers.deserialize(
config['optimizer'], custom_objects=custom_objects)
config['loss_scale'] = keras_loss_scale_module.deserialize(
config['loss_scale'], custom_objects=custom_objects)
return cls(**config)
def _raise_if_strategy_unsupported(self):
if not strategy_supports_loss_scaling():
strategy = distribution_strategy_context.get_strategy()
if isinstance(strategy,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
raise ValueError(
'Loss scaling is not supported with TPUStrategy. Loss scaling is '
'unnecessary with TPUs, since they support bfloat16 instead of '
'float16 and bfloat16 does not require loss scaling. You should '
'remove the use of the LossScaleOptimizer when TPUs are used.')
else:
raise ValueError('Loss scaling is not supported with the '
'tf.distribute.Strategy: %s. Try using a different '
'Strategy, e.g. a MirroredStrategy' %
strategy.__class__.__name__)
# Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer
# below.
@property
def iterations(self):
return self._optimizer.iterations
@iterations.setter
def iterations(self, variable):
self._optimizer.iterations = variable
def get_slot_names(self):
return self._optimizer.get_slot_names()
def variables(self):
return self._optimizer.variables()
@property
def weights(self):
return self._optimizer.weights
def get_weights(self):
return self._optimizer.get_weights()
def set_weights(self, weights):
return self._optimizer.set_weights(weights)
def _aggregate_gradients(self, grads_and_vars):
return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access
def _restore_slot_variable(self, slot_name, variable, slot_variable):
return self._optimizer._restore_slot_variable(slot_name, variable, # pylint: disable=protected-access
slot_variable)
def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
variable):
return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
slot_variable_position, slot_name, variable)
def get_slot(self, var, slot_name):
return self._optimizer.get_slot(var, slot_name)
def add_slot(self, var, slot_name, initializer='zeros'):
return self._optimizer.add_slot(var, slot_name, initializer)
# For the most part, we only expose methods in the base OptimizerV2, not
# individual subclasses like Adam. However, although "learning_rate" and "lr"
# properties are not part of the base OptimizerV2 class, they are part of most
# subclasses, so we expose them here for convenience.
@property
def learning_rate(self):
return self._optimizer.learning_rate
@learning_rate.setter
def learning_rate(self, lr):
self._optimizer.learning_rate = lr
@property
def lr(self):
return self._optimizer.lr
@lr.setter
def lr(self, lr):
self._optimizer.lr = lr
# We do not override some OptimizerV2 methods. For each, we describe why we do
# not delegate them to self._optimizer:
# * get_updates: get_updates() calls get_gradients(). Since we override
# get_gradients(), we cannot delegate get_updates() to self._optimizer,
# otherwise the overridden get_gradients() method would not be called.
# Luckily, get_updates() does not access any OptimizerV2 fields, so
# inheriting the OptimizerV2 version works fine.
# * minimize: We don't delegate for a similar as get_updates(): it calls
# both self._compute_gradients() and self.apply_gradients(), and both need
# to have the LossScaleOptimizer version called.
# TODO(reedwm): Maybe merge this class's functionality into OptimizerV2.
# TODO(reedwm): Maybe throw an error if mixed precision is used without this
# optimizer being used.
# Trackable delegations: Delegate all Trackable methods to the wrapped
# optimizer. This is so the checkpoint format for a LossScaleOptimizer is
# identical to the checkpoint format for a normal optimizer, except the loss
# scale is stored in the checkpoint.
class FakeOptimizerForRestoration(trackable.Trackable):
"""A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.
The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class
exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow.
In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the
following in LossScaleOptimizer.__init__
```
self._track_trackable(self._optimizer, 'base_optimizer')
```
This means a dependency from the LossScaleOptimizer to the wrapped optimizer
would be stored in the checkpoint. However now, the checkpoint format with a
LossScaleOptimizer is the same as the format without a LossScaleOptimizer,
except the loss scale is also stored. This means there is no dependency from
the LossScaleOptimizer to the wrapped optimizer. Instead, the
LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's
perspective, by overriding all Trackable methods and delegating them to the
wrapped optimizer.
To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency
on this class instead of the inner optimizer. When restored, this class will
instead restore the slot variables of the inner optimizer. Since this class
has no variables, it does not affect the checkpoint when saved.
"""
def __init__(self, optimizer):
self._optimizer = optimizer
def get_slot_names(self):
return self._optimizer.get_slot_names()
def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
variable):
return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
slot_variable_position, slot_name, variable)
# pylint: disable=protected-access
mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
LossScaleOptimizer)
def _multiply_gradient(gradient, scale):
"""Multiply a (possibly sparse) gradient by the given scale factor."""
scale = math_ops.cast(scale, gradient.dtype)
if isinstance(gradient, ops.IndexedSlices):
return ops.IndexedSlices(
gradient.values * scale,
gradient.indices,
dense_shape=gradient.dense_shape)
else:
return gradient * scale
def strategy_supports_loss_scaling():
"""Returns True if the current Strategy supports loss scaling."""
if not distribution_strategy_context.has_strategy():
return True
strategy = distribution_strategy_context.get_strategy()
# Strategies are supported if either there is only one replica or if variables
# are replicated per device. Otherwise, the current model.fit() implementation
# and most custom training loops incorrectly unscale the gradients. Currently,
# gradients are unscaled once per compute replica, but they should be unscaled
# once per variable replica. When there is one variable replica for each
# compute replica, this works fine, but otherwise issues will occur.
# TODO(reedwm): Support all strategies.
return isinstance(strategy, (
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
collective_all_reduce_strategy.CollectiveAllReduceStrategyV1,
one_device_strategy.OneDeviceStrategy,
one_device_strategy.OneDeviceStrategyV1,
mirrored_strategy.MirroredStrategy,
mirrored_strategy.MirroredStrategyV1,
))