Support non-float32 losses in LossScaleOptimizer.
PiperOrigin-RevId: 264288451
diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
index e8128f0..a68c6ff 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
@@ -22,6 +22,7 @@
from tensorflow.python.keras import backend
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.util.tf_export import keras_export
@@ -166,9 +167,12 @@
"""
loss_scale = self._loss_scale()
if callable(loss):
- return lambda: loss() * loss_scale
+ def new_loss():
+ loss_val = loss()
+ return loss_val * math_ops.cast(loss_scale, loss_val.dtype)
+ return new_loss
else:
- return loss * loss_scale
+ return loss * math_ops.cast(loss_scale, loss.dtype)
def get_unscaled_gradients(self, grads):
"""Unscales the gradients by the loss scale.
@@ -193,7 +197,8 @@
"""
loss_scale = self._loss_scale()
loss_scale_reciprocal = 1. / loss_scale
- return [g * loss_scale_reciprocal if g is not None else None for g in grads]
+ return [g * math_ops.cast(loss_scale_reciprocal, g.dtype) if g is not None
+ else None for g in grads]
def _compute_gradients(self, loss, var_list, grad_loss=None):
loss = self.get_scaled_loss(loss)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
index 1b1921f..320b30e 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
@@ -32,6 +32,7 @@
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.ops import control_flow_v2_toggles
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -116,13 +117,23 @@
def testGetScaledLoss(self):
opt = gradient_descent.SGD(2.0)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2.)
- self.assertEqual(10., self.evaluate(opt.get_scaled_loss(5.)))
+ loss = ops.convert_to_tensor(5.)
+ self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss)))
+ self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)()))
+ loss = ops.convert_to_tensor(5., dtype='float16')
+ self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss)))
+ self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)()))
@test_util.run_in_graph_and_eager_modes
def testGetUnscaledGradients(self):
opt = gradient_descent.SGD(2.0)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2)
- grads = opt.get_unscaled_gradients([3., None, -4.])
+ scaled_grads = [
+ ops.convert_to_tensor(3.),
+ None,
+ ops.convert_to_tensor(-4., dtype='float16')
+ ]
+ grads = opt.get_unscaled_gradients(scaled_grads)
grads = [self.evaluate(g) if g is not None else g for g in grads]
self.assertEqual([1.5, None, -2.], grads)
@@ -193,6 +204,28 @@
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
+ def testDynamicLossScaleWithFloat16Loss(self, strategy_fn):
+ strategy = strategy_fn()
+ learning_rate = 2.
+ with strategy.scope():
+ var = variables.Variable([5.0])
+ opt = gradient_descent.SGD(learning_rate)
+ loss_scale = loss_scale_module.DynamicLossScale(
+ initial_loss_scale=2, increment_period=1, multiplier=2)
+ opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
+
+ def loss():
+ return math_ops.cast(var / strategy.num_replicas_in_sync, 'float16')
+ run_fn = lambda: opt.minimize(loss, var_list=[var])
+ run_op = strategy.experimental_run(run_fn)
+ self.evaluate(variables.global_variables_initializer())
+ self._run_if_in_graph_mode(run_op)
+ # The loss is the identity of the variable. Therefore the gradient is 1,
+ # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3
+ self.assertAllClose([3.], self.evaluate(var))
+
+ @parameterized.named_parameters(*TESTCASES)
+ @test_util.run_in_graph_and_eager_modes
def testDynamicLossScaleWithSlots(self, strategy_fn):
strategy_obj = strategy_fn()
if (isinstance(strategy_obj, mirrored_strategy.MirroredStrategy) and