Scope instance keys under group keys for collective ops.
Before this change, we would store a global mapping from instance key ->
resolved instance runtime parameters. This prevented reusing the same key
across device groups.
After this change, instance keys are scoped under group key. It is legal to
execute 2 collectives with the same instance key, as long as they have different
device groups. This enables the user to assign the same instance key to a
logical collective which is sharded across device groups.
PiperOrigin-RevId: 324324902
Change-Id: Ib994b68f96c8f6cf1cc634d5a7c4998d9f3fb96c
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index a0153a5..ba21abc 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -586,25 +586,32 @@
InstanceRec* irec = nullptr;
bool exit_outside_locks = false;
{
+ bool found_instance = false;
mutex_lock l(instance_mu_);
- auto it = instance_table_.find(cp->instance.instance_key);
- if (it != instance_table_.end()) {
- irec = it->second.get();
- {
- mutex_lock l(irec->in_mu);
- if (irec->is_init) {
- exit_outside_locks = true;
- } else {
- irec->init_waiters.push_back([this, done](InstanceRec* irec) {
- CallbackWithStatus(done, irec);
- });
- return;
+ auto group_it = instance_table_.find(gr->group.group_key);
+ if (group_it != instance_table_.end()) {
+ auto instance_it = group_it->second.find(cp->instance.instance_key);
+ if (instance_it != group_it->second.end()) {
+ irec = instance_it->second.get();
+ {
+ mutex_lock l(irec->in_mu);
+ if (irec->is_init) {
+ exit_outside_locks = true;
+ } else {
+ irec->init_waiters.push_back([this, done](InstanceRec* irec) {
+ CallbackWithStatus(done, irec);
+ });
+ return;
+ }
}
+ found_instance = true;
}
- } else {
+ }
+ if (!found_instance) {
// Create new InstanceRec.
irec = new InstanceRec;
- instance_table_[cp->instance.instance_key].reset(irec);
+ instance_table_[gr->group.group_key][cp->instance.instance_key].reset(
+ irec);
}
}
Status status;
@@ -890,8 +897,10 @@
std::vector<InstanceRec*> instances;
{
mutex_lock l(instance_mu_);
- for (const auto& item : instance_table_) {
- instances.push_back(item.second.get());
+ for (const auto& group_entry : instance_table_) {
+ for (const auto& item : group_entry.second) {
+ instances.push_back(item.second.get());
+ }
}
}
for (InstanceRec* ir : instances) {
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 2b7528d..40f0f00 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -241,8 +241,8 @@
gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
TF_GUARDED_BY(group_mu_);
mutex instance_mu_;
- gtl::FlatMap<int32, std::unique_ptr<InstanceRec>> instance_table_
- TF_GUARDED_BY(instance_mu_);
+ gtl::FlatMap<int32, gtl::FlatMap<int32, std::unique_ptr<InstanceRec>>>
+ instance_table_ TF_GUARDED_BY(instance_mu_);
mutex status_mu_;
Status status_ TF_GUARDED_BY(status_mu_);
};
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
index bfcd5b8..650c52c 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -304,10 +304,15 @@
}
}
-bool CollectiveParamResolverDistributed::InstanceIsCached(int32 instance_key) {
+bool CollectiveParamResolverDistributed::InstanceIsCached(int32 group_key,
+ int32 instance_key) {
mutex_lock l(instance_mu_);
- const auto& it = instance_table_.find(instance_key);
- return it != instance_table_.end();
+ auto group_it = instance_table_.find(group_key);
+ if (group_it == instance_table_.end()) {
+ return false;
+ }
+ auto instance_it = group_it->second.find(instance_key);
+ return instance_it != group_it->second.end();
}
void CollectiveParamResolverDistributed::UpdateInstanceCache(
@@ -374,7 +379,7 @@
if (group_leader_.empty()) {
// This is the group leader so resolution is local.
return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
- } else if (InstanceIsCached(cp->instance.instance_key)) {
+ } else if (InstanceIsCached(gr->group.group_key, cp->instance.instance_key)) {
return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
} else {
CompleteInstanceCall* call = new CompleteInstanceCall(
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h
index 7d30c3d..6848874 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h
@@ -65,7 +65,8 @@
// Returns true iff there's an entry for this instance_key in the
// local instance_table_.
- bool InstanceIsCached(int32 instance_key) TF_LOCKS_EXCLUDED(instance_mu_);
+ bool InstanceIsCached(int32 group_key, int32 instance_key)
+ TF_LOCKS_EXCLUDED(instance_mu_);
// Updates instance_table_ with contents of resp.
void UpdateInstanceCache(const GroupRec* gr, CollectiveParams* cp,
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 51a5219..0230852 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -22,8 +22,10 @@
namespace {
-static string CollectiveKey(OpKernelContext* ctx, int32 instance_key) {
- return strings::StrCat(instance_key, ":", ctx->frame_iter().frame_id, ":",
+static string CollectiveKey(OpKernelContext* ctx, int32 group_key,
+ int32 instance_key) {
+ return strings::StrCat(group_key, ":", instance_key, ":",
+ ctx->frame_iter().frame_id, ":",
ctx->frame_iter().iter_id);
}
@@ -52,7 +54,8 @@
// A string encoding instance, frame and iter to be handed off to
// the implementation for use in generating RecvBuf keys.
string GetCollectiveKey(OpKernelContext* c) {
- return CollectiveKey(c, col_params_.instance.instance_key);
+ return CollectiveKey(c, col_params_.group.group_key,
+ col_params_.instance.instance_key);
}
// Returns false if calling invocation of ComputeAsync should return
@@ -557,7 +560,8 @@
<< " instance " << col_params->instance.instance_key;
col_exec->ExecuteAsync(
c, *col_params,
- CollectiveKey(c, col_params->instance.instance_key),
+ CollectiveKey(c, col_params->group.group_key,
+ col_params->instance.instance_key),
actual_done);
} else {
c->SetStatus(s);
diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py
index 25d9367..4225df7 100644
--- a/tensorflow/python/kernel_tests/collective_ops_test.py
+++ b/tensorflow/python/kernel_tests/collective_ops_test.py
@@ -36,6 +36,8 @@
self.assertEqual(len(cpus), 1)
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
+ context.LogicalDeviceConfiguration(),
+ context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
context.ensure_initialized()
@@ -78,6 +80,47 @@
for result in run_all_reduce_2cpus():
self.assertAllClose(result, [2.], rtol=1e-5, atol=1e-5)
+ @test_util.run_v2_only
+ def testInstanceKeyScopedUnderGroupKey(self):
+ self._setup_context()
+
+ @def_function.function
+ def single_all_reduce(in_value, group_size, group_key, instance_key):
+ return gen_collective_ops.collective_reduce_v2(
+ in_value, group_size, group_key, instance_key, merge_op='Add',
+ final_op='Id', communication_hint='auto')
+
+ @def_function.function
+ def run_all_reduce_4cpus_same_instance_key():
+ # Use a common instance key for both groups.
+ instance_key = constant_op.constant(0)
+ # We will create 2 groups each with 2 devices.
+ group_size = constant_op.constant(2)
+ # Group 0 comprises cpu:0 and cpu:1.
+ group0_key = constant_op.constant(0)
+ # Group 1 comprises cpu:2 and cpu:3.
+ group1_key = constant_op.constant(1)
+ collectives = []
+ with ops.device('/device:CPU:0'):
+ collectives.append(single_all_reduce(
+ constant_op.constant(1.), group_size, group0_key, instance_key))
+ with ops.device('/device:CPU:1'):
+ collectives.append(single_all_reduce(
+ constant_op.constant(2.), group_size, group0_key, instance_key))
+ with ops.device('/device:CPU:2'):
+ collectives.append(single_all_reduce(
+ constant_op.constant(3.), group_size, group1_key, instance_key))
+ with ops.device('/device:CPU:3'):
+ collectives.append(single_all_reduce(
+ constant_op.constant(4.), group_size, group1_key, instance_key))
+ return collectives
+
+ results = run_all_reduce_4cpus_same_instance_key()
+ self.assertAllClose(results[0], 3., rtol=1e-5, atol=1e-5)
+ self.assertAllClose(results[1], 3., rtol=1e-5, atol=1e-5)
+ self.assertAllClose(results[2], 7., rtol=1e-5, atol=1e-5)
+ self.assertAllClose(results[3], 7., rtol=1e-5, atol=1e-5)
+
if __name__ == '__main__':
test.main()