blob: bbbd0cd7ec4c664e78906f67397ce6f39a06072e [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains LossScale classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@six.add_metaclass(abc.ABCMeta)
@tf_export('train.experimental.LossScale')
class LossScale(trackable.Trackable):
"""Loss scale base class.
Loss scaling is a process that multiplies the loss by a multiplier called the
loss scale, and divides each gradient by the same multiplier. The pseudocode
for this process is:
```
loss = ...
loss *= loss_scale
grads = gradients(loss, vars)
grads /= loss_scale
```
Mathematically, loss scaling has no effect, but can help avoid numerical
underflow in intermediate gradients when float16 tensors are used for mixed
precision training. By multiplying the loss, each intermediate gradient will
have the same multiplier applied.
Instances of this class represent a loss scale. Calling instances of this
class returns the loss scale as a scalar float32 tensor, while method
`update()` updates the loss scale depending on the values of the gradients.
Optimizers use instances of this class to scale loss and gradients.
"""
def __init__(self):
"""Initializes the loss scale class."""
self._weights = {}
@abc.abstractmethod
def __call__(self):
"""Returns the current loss scale as a scalar `float32` tensor."""
pass
@abc.abstractmethod
def update(self, grads):
"""Updates the value of the loss scale.
The loss scale will be potentially updated, based on the value of `grads`.
The tensor returned by calling this class is only updated when this function
is evaluated.
In eager mode, this directly updates the loss scale, so that calling
`__call__` will return the newly updated loss scale. In graph mode,
this returns an op that, when evaluated, updates the loss scale.
This function also returns a `should_apply_gradients` bool. If False,
gradients should not be applied to the variables that step, as nonfinite
gradients were found, and the loss scale has been be updated to reduce the
chance of finding nonfinite gradients in the next step. Some loss scale
classes will always return True, as they cannot adjust themselves in
response to nonfinite gradients.
When a DistributionStrategy is used, this function may only be called in a
cross-replica context.
Args:
grads: A nested structure of unscaled gradients, each which is the
gradient of the loss with respect to a weight. The gradients should have
already been divided by the loss scale being before passed to this
function. 'None' gradients are accepted, and are ignored.
Returns:
update_op: In eager mode, None. In graph mode, an op to update the loss
scale.
should_apply_gradients: Either a bool or a scalar boolean tensor. If
False, the caller should skip applying `grads` to the variables this
step.
"""
pass
def _add_weight(self, name, initial_value, dtype=None):
"""Adds a weight to this loss scale.
Args:
name: Variable name.
initial_value: The variable's initial value.
dtype: The type of the variable.
Returns:
A variable.
Raises:
RuntimeError: If a weight with `name` has already been added.
"""
variable = variable_scope.variable(
initial_value=initial_value,
name=name,
dtype=dtype,
trainable=False,
use_resource=True,
synchronization=variables.VariableSynchronization.AUTO,
# Set aggregation to NONE, as loss scaling variables should never be
# aggregated.
aggregation=variables.VariableAggregation.NONE)
if context.executing_eagerly():
graph_key = None
else:
graph = ops.get_default_graph()
graph_key = graph._graph_key # pylint: disable=protected-access
key = (name, graph_key)
if self._weights.get(key, None) is not None:
raise RuntimeError('Duplicate variables detected. {}'.format(key))
self._weights[key] = variable
self._handle_deferred_dependencies(name=name, trackable=variable)
return variable
@property
def _checkpoint_dependencies(self):
"""From Trackable. Gather graph-specific weights to save."""
if context.executing_eagerly():
graph_key = None
else:
graph = ops.get_default_graph()
graph_key = graph._graph_key # pylint: disable=protected-access
weights = []
for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]):
if g == graph_key:
weights.append(trackable.TrackableReference(name=name, ref=v))
return super(LossScale, self)._checkpoint_dependencies + weights
def _lookup_dependency(self, name):
"""From Trackable. Find a weight in the current graph."""
unconditional = super(LossScale, self)._lookup_dependency(name)
if unconditional is not None:
return unconditional
if context.executing_eagerly():
graph_key = None
else:
graph = ops.get_default_graph()
graph_key = graph._graph_key # pylint: disable=protected-access
return self._weights.get((name, graph_key), None)
@abc.abstractmethod
def get_config(self):
"""Returns the config of this loss scale."""
pass
@classmethod
def from_config(cls, config):
"""Creates the LossScale from its config."""
return cls(**config)
def get_loss_scale_weights(loss_scale):
return loss_scale._weights.values() # pylint: disable=protected-access
@tf_export('train.experimental.FixedLossScale')
class FixedLossScale(LossScale):
"""Loss scale with a fixed value.
The loss scale is not updated for the lifetime of instances of this class.
A given instance of this class always returns the same number when called.
"""
def __init__(self, loss_scale_value):
"""Creates the fixed loss scale.
Args:
loss_scale_value: A Python float. Its ideal value varies depending on
models to run. Choosing a too small loss_scale might affect model
quality; a too big loss_scale might cause inf or nan. There is no single
right loss_scale to apply. There is no harm choosing a relatively big
number as long as no nan or inf is encountered in training.
Raises:
ValueError: If loss_scale is less than 1.
"""
super(FixedLossScale, self).__init__()
if not isinstance(loss_scale_value, six.integer_types + (float,)):
raise ValueError('loss_scale_value must be a Python int or float.')
if loss_scale_value < 1:
raise ValueError('loss_scale_value must be at least 1.')
# It's important we do not create tensors in the constructor, as such
# tensors might be on a different device or tf.function vs when the tensor
# is used. This would hurt performance. Therefore, we do not create a tensor
# from loss_scale_value, but instead leave it as a Python float.
# TODO(reedwm): Also do not create tensors in the DynamicLossScale
# constructor.
self._loss_scale_value = float(loss_scale_value)
def __call__(self):
return ops.convert_to_tensor(self._loss_scale_value)
def update(self, grads):
del grads
return control_flow_ops.no_op(), True
def get_config(self):
return {'loss_scale_value': self._loss_scale_value}
def _is_all_finite(grads):
"""Returns a scalar boolean tensor indicating if all gradients are finite."""
is_finite_per_grad = [
math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
]
return math_ops.reduce_all(is_finite_per_grad)
def _op_in_graph_mode(tensor):
"""Returns the tensor's op in graph mode, or the tensor in eager mode.
This is useful because sometimes an op is needed in graph mode instead of a
tensor. In eager mode, there are no ops.
Args:
tensor: A tensor.
Returns:
The tensor's op in graph mode. The tensor in eager mode.
"""
if context.executing_eagerly():
return tensor
return tensor.op
def _assign_if_finite(var, value):
"""Assigns a value to a variable if the value is finite."""
return control_flow_ops.cond(
math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)),
control_flow_ops.no_op)
@tf_export('train.experimental.DynamicLossScale')
class DynamicLossScale(LossScale):
"""Loss scale that dynamically adjusts itself.
Dynamic loss scaling works by adjusting the loss scale as training progresses.
The goal is to keep the loss scale as high as possible without overflowing the
gradients. As long as the gradients do not overflow, raising the loss scale
never hurts.
The algorithm starts by setting the loss scale to an initial value. Every N
steps that the gradients are finite, the loss scale is increased by some
factor. However, if a NaN or Inf gradient is found, the gradients for that
step are not applied, and the loss scale is decreased by the factor. This
process tends to keep the loss scale as high as possible without gradients
overflowing.
"""
def __init__(self,
initial_loss_scale=2 ** 15, # See docstring for why this is big.
increment_period=2000,
multiplier=2.):
"""Creates the dynamic loss scale.
Args:
initial_loss_scale: A Python float. The loss scale to use at the
beginning. It's better to start this at a very high number, because a
loss scale that is too high gets lowered far more quickly than a loss
scale that is too low gets raised. The default is 2 ** 15, which is
approximately half the maximum float16 value.
increment_period: Increases loss scale every `increment_period`
consecutive steps that finite gradients are encountered. If a nonfinite
gradient is encountered, the count is reset back to zero.
multiplier: The multiplier to use when increasing or decreasing the loss
scale.
"""
super(DynamicLossScale, self).__init__()
self._initial_loss_scale = float(initial_loss_scale)
self._increment_period = int(increment_period)
self._multiplier = float(multiplier)
self._current_loss_scale = self._add_weight(
name='current_loss_scale',
dtype=dtypes.float32,
initial_value=self._initial_loss_scale)
# The number of consecutive steps with finite gradients since the last
# nonfinite gradient or change in loss scale.
self._num_good_steps = self._add_weight(
name='good_steps', dtype=dtypes.int64, initial_value=0)
@property
def initial_loss_scale(self):
return self._initial_loss_scale
@property
def increment_period(self):
return self._increment_period
@property
def multiplier(self):
return self._multiplier
def __call__(self):
return ops.convert_to_tensor(self._current_loss_scale)
def update(self, grads):
"""Updates loss scale based on if gradients are finite in current step."""
grads = nest.flatten(grads)
if distribution_strategy_context.has_strategy():
distribution = distribution_strategy_context.get_cross_replica_context()
def get_is_finite(grads):
is_finite = _is_all_finite(grads)
# We cast to float, because we cannot reduce booleans with
# DistributionStrategy.
return math_ops.cast(is_finite, dtypes.float32)
is_finite_float = distribution.extended.call_for_each_replica(
get_is_finite, args=(grads,))
reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM,
is_finite_float, axis=None)
is_finite = math_ops.equal(reduced_is_finite_float,
distribution.num_replicas_in_sync)
else:
is_finite = _is_all_finite(grads)
def update_if_finite_grads():
"""Update assuming the gradients are finite."""
def incr_loss_scale():
new_loss_scale = self._current_loss_scale * self._multiplier
return control_flow_ops.group(
_assign_if_finite(self._current_loss_scale, new_loss_scale),
self._num_good_steps.assign(0))
return control_flow_ops.cond(
self._num_good_steps + 1 >= self._increment_period,
incr_loss_scale, lambda: _op_in_graph_mode(
self._num_good_steps.assign_add(1)))
def update_if_not_finite_grads():
"""Update assuming the gradients are nonfinite."""
new_loss_scale = math_ops.maximum(
self._current_loss_scale / self._multiplier, 1)
return control_flow_ops.group(
self._num_good_steps.assign(0),
self._current_loss_scale.assign(new_loss_scale))
update_op = control_flow_ops.cond(is_finite, update_if_finite_grads,
update_if_not_finite_grads)
should_apply_gradients = is_finite
return update_op, should_apply_gradients
def get_config(self):
return {
'initial_loss_scale': self.initial_loss_scale,
'increment_period': self.increment_period,
'multiplier': self.multiplier,
}
def get(identifier):
"""Get a loss scale object."""
if isinstance(identifier, six.integer_types + (float,)):
return FixedLossScale(identifier)
if identifier == 'dynamic':
return DynamicLossScale()
if isinstance(identifier, LossScale):
return identifier
elif identifier is None:
return None
else:
raise ValueError('Could not interpret loss scale identifier: %s' %
identifier)