[XLA] Add AllReduceScatter to HLO bindings.
PiperOrigin-RevId: 381092455
Change-Id: Id40276a45e2f0a99c5a3b32760569ec6fd87439b
diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc
index 4440ec0..6a11d63 100644
--- a/tensorflow/compiler/xla/python/ops.cc
+++ b/tensorflow/compiler/xla/python/ops.cc
@@ -77,6 +77,12 @@
py::arg("replica_groups") = py::list(),
py::arg("channel_id") = absl::nullopt,
py::arg("shape_with_layout") = absl::nullopt);
+ ops.def("AllReduceScatter", &AllReduceScatter, py::arg("operand"),
+ py::arg("computation"), py::arg("scatter_dimension"),
+ py::arg("shard_count"), py::arg("replica_groups") = py::list(),
+ py::arg("channel_id") = absl::nullopt,
+ py::arg("layout") = absl::nullopt,
+ py::arg("use_global_device_ids") = absl::nullopt);
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
py::arg("concat_dimension"), py::arg("split_count"),
py::arg("replica_groups") = py::list(),
diff --git a/tensorflow/compiler/xla/python/xla_extension/ops.pyi b/tensorflow/compiler/xla/python/xla_extension/ops.pyi
index 9131ae5..30492f2 100644
--- a/tensorflow/compiler/xla/python/xla_extension/ops.pyi
+++ b/tensorflow/compiler/xla/python/xla_extension/ops.pyi
@@ -68,6 +68,15 @@
replica_groups: Sequence[_ReplicaGroup] = ...,
channel_id: Optional[ChannelHandle] = ...,
shape_with_layout: Optional[_Layout] = ...) -> XlaOp: ...
+def AllReduceScatter(
+ operand: XlaOp,
+ computation: XlaComputation,
+ scatter_dimension: int,
+ shard_count: int,
+ replica_groups: Sequence[_ReplicaGroup] = ...,
+ channel_id: Optional[ChannelHandle] = ...,
+ layout: Optional[_Layout] = ...,
+ use_global_device_ids: Optional[bool] = ...) -> XlaOp: ...
def AllToAll(
operand: XlaOp,
split_dimension: int,