Add AllPermute interface to Dtensor to support halo exchange.
PiperOrigin-RevId: 332947906
Change-Id: Ib7c80eb201441d4a4f9399797bc25b26a5c85661
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index b009ce0..89f90f5 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -71,6 +71,9 @@
case GATHER_COLLECTIVE:
return "RingGather";
+ case PERMUTE_COLLECTIVE:
+ return "Permute";
+
default:
return "undef";
}
diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc
index e1a655e..df19752 100644
--- a/tensorflow/core/framework/collective.cc
+++ b/tensorflow/core/framework/collective.cc
@@ -85,6 +85,8 @@
other.impl_details.subdiv_source_rank.begin(),
other.impl_details.subdiv_source_rank.end());
impl_details.dependencies = other.impl_details.dependencies;
+ devices.assign(other.devices.begin(), other.devices.end());
+ permutation.assign(other.permutation.begin(), other.permutation.end());
}
return *this;
}
@@ -125,8 +127,18 @@
strings::StrAppend(&v, r, ",");
}
strings::StrAppend(&v, "}");
+ } // all subdivs
+ if (type == PERMUTE_COLLECTIVE) {
+ strings::StrAppend(&v, "}, permute_devices {");
+ for (const auto& d : devices) {
+ strings::StrAppend(&v, d, ",");
+ }
+ strings::StrAppend(&v, "}, permute_permutation {");
+ for (const auto& p : permutation) {
+ strings::StrAppend(&v, p, ",");
+ }
+ strings::StrAppend(&v, "}");
}
- strings::StrAppend(&v, "}"); // all subdivs
return v;
}