blob: 3fb1ed3ac503fb7477162f392116c531cb268a2b [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for V2 Collective Operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import threading
import time
from absl.testing import parameterized
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.experimental.ops import testing as dataset_testing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import collective_ops as _collective_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
class CollectiveOpsV1(object):
all_reduce = _collective_ops.all_reduce
all_gather = _collective_ops.all_gather
broadcast_send = _collective_ops.broadcast_send
broadcast_recv = _collective_ops.broadcast_recv
class CollectiveOpsV2(object):
@staticmethod
def all_reduce(t, group_size, group_key, instance_key, *args, **kwargs):
group_size = array_ops.identity(group_size)
group_key = array_ops.identity(group_key)
instance_key = array_ops.identity(instance_key)
return _collective_ops.all_reduce_v2(t, group_size, group_key, instance_key,
*args, **kwargs)
@staticmethod
def all_gather(t, group_size, group_key, instance_key, *args, **kwargs):
group_size = array_ops.identity(group_size)
group_key = array_ops.identity(group_key)
instance_key = array_ops.identity(instance_key)
return _collective_ops.all_gather_v2(t, group_size, group_key, instance_key,
*args, **kwargs)
@staticmethod
def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
*args, **kwargs):
group_size = array_ops.identity(group_size)
group_key = array_ops.identity(group_key)
instance_key = array_ops.identity(instance_key)
return _collective_ops.broadcast_send_v2(t, group_size, group_key,
instance_key, *args, **kwargs)
@staticmethod
def broadcast_recv(shape, dtype, group_size, group_key, instance_key, *args,
**kwargs):
group_size = array_ops.identity(group_size)
group_key = array_ops.identity(group_key)
instance_key = array_ops.identity(instance_key)
shape = array_ops.identity(shape)
return _collective_ops.broadcast_recv_v2(
shape, dtype, group_size, group_key, instance_key, *args, **kwargs)
device_combination = (
combinations.combine(device='CPU', communication='RING', required_gpus=0) +
combinations.combine(
device='GPU', communication=['RING', 'NCCL'], required_gpus=2))
collective_op_combinations = combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce', CollectiveOpsV1.all_reduce),
combinations.NamedObject('all_reduce_v2',
CollectiveOpsV2.all_reduce),
combinations.NamedObject('all_gather', CollectiveOpsV1.all_gather),
combinations.NamedObject('all_gather_v2',
CollectiveOpsV2.all_gather),
],
mode='eager'), device_combination)
@combinations.generate(
combinations.times(
combinations.combine(
collective_ops=[
combinations.NamedObject('v1', CollectiveOpsV1),
combinations.NamedObject('v2', CollectiveOpsV2)
],
mode='eager'), device_combination))
class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
def setUp(self):
_setup_context()
super().setUp()
def testReduce(self, collective_ops, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
@def_function.function
def run_all_reduce_1device():
with ops.device(dev0):
in_value = constant_op.constant([1.])
group_size = 1
group_key = 1
instance_key = 1
return collective_ops.all_reduce(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication)
@def_function.function
def run_all_reduce_2devices():
in_value = constant_op.constant([1.])
group_size = 2
group_key = 2
instance_key = 2
collectives = []
with ops.device(dev0):
collectives.append(
collective_ops.all_reduce(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication))
with ops.device(dev1):
collectives.append(
collective_ops.all_reduce(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication))
return collectives
self.assertAllClose(run_all_reduce_1device(), [1.], rtol=1e-5, atol=1e-5)
for result in run_all_reduce_2devices():
self.assertAllClose(result, [2.], rtol=1e-5, atol=1e-5)
def testGather(self, collective_ops, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
@def_function.function
def run_all_gather_1device():
with ops.device(dev0):
in_value = constant_op.constant([1.])
group_size = 1
group_key = 1
instance_key = 1
return collective_ops.all_gather(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication)
@def_function.function
def run_all_gather_2devices():
in_value = constant_op.constant([1.])
group_size = 2
group_key = 2
instance_key = 2
collectives = []
with ops.device(dev0):
collectives.append(
collective_ops.all_gather(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication))
with ops.device(dev1):
collectives.append(
collective_ops.all_gather(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication))
return collectives
self.assertAllClose(run_all_gather_1device(), [1.], rtol=1e-5, atol=1e-5)
for result in run_all_gather_2devices():
self.assertAllClose(result, [1., 1.], rtol=1e-5, atol=1e-5)
def testBroadcast(self, collective_ops, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
@def_function.function
def run_broadcast_2devices():
shape = [3]
in_value = constant_op.constant([1., 2., 3.], shape=shape)
group_size = 2
group_key = 2
instance_key = 2
collectives = []
with ops.device(dev0):
collectives.append(
collective_ops.broadcast_send(
in_value,
shape,
in_value.dtype,
group_size,
group_key,
instance_key,
communication_hint=communication))
with ops.device(dev1):
collectives.append(
collective_ops.broadcast_recv(
shape,
in_value.dtype,
group_size,
group_key,
instance_key,
communication_hint=communication))
return collectives
for result in run_broadcast_2devices():
self.assertAllClose(result, [1., 2., 3.], rtol=1e-5, atol=1e-5)
def testInstanceKeyScopedUnderGroupKey(self, collective_ops, device,
communication):
if device == 'GPU' and context.num_gpus() < 4:
self.skipTest('not enough GPU')
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
dev2 = '/device:%s:2' % device
dev3 = '/device:%s:3' % device
@def_function.function
def run_all_reduce_4devices_same_instance_key():
# Use a common instance key for both groups.
instance_key = 0
# We will create 2 groups each with 2 devices.
group_size = 2
# Group 0 comprises dev0 and dev1.
group0_key = 0
# Group 1 comprises dev2 and dev3.
group1_key = 1
collectives = []
with ops.device(dev0):
collectives.append(
collective_ops.all_reduce(
constant_op.constant(1.), group_size, group0_key, instance_key))
with ops.device(dev1):
collectives.append(
collective_ops.all_reduce(
constant_op.constant(2.), group_size, group0_key, instance_key))
with ops.device(dev2):
collectives.append(
collective_ops.all_reduce(
constant_op.constant(3.), group_size, group1_key, instance_key))
with ops.device(dev3):
collectives.append(
collective_ops.all_reduce(
constant_op.constant(4.), group_size, group1_key, instance_key))
return collectives
results = run_all_reduce_4devices_same_instance_key()
self.assertAllClose(results[0], 3., rtol=1e-5, atol=1e-5)
self.assertAllClose(results[1], 3., rtol=1e-5, atol=1e-5)
self.assertAllClose(results[2], 7., rtol=1e-5, atol=1e-5)
self.assertAllClose(results[3], 7., rtol=1e-5, atol=1e-5)
def testCollectiveGroupSizeOne(self, collective_ops, device, communication):
dev0 = '/device:%s:0' % device
group_size = 1
group_key = 100
in_value = [1., 2., 3., 4.]
in_tensor = constant_op.constant(in_value)
with ops.device(dev0):
reduced_tensor = collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key=100,
communication_hint=communication)
self.assertAllEqual(in_value, reduced_tensor.numpy())
with ops.device(dev0):
gathered_tensor = collective_ops.all_gather(
in_tensor,
group_size,
group_key,
instance_key=200,
communication_hint=communication)
self.assertAllEqual(in_value, gathered_tensor.numpy())
def testCollectiveInvalidKey(self, collective_ops, device, communication):
dev0 = '/device:%s:0' % device
group_size = 1
group_key = 100
instance_key = 100
in_value = [1., 2., 3., 4.]
in_tensor = constant_op.constant(in_value)
with ops.device(dev0):
reduced_tensor = collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
self.assertAllEqual(in_value, reduced_tensor.numpy())
with self.assertRaisesRegex(
errors.InternalError, 'instance 100 expected type 0 and data_type 1 but'
' got type 2 and data_type 1'):
with ops.device(dev0):
collective_ops.all_gather(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
def testMultipleGroups(self, collective_ops, device, communication):
if device == 'GPU' and context.num_gpus() < 4:
self.skipTest('not enough GPU')
num_elements = 4
@def_function.function
def run_all_reduce(group_size, group_key):
instance_key = group_key
input_value = [float(group_key) for i in range(num_elements)]
collectives = []
for device_idx in range(group_size):
with ops.device('/{}:{}'.format(device, device_idx)):
input_tensor = constant_op.constant(input_value)
collectives.append(
collective_ops.all_reduce(
input_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication))
return collectives
def run_and_assert(group_size, group_key):
for reduced_tensor in run_all_reduce(group_size, group_key):
self.assertAllEqual(
[float(group_key) * group_size for i in range(num_elements)],
reduced_tensor.numpy())
run_and_assert(group_size=2, group_key=1)
run_and_assert(group_size=3, group_key=2)
@combinations.generate(collective_op_combinations)
class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
def setUp(self):
_setup_context()
super().setUp()
def testAbortGroupParamsResolution(self, collective_op, device,
communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
def abort_fn():
time.sleep(2)
context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down')
t = threading.Thread(target=abort_fn)
t.start()
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
# This hangs on params resolution since we're only launching one
# collective for a group size of 2.
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# After abortion, subsequent collectives should fail immediately.
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
t.join()
# Reset the context in order to reset the collective executor.
_setup_context()
# After reset non-NCCL collectives should work.
def collective_fn():
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
def_function.function(collective_fn)()
def testAbortInstanceParamsResolution(self, collective_op, device,
communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
def collective_fn():
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# First perform a normal all-reduce to complete the group resolution.
def_function.function(collective_fn)()
def abort_fn():
time.sleep(2)
context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down')
t = threading.Thread(target=abort_fn)
t.start()
# Use a different instance key to trigger another instance resolution.
instance_key = 101
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
# This hangs on params resolution since we're only launching one
# collective for a group size of 2.
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# After abortion, subsequent collectives should fail immediately.
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
context._reset_context() # pylint: disable=protected-access
t.join()
# Reset the context in order to reset the collective executor.
_setup_context()
# After reset non-NCCL collectives should work.
def_function.function(collective_fn)()
def testAbortCommunication(self, collective_op, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
# First perform a normal collective to finish resolution.
def collective_fn():
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
def_function.function(collective_fn)()
# Launch a collective that hangs, and abort the collective executor after
# the launch.
def abort_fn():
time.sleep(2)
context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down')
t = threading.Thread(target=abort_fn)
t.start()
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# After abortion, subsequent collectives should fail immediately.
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# Reset the context in order to reset the collective executor.
t.join()
_setup_context()
def_function.function(collective_fn)()
class OpCancellationTest(test.TestCase, parameterized.TestCase):
def setUp(self):
_setup_context()
super().setUp()
@combinations.generate(
combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce',
CollectiveOpsV1.all_reduce),
combinations.NamedObject('all_reduce_v2',
CollectiveOpsV2.all_reduce),
combinations.NamedObject('all_gather',
CollectiveOpsV1.all_gather),
combinations.NamedObject('all_gather_v2',
CollectiveOpsV2.all_gather),
],
mode='eager'), device_combination))
def testOpErrorNotAbortIfNoCollective(self, collective_op, device,
communication):
# Do not abort if there's no active collective ops. There could be
# exceptions like EOF which we expect users to catch, aborting collective
# ops on all op errors intervenes with this workflow.
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
dataset = dataset_ops.Dataset.from_tensors([1.])
@def_function.function
def collective_fn(in_tensor):
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
@def_function.function
def f():
iterator = iter(dataset)
collective_fn(next(iterator))
# This next(iterator) should raise EOF.
collective_fn(next(iterator))
with self.assertRaises(errors.OutOfRangeError):
f()
collective_fn(constant_op.constant([1.]))
@combinations.generate(
combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce',
CollectiveOpsV1.all_reduce),
combinations.NamedObject('all_gather',
CollectiveOpsV1.all_gather),
],
mode='eager'), device_combination))
def testOpErrorAbortWithCollective(self, collective_op, device,
communication):
# Abort v1 collective ops if there're active collective ops at the time of
# an op error. This is due to the inability to cancel collective ops, and op
# errors may cause running collective ops to hang.
dev0 = '/device:%s:0' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
# Make the dataset sleep a while so that the collective is being executed
# when the EOF happens.
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
dataset_testing.sleep(sleep_microseconds=200))
@def_function.function
def f():
# Launch a collective op that won't be able to finish to test abortion
# when other ops error.
with ops.device(dev0):
ret = collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
iterator = iter(dataset)
next(iterator)
# This should raise EOF.
next(iterator)
return ret
with self.assertRaises(errors.OutOfRangeError):
f()
# Now collective ops is aborted, subsequent collective ops should fail with
# the previous error.
with self.assertRaises(errors.CancelledError):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
@combinations.generate(
combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce_v2',
CollectiveOpsV2.all_reduce),
combinations.NamedObject('all_gather_v2',
CollectiveOpsV2.all_gather),
],
mode='eager'), device_combination))
def testOpErrorNotAbortWithCollective(self, collective_op, device,
communication):
# Do not abort v2 collective ops even if there're active collective ops at
# the time of an op error. We rely cancellation to terminate active
# collective ops.
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
@def_function.function
def collective_fn():
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# Local params resolution cannot be cancelled yet, so we perform a normal
# collective so that the group is resolved.
collective_fn()
# Make the dataset sleep a while so that the collective is being executed
# when the EOF happens.
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
dataset_testing.sleep(sleep_microseconds=200))
@def_function.function
def f():
# Launch a collective op that won't be able to finish to test cancellation
# when other ops error.
with ops.device(dev0):
ret = collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
iterator = iter(dataset)
next(iterator)
# This should raise EOF.
next(iterator)
return ret
with self.assertRaises(errors.OutOfRangeError):
f()
# Collective ops shouldn't be aborted and new collectives should be able to
# proceed.
collective_fn()
@combinations.generate(
combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce_v2',
CollectiveOpsV2.all_reduce),
combinations.NamedObject('all_gather_v2',
CollectiveOpsV2.all_gather),
],
mode='eager'), device_combination))
def testCancelDuringParamResolution(self, collective_op, device,
communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
t1_cancellation_manager = cancellation.CancellationManager()
t2_cancellation_manager = cancellation.CancellationManager()
@def_function.function
def _collective_fn(x):
# Run an assertion to crash one of the two function executions running
# collectives. We explicitly cancel the other in response.
assert_op = check_ops.assert_equal(x, in_tensor)
with ops.control_dependencies([assert_op]):
return collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
collective_concrete = _collective_fn.get_concrete_function(in_tensor)
finish_mu = threading.Lock()
finishes = 0
def _placement_wrapper(device, x, my_cancellation, other_cancellation):
try:
with ops.device(device):
cancelable_collective = my_cancellation.get_cancelable_function(
collective_concrete)
return cancelable_collective(x)
except errors.InvalidArgumentError:
# `assert_equal` failed for this execution of the function. The other
# function would deadlock without cancellation.
other_cancellation.start_cancel()
except errors.CancelledError:
pass
nonlocal finishes
with finish_mu:
finishes += 1
t1 = threading.Thread(
target=_placement_wrapper,
args=(dev0, constant_op.constant([1.]), t1_cancellation_manager,
t2_cancellation_manager))
t2 = threading.Thread(
target=_placement_wrapper,
# Will cause the assertion to fail
args=(dev1, constant_op.constant([2.]), t2_cancellation_manager,
t1_cancellation_manager))
t1.start()
t2.start()
t1.join()
t2.join()
self.assertEqual(finishes, 2)
@combinations.generate(collective_op_combinations)
class TimeoutTest(test.TestCase, parameterized.TestCase):
def setUp(self):
_setup_context()
super().setUp()
def testTimeout(self, collective_op, device, communication):
timeout = 1.5
@def_function.function
def run(group_size, reported_group_size=None):
group_key = 20
instance_key = 30
tensor = [1., 2., 3., 4.]
results = []
if reported_group_size is None:
reported_group_size = group_size
for i in range(group_size):
with ops.device('/{}:{}'.format(device, i)):
input_data = constant_op.constant(tensor)
result = collective_op(
input_data,
group_size=reported_group_size,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication,
timeout=timeout)
results.append(result)
return results
run(2, 2)
start_time = time.time()
with self.assertRaisesRegex(errors.DeadlineExceededError,
'Collective has timed out during execution'):
run(1, 2)
elapsed = time.time() - start_time
self.assertAllGreaterEqual(elapsed, timeout)
def testParamResolutionAfterTimeout(self, collective_op, device,
communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
timeout = 1.5
group_key = 20
instance_key = 30
input_data = constant_op.constant([1., 2., 3., 4.])
# This timeout comes from param solution.
with self.assertRaisesRegex(
errors.DeadlineExceededError,
'Collective has timed out waiting for other workers'):
with ops.device(dev0):
collective_op(
input_data,
group_size=2,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication,
timeout=timeout)
# We launch the second device after the first device times out. This is to
# simulate the situation when other workers are slow and the timeout is
# short. It should error immediately.
with self.assertRaisesRegex(
errors.DeadlineExceededError,
'Collective has timed out waiting for other workers'):
with ops.device(dev1):
collective_op(
input_data,
group_size=2,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication)
def testExecutionAfterTimeout(self, collective_op, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
timeout = 1.5
group_key = 20
instance_key = 30
input_data = constant_op.constant([1., 2., 3., 4.])
@def_function.function
def run():
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
input_data,
group_size=2,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication,
timeout=timeout)
# Run a normal all-reduce to complete param resolution.
run()
with self.assertRaisesRegex(errors.DeadlineExceededError,
'Collective has timed out during execution'):
with ops.device(dev0):
collective_op(
input_data,
group_size=2,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication,
timeout=timeout)
# We launch the second device after the first device times out. This is to
# simulate the situation when other workers are slow and the timeout is
# short. It should error immediately.
with self.assertRaisesRegex(errors.DeadlineExceededError,
'Collective has timed out during execution'):
with ops.device(dev1):
# No timeout.
collective_op(
input_data,
group_size=2,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication)
@combinations.generate(
combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce_v2',
CollectiveOpsV2.all_reduce),
combinations.NamedObject('all_gather_v2',
CollectiveOpsV2.all_gather),
],
mode='eager'), device_combination))
class OrderingTest(test.TestCase, parameterized.TestCase):
def setUp(self):
_setup_context()
super().setUp()
def testOrdering(self, collective_op, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
with ops.device(dev0):
token0 = resource_variable_ops.ResourceVariable(0.)
with ops.device(dev1):
token1 = resource_variable_ops.ResourceVariable(0.)
@def_function.function
def f():
# Launch the first collective with token.
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
ordering_token=token0.handle)
with ops.device(dev1):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
ordering_token=token1.handle)
# Launch the second collective without token.
with ops.device(dev0):
collective_op(in_tensor, group_size, group_key, instance_key)
with ops.device(dev1):
collective_op(in_tensor, group_size, group_key, instance_key)
# Launch the third collective with token.
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
ordering_token=token0.handle)
with ops.device(dev1):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
ordering_token=token1.handle)
graph = f.get_concrete_function().graph
for device in [dev0, dev1]:
# Try to find the third collective, which should have the first collective
# as a control input.
third = None
for op in graph.get_operations():
if (op.type.startswith('Collective') and op.device.endswith(device) and
op.control_inputs and
op.control_inputs[0].type.startswith('Collective')):
self.assertIsNone(third)
third = op
self.assertIsNotNone(third)
# Verify it's not the second collective by looking at the inputs.
self.assertTrue(any(v.dtype == dtypes.resource for v in third.inputs))
first = third.control_inputs[0]
self.assertEqual(third.device, first.device)
# Verify it's not the second collective by looking at the inputs.
self.assertTrue(any(v.dtype == dtypes.resource for v in first.inputs))
self.assertEmpty(first.control_inputs)
def _setup_context():
context._reset_context()
test_util.set_logical_devices_to_at_least('CPU', 4)
context.ensure_initialized()
if __name__ == '__main__':
os.environ['NCCL_DEBUG'] = 'INFO'
v2_compat.enable_v2_behavior()
test.main()