blob: fe56daea3adf06ca357cc81e60916d8c67184c31 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the distributed values library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import itertools
import os
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model.model_utils import mode_keys
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import nest
class DistributedValuesTest(test.TestCase):
def testGetEager(self):
with ops.device("/device:CPU:0"):
one = constant_op.constant(1)
two = constant_op.constant(2)
device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
v = values.DistributedValues(device_map, (one, two))
self.assertEqual(two, v.get("/device:GPU:0"))
self.assertEqual(one, v.get())
with self.assertRaises(ValueError):
self.assertIsNone(v.get("/device:GPU:2"))
def testGetGraph(self):
with context.graph_mode(), \
ops.Graph().as_default(), \
ops.device("/device:CPU:0"):
one = constant_op.constant(1)
two = constant_op.constant(2)
device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
v = values.DistributedValues(device_map, (one, two))
self.assertEqual(two, v.get("/device:GPU:0"))
self.assertEqual(one, v.get())
with self.assertRaises(ValueError):
self.assertIsNone(v.get("/device:GPU:2"))
def testCanonicalization(self):
canonical_cpu = ("/job:localhost/replica:0/task:0/device:CPU:0",)
v = values.DistributedValues(values.SingleDeviceMap(""), (42,))
self.assertEqual(canonical_cpu, v.devices)
v = values.DistributedValues(values.SingleDeviceMap("/device:CPU:0"), (42,))
self.assertEqual(canonical_cpu, v.devices)
v = values.DistributedValues(values.SingleDeviceMap("/cpu:0"), (42,))
self.assertEqual(canonical_cpu, v.devices)
v = values.DistributedValues(values.SingleDeviceMap("/CPU:0"), (42,))
self.assertEqual(canonical_cpu, v.devices)
def testIsTensorLike(self):
with context.graph_mode(), \
ops.Graph().as_default(), \
ops.device("/device:CPU:0"):
one = constant_op.constant(1)
two = constant_op.constant(2)
device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
v = values.DistributedValues(device_map, (one, two))
self.assertEqual(two, v.get("/device:GPU:0"))
self.assertEqual(one, v.get())
self.assertTrue(v.is_tensor_like)
self.assertTrue(tensor_util.is_tensor(v))
def testIsTensorLikeWithAConstant(self):
with context.graph_mode(), \
ops.Graph().as_default(), \
ops.device("/device:CPU:0"):
one = constant_op.constant(1)
two = 2.0
device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
v = values.DistributedValues(device_map, (one, two))
self.assertEqual(two, v.get("/device:GPU:0"))
self.assertEqual(one, v.get())
self.assertFalse(v.is_tensor_like)
self.assertFalse(tensor_util.is_tensor(v))
class DistributedDelegateTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testGetAttr(self):
with ops.device("/device:CPU:0"):
class Foo(object):
def __init__(self, x):
self.x = x
device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
v = values.DistributedDelegate(device_map, (Foo(7), Foo(8)))
self.assertEqual(7, v.x)
with self.assertRaises(AttributeError):
_ = v.y
@test_util.run_in_graph_and_eager_modes
def testOperatorOverride(self):
with ops.device("/device:CPU:0"):
device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
v = values.DistributedDelegate(device_map, (7, 8))
# v should act like int(7).
self.assertEqual(8, v + 1)
self.assertEqual(10, 3 + v)
self.assertEqual(14, v + v)
self.assertEqual(5, v - 2)
self.assertEqual(6, 13 - v)
self.assertEqual(0, v - v)
self.assertEqual(14, v * 2)
self.assertEqual(21, 3 * v)
self.assertEqual(49, v * v)
self.assertEqual(3.5, v / 2)
self.assertEqual(1.5, 10.5 / v)
self.assertEqual(3, v // 2)
self.assertEqual(2, 15 // v)
self.assertEqual(1, v % 2)
self.assertEqual(2, 16 % v)
self.assertTrue(v < 12)
self.assertTrue(v <= 12)
self.assertFalse(v > 12)
self.assertFalse(v >= 12)
self.assertFalse(12 < v)
self.assertFalse(12 <= v)
self.assertTrue(12 > v)
self.assertTrue(12 >= v)
self.assertEqual(3, v & 3)
self.assertEqual(3, 11 & v)
self.assertEqual(15, v | 8)
self.assertEqual(23, 16 | v)
self.assertEqual(4, v ^ 3)
self.assertEqual(12, 11 ^ v)
self.assertEqual(343, pow(v, 3))
self.assertEqual(3, pow(v, 3, 10))
self.assertEqual(128, pow(2, v))
self.assertEqual(-7, -v)
self.assertEqual(~7, ~v)
self.assertEqual(7, abs(v))
with self.assertRaises(TypeError):
_ = v[2]
def _device_str(d):
return "/device:GPU:" + str(d)
def _nested_value(d):
return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
def _make_mirrored():
v = []
devices = ["/device:GPU:0", "/device:CPU:0"]
for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
with ops.device(d):
v.append(variable_scope.get_variable(
name=n, initializer=init, use_resource=True))
device_map = values.ReplicaDeviceMap(devices)
mirrored = values.MirroredVariable(None, device_map, v,
variable_scope.VariableAggregation.SUM)
return v, device_map, mirrored
class RegroupAndSelectDeviceTest(test.TestCase):
def _is_per_replica(self, result, expected, klass=values.PerReplica):
self.assertIsInstance(result, klass)
# We canonicalize the devices to match the device strings returned
# by PerReplica, which also does device string canonicalization.
devices = [device_util.canonicalize(_device_str(i))
for i in range(len(expected))]
self.assertEqual(set(devices), set(result.devices))
for i, d in enumerate(devices):
self.assertEqual(expected[i], result.get(d))
self.assertEqual(expected[i], result.get(_device_str(i)))
def testNested(self):
device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
result = values.regroup(device_map,
(_nested_value("1"), _nested_value("2")))
self.assertIsInstance(result, tuple)
self.assertEqual(3, len(result))
self._is_per_replica(result[0], ["a1", "a2"])
self._is_per_replica(result[2], ["h1", "h2"])
self.assertIsInstance(result[1], list)
self.assertEqual(3, len(result[1]))
self._is_per_replica(result[1][0], ["b1", "b2"])
self._is_per_replica(result[1][2], ["g1", "g2"])
self.assertIsInstance(result[1][1], dict)
self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
self._is_per_replica(result[1][1]["c"], ["d1", "d2"])
self._is_per_replica(result[1][1]["e"], ["f1", "f2"])
# Also test that we can undo the merge using select_replica()
self.assertEqual(_nested_value("1"),
values.select_replica(0, result))
self.assertEqual(_nested_value("2"),
values.select_replica(1, result))
# select_device_mirrored() should fail due to non-mirrored values
with self.assertRaises(TypeError):
values.select_device_mirrored(_device_str(0), result)
with self.assertRaises(TypeError):
values.select_device_mirrored(_device_str(1), result)
def testWrapClass(self):
# Normally a mirrored value would be the same across devices, but
# for a test it is convenient to be able to tell the values apart.
device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
result = values.regroup(device_map,
(_nested_value("1"), _nested_value("2")),
values.Mirrored)
self.assertIsInstance(result, tuple)
self.assertEqual(3, len(result))
self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored)
self.assertIsInstance(result[1], list)
self.assertEqual(3, len(result[1]))
self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored)
self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored)
self.assertIsInstance(result[1][1], dict)
self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored)
# Also test that we can undo the merge using select_replica()
self.assertEqual(_nested_value("1"),
values.select_replica(0, result))
self.assertEqual(_nested_value("2"),
values.select_replica(1, result))
# Values are marked as mirrored, so select_device_mirrored() is allowed.
self.assertEqual(_nested_value("1"),
values.select_device_mirrored(_device_str(0), result))
self.assertEqual(_nested_value("2"),
values.select_device_mirrored(_device_str(1), result))
def testWrapAListOfTwoTuples(self):
device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
result = values.regroup(device_map, [("1", "2"), ("3", "4")])
self.assertIsInstance(result, tuple)
self.assertEqual(2, len(result))
self._is_per_replica(result[0], ("1", "3"), values.PerReplica)
self._is_per_replica(result[1], ("2", "4"), values.PerReplica)
def testMirroredContainer(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
v, device_map, mirrored = _make_mirrored()
result = values.regroup(device_map, v)
self.assertIs(mirrored, result)
def testSameId(self):
foo = object()
device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
result = values.regroup(device_map, (("a", foo), ("b", foo)))
self.assertIsInstance(result, tuple)
self.assertEqual(2, len(result))
self._is_per_replica(result[0], ["a", "b"])
self.assertIs(foo, result[1])
# Test select_replica(), should undo the merge done by regroup().
result_0 = values.select_replica(0, result)
self.assertIsInstance(result_0, tuple)
self.assertEqual(2, len(result_0))
self.assertEqual("a", result_0[0])
self.assertIs(foo, result_0[1])
result_1 = values.select_replica(1, result)
self.assertIsInstance(result_1, tuple)
self.assertEqual(2, len(result_1))
self.assertEqual("b", result_1[0])
self.assertIs(foo, result_1[1])
def testOneDevice(self):
device_map = values.ReplicaDeviceMap((_device_str(0),))
result = values.regroup(device_map, (_nested_value("1"),))
# On one device regroup() and select_replica() are basically identity.
self.assertEqual(_nested_value("1"), result)
self.assertEqual(_nested_value("1"),
values.select_replica(0, result))
# The one exception has to do with MirroredVariables.
d = "/device:CPU:0"
with ops.device(d):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
device_map = values.ReplicaDeviceMap((d,))
mirrored = values.MirroredVariable(None, device_map, (v,),
variable_scope.VariableAggregation.SUM)
result = values.regroup(device_map, (v,))
self.assertIs(mirrored, result)
def testNamedTuple(self):
# We include toy implementations of Scaffold and EstimatorSpec to
# avoid a dependency on Estimator here.
class Scaffold(object):
pass
class EstimatorSpec(collections.namedtuple(
"EstimatorSpec", ["mode", "loss", "train_op", "scaffold"])):
def __new__(cls, mode, loss, train_op, scaffold=None):
return super(EstimatorSpec, cls).__new__(
cls, mode=mode, loss=loss, train_op=train_op,
scaffold=scaffold or Scaffold())
with context.graph_mode(), ops.Graph().as_default():
devices = []
created_estimator_specs = []
for device_id in range(3):
spec = EstimatorSpec(
mode=mode_keys.EstimatorModeKeys.TRAIN,
loss=constant_op.constant(device_id / 2),
train_op=array_ops.identity(constant_op.constant(device_id)))
devices.append(_device_str(device_id))
created_estimator_specs.append(spec)
device_map = values.ReplicaDeviceMap(devices)
merged_estimator_spec = values.regroup(
device_map, created_estimator_specs)
self.assertIsInstance(merged_estimator_spec, EstimatorSpec)
self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN,
merged_estimator_spec.mode)
for device_id in range(3):
d = _device_str(device_id)
self.assertEqual(created_estimator_specs[device_id].loss,
merged_estimator_spec.loss.get(d))
self.assertEqual(created_estimator_specs[device_id].train_op,
merged_estimator_spec.train_op.get(d))
# Scaffold is populated by `EstimatorSpec.__new__`.
self.assertEqual(created_estimator_specs[device_id].scaffold,
merged_estimator_spec.scaffold.get(d))
self.assertIsInstance(created_estimator_specs[device_id].scaffold,
Scaffold)
# Also test that we can undo the merge using select_replica()
self.assertEqual(created_estimator_specs[device_id],
values.select_replica(device_id,
merged_estimator_spec))
class MirroredVariableTest(test.TestCase, parameterized.TestCase):
config = config_pb2.ConfigProto()
config.allow_soft_placement = True
@test_util.run_in_graph_and_eager_modes(config=config)
def testProperties(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
v, _, mirrored = _make_mirrored()
self.assertEqual(v[0].name, mirrored.name)
self.assertEqual(v[0].dtype, mirrored.dtype)
self.assertEqual(v[0].shape, mirrored.shape)
@test_util.run_in_graph_and_eager_modes(config=config)
def testVariableOnAnotherDevice(self):
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
mirrored = values.MirroredVariable(None, device_map, (v,),
variable_scope.VariableAggregation.MEAN)
self.assertEqual(v.name, mirrored.name)
self.assertEqual(v.dtype, mirrored.dtype)
self.assertEqual(v.shape, mirrored.shape)
def _assign_mirrored(self, devices, v, new):
for d, var, n in zip(devices, v, new):
with ops.device(d):
self.evaluate(var.assign(n))
def _save_return_saver(self, sess, var):
saver = saver_lib.Saver(var_list=[var])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
return saver.save(sess, prefix), saver
def _save(self, sess, var):
save_path, _ = self._save_return_saver(sess, var)
return save_path
@test_util.run_in_graph_and_eager_modes(config=config)
def testSaveAndRestoreMirroredOneGraph(self):
if context.num_gpus() < 1 and context.executing_eagerly():
# Graph mode can work without GPU because the Placer "moves" the
# variable to a CPU. In other words, if there is no GPU available, but
# user requested to create a variable on GPU, Placer will ignore the
# user request and assign the VarHandleOp to CPU. This requires
# soft_placement, which is on by default.
self.skipTest("A GPU is not available for this test in eager mode.")
with self.cached_session(config=self.config) as sess:
v, device_map, mirrored = _make_mirrored()
devices = device_map.all_devices
# Overwrite the initial values.
self._assign_mirrored(devices, v, [3., 4.])
# Saves the current value of v[0], 3.
save_path, saver = self._save_return_saver(sess, mirrored)
# Change the values between save and restore.
self._assign_mirrored(devices, v, [5., 6.])
# Restores the saved value of 3. to both variables.
saver.restore(sess, save_path)
self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
def _save_mirrored(self):
"""Save variables with mirroring, returns save_path."""
with self.session(graph=ops.Graph()) as sess:
v, device_map, mirrored = _make_mirrored()
devices = device_map.all_devices
# Overwrite the initial values.
self._assign_mirrored(devices, v, [3., 4.])
# Saves the current value of v[0], 3.
save_path = self._save(sess, mirrored)
# Change the values between save and restore.
self._assign_mirrored(devices, v, [5., 6.])
return save_path
def _save_normal(self):
"""Save variables without mirroring, returns save_path."""
with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
# Overwrite the initial value.
self.evaluate(var.assign(3.))
# Saves the current value of var, 3.
save_path = self._save(sess, var)
# Change the values between save and restore.
self.evaluate(var.assign(5.))
return save_path
def _restore_normal(self, save_path):
"""Restore to variables without mirroring in a fresh graph."""
with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=7., use_resource=True)
# Overwrite the initial value.
self.evaluate(var.assign(8.))
# Restores the saved value of 3. to `var`.
saver = saver_lib.Saver(var_list=[var])
saver.restore(sess, save_path)
self.assertEqual(3., self.evaluate(var))
def _restore_mirrored(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
with self.session(graph=ops.Graph()) as sess:
v, device_map, mirrored = _make_mirrored()
devices = device_map.all_devices
# Overwrite the initial values.
self._assign_mirrored(devices, v, [7., 8.])
# Restores the saved value of 3. to both variables.
saver = saver_lib.Saver(var_list=[mirrored])
saver.restore(sess, save_path)
self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
@test_util.run_in_graph_and_eager_modes(config=config)
def testSaveMirroredRestoreMirrored(self):
if context.num_gpus() < 1 and context.executing_eagerly():
# Graph mode can work without GPU because the Placer "moves" the
# variable to a CPU. In other words, if there is no GPU available, but
# user requested to create a variable on GPU, Placer will ignore the
# user request and assign the VarHandleOp to CPU. This requires
# soft_placement, which is on by default.
self.skipTest("A GPU is not available for this test in eager mode.")
save_path = self._save_mirrored()
self._restore_mirrored(save_path)
@test_util.run_in_graph_and_eager_modes(config=config)
def testSaveMirroredRestoreNormal(self):
if context.num_gpus() < 1 and context.executing_eagerly():
# Graph mode can work without GPU because the Placer "moves" the
# variable to a CPU. In other words, if there is no GPU available, but
# user requested to create a variable on GPU, Placer will ignore the
# user request and assign the VarHandleOp to CPU. This requires
# soft_placement, which is on by default.
self.skipTest("A GPU is not available for this test in eager mode.")
save_path = self._save_mirrored()
self._restore_normal(save_path)
@test_util.run_in_graph_and_eager_modes(config=config)
def testSaveNormalRestoreMirrored(self):
if context.num_gpus() < 1 and context.executing_eagerly():
# Graph mode can work without GPU because the Placer "moves" the
# variable to a CPU. In other words, if there is no GPU available, but
# user requested to create a variable on GPU, Placer will ignore the
# user request and assign the VarHandleOp to CPU. This requires
# soft_placement, which is on by default.
self.skipTest("A GPU is not available for this test in eager mode.")
save_path = self._save_normal()
self._restore_mirrored(save_path)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_gpu,
],
mode=["graph"]))
def testFetchAMirroredVariable(self, distribution):
with self.session(graph=ops.Graph()) as sess, distribution.scope():
with ops.device("/device:GPU:0"):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
mirrored = values.MirroredVariable(
distribution, values.ReplicaDeviceMap(("/device:GPU:0",)), (v,),
variable_scope.VariableAggregation.MEAN)
sess.run(variables_lib.global_variables_initializer())
sess.run({"complicated": mirrored})
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
mode=["graph", "eager"]))
def testAssignOutOfScope_mirrored(self, distribution):
with distribution.scope():
mirrored = variables_lib.Variable(1.)
if not isinstance(mirrored, values.MirroredVariable):
self.assertIsInstance(mirrored, values.TPUMirroredVariable)
self.evaluate(mirrored.assign(3.))
self.assertEqual(self.evaluate(mirrored.read_value()), 3.)
for component in mirrored.values:
self.assertEqual(self.evaluate(component.read_value()), 3.)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.central_storage_strategy_with_two_gpus
],
mode=["graph", "eager"]))
def testAssignOutOfScope_aggregating(self, distribution):
with distribution.scope():
aggregating = variables_lib.Variable(1.)
self.assertIsInstance(aggregating, values.AggregatingVariable)
self.evaluate(aggregating.assign(3.))
self.assertEqual(self.evaluate(aggregating.read_value()), 3.)
self.assertEqual(self.evaluate(aggregating._v.read_value()), 3.)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["graph", "eager"]))
def testExtendsVariable(self, distribution):
with distribution.scope():
v = variables_lib.Variable(1.)
self.assertIsInstance(v, variables_lib.Variable)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["graph", "eager"]))
def testCheckpointing(self, distribution):
with distribution.scope():
v = variables_lib.Variable(constant_op.constant([1., 2., 3., 4]))
self.evaluate(v.initializer)
before_save = self.evaluate(v.read_value())
# Save random weights into checkpoint.
checkpoint = trackable_utils.Checkpoint(v=v)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
with self.test_session():
save_path = checkpoint.save(prefix)
# Assign inverted value.
self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.])))
after_assign = self.evaluate(v.read_value())
self.assertNotAllClose(before_save, after_assign)
# Restore from the checkpoint.
with self.test_session():
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
after_restore = self.evaluate(v)
self.assertAllClose(before_save, after_restore)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["graph"]))
def testTraceback(self, distribution):
with distribution.scope():
variable_scope.get_variable(
name="testVar", initializer=1., use_resource=True)
with self.assertRaisesRegex(
ValueError, "Variable testVar already exists"):
variable_scope.get_variable(
name="testVar", initializer=1., use_resource=True)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["eager"]))
def testInitializedToSameValueInsideEagerRun(self, distribution):
v = [None]
@def_function.function
def step():
def f():
if v[0] is None:
v[0] = variables_lib.Variable(random_ops.random_normal([]))
distribution.experimental_run_v2(f)
context.set_global_seed(None)
step()
vals = self.evaluate(v[0].values)
self.assertAllEqual(vals[0], vals[1])
_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
def _make_replica_local(method, strategy=None):
if strategy is None:
devices = ("/device:GPU:0", "/device:CPU:0")
else:
devices = strategy.extended.worker_devices
device_map = values.ReplicaDeviceMap(devices)
v = []
for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
with ops.device(d):
v.append(variable_scope.get_variable(
name=n, initializer=init, use_resource=True))
if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
var_cls = values.TPUSyncOnReadVariable
else:
var_cls = values.SyncOnReadVariable
replica_local = var_cls(strategy, device_map, v, method)
return v, replica_local
class SyncOnReadVariablePropertiesTest(test.TestCase):
config = config_pb2.ConfigProto()
config.allow_soft_placement = True
@test_util.run_in_graph_and_eager_modes(config=config)
def testProperties(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.SUM)
self.assertEqual(v[0].name, replica_local.name)
self.assertEqual(v[0].dtype, replica_local.dtype)
self.assertEqual(v[0].shape, replica_local.shape)
self.assertEqual(variable_scope.VariableAggregation.SUM,
replica_local.aggregation)
@test_util.run_in_graph_and_eager_modes(config=config)
def testVariableOnAnotherDevice(self):
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
replica_local = values.SyncOnReadVariable(
None, device_map, (v,), variable_scope.VariableAggregation.MEAN)
self.assertEqual(v.name, replica_local.name)
self.assertEqual(v.dtype, replica_local.dtype)
self.assertEqual(v.shape, replica_local.shape)
self.assertEqual(variable_scope.VariableAggregation.MEAN,
replica_local.aggregation)
def testTensorConversion(self):
with context.graph_mode():
_, replica_local = _make_replica_local(
variable_scope.VariableAggregation.SUM)
converted = ops.internal_convert_to_tensor(replica_local, as_ref=False)
self.assertIsInstance(converted, ops.Tensor)
self.assertEqual(converted.dtype, replica_local.dtype)
converted = ops.internal_convert_to_tensor(replica_local, as_ref=True)
# Resources variable are converted to tensors as well when as_ref is True.
self.assertIsInstance(converted, ops.Tensor)
self.assertEqual(converted.dtype, replica_local.dtype)
@test_util.run_v2_only
def testCanPassToDefFun(self):
@def_function.function
def add1(x):
return x + 1
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
replica_local = values.SyncOnReadVariable(
None, device_map, (v,), variable_scope.VariableAggregation.MEAN)
self.assertEqual(2., self.evaluate(add1(replica_local)))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
mode=["graph", "eager"]))
class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
def _assign_replica_local(self, v, new):
for var, n in zip(v, new):
with ops.device(var.device):
self.evaluate(var.assign(n))
def _save_return_saver(self, sess, var):
saver = saver_lib.Saver(var_list=[var])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
return saver.save(sess, prefix), saver
def _save(self, sess, var):
save_path, _ = self._save_return_saver(sess, var)
return save_path
def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution):
with self.cached_session() as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.SUM, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [3., 4.])
with distribution.scope():
# Saves the current value of v[0] + v[1], 7.
save_path, saver = self._save_return_saver(sess, replica_local)
# Change the values between save and restore.
self._assign_replica_local(v, [5., 6.])
# Restores the saved value of 7. which gets divided equally
# between the variables.
saver.restore(sess, save_path)
self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
with self.cached_session() as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.MEAN, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [3., 4.])
with distribution.scope():
# Saves the current value of (v[0] + v[1])/2, 3.5.
save_path, saver = self._save_return_saver(sess, replica_local)
# Change the values between save and restore.
self._assign_replica_local(v, [5., 6.])
# Restores the saved value of 3.5 to both variables.
saver.restore(sess, save_path)
self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
def _save_replica_local_mean(self, distribution):
"""Save variables with mirroring, returns save_path."""
with self.session(graph=ops.Graph()) as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.MEAN, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [3., 4.])
with distribution.scope():
# Saves the current value of (v[0] + v[1])/2, 3.5
save_path = self._save(sess, replica_local)
# Change the values between save and restore.
self._assign_replica_local(v, [5., 6.])
return save_path
def _save_replica_local_sum(self, distribution):
"""Save variables with mirroring, returns save_path."""
with self.session(graph=ops.Graph()) as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.SUM, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [1.5, 2.])
with distribution.scope():
# Saves the current value of v[0] + v[1], 3.5
save_path = self._save(sess, replica_local)
# Change the values between save and restore.
self._assign_replica_local(v, [5., 6.])
return save_path
def _save_normal(self):
"""Save variables without mirroring, returns save_path."""
with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
# Overwrite the initial value.
self.evaluate(var.assign(3.5))
# Saves the current value of var, 3.5.
save_path = self._save(sess, var)
# Change the values between save and restore.
self.evaluate(var.assign(5.))
return save_path
def _restore_normal(self, save_path):
"""Restore to variables without mirroring in a fresh graph."""
with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=7., use_resource=True)
# Overwrite the initial value.
self.evaluate(var.assign(8.))
# Restores the saved value of 3.5 to `var`.
saver = saver_lib.Saver(var_list=[var])
saver.restore(sess, save_path)
self.assertEqual(3.5, self.evaluate(var))
def _restore_replica_local_mean(self, save_path, distribution):
"""Restore to variables with mirroring in a fresh graph."""
with self.session(graph=ops.Graph()) as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.MEAN, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [7., 8.])
with distribution.scope():
# Restores the saved value of 3.5 to both variables.
saver = saver_lib.Saver(var_list=[replica_local])
saver.restore(sess, save_path)
self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
def _restore_replica_local_sum(self, save_path, distribution):
"""Restore to variables with mirroring in a fresh graph."""
with self.session(graph=ops.Graph()) as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.SUM, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [7., 8.])
with distribution.scope():
# Restores the saved value of 3.5 to both variables.
saver = saver_lib.Saver(var_list=[replica_local])
saver.restore(sess, save_path)
self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]]))
def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution):
save_path = self._save_replica_local_mean(distribution)
self._restore_replica_local_mean(save_path, distribution)
def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution):
save_path = self._save_replica_local_sum(distribution)
self._restore_replica_local_sum(save_path, distribution)
def testSaveReplicaLocalMeanRestoreNormal(self, distribution):
save_path = self._save_replica_local_mean(distribution)
self._restore_normal(save_path)
def testSaveReplicaLocalSumRestoreNormal(self, distribution):
save_path = self._save_replica_local_sum(distribution)
self._restore_normal(save_path)
def testSaveNormalRestoreReplicaLocalMean(self, distribution):
save_path = self._save_normal()
self._restore_replica_local_mean(save_path, distribution)
def testSaveNormalRestoreReplicaLocalSum(self, distribution):
save_path = self._save_normal()
self._restore_replica_local_sum(save_path, distribution)
def testAssign(self, distribution):
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
return distribution.experimental_local_results(
distribution.experimental_run_v2(update_fn))
updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
options = ( # VariableAggregation.SUM in cross-replica mode is tested below
[x for x in itertools.product(updates, aggregations, [True, False])
if not(x[1] == variables_lib.VariableAggregation.SUM and x[2])])
for update, aggregation, cross_replica in options:
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
fn, update_value = update
self.evaluate(assign(fn, v, update_value, cross_replica))
for component in v._values:
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
def testAssignDtypeConversion(self, distribution):
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
return distribution.experimental_local_results(
distribution.experimental_run_v2(update_fn))
updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
options = ( # VariableAggregation.SUM in cross-replica mode is tested below
[x for x in itertools.product(updates, aggregations, [True, False])
if not(x[1] == variables_lib.VariableAggregation.SUM and x[2])])
for update, aggregation, cross_replica in options:
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
fn, update_value = update
self.evaluate(assign(fn, v, update_value, cross_replica))
for component in v._values:
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
def testAssignWithAggregationSum(self, distribution):
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(variables_lib.global_variables_initializer())
self.evaluate(v.assign(1. * distribution.num_replicas_in_sync))
for component in v._values:
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
def testAssignAddSubWithAggregationSum(self, distribution):
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(variables_lib.global_variables_initializer())
with self.assertRaisesRegex(
ValueError, "SyncOnReadVariable does not support "):
self.evaluate(v.assign_add(1.))
with self.assertRaisesRegex(
ValueError, "SyncOnReadVariable does not support "):
self.evaluate(v.assign_sub(1.))
def testReadValueInReplicaContext(self, distribution):
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
for aggregation in aggregations:
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
results = self.evaluate(distribution.experimental_local_results(
distribution.experimental_run_v2(v.read_value)))
for component, value in zip(v._values, results):
self.assertAllEqual(self.evaluate(component.read_value()), value)
def testReadValueInCrossReplicaContext(self, distribution):
aggregations = [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
for aggregation in aggregations:
if isinstance(distribution, _TPU_STRATEGIES):
resolver = tpu_cluster_resolver.TPUClusterResolver('')
tpu_strategy_util.initialize_tpu_system(resolver)
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
def assign(v=v):
ctx = distribution_strategy_context.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return v.assign(math_ops.cast(replica_id, dtypes.float32))
self.evaluate(distribution.experimental_local_results(
distribution.experimental_run_v2(assign)))
result = self.evaluate(v.read_value())
num_replicas = distribution.num_replicas_in_sync
sum_of_replica_values = num_replicas * (num_replicas - 1) / 2.
if aggregation == variables_lib.VariableAggregation.SUM:
expected = sum_of_replica_values
elif aggregation == variables_lib.VariableAggregation.MEAN:
expected = sum_of_replica_values / num_replicas
else:
expected = 0
self.assertEqual(expected, result, aggregation)
def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=variables_lib.VariableAggregation.NONE)
self.evaluate(variables_lib.global_variables_initializer())
with self.assertRaisesRegex(
ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
self.evaluate(v.read_value())
def testInitializedToSameValueInsideEagerRun(self, distribution):
if not context.executing_eagerly(): self.skipTest("eager only")
v = [None]
@def_function.function
def step():
def f():
if v[0] is None:
v[0] = variables_lib.Variable(
random_ops.random_normal([]),
synchronization=variables_lib.VariableSynchronization.ON_READ)
distribution.experimental_run_v2(f)
context.set_global_seed(None)
step()
vals = self.evaluate(v[0].values)
self.assertAllEqual(vals[0], vals[1])
class PerReplicaTest(test.TestCase, parameterized.TestCase):
def testTypeSpec(self):
device_map = values.SingleDeviceMap("CPU")
vals = (constant_op.constant(1.),)
per_replica = values.PerReplica(device_map, vals)
spec = per_replica._type_spec
self.assertEqual(spec._value_specs,
(tensor_spec.TensorSpec([], dtypes.float32),))
self.assertEqual(spec._device_map, per_replica.device_map)
self.assertEqual(spec._logical_device, per_replica.logical_device)
def testTypeSpecRoundTrip(self):
device_map = values.SingleDeviceMap("CPU")
vals = (constant_op.constant(1.),)
per_replica = values.PerReplica(device_map, vals)
spec = per_replica._type_spec
tensor_list = spec._to_components(per_replica)
reconstructed = spec._from_components(tensor_list)
self.assertEqual(per_replica.device_map, reconstructed.device_map)
self.assertEqual(per_replica.logical_device, reconstructed.logical_device)
self.assertAllEqual(per_replica.values, reconstructed.values)
def testTypeSpecNest(self):
device_map = values.ReplicaDeviceMap(["CPU:0", "CPU:1"])
vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),)
per_replica = values.PerReplica(device_map, vals)
# Note: nest.map_structutre exercises nest.flatten and
# nest.pack_sequence_as.
result = nest.map_structure(lambda t: t + 10, per_replica,
expand_composites=True)
self.assertEqual(per_replica.device_map, result.device_map)
self.assertEqual(per_replica.logical_device, result.logical_device)
self.assertLen(result.values, 2)
self.assertAllEqual(result.values[0], 11.)
self.assertAllEqual(result.values[1], [15., 16.0])
@test_util.run_in_graph_and_eager_modes
def testIsGraphTensor(self):
per_replica = values.PerReplica(values.SingleDeviceMap("CPU"),
(constant_op.constant(1.),))
self.assertEqual(per_replica._is_graph_tensor,
not context.executing_eagerly())
def testDoesNotTriggerFunctionTracing(self):
traces = []
@def_function.function
def f(x):
traces.append(None) # Only happens on trace.
return x
per_replica = values.PerReplica(
values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))
# Trace once.
f(per_replica)
self.assertNotEmpty(traces)
del traces[:]
per_replica_spec = per_replica._type_spec
for _ in range(5):
vals = per_replica_spec._to_components(per_replica)
vals = [v * 2 for v in vals]
per_replica = per_replica_spec._from_components(vals)
output = f(per_replica)
self.assertIsInstance(output, values.PerReplica)
self.assertAllEqual(output._values, per_replica._values)
self.assertAllEqual(output._device_map, per_replica._device_map)
self.assertAllEqual(output._logical_device, per_replica._logical_device)
self.assertEmpty(traces) # Make sure we're not re-tracing `f`.
def testFunctionCanReturnPerReplica(self):
f = def_function.function(lambda x: x)
x = values.PerReplica(
values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))
y = f(x)
self.assertIsNot(x, y)
for a, b in zip(x._to_components(), y._to_components()):
self.assertAllEqual(a, b)
self.assertEqual(x._component_metadata(), y._component_metadata())
@test_util.run_in_graph_and_eager_modes
def testCondWithTensorValues(self):
device_map = values.SingleDeviceMap("CPU")
per_replica_1 = values.PerReplica(device_map, (constant_op.constant("a"),))
per_replica_2 = values.PerReplica(device_map,
(constant_op.constant(["b", "c"]),))
condition = array_ops.placeholder_with_default(True, [])
result = control_flow_ops.cond(
condition, lambda: per_replica_1, lambda: per_replica_2)
self.assertEqual(per_replica_1.device_map, result.device_map)
self.assertEqual(per_replica_1.logical_device, result.logical_device)
self.assertLen(result.values, 1)
self.assertAllEqual(result.values[0], "a")
@test_util.run_in_graph_and_eager_modes
def testCondWithValuesConvertibleToTensor(self):
device_map = values.SingleDeviceMap("CPU")
per_replica_1 = values.PerReplica(device_map, ("a",))
per_replica_2 = values.PerReplica(device_map, ("b",))
condition = array_ops.placeholder_with_default(True, [])
result = control_flow_ops.cond(
condition, lambda: per_replica_1, lambda: per_replica_2)
self.assertEqual(per_replica_1.device_map, result.device_map)
self.assertEqual(per_replica_1.logical_device, result.logical_device)
self.assertLen(result.values, 1)
self.assertAllEqual(result.values[0], "a")
@test_util.build_as_function_and_v1_graph
def testCondWithValuesNotConvertibleToTensor(self):
device_map = values.SingleDeviceMap("CPU")
per_replica_1 = values.PerReplica(device_map, (set(["a"]),))
per_replica_2 = values.PerReplica(device_map, (set(["b", "c"]),))
condition = array_ops.placeholder(dtypes.bool, [])
with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"):
control_flow_ops.cond(
condition, lambda: per_replica_1, lambda: per_replica_2)
class WorkerDeviceMapTest(test.TestCase):
class ReplicaContext(object):
def __init__(self, replica_id_in_sync_group):
self.replica_id_in_sync_group = replica_id_in_sync_group
def testBasic(self):
devices = [
"/job:worker/replica:0/task:0/device:CPU:0",
"/job:worker/replica:0/task:2/device:CPU:0"
]
device_map = values.WorkerDeviceMap(devices, 1)
self.assertAllEqual(devices, device_map.all_devices)
# pylint:disable=pointless-statement
with self.assertRaisesWithPredicateMatch(
ValueError, "`WorkerDeviceMap` is not indexed by replicas"):
device_map.devices_by_replica
self.assertEqual(1, device_map.num_logical_devices)
self.assertEqual(2, device_map.num_replicas_in_graph)
self.assertEqual(0, device_map.logical_device_from_values(["a", "b"]))
self.assertAllEqual(devices, device_map.logical_to_actual_devices(0))
replica_context = WorkerDeviceMapTest.ReplicaContext(1)
self.assertEqual(
"b", device_map.select_for_current_replica(["a", "b"], replica_context))
with self.assertRaisesWithPredicateMatch(
ValueError, "`WorkerDeviceMap` not indexed by replicas"):
device_map.replica_for_device(devices[1])
self.assertEqual("b", device_map.select_for_device(["a", "b"], devices[1]))
with self.assertRaisesWithPredicateMatch(
ValueError, "WorkerDeviceMap not indexed by replicas"):
device_map.is_device_in_replica(devices[1], 1)
self.assertEqual(
"WorkerDeviceMap(('/job:worker/replica:0/task:0/device:CPU:0', "
"'/job:worker/replica:0/task:2/device:CPU:0'), "
"num_replicas_per_worker=1)", repr(device_map))
def testMultipleReplicasPerWorker(self):
devices = [
"/job:worker/replica:0/task:0/device:CPU:0",
"/job:worker/replica:0/task:2/device:CPU:0"
]
device_map = values.WorkerDeviceMap(devices, 2)
replica_context = WorkerDeviceMapTest.ReplicaContext(3)
self.assertEqual(
"b", device_map.select_for_current_replica(["a", "b"], replica_context))
if __name__ == "__main__":
test.main()