Save a non-distributed model correctly
We used to save in either the default replica context or the cross replica
context. In either the variables behave in the desired way, which is as if
there's no distribute strategy.
This change make this behavior explicitly implemented. Note that now you will be
able to optionally save a distributed version of the model by setting
experimental_variable_policy to EXPAND_DISTRIBUTED_VARIABLES in SaveOptions.
This change is somewhat messy due to the ongoing refactoring to
DistributedVariable, but the fix is important and we can clean up later.
PiperOrigin-RevId: 323746817
Change-Id: I5ec5db232d86be97c93a2d54c9d3b1ceb344b3df
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index a3279e8..947ec98 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -794,9 +794,13 @@
deps = [
":distribute_lib",
":reduce_util",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:variable_scope",
+ "//tensorflow/python/saved_model:save_context",
+ "//tensorflow/python/saved_model:save_options",
],
)
@@ -806,6 +810,7 @@
deps = [
":packed_distributed_variable",
":values",
+ ":values_util",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops_gen",
@@ -1165,6 +1170,7 @@
":distribute_lib",
":distribute_utils",
":packed_distributed_variable",
+ ":parameter_server_strategy",
":strategy_combinations",
":test_util",
":tpu_strategy",
@@ -1172,6 +1178,7 @@
":values",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
@@ -1183,6 +1190,7 @@
"//tensorflow/python:saver",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tf2",
"//tensorflow/python:training",
@@ -1195,6 +1203,7 @@
"//tensorflow/python/eager:test",
"//tensorflow/python/saved_model:save_context",
"//tensorflow/python/saved_model:save_options",
+ "//tensorflow/python/saved_model/model_utils:mode_keys",
"//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/types",
"@absl_py//absl/testing:parameterized",
diff --git a/tensorflow/python/distribute/integration_test/saved_model_test.py b/tensorflow/python/distribute/integration_test/saved_model_test.py
index 60de590..4455e1f 100644
--- a/tensorflow/python/distribute/integration_test/saved_model_test.py
+++ b/tensorflow/python/distribute/integration_test/saved_model_test.py
@@ -68,25 +68,9 @@
# context and the cross-replica context. Saving happens in the cross replica
# context or the default startegy's replica context.
- def test_read_sync_on_read_variable_broken(self, strategy):
+ def test_read_sync_on_read_variable(self, strategy):
# synchronizaiton=ON_READ variables are typically used in Keras metrics and
# batch norm layers.
- #
- # This is broken now since the saved variable already has the aggregated
- # value, but the saved tf.function is traced under the cross-replica context
- # and contains the aggregation.
- #
- # Impacts:
- # - MirroredStrategy, TPUStrategy
- # - aggregation=NONE: error when saving.
- # - aggregation=SUM: incorrect results.
- # - aggregation=MEAN: slight computation overhead.
- # - aggregation=ONLY_FIRST_REPLICA: none.
- # - MultiWorkerMirroredStrategy:
- # - aggregation=NONE: error when saving
- # - aggregation=MEAN, SUM: error or hanging when using the loaded model.
- # - aggregation=ONLY_FIRST_REPLICA: none.
- # Note that batch norm uses aggregation=MEAN.
class Model(tf.Module):
@@ -113,9 +97,7 @@
loaded = tf.saved_model.load(export_dir)
# The variable already has the aggregated value.
self.assertEqual(self.evaluate(loaded.v.read_value()), 1.)
- # TODO(b/159752793): reading the variable aggregates the values again.
- # got 2., want 1.
- self.assertEqual(self.evaluate(loaded()), 2.)
+ self.assertEqual(self.evaluate(loaded()), 1.)
def test_read_mirrored_variable(self, strategy):
# synchronizaiton=ON_WRITE is the default variable created under
@@ -142,14 +124,10 @@
loaded = tf.saved_model.load(export_dir)
self.assertEqual(self.evaluate(loaded()), 1.)
- def test_update_sync_on_read_variable_broken(self, strategy):
+ def test_update_sync_on_read_variable(self, strategy):
# It's rare to update aggregation=ON_READ variables in serving, but it's
# possible that the SavedModel contains both serving and training graphs,
# and the training may contain metrics layers.
- #
- # This is now partially broken since assign_add() and assign_sub() are not
- # allowed in the cross-replica context if aggregation=SUM, which blocks
- # saving the model.
class Model(tf.Module):
@@ -167,23 +145,15 @@
export_dir = self.get_temp_dir()
with strategy.scope():
m = Model()
- # got error, want no error.
- with self.assertRaisesRegex(ValueError,
- "SyncOnReadVariable does not support"):
- tf.saved_model.save(m, export_dir)
+ tf.saved_model.save(m, export_dir)
- # TODO(b/159752793): Uncomment after fix.
- # loaded = tf.saved_model.load(export_dir)
- # loaded.update()
- # self.assertEqual(self.evaluate(loaded.v), 1.)
+ loaded = tf.saved_model.load(export_dir)
+ loaded.update()
+ self.assertEqual(self.evaluate(loaded.v), 1.)
- def test_update_mirrored_variable_broken(self, strategy):
+ def test_update_mirrored_variable(self, strategy):
# It's very rare to update aggregation=ON_WRITE variables in the forward
# path, and this test case is mainly for completeness.
- #
- # The saved tf.function updates each components of the distributed variable,
- # which effectively updates the variable in the saved model N times where N
- # equals the number of local replicas during training.
class Model(tf.Module):
@@ -205,9 +175,7 @@
loaded = tf.saved_model.load(export_dir)
self.assertEqual(self.evaluate(loaded.v), 0.)
loaded.update()
- # TODO(b/159752793): Change after fix.
- # got 2., want 1.
- self.assertEqual(self.evaluate(loaded.v), 2.)
+ self.assertEqual(self.evaluate(loaded.v), 1.)
def test_training_only_device(self, strategy):
# tf.distribute APIs may enter device scopes, but the saved model should not
@@ -314,12 +282,14 @@
# can workaround most issues since Keras loader restructs the layers with
# saved configs if possible, in which case the saved graph is not used.
- def test_read_sync_on_read_variable_broken(self, strategy):
- # Reading a synchronizaiton=ON_READ in the replica context should only read
- # the local value, however with a loaded model, reading in the replica
- # context triggers aggregation as well. While one may argue the behavior is
- # desirable, note that aggregation can cause hanging if the originall model
- # is trained with MultiWorkerMirroredStrategy.
+ def test_read_sync_on_read_variable(self, strategy):
+ # Reading a synchronizaiton=ON_READ in the replica context should just read
+ # the local value. Reading it in the cross replica context aggregates the
+ # value from all replicas. Both are true with a loaded model.
+ #
+ # Note that if aggregation=SUM, the value of each replica is the saved value
+ # divided by the number of replicas. In this way if you load a model and
+ # save it again, the values of the variables don't change.
class Model(tf.Module):
@@ -334,26 +304,45 @@
return self.v.read_value()
export_dir = self.get_temp_dir()
+ value = strategy.experimental_distribute_values_from_function(
+ lambda ctx: tf.identity([3., 7.][ctx.replica_id_in_sync_group]))
with strategy.scope():
m = Model()
- m.v.assign(1.)
+ strategy.run(m.v.assign, args=(value,))
self.assertAllEqual(
- self.evaluate(strategy.experimental_local_results(m.v)), [0.5, 0.5])
+ self.evaluate(strategy.experimental_local_results(m.v)), [3., 7.])
+ self.assertEqual(self.evaluate(m.v.read_value()), 10.)
tf.saved_model.save(m, export_dir)
+ del m
with strategy.scope():
loaded = tf.saved_model.load(export_dir)
- # After loading, reading in the replica context is the same as reading in
- # the cross-replica context.
- # TODO(b/159752793): change after fix.
+ # It's intended that we don't save the each replica, but just the aggregated
+ # value.
self.assertAllEqual(
self.evaluate(
strategy.experimental_local_results(strategy.run(loaded))),
- [1., 1.])
- self.assertEqual(self.evaluate(loaded.v.read_value()), 1.)
+ [5., 5.])
+ self.assertEqual(self.evaluate(loaded.v.read_value()), 10.)
- def test_update_sync_on_read_variable_broken(self, strategy):
- # Can't even save.
+ # save and load again.
+ export_dir2 = self.get_temp_dir()
+ tf.saved_model.save(loaded, export_dir2)
+ # loaded.v.read_value() is still 1., both with and without strategy.
+ loaded = tf.saved_model.load(export_dir2)
+ self.assertEqual(self.evaluate(loaded.v.read_value()), 10.)
+ with strategy.scope():
+ loaded = tf.saved_model.load(export_dir2)
+ self.assertEqual(self.evaluate(loaded.v.read_value()), 10.)
+
+ def test_update_sync_on_read_variable(self, strategy):
+ # Updating a synchronizaiton=ON_READ in the replica context should just
+ # update the local value. Updating it in the cross replica context updates
+ # each component of the variable. Both are true with a loaded model.
+ #
+ # Note that if assigning a variable whose aggregation=SUM in the cross
+ # replica context, each replica is assigned with the value divided by the
+ # number of replicas.
class Model(tf.Module):
@@ -363,19 +352,36 @@
synchronization=tf.VariableSynchronization.ON_READ,
aggregation=tf.VariableAggregation.SUM)
- @tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.float32)])
+ @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
def update(self, value):
self.v.assign_add(value)
export_dir = self.get_temp_dir()
+ value = strategy.experimental_distribute_values_from_function(
+ lambda ctx: tf.identity([3., 7.][ctx.replica_id_in_sync_group]))
with strategy.scope():
m = Model()
- # got error, want no error.
- with self.assertRaisesRegex(ValueError,
- "SyncOnReadVariable does not support"):
- tf.saved_model.save(m, export_dir)
+ tf.saved_model.save(m, export_dir)
+ self.evaluate(m.v.assign(10.))
+ self.assertAllEqual(
+ self.evaluate(strategy.experimental_local_results(m.v)), [5., 5.])
+ del m
+ # TODO(b/161488560): strategy.run doesn't work with tf.function with
+ # input_signature.
+ # self.evaluate(strategy.run(m.update, args=(value,)))
+ # self.assertAllEqual(
+ # self.evaluate(strategy.experimental_local_results(m.v)), [8., 12.])
- # TODO(b/159752793): Complete the test after the saving issue is fixed.
+ with strategy.scope():
+ loaded = tf.saved_model.load(export_dir)
+ self.evaluate(loaded.v.assign(10.))
+ self.assertAllEqual(
+ self.evaluate(strategy.experimental_local_results(loaded.v)),
+ [5., 5.])
+ self.evaluate(strategy.run(loaded.update, args=(value,)))
+ self.assertAllEqual(
+ self.evaluate(strategy.experimental_local_results(loaded.v)),
+ [8., 12.])
def test_read_mirrored_variable(self, strategy):
@@ -402,13 +408,18 @@
strategy.experimental_local_results(strategy.run(loaded))),
[1., 1.])
- def test_update_mirrored_variable_broken(self, strategy):
+ def test_update_mirrored_variable(self, strategy):
# This is also uncommon since most model parameters should be updated by
# optimizer, and this test case is for completeness.
#
- # It's broken the saved model may not contain the aggregation logic. Even if
- # it does, it's wrong since all inputs to the aggregation are the same
- # variable.
+ # In the cross replica context, assigning to the variable assigns the same
+ # value to all replicas. This is true with the loaded model as well.
+ #
+ # However in replica context, MirroredVariable (synchronization=ON_WRITE)
+ # in a loaded model behaves differently. Updating MirroredVariable only
+ # update the current replica's variable with the current replica's value.
+ # There's no aggregation. This doesn't affect variables that are updated
+ # through optimizer. This is work as intended but can be surprising.
class Model(tf.Module):
@@ -418,24 +429,28 @@
synchronization=tf.VariableSynchronization.ON_WRITE,
aggregation=tf.VariableAggregation.MEAN)
- @tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.float32)])
+ @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
def update(self, value):
- self.v.assign_add(value[0])
+ self.v.assign_add(value)
export_dir = self.get_temp_dir()
+ value = strategy.experimental_distribute_values_from_function(
+ lambda ctx: tf.identity([1., 2.][ctx.replica_id_in_sync_group]))
with strategy.scope():
m = Model()
tf.saved_model.save(m, export_dir)
+ del m
with strategy.scope():
loaded = tf.saved_model.load(export_dir)
- value = strategy.experimental_distribute_dataset(
- tf.data.Dataset.from_tensor_slices([1., 2.]).batch(2))
- strategy.run(loaded.update, args=(next(iter(value)),))
- # TODO(b/159752793): Change after fix.
- # got [2., 4.], want [1.5, 1.5].
self.assertAllEqual(
- self.evaluate(strategy.experimental_local_results(loaded.v)), [2., 4.])
+ self.evaluate(strategy.experimental_local_results(loaded.v)), [0., 0.])
+ self.evaluate(loaded.v.assign(1.))
+ self.assertAllEqual(
+ self.evaluate(strategy.experimental_local_results(loaded.v)), [1., 1.])
+ strategy.run(loaded.update, args=(value,))
+ self.assertAllEqual(
+ self.evaluate(strategy.experimental_local_results(loaded.v)), [2., 3.])
# TODO(crccw): add a test case that trains a saved model with optimizer.
diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py
index ce6d2e7..901b906 100644
--- a/tensorflow/python/distribute/tpu_values.py
+++ b/tensorflow/python/distribute/tpu_values.py
@@ -26,6 +26,7 @@
from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import values
+from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import ops
@@ -162,6 +163,8 @@
@property
def op(self):
+ if values_util.is_saving_non_distributed():
+ return self._primary.op
return values.DistributedVarOp(self._primary.op.name,
self._primary.op.graph,
self._primary.op.traceback,
@@ -289,24 +292,38 @@
read_value=read_value)
def scatter_sub(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_sub(*args, **kwargs)
raise NotImplementedError
def scatter_add(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_add(*args, **kwargs)
raise NotImplementedError
def scatter_max(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_max(*args, **kwargs)
raise NotImplementedError
def scatter_min(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_min(*args, **kwargs)
raise NotImplementedError
def scatter_mul(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_mul(*args, **kwargs)
raise NotImplementedError
def scatter_div(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_div(*args, **kwargs)
raise NotImplementedError
def scatter_update(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_update(*args, **kwargs)
raise NotImplementedError
def _is_mirrored(self):
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index 7dedbee..e6b77ad 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -487,6 +487,8 @@
The op that evaluates to True or False depending on if all the
component variables are initialized.
"""
+ if values_util.is_saving_non_distributed():
+ return self._primary.is_initialized()
if self._use_packed_variable():
return self._packed_var.is_initialized()
result = self._primary.is_initialized()
@@ -502,6 +504,8 @@
@property
def initializer(self):
+ if values_util.is_saving_non_distributed():
+ return self._primary.initializer
if self._initializer_op:
init_op = self._initializer_op
else:
@@ -567,6 +571,8 @@
@property
def handle(self):
+ if values_util.is_saving_non_distributed():
+ return self._primary.handle
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
raise ValueError("`handle` is not available outside the replica context"
@@ -610,6 +616,8 @@
@property
def op(self):
+ if values_util.is_saving_non_distributed():
+ return self._primary.op
# We want cross-replica code that does some var.op.X calls
# to work (even if the current device isn't in self._devices), but
# other uses of var.op in a cross-replica context to fail.
@@ -630,6 +638,8 @@
def _get(self):
"""Returns the value for the current device or raises a ValueError."""
+ if values_util.is_saving_non_distributed():
+ return self._primary
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
return self._get_cross_replica()
@@ -638,6 +648,8 @@
def _get_on_device_or_primary(self):
"""Returns value in same replica or device if possible, else the _primary."""
+ if values_util.is_saving_non_distributed():
+ return self._primary
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
# Try to find a value on the current device.
@@ -654,6 +666,8 @@
return array_ops.identity(self._get())
def value(self):
+ if values_util.is_saving_non_distributed():
+ return self._primary.value()
if self._policy:
return self._policy.value(self)
return self._get_on_device_or_primary().value()
@@ -666,6 +680,8 @@
"numpy() is only available when eager execution is enabled.")
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
+ if values_util.is_saving_non_distributed():
+ return self._primary.assign_sub(value, use_locking, name, read_value)
if self._policy:
return self._policy.assign_sub(
self,
@@ -677,6 +693,8 @@
self, value, use_locking=use_locking, name=name, read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
+ if values_util.is_saving_non_distributed():
+ return self._primary.assign_add(value, use_locking, name, read_value)
if self._policy:
return self._policy.assign_add(
self,
@@ -688,6 +706,8 @@
self, value, use_locking=use_locking, name=name, read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
+ if values_util.is_saving_non_distributed():
+ return self._primary.assign(value, use_locking, name, read_value)
if self._policy:
return self._policy.assign(
self,
@@ -699,6 +719,8 @@
self, value, use_locking=use_locking, name=name, read_value=read_value)
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_sub(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_sub(
self, sparse_delta, use_locking=use_locking, name=name)
@@ -706,6 +728,8 @@
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_add(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_add(
self, sparse_delta, use_locking=use_locking, name=name)
@@ -713,6 +737,8 @@
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_mul(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_mul(
self, sparse_delta, use_locking=use_locking, name=name)
@@ -720,6 +746,8 @@
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_div(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_div(
self, sparse_delta, use_locking=use_locking, name=name)
@@ -727,6 +755,8 @@
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_min(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_min(
self, sparse_delta, use_locking=use_locking, name=name)
@@ -734,6 +764,8 @@
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_max(self, sparse_delta, use_locking=False, name=None):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_max(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_max(
self, sparse_delta, use_locking=use_locking, name=name)
@@ -741,6 +773,8 @@
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_update(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_update(
self, sparse_delta, use_locking=use_locking, name=name)
@@ -763,12 +797,16 @@
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _as_graph_element(self):
+ if values_util.is_saving_non_distributed():
+ return self._primary._as_graph_element() # pylint: disable=protected-access
if self._policy:
return self._policy._as_graph_element(self) # pylint: disable=protected-access
raise NotImplementedError("No policy set for calling _as_graph_element.")
def _get_cross_replica(self):
+ if values_util.is_saving_non_distributed():
+ return self._primary
if self._policy:
return self._policy._get_cross_replica(self) # pylint: disable=protected-access
@@ -827,6 +865,8 @@
Updated variable or `tf.Operation`.
"""
+ if values_util.is_saving_non_distributed():
+ return update_fn(self._primary, value, **kwargs)
with ds_context.enter_or_assert_strategy(self.distribute_strategy):
if ds_context.in_cross_replica_context():
update_replica_id = distribute_lib.get_update_replica_id()
@@ -919,6 +959,8 @@
return _on_write_update_replica(self, update_fn, value, **kwargs)
def scatter_min(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_min(*args, **kwargs)
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(values_util.scatter_error_msg.format(
@@ -926,20 +968,26 @@
return super(MirroredVariable, self).scatter_min(*args, **kwargs)
def scatter_max(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_max(*args, **kwargs)
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(values_util.scatter_error_msg.format(
- op_name="scatter_min", aggregation=self._aggregation))
+ op_name="scatter_max", aggregation=self._aggregation))
return super(MirroredVariable, self).scatter_max(*args, **kwargs)
def scatter_update(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_update(*args, **kwargs)
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(values_util.scatter_error_msg.format(
- op_name="scatter_min", aggregation=self._aggregation))
+ op_name="scatter_update", aggregation=self._aggregation))
return super(MirroredVariable, self).scatter_update(*args, **kwargs)
def _get_cross_replica(self):
+ if values_util.is_saving_non_distributed():
+ return self._primary.read_value()
# Return identity, to avoid directly exposing the variable to the user and
# allowing it to be modified by mistake.
return array_ops.identity(Mirrored._get_cross_replica(self))
@@ -1022,6 +1070,8 @@
# TODO(b/154017756): Make assign behaivor in cross replica context consistent
# with MirroredVariable.
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
+ if values_util.is_saving_non_distributed():
+ return self._primary.assign_sub(value, use_locking, name, read_value)
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context() and not _in_update_replica():
return values_util.on_read_assign_sub_cross_replica(
@@ -1031,6 +1081,8 @@
self).assign_sub(value, use_locking, name, read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
+ if values_util.is_saving_non_distributed():
+ return self._primary.assign_add(value, use_locking, name, read_value)
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context() and not _in_update_replica():
return values_util.on_read_assign_add_cross_replica(
@@ -1040,6 +1092,8 @@
self).assign_add(value, use_locking, name, read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
+ if values_util.is_saving_non_distributed():
+ return self._primary.assign(value, use_locking, name, read_value)
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context() and not _in_update_replica():
return values_util.on_read_assign_cross_replica(
@@ -1054,27 +1108,43 @@
method)
def scatter_sub(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_sub(*args, **kwargs)
self._scatter_not_implemented("scatter_sub")
def scatter_add(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_add(*args, **kwargs)
self._scatter_not_implemented("scatter_add")
def scatter_mul(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_mul(*args, **kwargs)
self._scatter_not_implemented("scatter_mul")
def scatter_div(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_div(*args, **kwargs)
self._scatter_not_implemented("scatter_div")
def scatter_min(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_min(*args, **kwargs)
self._scatter_not_implemented("scatter_min")
def scatter_max(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_max(*args, **kwargs)
self._scatter_not_implemented("scatter_max")
def scatter_update(self, *args, **kwargs):
+ if values_util.is_saving_non_distributed():
+ return self._primary.scatter_update(*args, **kwargs)
self._scatter_not_implemented("scatter_update")
def value(self):
+ if values_util.is_saving_non_distributed():
+ return self._primary.value()
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context() and not _in_update_replica():
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
@@ -1085,6 +1155,8 @@
return self._get_on_device_or_primary().value()
def _get_cross_replica(self):
+ if values_util.is_saving_non_distributed():
+ return self._primary.read_value()
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
# Consider returning a tensor value here to make the return value of
# _get_cross_replica consistent.
@@ -1097,6 +1169,8 @@
axis=None)
def _as_graph_element(self):
+ if values_util.is_saving_non_distributed():
+ return self._primary._as_graph_element() # pylint: disable=protected-access
# pylint: disable=protected-access
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index e445c11..48b6b97 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -30,6 +30,7 @@
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import packed_distributed_variable as packed
+from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util as ds_test_util
from tensorflow.python.distribute import tpu_strategy
@@ -43,9 +44,11 @@
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
@@ -578,6 +581,207 @@
distribution.extended.update(v, read_assign_fn, args=(value,)))
self.assertAllEqual(self.evaluate(v.values), [3., 6.])
+ def testSaveNonDistributed(self, distribution, synchronization, aggregation):
+ # This test verifies that the DistributedVariable behave like the primary
+ # variable when saving a non-distributed version of the model (the default).
+ # The test asserts that the function traced under SaveContext has no device
+ # annotations and only reference the primary component of the variable. Note
+ # that please avoid capturing other eager tensors in this test to make the
+ # assertion easy.
+
+ if isinstance(distribution.extended,
+ parameter_server_strategy.ParameterServerStrategyExtended):
+ self.skipTest("b/148689177: AggregatingVariable doesn't "
+ "conform to Variable interface well")
+
+ # tf.function requires the return value to be Tensors, which is not always
+ # case for properties and methods of Variable, so we simply discard the
+ # return values.
+ def _discard_return(f):
+ f()
+ return
+
+ def _test(f, v):
+ # This verifies that the function under SaveContext:
+ # - contains no device annotations.
+ # - only references the primary component of the variable.
+ g = def_function.function(lambda: _discard_return(f))
+ options = save_options.SaveOptions(
+ experimental_variable_policy=save_options.VariablePolicy.NONE)
+ with save_context.save_context(options):
+ # The graph should contain no device.
+ graph = g.get_concrete_function().graph
+ for op in graph.get_operations():
+ self.assertEqual(op.device, "", msg=str(op))
+ # The function should only capture the primary variable. Note that it
+ # may not have captures, e.g. v.aggregation.
+ captures = list(graph.captures)
+ self.assertLessEqual(len(captures), 1)
+ if graph.captures:
+ self.assertIs(captures[0][0], v._primary.handle)
+
+ def _assert(cond):
+ return control_flow_ops.Assert(cond, [cond])
+
+ with distribution.scope():
+ # We use four variables for convenience reasons. They have no special
+ # meaning.
+ # - v is used whenever possible, and for the methods that require the
+ # dtype to be integer.
+ # - w is used for scatter and gather, which require the variable to be
+ # non-scalar.
+ # - y is used when the dtype needs to be float.
+ v = variables_lib.Variable(
+ 0,
+ synchronization=synchronization,
+ aggregation=aggregation,
+ trainable=True)
+ w = variables_lib.Variable([0., 0., 0.],
+ synchronization=synchronization,
+ aggregation=aggregation,
+ trainable=True)
+ y = variables_lib.Variable(
+ 7.,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+ # pylint: disable=g-long-lambda
+
+ # tf.Variable properties.
+ _test(lambda: self.assertEqual(v.aggregation, aggregation), v)
+ _test(lambda: self.assertIs(v.constraint, None), v)
+ # TODO(crccw): should we raise an error instead?
+ _test(lambda: self.assertEqual(v.device, v._primary.device), v)
+ _test(lambda: self.assertEqual(v.dtype, dtypes.int32), v)
+ if not context.executing_eagerly():
+ _test(lambda: self.assertIs(v.graph, v._primary.graph), v)
+ if not context.executing_eagerly():
+ _test(lambda: _assert(v.initial_value == 0), v)
+ _test(lambda: self.assertIs(v.initializer, v._primary.initializer), v)
+ _test(lambda: self.assertEqual(v.name, "Variable:0"), v)
+ if not context.executing_eagerly():
+ _test(lambda: self.assertIs(v.op, v._primary.op), v)
+ _test(lambda: self.assertEqual(v.shape, tensor_shape.TensorShape(())), v)
+ _test(lambda: self.assertEqual(v.synchronization, synchronization), v)
+ _test(lambda: self.assertTrue(v.trainable, True), v)
+
+ # tf.Variable methods.
+ _test(lambda: check_ops.assert_equal_v2(v.assign(1), 1), v)
+ _test(lambda: check_ops.assert_equal_v2(v.assign_add(1), 2), v)
+ _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1), 1), v)
+ # TODO(b/148689177): Implement batch_scatter_update.
+ # count_up_to() is skipped since it's deprecated.
+ # eval() is skipped since it shouldn't called in a tf.function.
+ # experimental_ref() is skipped since it's deprecated.
+ # from_proto() is skipped since it shouldn't called in a tf.function.
+ # TODO(b/148689177): Implement gather_nd.
+ _test(
+ lambda: check_ops.assert_equal_v2(v.get_shape(),
+ tensor_shape.TensorShape(())), v)
+ # initialized_value() is skipped since it shouldn't called in a tf.function.
+ # load() is skipped since it shouldn't called in a tf.function.
+ _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1), v)
+ # ref() is skipped since it shouldn't called in a tf.function.
+ _test(
+ lambda: check_ops.assert_equal_v2(
+ w.scatter_add(_make_index_slices(values=[1., 2.], indices=[0, 2])),
+ [1., 0., 2.]), w)
+ _test(
+ lambda: check_ops.assert_equal_v2(
+ w.scatter_div(_make_index_slices(values=[4., 2.], indices=[0, 2])),
+ [0.25, 0., 1.]), w)
+ _test(
+ lambda: check_ops.assert_equal_v2(
+ w.scatter_max(_make_index_slices(values=[1., 0.5], indices=[1, 2])),
+ [0.25, 1., 1.]), w)
+ _test(
+ lambda: check_ops.assert_equal_v2(
+ w.scatter_min(_make_index_slices(values=[1., 0.5], indices=[0, 1])),
+ [0.25, 0.5, 1.]), w)
+ _test(
+ lambda: check_ops.assert_equal_v2(
+ w.scatter_mul(_make_index_slices(values=[2., 0.5], indices=[0, 1])),
+ [0.5, 0.25, 1.]), w)
+ # TODO(b/148689177): Implement scatter_nd_*
+ _test(
+ lambda: check_ops.assert_equal_v2(
+ w.scatter_sub(_make_index_slices(values=[2., 0.5], indices=[0, 1])),
+ [-1.5, -0.25, 1.]), w)
+ _test(
+ lambda: check_ops.assert_equal_v2(
+ w.scatter_update(
+ _make_index_slices(values=[2., 0.5], indices=[0, 1])),
+ [2., 0.5, 1.]), w)
+ # set_shape() is skipped since ResourceVariable doesn't implement it.
+ # to_proto() is skipped since it shouldn't called in a tf.function.
+ _test(lambda: check_ops.assert_equal_v2(v.value(), 1), v)
+
+ # DistributedVariable should be treated as ResourceVariable, so it needs to
+ # conform to ResourceVariable interface as well.
+ _test(lambda: self.assertIs(v.handle, v._primary.handle), v)
+
+ # Convert to tensor.
+ _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1), v)
+
+ # Control dependency.
+ def _with_control_dep():
+ with ops.control_dependencies([v.assign(1)]):
+ return array_ops.identity(1)
+
+ _test(_with_control_dep, v)
+
+ # Operator overloads.
+ _test(lambda: check_ops.assert_equal_v2(v.assign(7), 7), v)
+ _test(lambda: check_ops.assert_equal_v2(v + 1, 8), v)
+ _test(lambda: check_ops.assert_equal_v2(3 + v, 10), v)
+ _test(lambda: check_ops.assert_equal_v2(v + v, 14), v)
+ _test(lambda: check_ops.assert_equal_v2(v - 2, 5), v)
+ _test(lambda: check_ops.assert_equal_v2(v - v, 0), v)
+ _test(lambda: check_ops.assert_equal_v2(v * 2, 14), v)
+ _test(lambda: check_ops.assert_equal_v2(3 * v, 21), v)
+ _test(lambda: check_ops.assert_equal_v2(v * v, 49), v)
+ _test(
+ lambda: check_ops.assert_equal_v2(
+ math_ops.cast(v / 2, dtypes.float32), 3.5), v)
+ _test(
+ lambda: check_ops.assert_equal_v2(
+ math_ops.cast(14 / v, dtypes.float32), 2.), v)
+ _test(lambda: check_ops.assert_equal_v2(v // 2, 3), v)
+ _test(lambda: check_ops.assert_equal_v2(15 // v, 2), v)
+ _test(lambda: check_ops.assert_equal_v2(v % 2, 1), v)
+ _test(lambda: check_ops.assert_equal_v2(16 % v, 2), v)
+ _test(lambda: _assert(v < 12), v)
+ _test(lambda: _assert(v <= 12), v)
+ _test(lambda: _assert(not v > 12), v)
+ _test(lambda: _assert(not v >= 12), v)
+ _test(lambda: _assert(not 12 < v), v)
+ _test(lambda: _assert(not 12 <= v), v)
+ _test(lambda: _assert(12 > v), v)
+ _test(lambda: _assert(12 >= v), v)
+ # XLA doesn't implement pow() with integers.
+ _test(lambda: check_ops.assert_near_v2(pow(y, 3.), 343.), y)
+ _test(lambda: check_ops.assert_near_v2(pow(2., y), 128.), y)
+ _test(lambda: check_ops.assert_equal_v2(abs(v), 7), v)
+ _test(lambda: check_ops.assert_equal_v2(v & 3, 3), v)
+ _test(lambda: check_ops.assert_equal_v2(3 & v, 3), v)
+ _test(lambda: check_ops.assert_equal_v2(v | 8, 15), v)
+ _test(lambda: check_ops.assert_equal_v2(16 | v, 23), v)
+ _test(lambda: check_ops.assert_equal_v2(v ^ 3, 4), v)
+ _test(lambda: check_ops.assert_equal_v2(11 ^ v, 12), v)
+ _test(lambda: check_ops.assert_equal_v2(-v, -7), v)
+ _test(lambda: check_ops.assert_equal_v2(~v, ~7), v)
+
+ # Index.
+ if isinstance(distribution.extended, tpu_strategy.TPUExtended):
+ # TODO(b/161572567): slice assignment doesn't work for TPU.
+ _test(lambda: check_ops.assert_equal_v2(w[0], 2.), w)
+ else:
+ _test(lambda: check_ops.assert_equal_v2(w[0].assign(1.), [1., 0.5, 1.]),
+ w)
+ _test(lambda: check_ops.assert_equal_v2(w[0], 1.), w)
+
+ # pylint: enable=g-long-lambda
+
@combinations.generate(
combinations.combine(
diff --git a/tensorflow/python/distribute/values_util.py b/tensorflow/python/distribute/values_util.py
index 5909bdd..099184d 100644
--- a/tensorflow/python/distribute/values_util.py
+++ b/tensorflow/python/distribute/values_util.py
@@ -26,6 +26,8 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.saved_model import save_context
+from tensorflow.python.saved_model import save_options
def on_write_assign(var, value, use_locking=False, name=None, read_value=True):
@@ -247,3 +249,20 @@
"variable (variable created within certain "
"`tf.distribute.Strategy` scope) with NONE or "
"`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.")
+
+
+def is_saving_non_distributed():
+ """Returns whether we're saving a non-distributed version of the model.
+
+ It returns True iff we are in saving context and are saving a non-distributed
+ version of the model. That is, SaveOptions.experimental_variable_policy is
+ NONE.
+
+ Returns:
+ A boolean.
+ """
+ if not save_context.in_save_context():
+ return False
+ options = save_context.get_save_options()
+ return (options is not None and options.experimental_variable_policy !=
+ save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)
diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py
index a5171f3..28b8fa9 100644
--- a/tensorflow/python/saved_model/save_test.py
+++ b/tensorflow/python/saved_model/save_test.py
@@ -580,8 +580,7 @@
else:
self.assertIsNone(v1)
self.assertEmpty(v0.device)
- # TODO(b/159752793): There should be only one input here.
- self.assertLen(saved_function.signature.input_arg, 2)
+ self.assertLen(saved_function.signature.input_arg, 1)
def test_expand_distributed_variables_not_allowed(self):
root = tracking.AutoTrackable()