blob: 8f94f3907405cd599b61fdb84b5848ce06c7a90e [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for MirroredStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import sys
from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import values
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras.layers import core as keras_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import server_lib
GPU_TEST = "test_gpu" in sys.argv[0]
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
mode=["graph", "eager"]))
class MirroredTwoDeviceDistributionTest(
strategy_test_lib.DistributionTestBase,
strategy_test_lib.TwoDeviceDistributionTestBase,
parameterized.TestCase):
def testMinimizeLoss(self, distribution):
if context.executing_eagerly():
self._test_minimize_loss_eager(distribution)
else:
self._test_minimize_loss_graph(distribution)
def testReplicaId(self, distribution):
self._test_replica_id(distribution)
def testNumReplicasInSync(self, distribution):
self.assertEqual(2, distribution.num_replicas_in_sync)
def testCallAndMergeExceptions(self, distribution):
self._test_call_and_merge_exceptions(distribution)
def testRunRegroupError(self, distribution):
def run_fn():
replica_id = int(self.evaluate(_replica_id()))
# Generates a list with different lengths on different devices.
# Will fail in _regroup() (if more than one device).
return list(range(replica_id))
with distribution.scope(), self.assertRaises(AssertionError):
distribution.extended.call_for_each_replica(run_fn)
def testReduceToCpu(self, distribution):
with distribution.scope():
result = distribution.extended.call_for_each_replica(_replica_id)
reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=None)
expected = sum(range(distribution.num_replicas_in_sync))
self.assertEqual(expected, self.evaluate(reduced))
def reduce_axis_helper(self, distribution, replica_squared_fn):
with distribution.scope():
num_replicas = distribution.num_replicas_in_sync
result = distribution.extended.call_for_each_replica(replica_squared_fn)
# sum
reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=0)
expected = sum(x * (x + 1) for x in range(num_replicas))
self.assertNear(expected, self.evaluate(reduced), 0.00001)
# mean
reduced = distribution.reduce(reduce_util.ReduceOp.MEAN, result, axis=0)
expected /= sum(x + 1 for x in range(num_replicas))
self.assertNear(expected, self.evaluate(reduced), 0.00001)
def testReduceAxisToCpu(self, distribution):
for dtype in (dtypes.float32, dtypes.int32):
def replica_squared_fn(dtype=dtype):
# Lists with different lengths on different replicas.
replica_id = _replica_id_as_int()
return math_ops.cast([replica_id] * (replica_id + 1), dtype)
self.reduce_axis_helper(distribution, replica_squared_fn)
def set_v2_tensorshape(self, v2):
if v2:
tensor_shape.enable_v2_tensorshape()
else:
tensor_shape.disable_v2_tensorshape()
def testReduceAxisToCpuUnknownShape(self, distribution):
original_v2 = tensor_shape._TENSORSHAPE_V2_OVERRIDE # pylint: disable=protected-access
try:
for v2 in (False, True):
self.set_v2_tensorshape(v2)
for dtype in (dtypes.float32, dtypes.int32):
for shape in ((None,), None): # Test both unknown size and rank.
def replica_squared_fn(dtype=dtype, shape=shape):
# Lists with different lengths on different replicas.
replica_id = _replica_id_as_int()
tensor = math_ops.cast([replica_id] * (replica_id + 1), dtype)
# Erase shape information
return array_ops.placeholder_with_default(tensor, shape=shape)
self.reduce_axis_helper(distribution, replica_squared_fn)
finally:
self.set_v2_tensorshape(original_v2)
def testReplicateDataset(self, distribution):
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
input_fn = self._input_fn_to_test_input_context(
dataset_fn,
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
self._test_input_fn_iterable(distribution, input_fn, expected_values)
def testMakeInputFnIteratorWithDataset(self, distribution):
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
input_fn = self._input_fn_to_test_input_context(
dataset_fn,
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(iterator, distribution.extended.worker_devices,
expected_values)
def testMakeInputFnIteratorWithCallable(self, distribution):
def fn():
dataset = dataset_ops.Dataset.range(2).interleave(
(lambda _: dataset_ops.Dataset.range(10)), cycle_length=2)
it = dataset.make_one_shot_iterator()
return it.get_next
expected_values = [[i, i] for i in range(0, 10)]
input_fn = self._input_fn_to_test_input_context(
fn,
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(iterator, distribution.extended.worker_devices,
expected_values, test_reinitialize=False,
ignore_order=True)
def testNumpyDataset(self, distribution):
self._test_numpy_dataset(distribution)
def testGlobalStepUpdate(self, distribution):
self._test_global_step_update(distribution)
def testRun(self, distribution):
self._test_run(distribution)
def testAllReduceSum(self, distribution):
self._test_all_reduce_sum(distribution)
def testAllReduceSumGradients(self, distribution):
self._test_all_reduce_sum_gradients(distribution)
def testAllReduceSumGradientTape(self, distribution):
self._test_all_reduce_sum_gradient_tape(distribution)
def testAllReduceMean(self, distribution):
self._test_all_reduce_mean(distribution)
def testAllReduceMeanGradients(self, distribution):
self._test_all_reduce_mean_gradients(distribution)
def testAllReduceMeanGradientTape(self, distribution):
self._test_all_reduce_mean_gradient_tape(distribution)
def testSummaryForReplicaZeroOnly(self, distribution):
self._test_summary_for_replica_zero_only(distribution)
def testTrainableVariables(self, distribution):
self._test_trainable_variable(distribution)
def one_device_combinations():
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_one_gpu,
],
mode=["graph", "eager"])
@combinations.generate(one_device_combinations())
class MirroredOneDeviceDistributionTest(
strategy_test_lib.DistributionTestBase,
strategy_test_lib.OneDeviceDistributionTestBase,
parameterized.TestCase):
def testMinimizeLoss(self, distribution):
if context.executing_eagerly():
self._test_minimize_loss_eager(distribution)
else:
self._test_minimize_loss_graph(distribution)
def testReplicaId(self, distribution):
self._test_replica_id(distribution)
def testCallAndMergeExceptions(self, distribution):
self._test_call_and_merge_exceptions(distribution)
def testRun(self, distribution):
self._test_run(distribution)
def testAllReduceSum(self, distribution):
self._test_all_reduce_sum(distribution)
def testAllReduceSumGradients(self, distribution):
self._test_all_reduce_sum_gradients(distribution)
def testAllReduceSumGradientTape(self, distribution):
self._test_all_reduce_sum_gradient_tape(distribution)
def testAllReduceMean(self, distribution):
self._test_all_reduce_mean(distribution)
def testAllReduceMeanGradients(self, distribution):
self._test_all_reduce_mean_gradients(distribution)
def testAllReduceMeanGradientTape(self, distribution):
self._test_all_reduce_mean_gradient_tape(distribution)
class MirroredStrategyVariableCreatorStackTest(
test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph"]))
def testCreatorStacksAreThreadLocal(self, distribution):
def model_fn():
replica_id_str = str(self.evaluate(_replica_id()))
def thread_creator_fn(next_creator, *args, **kwargs):
return next_creator(*args, **kwargs) + ":thread_" + replica_id_str
with variable_scope.variable_creator_scope(thread_creator_fn):
# Create a variable in this scope.
v = variable_scope.variable(1.0)
# This will pause the current thread, and execute the other thread.
ds_context.get_replica_context().merge_call(lambda _: _)
return v
def main_thread_creator(next_creator, *args, **kwargs):
# We are not using the underlying next_creator for test purposes.
del next_creator, args, kwargs
return "main_thread"
with context.graph_mode(), \
distribution.scope(), \
variable_scope.variable_creator_scope(main_thread_creator):
result = distribution.extended.call_for_each_replica(model_fn)
result = distribution.experimental_local_results(result)
expected = ("main_thread:thread_0", "main_thread:thread_1")
self.assertEqual(expected, result)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
class MirroredStrategyCallForEachReplicaTest(test.TestCase):
def testExecutingEagerlyOutsideFunction(self, distribution):
"""Verify we preserve the value of executing_eagerly_outside_functions()."""
def model_fn():
return ops.executing_eagerly_outside_functions()
originally = ops.executing_eagerly_outside_functions()
with distribution.scope():
in_scope = ops.executing_eagerly_outside_functions()
in_model_fn = distribution.extended.call_for_each_replica(model_fn)
unwrapped = distribution.experimental_local_results(in_model_fn)
self.assertEqual(in_scope, unwrapped[0])
self.assertEqual(in_scope, originally)
# Verify this all again, but this time in a FuncGraph.
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
in_scope = ops.executing_eagerly_outside_functions()
in_model_fn = distribution.extended.call_for_each_replica(model_fn)
unwrapped = distribution.experimental_local_results(in_model_fn)
self.assertEqual(in_scope, unwrapped[0])
self.assertEqual(in_scope, originally)
def testFunctionInCallForEachReplica(self, distribution):
traces = []
@def_function.function
def model_fn():
traces.append(1)
return ds_context.get_replica_context().replica_id_in_sync_group
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual((0, 1), self.evaluate(result.values))
self.assertLen(traces, distribution.num_replicas_in_sync)
def testFunctionInCallForEachReplicaInsideAnotherFunction(self, distribution):
traces = []
@def_function.function
def model_fn():
traces.append(1)
return ds_context.get_replica_context().replica_id_in_sync_group
@def_function.function
def step():
return distribution.extended.call_for_each_replica(model_fn)
with distribution.scope():
result = step()
self.assertEqual((0, 1), self.evaluate(result.values))
self.assertLen(traces, distribution.num_replicas_in_sync)
def testFunctionInCallForEachReplicaWithMergeCall(self, distribution):
def merge_fn(_):
pass
@def_function.function
def model_fn():
ds_context.get_replica_context().merge_call(merge_fn)
return 0.
with distribution.scope():
with self.assertRaisesRegexp(
RuntimeError, "`merge_call` called while defining a new graph."):
distribution.extended.call_for_each_replica(model_fn)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph"]))
class MirroredStrategyNameScopeTest(test.TestCase):
# NOTE(priyag): Names and name scopes are ignored in eager, hence we are not
# testing this in eager mode.
def testNameScope(self, distribution):
def model_fn():
with ops.name_scope("foo"):
a = constant_op.constant(1.0, name="a")
ds_context.get_replica_context().merge_call(lambda _: _)
b = constant_op.constant(1.0, name="b")
return a, b
with context.graph_mode(), distribution.scope():
with ops.name_scope("main"):
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(2, len(result))
for v, name in zip(result, ["a", "b"]):
self.assertIsInstance(v, values.DistributedValues)
v0, v1 = distribution.experimental_local_results(v)
self.assertEqual("main/foo/" + name + ":0", v0.name)
self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name)
def testWithDefaultName(self, distribution):
def model_fn():
with ops.name_scope(None, "foo"):
a = constant_op.constant(1.0, name="a")
ds_context.get_replica_context().merge_call(lambda _: _)
b = constant_op.constant(2.0, name="b")
return a, b
with context.graph_mode(), distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(2, len(result))
for v, name in zip(result, ["a", "b"]):
self.assertIsInstance(v, values.DistributedValues)
v0, v1 = distribution.experimental_local_results(v)
self.assertEqual("foo/" + name + ":0", v0.name)
self.assertEqual("replica_1/foo/" + name + ":0", v1.name)
# variable_scope.variable() respects name scopes when creating
# variables. On the other hand variable_scope.get_variable() ignores name
# scopes but respects variable scope when creating variables. We test both
# methods of creating variables to make sure that we have the same
# variable names in both cases.
def testNameScopeWithVariable(self, distribution):
def in_cross_replica(_):
c = variable_scope.variable(1.0, name="c")
return c
def model_fn():
b = variable_scope.variable(1.0, name="b")
with ops.name_scope("foo"):
c = ds_context.get_replica_context().merge_call(in_cross_replica)
return b, c
with context.graph_mode(), distribution.scope():
with ops.name_scope("main"):
a = variable_scope.variable(1.0, name="a")
result = distribution.extended.call_for_each_replica(model_fn)
result_b = result[0]
result_c = result[1]
self.assertIsInstance(result_b, values.DistributedValues)
self.assertIsInstance(result_c, values.DistributedValues)
a0, a1 = distribution.experimental_local_results(a)
b0, b1 = distribution.experimental_local_results(result_b)
c0, c1 = distribution.experimental_local_results(result_c)
self.assertEqual("main/a:0", a0.name)
self.assertEqual("main/a/replica_1:0", a1.name)
self.assertEqual("main/b:0", b0.name)
self.assertEqual("main/b/replica_1:0", b1.name)
self.assertEqual("main/foo/c:0", c0.name)
self.assertEqual("main/foo/c/replica_1:0", c1.name)
def testNameScopeWithGetVariable(self, distribution):
def in_cross_replica(_):
c = variable_scope.get_variable("c", [1])
return c
def model_fn():
b = variable_scope.get_variable("b", [1])
with ops.name_scope("foo"):
c = ds_context.get_replica_context().merge_call(in_cross_replica)
return b, c
with context.graph_mode(), distribution.scope():
with ops.name_scope("main"):
a = variable_scope.get_variable("a", [1])
result = distribution.extended.call_for_each_replica(model_fn)
result_b = result[0]
result_c = result[1]
self.assertIsInstance(result_b, values.DistributedValues)
self.assertIsInstance(result_c, values.DistributedValues)
a0, a1 = distribution.experimental_local_results(a)
b0, b1 = distribution.experimental_local_results(result_b)
c0, c1 = distribution.experimental_local_results(result_c)
self.assertEqual("a:0", a0.name)
self.assertEqual("a/replica_1:0", a1.name)
self.assertEqual("b:0", b0.name)
self.assertEqual("b/replica_1:0", b1.name)
self.assertEqual("c:0", c0.name)
self.assertEqual("c/replica_1:0", c1.name)
def testVariableScopeWithGetVariable(self, distribution):
def in_cross_replica(_):
c = variable_scope.get_variable("c", [1])
return c
def model_fn():
b = variable_scope.get_variable("b", [1])
with variable_scope.variable_scope("foo"):
c = ds_context.get_replica_context().merge_call(in_cross_replica)
return b, c
with context.graph_mode(), distribution.scope():
with variable_scope.variable_scope("main"):
a = variable_scope.get_variable("a", [1])
result = distribution.extended.call_for_each_replica(model_fn)
result_b = result[0]
result_c = result[1]
self.assertIsInstance(result_b, values.DistributedValues)
self.assertIsInstance(result_c, values.DistributedValues)
a0, a1 = distribution.experimental_local_results(a)
b0, b1 = distribution.experimental_local_results(result_b)
c0, c1 = distribution.experimental_local_results(result_c)
self.assertEqual("main/a:0", a0.name)
self.assertEqual("main/a/replica_1:0", a1.name)
self.assertEqual("main/b:0", b0.name)
self.assertEqual("main/b/replica_1:0", b1.name)
self.assertEqual("main/foo/c:0", c0.name)
self.assertEqual("main/foo/c/replica_1:0", c1.name)
@combinations.generate(
combinations.combine(
distribution=[
combinations.NamedDistribution(
"Mirrored3Devices",
# pylint: disable=g-long-lambda
lambda: mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]),
required_gpus=2)
],
mode=["graph", "eager"]))
class MirroredThreeDeviceDistributionTest(
strategy_test_lib.DistributionTestBase,
parameterized.TestCase):
def testThreeDevices(self, distribution):
def model_fn():
v = variable_scope.variable(1.0, name="foo")
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self.assertIsInstance(result, values.MirroredVariable)
self.assertEqual("foo:0", result.name)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
class MirroredVariableUpdateTest(test.TestCase):
# The following tests check assign, assign_add and assign_sub on Mirrored
# variables in replica and cross replica context.
def testAssignMirroredVarReplicaContextWithoutAggregationType(self,
distribution):
# Test that we always have an aggregation type set on the mirrored variable
# if we assign to it in replica mode.
def var_fn():
v = variable_scope.variable(1.0, name="foo")
return v
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
def model_fn():
return mirrored_var.assign(5.0)
with self.assertRaisesRegexp(
ValueError, "You must specify an aggregation method to update a "
"MirroredVariable in Replica Context. You can do so by"):
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
def testAssignMirroredVarReplicaContextWithSum(self, distribution):
# Test that we don't reduce a non-per-replica value with the "sum"
# aggregation type.
def var_fn():
v = variable_scope.variable(
1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM)
return v
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
def model_fn():
return mirrored_var.assign(5.0)
with self.assertRaisesRegexp(
ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
"with the given reduce op ReduceOp.SUM."):
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
def testAssignMirroredVarCrossDeviceContext(self, distribution):
def var_fn():
return variable_scope.variable(1.0, name="foo")
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
self.assertEqual(6.0, mirrored_var_result)
def testAssignMirroredVarReplicaContext(self, distribution):
def var_fn():
return variable_scope.variable(
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
def model_fn():
value = math_ops.cast(
ds_context.get_replica_context().replica_id_in_sync_group,
mirrored_var.dtype)
return mirrored_var.assign(value)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(0.5, self.evaluate(mirrored_var))
def testAssignMirroredVarReplicaContextWithSingleValue(self, distribution):
def var_fn():
return variable_scope.variable(
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
def model_fn():
return mirrored_var.assign(5.0)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(5.0, self.evaluate(mirrored_var))
def testAssignAddMirroredVarCrossDeviceContext(self, distribution):
def var_fn():
return variable_scope.variable(1.0, name="foo")
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
# read_value == True
mirrored_var_result = self.evaluate(
mirrored_var.assign_add(6.0, read_value=True))
self.assertEqual(7.0, mirrored_var_result)
self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
# read_value == False
self.evaluate(mirrored_var.assign_add(2.0, read_value=False))
self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
def testAssignAddMirroredVarReplicaContext(self, distribution):
def var_fn():
return variable_scope.variable(
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
def model_fn():
value = math_ops.cast(
ds_context.get_replica_context().replica_id_in_sync_group,
mirrored_var.dtype)
return mirrored_var.assign_add(value)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(1.5, self.evaluate(mirrored_var))
def testAssignAddMirroredVarReplicaContextWithSingleValue(self, distribution):
def var_fn():
return variable_scope.variable(
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
def model_fn():
return mirrored_var.assign_add(5.0)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(6.0, self.evaluate(mirrored_var))
def testAssignSubMirroredVarCrossDeviceContext(self, distribution):
def var_fn():
return variable_scope.variable(5.0, name="foo")
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(5.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
self.assertEqual(3.0, mirrored_var_result)
self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
def testAssignSubMirroredVarReplicaContext(self, distribution):
def var_fn():
return variable_scope.variable(
5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(5.0, self.evaluate(mirrored_var))
def model_fn():
value = math_ops.cast(
ds_context.get_replica_context().replica_id_in_sync_group,
mirrored_var.dtype)
return mirrored_var.assign_sub(value)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(4.5, self.evaluate(mirrored_var))
def testAssignSubMirroredVarReplicaContextWithSingleValue(self, distribution):
def var_fn():
return variable_scope.variable(
5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(5.0, self.evaluate(mirrored_var))
def model_fn():
return mirrored_var.assign_sub(1.0)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(4.0, self.evaluate(mirrored_var))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
class MirroredAndSyncOnReadVariableInitializerTest(test.TestCase):
def testAssignMirroredVarInitializer(self, distribution):
# This test is not eager compatible since in eager variables are initialized
# upon construction instead of once the initialization op is run.
with context.graph_mode():
def var_fn():
v = variable_scope.variable(1.0, name="foo")
return v
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.assertFalse(self.evaluate(mirrored_var.is_initialized()))
self.evaluate(mirrored_var.initializer)
self.assertTrue(self.evaluate(mirrored_var.is_initialized()))
def testAssignReplicaLocalVarInitializer(self, distribution):
# This test is not eager compatible since in eager variables are initialized
# upon construction instead of once the initialization op is run.
with context.graph_mode():
def model_fn():
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
return v_sum
with distribution.scope():
sync_on_read_var = distribution.extended.call_for_each_replica(
model_fn)
self.assertIsInstance(sync_on_read_var, values.SyncOnReadVariable)
self.assertFalse(self.evaluate(sync_on_read_var.is_initialized()))
self.evaluate(sync_on_read_var.initializer)
self.assertTrue(self.evaluate(sync_on_read_var.is_initialized()))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
class SyncOnReadVariableAssignTest(test.TestCase):
def testAssignReplicaLocalVarSumAggregation(self, distribution):
def model_fn():
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
return v_sum
with distribution.scope():
sync_on_read_var = distribution.extended.call_for_each_replica(model_fn)
self.assertIsInstance(sync_on_read_var, values.SyncOnReadVariable)
self.evaluate(variables.global_variables_initializer())
# Each replica has a value of 1.0 assigned to it in replica context.
# When we read the value using `read_var` we should see the SUM of each of
# values on each of the replicas.
self.assertEqual(2.0, self.evaluate(
distribution.extended.read_var(sync_on_read_var)))
# Assigning 6.0 in cross replica context will assign a value of
# 6.0/num_replicas to each replica.
tlv_ops = sync_on_read_var.assign(6.0)
self.evaluate(tlv_ops)
# On reading the sync on read var we should get the assigned value back.
# The value on all the replicas are added before being returned by
# `read_var`.
self.assertEqual(6.0, self.evaluate(
distribution.extended.read_var(sync_on_read_var)))
def testAssignReplicaLocalVarMeanAggregation(self, distribution):
def model_fn():
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.MEAN)
return v_sum
with distribution.scope():
sync_on_read_var = distribution.extended.call_for_each_replica(model_fn)
self.assertIsInstance(sync_on_read_var, values.SyncOnReadVariable)
self.evaluate(variables.global_variables_initializer())
# Each replica has a value of 1.0 assigned to it in replica context.
# When we read the value using `read_var` we should see the MEAN of values
# on all replicas which is the value assigned in replica context.
self.assertEqual(1.0, self.evaluate(
distribution.extended.read_var(sync_on_read_var)))
tlv_ops = sync_on_read_var.assign(6.0)
self.evaluate(tlv_ops)
# On reading the sync on read var we should get the MEAN of all values
# which is equal to the value assigned.
self.assertEqual(6.0, self.evaluate(
distribution.extended.read_var(sync_on_read_var)))
class MockModel(object):
def __init__(self, two_variables=False):
self.variables = []
self.variables.append(variable_scope.variable(1.25, name="dummy_var1"))
if two_variables:
self.variables.append(variable_scope.variable(2.0, name="dummy_var2"))
def __call__(self, factor=2):
x = factor * self.variables[0]
if len(self.variables) > 1:
x += self.variables[1]
return x
class MiniModel(keras_training.Model):
"""Minimal model for mnist.
Useful for testing and debugging on slow TPU simulators.
"""
def __init__(self):
super(MiniModel, self).__init__(name="")
self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones",
bias_initializer="ones")
def call(self, inputs, training=True):
inputs = array_ops.ones([1, 10])
return self.fc(inputs)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
class MirroredStrategyDefunTest(test.TestCase):
def _call_and_check(self, distribution, model_fn, inputs, expected_result,
defuns, two_variables=False):
cpu_dev = device_util.canonicalize("CPU:0")
gpu_dev = device_util.canonicalize("GPU:0")
devices = [cpu_dev, gpu_dev]
with distribution.scope():
mock_model = MockModel(two_variables)
self.evaluate(variables.global_variables_initializer())
result = distribution.extended.call_for_each_replica(
model_fn, args=[mock_model] + inputs)
for r in range(len(devices)):
device_result = values.select_replica(r, result)
device_expected_result = values.select_replica(r, expected_result)
self.assertAllClose(device_expected_result,
self.evaluate(device_result))
for defun in defuns:
# `Function`s are specialized to the current device stack, so
# call_for_each has one trace per device. To check that the expected set
# of variables was accessed on each trace, we first retrieve each
# device-specific graph function.
per_replica_graph_functions = (
distribution.extended.call_for_each_replica(
defun.get_concrete_function, args=[mock_model] + inputs))
for device in devices:
graph_function = per_replica_graph_functions.get(device=device)
# TODO(b/129555712): re-enable an assertion here that the two sets of
# variables are the same.
# self.assertEqual(set(graph_function.graph.variables),
# set(mock_model.variables))
del graph_function
def testVariableInDefun(self, distribution):
@function.defun
def times_two(mock_model):
return mock_model()
def model_fn(mock_model):
return times_two(mock_model)
self._call_and_check(distribution, model_fn, [], 2.5, [times_two])
def testVariableInNestedDefun(self, distribution):
@function.defun
def times_two(mock_model):
return mock_model()
@function.defun
def two_x_plus_one(mock_model):
return times_two(mock_model) + 1
def model_fn(mock_model):
return two_x_plus_one(mock_model)
self._call_and_check(distribution, model_fn, [], 3.5,
[times_two, two_x_plus_one])
def testTwoVariablesInNestedDefun(self, distribution):
@function.defun
def fn1(mock_model):
return mock_model()
@function.defun
def fn2(mock_model):
return fn1(mock_model) + 1
def model_fn(mock_model):
return fn2(mock_model)
self._call_and_check(distribution, model_fn, [], 5.5, [fn1, fn2],
two_variables=True)
def testGradientTapeOverNestedDefuns(self, distribution):
@function.defun
def fn1(mock_model):
return mock_model()
@function.defun
def fn2(mock_model):
return fn1(mock_model) + 1
def model_fn(mock_model):
with backprop.GradientTape(persistent=True) as gtape:
result = fn2(mock_model)
grads = gtape.gradient(result,
[v.get() for v in mock_model.variables])
return grads
self._call_and_check(distribution, model_fn, [], [2.0, 1.0], [fn1, fn2],
two_variables=True)
def testPassPerReplica(self, distribution):
@function.defun
def fn1(mock_model, factor):
return mock_model(factor)
device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
factors = values.PerReplica(device_map, (5.0, 3.0))
expected_result = values.PerReplica(device_map, (5.0 * 1.25, 3.0 * 1.25))
self._call_and_check(distribution, fn1, [factors], expected_result, [fn1])
def testTrain(self, distribution):
with distribution.scope():
mock_model = MiniModel()
mock_model.call = function.defun(mock_model.call)
def loss_fn(ctx):
del ctx
return mock_model(array_ops.ones([1, 10]))
gradients_fn = backprop.implicit_grad(loss_fn)
gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
grads_and_vars = distribution.extended.call_for_each_replica(
gradients_fn, args=(None,))
optimizer = gradient_descent.GradientDescentOptimizer(0.25)
update_ops = optimizer._distributed_apply(distribution, grads_and_vars) # pylint: disable=protected-access
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.evaluate(update_ops)
updated_var_values = self.evaluate(mock_model.variables)
# All variables start at 1.0 and get two updates of 0.25.
self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0])
self.assertAllEqual([0.5], updated_var_values[1])
@combinations.generate(
combinations.combine(
distribution=[
combinations.NamedDistribution(
"Mirrored",
# pylint: disable=g-long-lambda
lambda: mirrored_strategy.MirroredStrategy(
devices=mirrored_strategy.all_local_devices(),
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce([
"/job:worker/task:0", "/job:worker/task:1"
], context.num_gpus())),
required_gpus=1)
],
mode=["graph"]))
class MultiWorkerMirroredStrategyTest(
multi_worker_test_base.MultiWorkerTestBase,
strategy_test_lib.DistributionTestBase):
def _configure_distribution_strategy(self, distribution):
cluster_spec = server_lib.ClusterSpec({
"worker": ["/job:worker/task:0", "/job:worker/task:1"]
})
distribution.configure(cluster_spec=cluster_spec)
def test_num_replicas_in_sync(self, distribution):
self._configure_distribution_strategy(distribution)
# We calculate the total number of gpus across the workers(2) specified in
# the cluster spec.
self.assertEqual(context.num_gpus() * 2, distribution.num_replicas_in_sync)
def testMinimizeLossGraph(self, distribution):
self._configure_distribution_strategy(distribution)
self._test_minimize_loss_graph(distribution, learning_rate=0.05)
def testDeviceScope(self, distribution):
"""Test the device scope of multi-worker MirroredStrategy."""
self._configure_distribution_strategy(distribution)
with distribution.scope():
a = constant_op.constant(1.)
with ops.device("/cpu:0"):
b = constant_op.constant(1.)
self.assertEqual(a.device, "/job:worker/task:0")
self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0")
def testMakeInputFnIteratorWithDataset(self, distribution):
self._configure_distribution_strategy(distribution)
dataset_fn = lambda: dataset_ops.Dataset.range(100)
num_gpus = context.num_gpus()
num_workers = 2
expected_values = [[i+j for j in range(num_gpus)] * num_workers
for i in range(0, 100, num_gpus)]
with context.graph_mode(), self.cached_session() as sess:
# `expected_input_pipeline_id` is None because the input_fn will be called
# multiple times, each with a different input_pipeline_id.
input_fn = self._input_fn_to_test_input_context(
dataset_fn,
expected_num_replicas_in_sync=num_workers*num_gpus,
expected_num_input_pipelines=num_workers,
expected_input_pipeline_id=None)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(
iterator, distribution.extended.worker_devices, expected_values, sess)
def testMakeInputFnIteratorWithCallable(self, distribution):
self._configure_distribution_strategy(distribution)
def fn():
dataset = dataset_ops.Dataset.range(100)
it = dataset.make_one_shot_iterator()
return it.get_next
num_gpus = context.num_gpus()
num_workers = 2
expected_values = []
for i in range(0, 100, num_gpus):
expected_values.append([i+j for j in range(num_gpus)] * num_workers)
with context.graph_mode(), self.cached_session() as sess:
# `expected_input_pipeline_id` is None because the input_fn will be called
# multiple times, each with a different input_pipeline_id.
input_fn = self._input_fn_to_test_input_context(
fn,
expected_num_replicas_in_sync=num_workers*num_gpus,
expected_num_input_pipelines=num_workers,
expected_input_pipeline_id=None)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(
iterator, distribution.extended.worker_devices, expected_values, sess,
test_reinitialize=False, ignore_order=True)
def testUpdateConfigProto(self, distribution):
distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]})
config_proto = config_pb2.ConfigProto()
new_config = distribution.update_config_proto(config_proto)
# Verify isolate_session_state
self.assertTrue(new_config.isolate_session_state)
class MultiWorkerMirroredStrategyTestWithChief(
multi_worker_test_base.MultiWorkerTestBase,
strategy_test_lib.DistributionTestBase):
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers and 1 chief."""
cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=2, num_ps=0, has_chief=True)
cls._default_target = "grpc://" + cls._cluster_spec["chief"][0]
def _make_cross_device_ops(self):
return cross_device_ops_lib.MultiWorkerAllReduce(
["/job:chief/task:0", "/job:worker/task:0", "/job:worker/task:1"],
context.num_gpus())
def testMinimizeLossGraph(self):
strategy = mirrored_strategy.MirroredStrategy(
cross_device_ops=self._make_cross_device_ops())
strategy.configure(cluster_spec=self._cluster_spec)
self._test_minimize_loss_graph(strategy, learning_rate=0.05)
def testMinimizeLossGraphMirroredStrategy(self):
strategy = mirrored_strategy.MirroredStrategy(
mirrored_strategy.all_local_devices(),
cross_device_ops=self._make_cross_device_ops())
strategy.configure(cluster_spec=self._cluster_spec)
self._test_minimize_loss_graph(strategy, learning_rate=0.05)
def testMinimizeLossGraphMirroredStrategyWithOneNode(self):
cluster_spec = {}
cluster_spec["chief"] = self._cluster_spec["chief"]
tf_config = {"cluster": cluster_spec}
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(tf_config)}):
strategy = mirrored_strategy.MirroredStrategy()
self.assertIsInstance(strategy.extended._inferred_cross_device_ops,
cross_device_ops_lib.NcclAllReduce)
self.skipTest("b/130551176, run the following once fixed.")
self._test_minimize_loss_graph(strategy, learning_rate=0.05)
def testInitializeFromTFConfig(self):
tf_config = {"cluster": self._cluster_spec}
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(tf_config)}):
strategy = mirrored_strategy.MirroredStrategy(
cross_device_ops=self._make_cross_device_ops())
self.assertEqual(
max(context.num_gpus(), 1) * 3, strategy.num_replicas_in_sync)
def testSummaryForReplicaZeroOnly(self):
strategy = mirrored_strategy.MirroredStrategy(
mirrored_strategy.all_local_devices(),
cross_device_ops=self._make_cross_device_ops())
strategy.configure(cluster_spec=self._cluster_spec)
self._test_summary_for_replica_zero_only(strategy)
class MirroredVariableStopGradientTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_one_gpu,
],
mode=["graph"]))
def testMirroredVariableAsStopGradient(self, distribution):
with distribution.scope():
inp = constant_op.constant(1.0)
x = variables.Variable(1.0)
y = inp*x
grads = gradients.gradients(x, y, stop_gradients=x)
self.assertIsNone(grads[0])
def _replica_id():
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
if not isinstance(replica_id, ops.Tensor):
replica_id = constant_op.constant(replica_id)
return replica_id
def _replica_id_as_int():
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
if isinstance(replica_id, ops.Tensor):
replica_id = tensor_util.constant_value(replica_id)
return replica_id
if __name__ == "__main__":
test.main()