| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for CollectiveAllReduceStrategy.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import functools |
| |
| from absl.testing import parameterized |
| import numpy as np |
| |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.core.protobuf import rewriter_config_pb2 |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib |
| from tensorflow.python.distribute import collective_all_reduce_strategy |
| from tensorflow.python.distribute import combinations |
| from tensorflow.python.distribute import distribute_lib |
| from tensorflow.python.distribute import distribute_utils |
| from tensorflow.python.distribute import distribution_strategy_context |
| from tensorflow.python.distribute import input_lib |
| from tensorflow.python.distribute import multi_worker_test_base |
| from tensorflow.python.distribute import multi_worker_util |
| from tensorflow.python.distribute import reduce_util |
| from tensorflow.python.distribute import strategy_combinations |
| from tensorflow.python.distribute import strategy_test_lib |
| from tensorflow.python.distribute import test_util |
| from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import config as tf_config |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import device as tf_device |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gen_math_ops |
| from tensorflow.python.ops import gradients |
| from tensorflow.python.ops import init_ops |
| from tensorflow.python.ops import init_ops_v2 |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import test |
| from tensorflow.python.training.server_lib import ClusterSpec |
| |
| |
| CollectiveAllReduceStrategy = ( |
| collective_all_reduce_strategy.CollectiveAllReduceStrategy) |
| CollectiveAllReduceExtended = ( |
| collective_all_reduce_strategy.CollectiveAllReduceExtended) |
| _CollectiveAllReduceStrategyExperimental = ( |
| collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental) |
| |
| |
| def create_test_objects(cluster_spec=None, |
| task_type=None, |
| task_id=None, |
| num_gpus=None): |
| sess_config = config_pb2.ConfigProto() |
| if num_gpus is None: |
| num_gpus = context.num_gpus() |
| |
| if cluster_spec and task_type and task_id is not None: |
| cluster_resolver = SimpleClusterResolver( |
| cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), |
| task_type=task_type, |
| task_id=task_id, |
| num_accelerators={'GPU': num_gpus}) |
| target = 'grpc://' + cluster_spec[task_type][task_id] |
| else: |
| cluster_resolver = SimpleClusterResolver( |
| ClusterSpec({}), num_accelerators={'GPU': num_gpus}) |
| target = '' |
| |
| strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( |
| cluster_resolver=cluster_resolver) |
| sess_config = strategy.update_config_proto(sess_config) |
| |
| return strategy, target, sess_config |
| |
| |
| class CollectiveAllReduceStrategyTestBase( |
| multi_worker_test_base.MultiWorkerTestBase): |
| |
| def setUp(self): |
| # We use a different key_base for each test so that collective keys won't be |
| # reused. |
| CollectiveAllReduceStrategy._collective_key_base += 100000 |
| super(CollectiveAllReduceStrategyTestBase, self).setUp() |
| |
| def _get_test_object(self, task_type, task_id, num_gpus=0): |
| strategy, target, session_config = create_test_objects( |
| cluster_spec=self._cluster_spec, |
| task_type=task_type, |
| task_id=task_id, |
| num_gpus=num_gpus) |
| return strategy, target, session_config |
| |
| def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): |
| d, master_target, config = self._get_test_object(task_type, task_id, |
| num_gpus) |
| with ops.Graph().as_default(), \ |
| self.cached_session(config=config, |
| target=master_target) as sess, \ |
| d.scope(): |
| initializer = functools.partial( |
| init_ops_v2.GlorotUniform(), (1, 1), dtype=dtypes.float32) |
| kernel = variables.Variable( |
| initial_value=initializer, |
| name='gpu_%d/kernel' % d.extended._num_gpus_per_worker, |
| trainable=True) |
| |
| def loss_fn(x): |
| y = array_ops.reshape( |
| gen_math_ops.mat_mul(x, kernel), []) - constant_op.constant(1.) |
| return y * y |
| |
| # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for |
| # multiple graphs (b/111216820). |
| def grad_fn(x): |
| loss = loss_fn(x) |
| var_list = ( |
| variables.trainable_variables() + ops.get_collection( |
| ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) |
| grads = gradients.gradients(loss, var_list) |
| ret = list(zip(grads, var_list)) |
| return ret |
| |
| def update(v, g): |
| return v.assign_sub(0.05 * g, use_locking=True) |
| |
| one = constant_op.constant([[1.]]) |
| |
| def step(): |
| """Perform one optimization step.""" |
| # Run forward & backward to get gradients, variables list. |
| g_v = d.extended.call_for_each_replica(grad_fn, args=[one]) |
| # Update the variables using the gradients and the update() function. |
| before_list = [] |
| after_list = [] |
| for g, v in g_v: |
| fetched = d.extended.read_var(v) |
| before_list.append(fetched) |
| with ops.control_dependencies([fetched]): |
| # TODO(yuefengz): support non-Mirrored variable as destinations. |
| g = d.extended.reduce_to( |
| reduce_util.ReduceOp.SUM, g, destinations=v) |
| with ops.control_dependencies( |
| d.extended.update(v, update, args=(g,), group=False)): |
| after_list.append(d.extended.read_var(v)) |
| return before_list, after_list |
| |
| before_out, after_out = step() |
| |
| if context.num_gpus() < d.extended._num_gpus_per_worker: |
| return True |
| |
| sess.run(variables.global_variables_initializer()) |
| |
| for i in range(10): |
| b, a = sess.run((before_out, after_out)) |
| if i == 0: |
| before, = b |
| after, = a |
| |
| error_before = abs(before - 1) |
| error_after = abs(after - 1) |
| # Error should go down |
| self.assertLess(error_after, error_before) |
| |
| def _test_variable_initialization(self, task_type, task_id, num_gpus): |
| distribution, master_target, config = self._get_test_object( |
| task_type, task_id, num_gpus) |
| with ops.Graph().as_default(), \ |
| self.cached_session(config=config, |
| target=master_target) as sess, \ |
| distribution.scope(): |
| |
| def model_fn(): |
| x = variable_scope.get_variable( |
| 'x', |
| shape=(2, 3), |
| initializer=init_ops.random_uniform_initializer( |
| 1.0, 10.0, dtype=dtypes.float32)) |
| return array_ops.identity(x) |
| |
| x = distribution.extended.call_for_each_replica(model_fn) |
| reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x, axis=None) |
| x = distribution.experimental_local_results(x)[0] |
| |
| sess.run(variables.global_variables_initializer()) |
| |
| x_value, reduced_x_value = sess.run([x, reduced_x]) |
| self.assertTrue( |
| np.allclose(x_value, reduced_x_value, atol=1e-5), |
| msg=('x_value = %r, reduced_x_value = %r' % (x_value, |
| reduced_x_value))) |
| |
| def _test_input_fn_iterator(self, |
| task_type, |
| task_id, |
| num_gpus, |
| input_fn, |
| expected_values, |
| test_reinitialize=True, |
| ignore_order=False): |
| distribution, master_target, config = self._get_test_object( |
| task_type, task_id, num_gpus) |
| devices = distribution.extended.worker_devices |
| |
| with ops.Graph().as_default(), \ |
| self.cached_session(config=config, |
| target=master_target) as sess: |
| iterator = distribution.make_input_fn_iterator(input_fn) |
| sess.run(iterator.initializer) |
| |
| for expected_value in expected_values: |
| next_element = iterator.get_next() |
| computed_value = sess.run([distribute_utils.select_replica( |
| r, next_element) for r in range(len(devices))]) |
| if ignore_order: |
| self.assertCountEqual(list(expected_value), list(computed_value)) |
| else: |
| self.assertEqual(list(expected_value), list(computed_value)) |
| |
| with self.assertRaises(errors.OutOfRangeError): |
| next_element = iterator.get_next() |
| sess.run([distribute_utils.select_replica(r, next_element) |
| for r in range(len(devices))]) |
| |
| # After re-initializing the iterator, should be able to iterate again. |
| if test_reinitialize: |
| sess.run(iterator.initializer) |
| |
| for expected_value in expected_values: |
| next_element = iterator.get_next() |
| computed_value = sess.run([ |
| distribute_utils.select_replica(r, next_element) |
| for r in range(len(devices))]) |
| if ignore_order: |
| self.assertCountEqual(list(expected_value), list(computed_value)) |
| else: |
| self.assertEqual(list(expected_value), list(computed_value)) |
| |
| |
| class DistributedCollectiveAllReduceStrategyTest( |
| CollectiveAllReduceStrategyTestBase, |
| strategy_test_lib.DistributionTestBase, |
| parameterized.TestCase): |
| |
| @classmethod |
| def setUpClass(cls): |
| """Create a local cluster with 3 workers.""" |
| cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( |
| num_workers=3, num_ps=0) |
| |
| @combinations.generate(combinations.combine(mode=['graph'])) |
| def test_num_replicas_in_sync(self): |
| distribution, _, _ = create_test_objects( |
| cluster_spec=self._cluster_spec, |
| task_type='worker', |
| task_id=0, |
| num_gpus=2) |
| num_workers = len(self._cluster_spec.get('chief', []) + |
| self._cluster_spec.get('worker', [])) |
| self.assertEqual(2 * num_workers, |
| distribution.num_replicas_in_sync) |
| |
| @combinations.generate(combinations.combine( |
| mode=['graph'], |
| prefetch_to_device=[None, True])) |
| def test_prefetch_to_device_dataset(self, prefetch_to_device): |
| distribution, _, _ = self._get_test_object( |
| task_type='worker', |
| task_id=0, |
| num_gpus=2) |
| if prefetch_to_device is None: |
| input_options = None |
| else: |
| input_options = distribute_lib.InputOptions( |
| experimental_prefetch_to_device=prefetch_to_device) |
| dataset = dataset_ops.Dataset.range(100) |
| dataset = dataset.batch(distribution.num_replicas_in_sync) |
| dataset = distribution.experimental_distribute_dataset( |
| dataset, options=input_options) |
| if isinstance(dataset, input_lib.DistributedDatasetV1): |
| item = dataset.make_initializable_iterator().get_next() |
| else: |
| self.skipTest('unsupported test combination') |
| device_types = { |
| tf_device.DeviceSpec.from_string(tensor.device).device_type for |
| tensor in item.values} |
| self.assertAllEqual(list(device_types), ['GPU']) |
| |
| @combinations.generate(combinations.combine(mode=['graph'])) |
| def test_prefetch_to_host_dataset(self): |
| distribution, _, _ = self._get_test_object( |
| task_type='worker', |
| task_id=0, |
| num_gpus=2) |
| input_options = distribute_lib.InputOptions( |
| experimental_prefetch_to_device=False) |
| dataset = dataset_ops.Dataset.range(100) |
| dataset = dataset.batch(distribution.num_replicas_in_sync) |
| dataset = distribution.experimental_distribute_dataset( |
| dataset, options=input_options) |
| if isinstance(dataset, input_lib.DistributedDatasetV1): |
| item = dataset.make_initializable_iterator().get_next() |
| else: |
| self.skipTest('unsupported test combination') |
| device_types = { |
| tf_device.DeviceSpec.from_string(tensor.device).device_type for |
| tensor in item.values} |
| self.assertAllEqual(list(device_types), ['CPU']) |
| |
| @combinations.generate( |
| combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) |
| def testMinimizeLossGraph(self, required_gpus): |
| self._run_between_graph_clients(self._test_minimize_loss_graph, |
| self._cluster_spec, required_gpus) |
| |
| @combinations.generate( |
| combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) |
| def testVariableInitialization(self, required_gpus): |
| self._run_between_graph_clients( |
| self._test_variable_initialization, |
| self._cluster_spec, |
| num_gpus=required_gpus) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=['graph'], required_gpus=[0, 1, 2], use_dataset=[True, False])) |
| def testMakeInputFnIterator(self, required_gpus, use_dataset): |
| def _worker_fn(task_type, task_id, required_gpus): |
| if use_dataset: |
| fn = lambda: dataset_ops.Dataset.range(20) |
| else: |
| def fn(): |
| dataset = dataset_ops.Dataset.range(20) |
| it = dataset_ops.make_one_shot_iterator(dataset) |
| return it.get_next |
| # We use CPU as the device when required_gpus = 0 |
| devices_per_worker = max(1, required_gpus) |
| expected_values = [[i+j for j in range(devices_per_worker)] |
| for i in range(0, 20, devices_per_worker)] |
| |
| input_fn = self._input_fn_to_test_input_context( |
| fn, |
| expected_num_replicas_in_sync=3*devices_per_worker, |
| expected_num_input_pipelines=3, |
| expected_input_pipeline_id=task_id) |
| self._test_input_fn_iterator( |
| task_type, |
| task_id, |
| required_gpus, |
| input_fn, |
| expected_values, |
| test_reinitialize=use_dataset, |
| ignore_order=not use_dataset) |
| |
| self._run_between_graph_clients(_worker_fn, self._cluster_spec, |
| required_gpus) |
| |
| @combinations.generate(combinations.combine(mode=['graph'])) |
| def testUpdateConfigProto(self): |
| strategy, _, _ = self._get_test_object( |
| task_type='worker', task_id=1, num_gpus=2) |
| |
| config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) |
| rewrite_options = config_proto.graph_options.rewrite_options |
| rewrite_options.scoped_allocator_opts.enable_op.append('to_be_removed') |
| |
| new_config = strategy.update_config_proto(config_proto) |
| |
| # Verify group leader |
| self.assertEqual('/job:worker/replica:0/task:0', |
| new_config.experimental.collective_group_leader) |
| |
| # Verify device filters. |
| self.assertEqual(['/job:worker/task:1'], new_config.device_filters) |
| |
| # Verify rewrite options. |
| new_rewrite_options = new_config.graph_options.rewrite_options |
| self.assertEqual(rewriter_config_pb2.RewriterConfig.ON, |
| new_rewrite_options.scoped_allocator_optimization) |
| self.assertEqual(['CollectiveReduce'], |
| new_rewrite_options.scoped_allocator_opts.enable_op) |
| |
| def _get_strategy_with_mocked_methods(self): |
| mock_called = [False] |
| |
| # pylint: disable=dangerous-default-value |
| def mock_enable_collective_ops(server_def, mock_called=mock_called): |
| self.assertEqual('worker', server_def.job_name) |
| self.assertEqual(1, server_def.task_index) |
| self.assertEqual('grpc', server_def.protocol) |
| mock_called[0] = True |
| |
| def mock_configure_collective_ops(*args, **kwargs): |
| del args, kwargs |
| |
| with test.mock.patch.object(context.context(), 'enable_collective_ops', |
| mock_enable_collective_ops), \ |
| test.mock.patch.object(context.context(), 'configure_collective_ops', |
| mock_configure_collective_ops): |
| strategy, _, _ = self._get_test_object( |
| task_type='worker', task_id=1, num_gpus=2) |
| |
| return strategy, mock_called |
| |
| @combinations.generate(combinations.combine(mode=['eager'])) |
| def testEnableCollectiveOps(self): |
| # We cannot enable check health with this test because it mocks |
| # enable_collective_ops. |
| CollectiveAllReduceExtended._enable_check_health = False |
| strategy, mock_called = self._get_strategy_with_mocked_methods() |
| CollectiveAllReduceExtended._enable_check_health = True |
| self.assertTrue(strategy.extended._std_server_started) |
| self.assertTrue(mock_called[0]) |
| |
| @combinations.generate(combinations.combine(mode=['eager'])) |
| def testEnableCollectiveOpsAndClusterResolver(self): |
| # We cannot enable check health with this test because it mocks |
| # enable_collective_ops. |
| CollectiveAllReduceExtended._enable_check_health = False |
| strategy, _ = self._get_strategy_with_mocked_methods() |
| CollectiveAllReduceExtended._enable_check_health = True |
| self.assertEqual(strategy.cluster_resolver.task_type, 'worker') |
| self.assertEqual(strategy.cluster_resolver.task_id, 1) |
| |
| |
| class DistributedCollectiveAllReduceStrategyTestWithChief( |
| CollectiveAllReduceStrategyTestBase, parameterized.TestCase): |
| |
| @classmethod |
| def setUpClass(cls): |
| """Create a local cluster with 3 workers and 1 chief.""" |
| cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( |
| num_workers=3, num_ps=0, has_chief=True) |
| |
| @combinations.generate( |
| combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) |
| def testMinimizeLossGraph(self, required_gpus): |
| self._run_between_graph_clients(self._test_minimize_loss_graph, |
| self._cluster_spec, required_gpus) |
| |
| @combinations.generate( |
| combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) |
| def testVariableInitialization(self, required_gpus): |
| self._run_between_graph_clients( |
| self._test_variable_initialization, |
| self._cluster_spec, |
| num_gpus=required_gpus) |
| |
| |
| class LocalCollectiveAllReduceStrategy( |
| CollectiveAllReduceStrategyTestBase, |
| strategy_test_lib.DistributionTestBase, |
| strategy_test_lib.TwoDeviceDistributionTestBase, |
| parameterized.TestCase): |
| |
| @combinations.generate( |
| combinations.combine(mode=['graph', 'eager'], required_gpus=[2, 4])) |
| def testMinimizeLoss(self, required_gpus): |
| # Collective ops doesn't support strategy with one device. |
| if context.executing_eagerly(): |
| strategy, _, _ = self._get_test_object(None, None, required_gpus) |
| self._test_minimize_loss_eager(strategy) |
| else: |
| self._test_minimize_loss_graph(None, None, required_gpus) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=['graph'], required_gpus=2, use_dataset=[True, False])) |
| def testMakeInputFnIterator(self, required_gpus, use_dataset): |
| if use_dataset: |
| fn = lambda: dataset_ops.Dataset.range(5 * required_gpus) |
| else: |
| def fn(): |
| dataset = dataset_ops.Dataset.range(5 * required_gpus) |
| it = dataset_ops.make_one_shot_iterator(dataset) |
| return it.get_next |
| |
| expected_values = [ |
| range(i, i + required_gpus) for i in range(0, 10, required_gpus) |
| ] |
| |
| input_fn = self._input_fn_to_test_input_context( |
| fn, |
| expected_num_replicas_in_sync=required_gpus, |
| expected_num_input_pipelines=1, |
| expected_input_pipeline_id=0) |
| self._test_input_fn_iterator( |
| None, |
| None, |
| required_gpus, |
| input_fn, |
| expected_values, |
| test_reinitialize=use_dataset, |
| ignore_order=not use_dataset) |
| |
| @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) |
| def testAllReduceSum(self, required_gpus): |
| distribution, target, config = self._get_test_object( |
| None, None, num_gpus=required_gpus) |
| with self.cached_session(config=config, target=target): |
| self._test_all_reduce_sum(distribution) |
| |
| @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) |
| def testAllReduceSumGradients(self, required_gpus): |
| distribution, target, config = self._get_test_object( |
| None, None, num_gpus=required_gpus) |
| with self.cached_session(config=config, target=target): |
| self._test_all_reduce_sum_gradients(distribution) |
| |
| @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) |
| def testAllReduceSumGradientTape(self, required_gpus): |
| distribution, target, config = self._get_test_object( |
| None, None, num_gpus=required_gpus) |
| with self.cached_session(config=config, target=target): |
| self._test_all_reduce_sum_gradient_tape(distribution) |
| |
| @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) |
| def testAllReduceMean(self, required_gpus): |
| distribution, target, config = self._get_test_object( |
| None, None, num_gpus=required_gpus) |
| with self.cached_session(config=config, target=target): |
| self._test_all_reduce_mean(distribution) |
| |
| @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) |
| def testAllReduceMeanGradients(self, required_gpus): |
| distribution, target, config = self._get_test_object( |
| None, None, num_gpus=required_gpus) |
| with self.cached_session(config=config, target=target): |
| self._test_all_reduce_mean_gradients(distribution) |
| |
| @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) |
| def testAllReduceMeanGradientTape(self, required_gpus): |
| distribution, target, config = self._get_test_object( |
| None, None, num_gpus=required_gpus) |
| with self.cached_session(config=config, target=target): |
| self._test_all_reduce_mean_gradient_tape(distribution) |
| |
| @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) |
| def testNumpyDataset(self, required_gpus): |
| strategy, target, config = self._get_test_object( |
| None, None, num_gpus=required_gpus) |
| self._test_numpy_dataset( |
| strategy, session=self.cached_session(config=config, target=target)) |
| |
| |
| class LogicalDeviceTest(test.TestCase, parameterized.TestCase): |
| |
| @combinations.generate(combinations.combine(mode=['eager'], required_gpus=1)) |
| def testKeepLogicalDevice(self): |
| # Cannot change logical device after the context initialization. |
| context._reset_context() # pylint: disable=protected-access |
| cluster_spec = multi_worker_test_base.create_cluster_spec( |
| has_chief=False, num_workers=1) |
| resolver = cluster_resolver_lib.SimpleClusterResolver( |
| cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), |
| task_type='worker', |
| task_id=0) |
| gpus = tf_config.list_physical_devices('GPU') |
| tf_config.set_logical_device_configuration(gpus[-1], [ |
| context.LogicalDeviceConfiguration(64), |
| context.LogicalDeviceConfiguration(64), |
| ]) |
| collective_all_reduce_strategy.CollectiveAllReduceStrategy( |
| cluster_resolver=resolver) |
| # Since we create two logical GPUs out of the last GPU, there should be one |
| # more logical GPUs than physical GPUs. |
| self.assertLen(tf_config.list_logical_devices('GPU'), len(gpus) + 1) |
| context._reset_context() # pylint: disable=protected-access |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| strategy=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| ], |
| mode=['eager'])) |
| class CollectiveAllReduceStrategyV2Test(test.TestCase, parameterized.TestCase): |
| |
| def test_replica_id_in_sync_group(self, strategy): |
| |
| def replica_fn(): |
| replica_ctx = distribution_strategy_context.get_replica_context() |
| return replica_ctx.replica_id_in_sync_group, replica_ctx._replica_id |
| |
| results = test_util.gather(strategy, strategy.run(replica_fn)) |
| self.assertAllEqual(list(range(strategy.extended._num_replicas_in_sync)), |
| results[0].numpy()) |
| self.assertAllEqual( |
| list(range(len(strategy.extended.worker_devices))) * |
| strategy.extended._num_workers, results[1].numpy()) |
| |
| |
| class ExperimentalCompatibilityTest(test.TestCase): |
| |
| def testIsInstance(self): |
| # It's not uncommon for people to special case MultiWorkerMirroredStrategy, |
| # so we need to make sure isinstance check works for combinations between |
| # the experimental and non-experimental endpoints. |
| strategy = CollectiveAllReduceStrategy() |
| experimental_strategy = _CollectiveAllReduceStrategyExperimental() |
| self.assertIsInstance(strategy, CollectiveAllReduceStrategy) |
| self.assertIsInstance(strategy, _CollectiveAllReduceStrategyExperimental) |
| self.assertIsInstance(experimental_strategy, CollectiveAllReduceStrategy) |
| self.assertIsInstance(experimental_strategy, |
| _CollectiveAllReduceStrategyExperimental) |
| |
| def testName(self): |
| # Estimator checks the __name__ to special case MultiWorkerMirroredStrategy. |
| self.assertEqual(CollectiveAllReduceStrategy.__name__, |
| 'CollectiveAllReduceStrategy') |
| self.assertEqual(_CollectiveAllReduceStrategyExperimental.__name__, |
| 'CollectiveAllReduceStrategy') |
| |
| |
| if __name__ == '__main__': |
| # TODO(b/172304955): enable logical devices. |
| test_util.main(config_logical_devices=False) |