Return an error if keras metric is declared outside of TPU device scope and then
updated per replica.

PiperOrigin-RevId: 280267869
Change-Id: Ida6e5f18f3da2d1f353660a99a08913abb056859
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index ed58968..3918c7b 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -216,6 +216,7 @@
         "//tensorflow/python/ops/ragged:ragged_tensor",
         "//tensorflow/python/ops/ragged:ragged_util",
         "//tensorflow/python/profiler:traceme",
+        "//tensorflow/python/tpu:tpu_lib",
         "//tensorflow/python/training/tracking:data_structures",
         "//tensorflow/tools/docs:doc_controls",
         "@six_archive//:six",
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 632fdb3..276431b 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -1202,6 +1202,9 @@
         # ignored, following the default path for adding updates.
         not call_context.saving):
       # Updates don't need to be run in a cross-replica context.
+      # TODO(b/142574744): Relax this restriction so that metrics/variables
+      # created outside of a strategy scope can be updated in the cross-replica
+      # context.
       if (ops.executing_eagerly_outside_functions() and
           not base_layer_utils.is_in_keras_graph()):
         raise RuntimeError(  # pylint: disable=g-doc-exception
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 7e6f180..7abd44b 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -24,6 +24,7 @@
 import numpy as np
 import six
 
+from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
@@ -253,11 +254,10 @@
                  initializer=None,
                  dtype=None):
     """Adds state variable. Only for use by subclasses."""
-    from tensorflow.python.distribute import distribution_strategy_context as ds_context  # pylint:disable=g-import-not-at-top
     from tensorflow.python.keras.distribute import distributed_training_utils  # pylint:disable=g-import-not-at-top
 
-    if ds_context.has_strategy():
-      strategy = ds_context.get_strategy()
+    if distribute_ctx.has_strategy():
+      strategy = distribute_ctx.get_strategy()
     else:
       strategy = None
 
diff --git a/tensorflow/python/keras/utils/metrics_utils.py b/tensorflow/python/keras/utils/metrics_utils.py
index 3431c55..3eaf006 100644
--- a/tensorflow/python/keras/utils/metrics_utils.py
+++ b/tensorflow/python/keras/utils/metrics_utils.py
@@ -39,6 +39,7 @@
 from tensorflow.python.ops.losses import util as tf_losses_utils
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.ops.ragged import ragged_util
+from tensorflow.python.tpu import tpu
 from tensorflow.python.util import tf_decorator
 
 NEG_INF = -1e10
@@ -71,6 +72,19 @@
 
   def decorated(metric_obj, *args, **kwargs):
     """Decorated function with `add_update()`."""
+    strategy = distribution_strategy_context.get_strategy()
+    # TODO(b/142574744): Remove this check if a better solution is found for
+    # declaring keras Metric outside of TPUStrategy and then updating it per
+    # replica.
+
+    for weight in metric_obj.weights:
+      if (tpu.is_tpu_strategy(strategy) and
+          not strategy.extended.variable_created_in_scope(weight)
+          and not distribution_strategy_context.in_cross_replica_context()):
+        raise ValueError(
+            'Trying to run metric.update_state in replica context when '
+            'the metric was not created in TPUStrategy scope. '
+            'Make sure the keras Metric is created in TPUstrategy scope. ')
 
     with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
       update_op = update_state_fn(*args, **kwargs)