Support non-float32 losses in LossScaleOptimizer.

PiperOrigin-RevId: 264288451
diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
index e8128f0..a68c6ff 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
@@ -22,6 +22,7 @@
 from tensorflow.python.keras import backend
 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.training.experimental import loss_scale as loss_scale_module
 from tensorflow.python.util.tf_export import keras_export
 
@@ -166,9 +167,12 @@
     """
     loss_scale = self._loss_scale()
     if callable(loss):
-      return lambda: loss() * loss_scale
+      def new_loss():
+        loss_val = loss()
+        return loss_val * math_ops.cast(loss_scale, loss_val.dtype)
+      return new_loss
     else:
-      return loss * loss_scale
+      return loss * math_ops.cast(loss_scale, loss.dtype)
 
   def get_unscaled_gradients(self, grads):
     """Unscales the gradients by the loss scale.
@@ -193,7 +197,8 @@
     """
     loss_scale = self._loss_scale()
     loss_scale_reciprocal = 1. / loss_scale
-    return [g * loss_scale_reciprocal if g is not None else None for g in grads]
+    return [g * math_ops.cast(loss_scale_reciprocal, g.dtype) if g is not None
+            else None for g in grads]
 
   def _compute_gradients(self, loss, var_list, grad_loss=None):
     loss = self.get_scaled_loss(loss)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
index 1b1921f..320b30e 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
@@ -32,6 +32,7 @@
 from tensorflow.python.keras.optimizer_v2 import adam
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
 from tensorflow.python.ops import control_flow_v2_toggles
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
@@ -116,13 +117,23 @@
   def testGetScaledLoss(self):
     opt = gradient_descent.SGD(2.0)
     opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2.)
-    self.assertEqual(10., self.evaluate(opt.get_scaled_loss(5.)))
+    loss = ops.convert_to_tensor(5.)
+    self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss)))
+    self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)()))
+    loss = ops.convert_to_tensor(5., dtype='float16')
+    self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss)))
+    self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)()))
 
   @test_util.run_in_graph_and_eager_modes
   def testGetUnscaledGradients(self):
     opt = gradient_descent.SGD(2.0)
     opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2)
-    grads = opt.get_unscaled_gradients([3., None, -4.])
+    scaled_grads = [
+        ops.convert_to_tensor(3.),
+        None,
+        ops.convert_to_tensor(-4., dtype='float16')
+    ]
+    grads = opt.get_unscaled_gradients(scaled_grads)
     grads = [self.evaluate(g) if g is not None else g for g in grads]
     self.assertEqual([1.5, None, -2.], grads)
 
@@ -193,6 +204,28 @@
 
   @parameterized.named_parameters(*TESTCASES)
   @test_util.run_in_graph_and_eager_modes
+  def testDynamicLossScaleWithFloat16Loss(self, strategy_fn):
+    strategy = strategy_fn()
+    learning_rate = 2.
+    with strategy.scope():
+      var = variables.Variable([5.0])
+      opt = gradient_descent.SGD(learning_rate)
+      loss_scale = loss_scale_module.DynamicLossScale(
+          initial_loss_scale=2, increment_period=1, multiplier=2)
+      opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
+
+      def loss():
+        return math_ops.cast(var / strategy.num_replicas_in_sync, 'float16')
+      run_fn = lambda: opt.minimize(loss, var_list=[var])
+      run_op = strategy.experimental_run(run_fn)
+      self.evaluate(variables.global_variables_initializer())
+      self._run_if_in_graph_mode(run_op)
+      # The loss is the identity of the variable. Therefore the gradient is 1,
+      # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3
+      self.assertAllClose([3.], self.evaluate(var))
+
+  @parameterized.named_parameters(*TESTCASES)
+  @test_util.run_in_graph_and_eager_modes
   def testDynamicLossScaleWithSlots(self, strategy_fn):
     strategy_obj = strategy_fn()
     if (isinstance(strategy_obj, mirrored_strategy.MirroredStrategy) and