blob: acf6861b9a277ea5c4e9a2e27dfb35853e8b282b [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test MirroredVariable in MirroredStrategy and MultiWorkerMirroredStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.training.tracking import util as tracking_util
def _replica_id():
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
if not isinstance(replica_id, ops.Tensor):
replica_id = constant_op.constant(replica_id)
return replica_id
def _mimic_two_cpus():
cpus = config.list_physical_devices("CPU")
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
])
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
combinations.NamedDistribution(
"Collective2CPUs",
# pylint: disable=g-long-lambda
lambda: collective_all_reduce_strategy.
CollectiveAllReduceStrategy._from_local_devices((
"/device:CPU:0", "/device:CPU:1")),
required_gpus=0)
],
mode=["graph", "eager"]))
class MirroredVariableCreationTest(test.TestCase):
"""Base class that tests mirrored variable creator.
Currently it assumes all strategy objects have two replicas.
"""
@classmethod
def setUpClass(cls):
_mimic_two_cpus()
def assertAllDifferent(self, objs):
for i in range(len(objs)):
for j in range(len(objs)):
if i == j:
continue
self.assertIsNot(objs[i], objs[j])
# TODO(priyag): Modify more tests to use this helper and check more
# properties.
def _test_mv_properties(self, var, name, strategy):
self.assertTrue(distribute_utils.is_mirrored(var))
self.assertEqual(name, var.name)
self.assertIs(strategy, var.distribute_strategy)
for i, d in enumerate(var._devices):
self.assertEqual(d, strategy.experimental_local_results(var)[i].device)
self.assertIs(
strategy,
strategy.experimental_local_results(var)[i]._distribute_strategy) # pylint: disable=protected-access
def testVariableInFuncGraph(self, distribution):
def model_fn():
v = variable_scope.variable(2.0, name="bar")
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
v1 = variable_scope.variable(1.0, name="foo")
v2 = distribution.extended.call_for_each_replica(model_fn)
self._test_mv_properties(v1, "foo:0", distribution)
self._test_mv_properties(v2, "bar:0", distribution)
def testVariableWithTensorInitialValueInFunction(self, distribution):
if not context.executing_eagerly():
self.skipTest("`tf.function` is an eager-only feature")
v = [None]
def model_fn():
if v[0] is None:
init_val = array_ops.zeros([])
v[0] = variables.Variable(init_val)
ds_context.get_replica_context().merge_call(lambda _: _)
return v[0]
@def_function.function(autograph=False)
def make_v1():
return distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn))
self.assertAllEqual([0, 0], make_v1())
def testSingleVariable(self, distribution):
def model_fn():
# This variable should be created only once across the threads because of
# special variable_creator functions used by
# `distribution.extended.call_for_each_replica`.
v = variable_scope.variable(1.0, name="foo")
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self._test_mv_properties(result, "foo:0", distribution)
def testUnnamedVariable(self, distribution):
def model_fn():
v = variable_scope.variable(1.0)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self._test_mv_properties(result, "Variable:0", distribution)
def testMultipleVariables(self, distribution):
def model_fn():
vs = []
for i in range(5):
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
ds_context.get_replica_context().merge_call(lambda _: _)
return vs
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
for i, v in enumerate(result):
self._test_mv_properties(v, "foo" + str(i) + ":0", distribution)
def testMultipleVariablesWithSameCanonicalName(self, distribution):
def model_fn():
vs = []
vs.append(variable_scope.variable(1.0, name="foo/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
ds_context.get_replica_context().merge_call(lambda _: _)
return vs
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
for v in result:
self.assertTrue(distribute_utils.is_mirrored(v))
self.assertEqual(4, len(result))
self.assertEqual("foo/bar:0", result[0].name)
self.assertEqual("foo_1/bar:0", result[1].name)
self.assertEqual("foo_1/bar_1:0", result[2].name)
self.assertEqual("foo/bar_1:0", result[3].name)
def testVariableWithSameCanonicalNameAcrossThreads(self, distribution):
def model_fn():
replica_id = self.evaluate(_replica_id())
v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self.assertTrue(distribute_utils.is_mirrored(result))
# The resulting mirrored variable will use the name from the first device.
self.assertEqual("foo_0:0", result.name)
def testWithVariableAndVariableScope(self, distribution):
def model_fn():
v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
with variable_scope.variable_scope("common"):
v1 = variable_scope.variable(1.0, name="var1")
# This will pause the current thread, and execute the other thread.
ds_context.get_replica_context().merge_call(lambda _: _)
v2 = variable_scope.variable(
1.0,
name="var2",
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
v3 = variable_scope.variable(
1.0,
name="var3",
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation=variable_scope.VariableAggregation.MEAN)
return v0, v1, v2, v3
with distribution.scope():
v = variable_scope.variable(1.0, name="var-main0")
self.assertEqual("var-main0:0", v.name)
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(4, len(result))
v0, v1, v2, v3 = result
self.assertTrue(distribute_utils.is_mirrored(v0))
self.assertEqual("var0:0", v0.name)
self.assertTrue(distribute_utils.is_mirrored(v1))
self.assertEqual("common/var1:0", v1.name)
self.assertTrue(distribute_utils.is_sync_on_read(v2))
self.assertEqual("common/var2:0", v2.name)
self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
self.assertTrue(distribute_utils.is_mirrored(v3))
self.assertEqual("common/var3:0", v3.name)
self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation)
def testWithGetVariableAndVariableScope(self, distribution):
def model_fn():
v0 = variable_scope.get_variable("var0", [1])
with variable_scope.variable_scope("common"):
v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
ds_context.get_replica_context().merge_call(lambda _: _)
v2 = variable_scope.get_variable(
"var2", [1],
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
v3 = variable_scope.get_variable(
"var3", [1],
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation=variable_scope.VariableAggregation.MEAN)
return v0, v1, v2, v3
with distribution.scope():
with variable_scope.variable_scope("main"):
v = variable_scope.get_variable("var-main0", [1])
self.assertEqual("main/var-main0:0", v.name)
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(4, len(result))
v0, v1, v2, v3 = result
self.assertTrue(distribute_utils.is_mirrored(v0))
self.assertEqual("main/var0:0", v0.name)
self.assertTrue(distribute_utils.is_mirrored(v1))
self.assertEqual("main/common/var1:0", v1.name)
self.assertTrue(distribute_utils.is_sync_on_read(v2))
self.assertEqual("main/common/var2:0", v2.name)
self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
self.assertTrue(distribute_utils.is_mirrored(v3))
self.assertEqual("main/common/var3:0", v3.name)
self.assertEqual(variable_scope.VariableAggregation.MEAN,
v3.aggregation)
def testOnlyFirstReplicaUpdatesVariables(self, distribution):
def create_fn():
aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
v0 = variable_scope.variable(
2.0,
name="on_read",
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=aggregation)
v1 = variable_scope.variable(
3.0,
name="on_write",
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation=aggregation)
return v0, v1
with distribution.scope():
v0, v1 = distribution.extended.call_for_each_replica(create_fn)
self.evaluate(v0.initializer)
self.assertEqual(
2.0, self.evaluate(distribution.experimental_local_results(v0)[0]))
self.assertEqual(
2.0, self.evaluate(distribution.experimental_local_results(v0)[1]))
self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0)))
self.evaluate(v1.initializer)
self.assertEqual(
3.0, self.evaluate(distribution.experimental_local_results(v1)[0]))
self.assertEqual(
3.0, self.evaluate(distribution.experimental_local_results(v1)[1]))
self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1)))
def replica_id_plus_one():
return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32)
# Update using the assign_add member function.
def update_member_fn():
update0 = v0.assign_add(5.0 * replica_id_plus_one())
update1 = v1.assign_add(7.0 * replica_id_plus_one())
return update0, update1
update0a, update1a = distribution.extended.call_for_each_replica(
update_member_fn)
# Update "sync on read" variable.
self.evaluate(distribution.group(update0a))
local_results = self.evaluate(distribution.experimental_local_results(v0))
self.assertEqual(2.0 + 5.0, local_results[0])
# Writes are not synchronized for "sync on read" variables,
# so device[1] can end up with a different value.
self.assertEqual(2.0 + 2 * 5.0, local_results[1])
# Always reads from device 0.
self.assertEqual(2.0 + 5.0,
self.evaluate(distribution.extended.read_var(v0)))
# Update "sync on write" variable.
self.evaluate(distribution.group(update1a))
local_results1 = self.evaluate(
distribution.experimental_local_results(v1))
self.assertEqual(3.0 + 7.0, local_results1[0])
# Writes are synchronized for v1, only the argument to assign_add on
# device[0] is used.
self.assertEqual(3.0 + 7.0, local_results1[1])
self.assertEqual(3.0 + 7.0,
self.evaluate(distribution.extended.read_var(v1)))
# Update using state_ops.assign_add global function.
def update_state_ops_fn():
update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one())
update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one())
return update0, update1
update0b, update1b = distribution.extended.call_for_each_replica(
update_state_ops_fn)
self.evaluate(distribution.group(update0b))
# Update "sync on read" variable.
local_results = self.evaluate(distribution.experimental_local_results(v0))
self.assertEqual(2.0 + 5.0 + 11.0, local_results[0])
self.assertEqual(2.0 + 2 * 5.0 + 2 * 11.0, local_results[1])
self.assertEqual(2.0 + 5.0 + 11.0,
self.evaluate(distribution.extended.read_var(v0)))
# Update "sync on write" variable.
self.evaluate(distribution.group(update1b))
local_results1 = self.evaluate(
distribution.experimental_local_results(v1))
self.assertEqual(3.0 + 7.0 + 13.0, local_results1[0])
self.assertEqual(3.0 + 7.0 + 13.0, local_results1[1])
self.assertEqual(3.0 + 7.0 + 13.0,
self.evaluate(distribution.extended.read_var(v1)))
def testNoneSynchronizationWithGetVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegex(
ValueError, "`NONE` variable synchronization mode is not "
"supported with "):
variable_scope.get_variable(
"v", [1],
synchronization=variable_scope.VariableSynchronization.NONE)
def testNoneSynchronizationWithVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegex(
ValueError, "`NONE` variable synchronization mode is not "
"supported with "):
variable_scope.variable(
1.0,
name="v",
synchronization=variable_scope.VariableSynchronization.NONE)
def testInvalidSynchronizationWithVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegex(
ValueError, "Invalid variable synchronization mode: Invalid for "
"variable: v"):
variable_scope.variable(1.0, name="v", synchronization="Invalid")
def testInvalidAggregationWithGetVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegex(
ValueError, "Invalid variable aggregation mode: invalid for "
"variable: v"):
variable_scope.get_variable(
"v", [1],
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation="invalid")
def testInvalidAggregationWithVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegex(
ValueError, "Invalid variable aggregation mode: invalid for "
"variable: v"):
variable_scope.variable(
1.0,
name="v",
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation="invalid")
def testNonMatchingVariableCreation(self, distribution):
def model_fn(name):
v = variable_scope.variable(1.0, name=name)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
names = values.PerReplica(("foo", "bar"))
with self.assertRaises(RuntimeError):
_ = distribution.extended.call_for_each_replica(model_fn, args=(names,))
def testSyncOnReadVariable(self, distribution):
if context.executing_eagerly():
self.skipTest("Skip the test due to b/137400477.")
all_v_sum = {}
all_v_mean = {}
components_sum = {}
components_mean = {}
def model_fn():
replica_id = self.evaluate(_replica_id())
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
v_mean = variable_scope.variable(
4.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.MEAN)
self.assertTrue(distribute_utils.is_sync_on_read(v_sum))
self.assertTrue(distribute_utils.is_sync_on_read(v_mean))
updates = [
v_sum.assign_add(2.0 + replica_id),
v_mean.assign(6.0 * replica_id)
]
all_v_sum[replica_id] = v_sum
all_v_mean[replica_id] = v_mean
c_sum = v_sum._get()
c_mean = v_mean._get()
components_sum[replica_id] = c_sum
components_mean[replica_id] = c_mean
self.assertIsNot(v_sum, c_sum)
self.assertIsNot(v_mean, c_mean)
return updates, v_sum, v_mean, c_sum, c_mean
with distribution.scope():
# Create "sum" and "mean" versions of SyncOnReadVariables.
ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
distribution.extended.call_for_each_replica(model_fn))
# Should see the same wrapping instance in all replicas.
self.assertIs(all_v_sum[0], ret_v_sum)
self.assertIs(all_v_mean[0], ret_v_mean)
self.assertIs(all_v_sum[0], all_v_sum[1])
self.assertIs(all_v_mean[0], all_v_mean[1])
# Regroup should recover the same wrapper.
self.assertIs(ret_v_sum, regrouped_sum)
self.assertIs(ret_v_mean, regrouped_mean)
self.assertIsNot(components_sum[0], components_sum[1])
self.assertIsNot(components_mean[0], components_mean[1])
# Apply updates
self.evaluate(variables.global_variables_initializer())
self.evaluate([
y for x in ret_ops # pylint: disable=g-complex-comprehension
for y in distribution.experimental_local_results(x)
])
expected_sum = 0.0
expected_mean = 0.0
for i, _ in enumerate(distribution.extended.worker_devices):
# Should see different values on different devices.
v_sum_value = self.evaluate(
distribution.experimental_local_results(ret_v_sum)[i].read_value())
v_mean_value = self.evaluate(
distribution.experimental_local_results(ret_v_mean)[i].read_value())
expected = i + 3.0
self.assertEqual(expected, v_sum_value)
expected_sum += expected
expected = i * 6.0
self.assertEqual(expected, v_mean_value)
expected_mean += expected
expected_mean /= len(distribution.extended.worker_devices)
# Without get(device), should return the value you get by
# applying the reduction across all replicas (whether you use
# read_var(), get(), or nothing).
self.assertEqual(expected_sum, self.evaluate(
distribution.extended.read_var(ret_v_sum)))
self.assertEqual(expected_mean, self.evaluate(
distribution.extended.read_var(ret_v_mean)))
self.assertEqual(expected_sum, self.evaluate(ret_v_sum._get()))
self.assertEqual(expected_mean, self.evaluate(ret_v_mean._get()))
self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
# TODO(priyag): Update this test to work in eager mode as well.
def testDynamicRnnVariables(self, distribution):
def model_fn():
inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]])
cell_fw = rnn_cell_impl.LSTMCell(300)
cell_bw = rnn_cell_impl.LSTMCell(300)
(outputs, _) = rnn.bidirectional_dynamic_rnn(
cell_fw, cell_bw, inputs, dtype=dtypes.float32)
return outputs
with context.graph_mode(), distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
# Two variables are created by the RNN layer.
self.assertEqual(2, len(result))
for v in result:
self.assertIsInstance(v, values.DistributedValues)
_, v1 = distribution.experimental_local_results(v)
self.assertStartsWith(v1._op.name, "replica_1/")
def testSyncOnReadVariableUpdate(self, distribution):
if context.executing_eagerly():
self.skipTest("Skip the test due to b/137400477.")
def model_fn():
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
self.assertTrue(distribute_utils.is_sync_on_read(v_sum))
return v_sum
def update(var, value):
return var.assign(value)
with distribution.scope():
ret_v_sum = distribution.extended.call_for_each_replica(model_fn)
# Initialize variables.
self.evaluate(variables.global_variables_initializer())
# Assert that the aggregated value of the sync on read var is the sum
# of the individual values before running the update ops.
self.assertEqual(
1.0,
self.evaluate(
distribution.experimental_local_results(ret_v_sum)
[0].read_value()))
self.assertEqual(2.0, self.evaluate(ret_v_sum))
# Apply updates.
update_ops = distribution.extended.update(
ret_v_sum, update, args=(5.0,), group=False)
self.evaluate(update_ops)
# Assert that the aggregated value of the sync on read vars is the sum
# of the individual values after running the update ops.
self.assertEqual(
5.0,
self.evaluate(
distribution.experimental_local_results(ret_v_sum)
[0].read_value()))
self.assertEqual(10.0, self.evaluate(ret_v_sum))
def testVarDistributeStrategy(self, distribution):
with distribution.scope():
mirrored = variable_scope.variable(1.0)
sync_on_read = variable_scope.variable(
1.0, synchronization=variable_scope.VariableSynchronization.ON_READ)
self.assertIs(distribution, mirrored.distribute_strategy)
self.assertIs(distribution, sync_on_read.distribute_strategy)
def testInitializer(self, distribution, mode):
if mode == "graph":
self.skipTest("Skip graph mode")
temp_dir = self.get_temp_dir()
class Model(tracking_util.Checkpoint):
def __init__(self):
self._v = variables.Variable(1.0)
with distribution.scope():
m = Model()
save.save(m, temp_dir)
g = ops.Graph()
with g.as_default():
with distribution.scope():
load.load(temp_dir)
for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES):
self.assertIsNotNone(v.initializer)
if __name__ == "__main__":
test.main()