[XLA] Minor fix to AllToAll verifier; add a comment
PiperOrigin-RevId: 429163928
Change-Id: I8086e3d8747855f925f60f3c65ddcabf66e3c3bf
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index b16a432..5fc02f3 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -745,6 +745,9 @@
const absl::optional<Layout>& layout = absl::nullopt,
const absl::optional<bool> use_global_device_ids = absl::nullopt);
+ // TODO(b/219961627): Add overload that accepts one operand per replica (i.e.
+ // with no split_dimension provided). Also, allow the replica_groups to be
+ // inferred (one group containing all replicas).
XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index f7d9eee..05f9e6d 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -469,12 +469,6 @@
TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, group_mode));
TF_RET_CHECK(all_to_all != nullptr);
- if (all_to_all->split_dimension()) {
- if (hlo->replica_groups().empty()) {
- return InternalError(
- "An array all-to-all must have an explicit replica_groups config");
- }
- }
// The size of each replica group must be the same (checked in
// CheckReplicaGroups). This is the split count of the operation). In case the