[XLA] Add AllReduceScatter to HLO bindings.

PiperOrigin-RevId: 381092455
Change-Id: Id40276a45e2f0a99c5a3b32760569ec6fd87439b
diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc
index 4440ec0..6a11d63 100644
--- a/tensorflow/compiler/xla/python/ops.cc
+++ b/tensorflow/compiler/xla/python/ops.cc
@@ -77,6 +77,12 @@
       py::arg("replica_groups") = py::list(),
       py::arg("channel_id") = absl::nullopt,
       py::arg("shape_with_layout") = absl::nullopt);
+  ops.def("AllReduceScatter", &AllReduceScatter, py::arg("operand"),
+          py::arg("computation"), py::arg("scatter_dimension"),
+          py::arg("shard_count"), py::arg("replica_groups") = py::list(),
+          py::arg("channel_id") = absl::nullopt,
+          py::arg("layout") = absl::nullopt,
+          py::arg("use_global_device_ids") = absl::nullopt);
   ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
           py::arg("concat_dimension"), py::arg("split_count"),
           py::arg("replica_groups") = py::list(),
diff --git a/tensorflow/compiler/xla/python/xla_extension/ops.pyi b/tensorflow/compiler/xla/python/xla_extension/ops.pyi
index 9131ae5..30492f2 100644
--- a/tensorflow/compiler/xla/python/xla_extension/ops.pyi
+++ b/tensorflow/compiler/xla/python/xla_extension/ops.pyi
@@ -68,6 +68,15 @@
     replica_groups: Sequence[_ReplicaGroup] = ...,
     channel_id: Optional[ChannelHandle] = ...,
     shape_with_layout: Optional[_Layout] = ...) -> XlaOp: ...
+def AllReduceScatter(
+    operand: XlaOp,
+    computation: XlaComputation,
+    scatter_dimension: int,
+    shard_count: int,
+    replica_groups: Sequence[_ReplicaGroup] = ...,
+    channel_id: Optional[ChannelHandle] = ...,
+    layout: Optional[_Layout] = ...,
+    use_global_device_ids: Optional[bool] = ...) -> XlaOp: ...
 def AllToAll(
     operand: XlaOp,
     split_dimension: int,