Return an error if keras metric is declared outside of TPU device scope and then
updated per replica.
PiperOrigin-RevId: 280267869
Change-Id: Ida6e5f18f3da2d1f353660a99a08913abb056859
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index ed58968..3918c7b 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -216,6 +216,7 @@
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:ragged_util",
"//tensorflow/python/profiler:traceme",
+ "//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/training/tracking:data_structures",
"//tensorflow/tools/docs:doc_controls",
"@six_archive//:six",
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 632fdb3..276431b 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -1202,6 +1202,9 @@
# ignored, following the default path for adding updates.
not call_context.saving):
# Updates don't need to be run in a cross-replica context.
+ # TODO(b/142574744): Relax this restriction so that metrics/variables
+ # created outside of a strategy scope can be updated in the cross-replica
+ # context.
if (ops.executing_eagerly_outside_functions() and
not base_layer_utils.is_in_keras_graph()):
raise RuntimeError( # pylint: disable=g-doc-exception
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 7e6f180..7abd44b 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -24,6 +24,7 @@
import numpy as np
import six
+from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
@@ -253,11 +254,10 @@
initializer=None,
dtype=None):
"""Adds state variable. Only for use by subclasses."""
- from tensorflow.python.distribute import distribution_strategy_context as ds_context # pylint:disable=g-import-not-at-top
from tensorflow.python.keras.distribute import distributed_training_utils # pylint:disable=g-import-not-at-top
- if ds_context.has_strategy():
- strategy = ds_context.get_strategy()
+ if distribute_ctx.has_strategy():
+ strategy = distribute_ctx.get_strategy()
else:
strategy = None
diff --git a/tensorflow/python/keras/utils/metrics_utils.py b/tensorflow/python/keras/utils/metrics_utils.py
index 3431c55..3eaf006 100644
--- a/tensorflow/python/keras/utils/metrics_utils.py
+++ b/tensorflow/python/keras/utils/metrics_utils.py
@@ -39,6 +39,7 @@
from tensorflow.python.ops.losses import util as tf_losses_utils
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util
+from tensorflow.python.tpu import tpu
from tensorflow.python.util import tf_decorator
NEG_INF = -1e10
@@ -71,6 +72,19 @@
def decorated(metric_obj, *args, **kwargs):
"""Decorated function with `add_update()`."""
+ strategy = distribution_strategy_context.get_strategy()
+ # TODO(b/142574744): Remove this check if a better solution is found for
+ # declaring keras Metric outside of TPUStrategy and then updating it per
+ # replica.
+
+ for weight in metric_obj.weights:
+ if (tpu.is_tpu_strategy(strategy) and
+ not strategy.extended.variable_created_in_scope(weight)
+ and not distribution_strategy_context.in_cross_replica_context()):
+ raise ValueError(
+ 'Trying to run metric.update_state in replica context when '
+ 'the metric was not created in TPUStrategy scope. '
+ 'Make sure the keras Metric is created in TPUstrategy scope. ')
with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
update_op = update_state_fn(*args, **kwargs)