[XLA:Python] Split bindings for XLA ops into a separate file. No functional changes.

This is partially to make xla.cc shorter and partially to parallelize its build time.

PiperOrigin-RevId: 313307447
Change-Id: I4f6de5723dbef4464599813bc9284b4ac9e271d7
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 5b4182b..3dcdc46 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -186,6 +186,32 @@
     ],
 )
 
+cc_library(
+    name = "ops",
+    srcs = ["ops.cc"],
+    hdrs = ["ops.h"],
+    copts = [
+        "-fexceptions",
+        "-fno-strict-aliasing",
+    ],
+    features = ["-use_header_modules"],
+    deps = [
+        ":types",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/client:xla_computation",
+        "//tensorflow/compiler/xla/client/lib:comparators",
+        "//tensorflow/compiler/xla/client/lib:math",
+        "//tensorflow/compiler/xla/client/lib:qr",
+        "//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
+        "//tensorflow/compiler/xla/client/lib:sorting",
+        "//tensorflow/compiler/xla/client/lib:svd",
+        "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/types:span",
+        "@pybind11",
+    ],
+)
+
 config_setting(
     name = "enable_gpu",
     values = {"define": "xla_python_enable_gpu=true"},
@@ -205,6 +231,7 @@
     deps = [
         ":bfloat16",
         ":dlpack",
+        ":ops",
         ":python_ref_manager",
         ":types",
         "@com_google_absl//absl/base",
@@ -228,12 +255,6 @@
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/client:xla_builder",
         "//tensorflow/compiler/xla/client:xla_computation",
-        "//tensorflow/compiler/xla/client/lib:comparators",
-        "//tensorflow/compiler/xla/client/lib:math",
-        "//tensorflow/compiler/xla/client/lib:qr",
-        "//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
-        "//tensorflow/compiler/xla/client/lib:sorting",
-        "//tensorflow/compiler/xla/client/lib:svd",
         "//tensorflow/compiler/xla/pjrt:cpu_device",
         "//tensorflow/compiler/xla/pjrt:nvidia_gpu_device",
         "//tensorflow/compiler/xla/pjrt:pjrt_client",
diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc
new file mode 100644
index 0000000..89891d3
--- /dev/null
+++ b/tensorflow/compiler/xla/python/ops.cc
@@ -0,0 +1,356 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/python/ops.h"
+
+#include <string>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "pybind11/attr.h"
+#include "pybind11/pybind11.h"
+#include "tensorflow/compiler/xla/client/lib/comparators.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/lib/qr.h"
+#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
+#include "tensorflow/compiler/xla/client/lib/sorting.h"
+#include "tensorflow/compiler/xla/client/lib/svd.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/python/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+namespace py = pybind11;
+
+void BuildOpsSubmodule(py::module* m) {
+  // ops submodule, containing free functions that add operators to an
+  // XlaBuilder.
+  py::module ops = m->def_submodule("ops", "XLA operations");
+
+  py::enum_<TriangularSolveOptions::Transpose>(
+      ops, "TriangularSolveOptions_Transpose")
+      .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
+      .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
+      .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
+      .value("ADJOINT", TriangularSolveOptions::ADJOINT);
+
+  ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
+  ops.def(
+      "AllReduce",
+      static_cast<XlaOp (*)(
+          XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
+          const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
+          &AllReduce),
+      py::arg("operand"), py::arg("computation"),
+      py::arg("replica_groups") = py::list(),
+      py::arg("channel_id") = absl::nullopt,
+      py::arg("shape_with_layout") = absl::nullopt);
+  ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
+          py::arg("concat_dimension"), py::arg("split_count"),
+          py::arg("replica_groups") = py::list(),
+          py::arg("layout") = absl::nullopt);
+  ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
+          py::arg("source_target_pairs"));
+  ops.def("CreateToken", &CreateToken, py::arg("builder"));
+  ops.def("CrossReplicaSum",
+          static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
+              &CrossReplicaSum),
+          py::arg("operand"), py::arg("replica_groups") = py::list());
+  ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
+          py::arg("new_element_type"));
+  ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
+  ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
+          py::arg("shape"), py::arg("broadcast_dimensions"));
+  ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
+          py::arg("operands"));
+  ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
+  ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
+  ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
+  ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
+          py::arg("dimension"));
+  ops.def("Conditional",
+          static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
+                                absl::Span<const XlaOp>)>(&Conditional),
+          py::arg("branch_index"), py::arg("branch_computations"),
+          py::arg("branch_operands"));
+  ops.def("Conditional",
+          static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
+                                const XlaComputation&)>(&Conditional),
+          py::arg("predicate"), py::arg("true_operand"),
+          py::arg("true_computation"), py::arg("false_operand"),
+          py::arg("false_computation"));
+  ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
+  ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
+          py::arg("literal"));
+  ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
+          py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
+          py::arg("lhs_dilation"), py::arg("rhs_dilation"),
+          py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
+          py::arg("batch_group_count") = 1,
+          py::arg("precision_config") = nullptr);
+  ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
+          py::arg("new_element_type"));
+  ops.def(
+      "CustomCall",
+      [](XlaBuilder* builder, const py::bytes& call_target_name,
+         absl::Span<const XlaOp> operands, const Shape& shape,
+         const py::bytes& opaque) -> XlaOp {
+        return CustomCall(builder, call_target_name, operands, shape, opaque);
+      },
+      py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
+      py::arg("shape"), py::arg("opaque") = py::bytes(""));
+  ops.def(
+      "CustomCallWithLayout",
+      [](XlaBuilder* builder, const py::bytes& call_target_name,
+         absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
+         absl::Span<const Shape> operand_shapes_with_layout,
+         const py::bytes& opaque) -> XlaOp {
+        return CustomCallWithLayout(builder, call_target_name, operands,
+                                    shape_with_layout,
+                                    operand_shapes_with_layout, opaque);
+      },
+      py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
+      py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
+      py::arg("opaque") = py::bytes(""));
+  ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
+          py::arg("precision_config") = nullptr);
+  ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
+          py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
+  ops.def("DynamicSlice",
+          static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
+                                absl::Span<const int64>)>(&DynamicSlice),
+          py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
+  ops.def("DynamicUpdateSlice",
+          static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
+              &DynamicUpdateSlice),
+          py::arg("operand"), py::arg("update"), py::arg("start_indices"));
+
+  ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
+          py::arg("fft_length"));
+
+  ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
+          py::arg("dimension_numbers"), py::arg("slice_sizes"),
+          py::arg("indices_are_sorted") = false);
+  ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
+          py::arg("index"));
+  ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
+          py::arg("shape"), py::arg("config") = "");
+  ops.def("Iota",
+          static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota),
+          py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
+  ops.def("Iota",
+          static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota),
+          py::arg("builder"), py::arg("type"), py::arg("size"));
+  ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
+          py::arg("computation"), py::arg("dimensions"),
+          py::arg("static_operands") = py::list());
+  ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
+  ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
+          py::arg("token"), py::arg("shape_with_layout"),
+          py::arg("outfeed_config") = "");
+  ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
+          py::arg("padding_config"));
+  ops.def("Parameter",
+          static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&,
+                                const std::string&, const std::vector<bool>&)>(
+              &Parameter),
+          py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
+          py::arg("name") = "",
+          py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
+  ops.def(
+      "QR",
+      [](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
+        TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
+        return std::make_pair(qr.q, qr.r);
+      },
+      py::arg("operand"), py::arg("full_matrices"));
+  ops.def(
+      "Eigh",
+      [](XlaOp a, bool lower, int64 max_iter,
+         float epsilon) -> std::pair<XlaOp, XlaOp> {
+        auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon);
+        return std::make_pair(eigh.v, eigh.w);
+      },
+      py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100,
+      py::arg("epsilon") = 1e-6);
+  ops.def(
+      "SVD",
+      [](XlaOp a, int64 max_iter,
+         float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
+        auto svd = SVD(a, max_iter, epsilon);
+        return std::make_tuple(svd.u, svd.d, svd.v);
+      },
+      py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
+  ops.def("Reduce",
+          static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
+                                absl::Span<const XlaOp>, const XlaComputation&,
+                                absl::Span<const int64>)>(&Reduce),
+          py::arg("builder"), py::arg("operands"), py::arg("init_values"),
+          py::arg("computation"), py::arg("dimensions_to_reduce"));
+  ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
+          py::arg("exponent_bits"), py::arg("mantissa_bits"));
+  ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding,
+          py::arg("operand"), py::arg("init_value"), py::arg("computation"),
+          py::arg("window_dimensions"), py::arg("window_strides"),
+          py::arg("base_dilations"), py::arg("window_dilations"),
+          py::arg("padding"));
+  ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
+  ops.def("Reshape",
+          static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>,
+                                absl::Span<const int64>)>(&Reshape),
+          py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
+  ops.def("Reshape",
+          static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape),
+          py::arg("operand"), py::arg("new_sizes"));
+  ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
+  ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
+          py::arg("shape"));
+  ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
+          py::arg("shape"));
+  ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"),
+          py::arg("updates"), py::arg("update_computation"),
+          py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false,
+          py::arg("unique_indices") = false);
+  ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
+          py::arg("on_false"));
+  ops.def("SelectAndScatterWithGeneralPadding",
+          &SelectAndScatterWithGeneralPadding, py::arg("operand"),
+          py::arg("select"), py::arg("window_dimensions"),
+          py::arg("window_strides"), py::arg("padding"), py::arg("source"),
+          py::arg("init_value"), py::arg("scatter"));
+  ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
+          py::arg("limit_indices"), py::arg("strides"));
+  ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
+          py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
+  ops.def(
+      "Sort",
+      [](XlaBuilder* builder, absl::Span<const XlaOp> operands,
+         absl::optional<const XlaComputation*> comparator, int64 dimension,
+         bool is_stable) -> XlaOp {
+        return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+          std::vector<PrimitiveType> operand_types;
+          for (const auto& operand : operands) {
+            TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
+            operand_types.push_back(operand_shape.element_type());
+          }
+
+          if (comparator) {
+            return Sort(operands, **comparator, dimension, is_stable);
+          } else {
+            return Sort(operands,
+                        CreateScalarLtComputation(operand_types, builder),
+                        dimension, is_stable);
+          }
+        });
+      },
+      py::arg("builder"), py::arg("operands"),
+      py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1,
+      py::arg("is_stable") = false);
+  ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
+  ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
+  ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
+          py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
+          py::arg("transpose_a"));
+  ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
+  ops.def("While", &While, py::arg("condition"), py::arg("body"),
+          py::arg("init"));
+
+  ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
+  ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
+  ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
+  ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
+  ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
+          py::arg("b"), py::arg("x"));
+
+#define BINARY_OP(op)                                                 \
+  ops.def(                                                            \
+      #op,                                                            \
+      [](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \
+        return dims ? op(a, b, *dims) : op(a, b);                     \
+      },                                                              \
+      py::arg("lhs"), py::arg("rhs"),                                 \
+      py::arg("broadcast_dimensions") = absl::nullopt)
+  BINARY_OP(Eq);
+  BINARY_OP(Ne);
+  BINARY_OP(Ge);
+  BINARY_OP(Gt);
+  BINARY_OP(Lt);
+  BINARY_OP(Le);
+  BINARY_OP(Add);
+  BINARY_OP(Sub);
+  BINARY_OP(Mul);
+  BINARY_OP(Div);
+  BINARY_OP(Rem);
+  BINARY_OP(Max);
+  BINARY_OP(Min);
+  BINARY_OP(And);
+  BINARY_OP(Or);
+  BINARY_OP(Xor);
+  BINARY_OP(ShiftLeft);
+  BINARY_OP(ShiftRightArithmetic);
+  BINARY_OP(ShiftRightLogical);
+  BINARY_OP(Atan2);
+  BINARY_OP(Pow);
+  BINARY_OP(Complex);
+#undef BINARY_OP
+
+#define UNARY_OP(op) ops.def(#op, &op)
+  UNARY_OP(Not);
+  UNARY_OP(PopulationCount);
+  UNARY_OP(Clz);
+  UNARY_OP(Abs);
+  UNARY_OP(Exp);
+  UNARY_OP(Expm1);
+  UNARY_OP(Floor);
+  UNARY_OP(Ceil);
+  UNARY_OP(Round);
+  UNARY_OP(Log);
+  UNARY_OP(Log1p);
+  UNARY_OP(Sign);
+  UNARY_OP(Cos);
+  UNARY_OP(Sin);
+  UNARY_OP(Tanh);
+  UNARY_OP(IsFinite);
+  UNARY_OP(Neg);
+  UNARY_OP(Sqrt);
+  UNARY_OP(Rsqrt);
+  UNARY_OP(Square);
+  UNARY_OP(Reciprocal);
+  UNARY_OP(Erfc);
+  UNARY_OP(Erf);
+  UNARY_OP(ErfInv);
+  UNARY_OP(Lgamma);
+  UNARY_OP(Digamma);
+  UNARY_OP(BesselI0e);
+  UNARY_OP(BesselI1e);
+  UNARY_OP(Acos);
+  UNARY_OP(Asin);
+  UNARY_OP(Atan);
+  UNARY_OP(Tan);
+  UNARY_OP(Acosh);
+  UNARY_OP(Asinh);
+  UNARY_OP(Atanh);
+  UNARY_OP(Cosh);
+  UNARY_OP(Sinh);
+  UNARY_OP(Real);
+  UNARY_OP(Imag);
+  UNARY_OP(Conj);
+#undef UNARY_OP
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/python/ops.h b/tensorflow/compiler/xla/python/ops.h
new file mode 100644
index 0000000..7fe34e9
--- /dev/null
+++ b/tensorflow/compiler/xla/python/ops.h
@@ -0,0 +1,27 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_
+#define TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_
+
+#include "pybind11/pybind11.h"
+
+namespace xla {
+
+void BuildOpsSubmodule(pybind11::module* m);
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_
diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index abf0937..fb7d7df 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -30,12 +30,6 @@
 #include "pybind11/pybind11.h"
 #include "pybind11/pytypes.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/lib/comparators.h"
-#include "tensorflow/compiler/xla/client/lib/math.h"
-#include "tensorflow/compiler/xla/client/lib/qr.h"
-#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
-#include "tensorflow/compiler/xla/client/lib/sorting.h"
-#include "tensorflow/compiler/xla/client/lib/svd.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -48,6 +42,7 @@
 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
 #include "tensorflow/compiler/xla/python/bfloat16.h"
 #include "tensorflow/compiler/xla/python/dlpack.h"
+#include "tensorflow/compiler/xla/python/ops.h"
 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
 #include "tensorflow/compiler/xla/python/types.h"
 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
@@ -306,321 +301,6 @@
   return result;
 }
 
-void BuildOpsSubmodule(py::module* m) {
-  // ops submodule, containing free functions that add operators to an
-  // XlaBuilder.
-  py::module ops = m->def_submodule("ops", "XLA operations");
-
-  py::enum_<TriangularSolveOptions::Transpose>(
-      ops, "TriangularSolveOptions_Transpose")
-      .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
-      .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
-      .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
-      .value("ADJOINT", TriangularSolveOptions::ADJOINT);
-
-  ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
-  ops.def(
-      "AllReduce",
-      static_cast<XlaOp (*)(
-          XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
-          const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
-          &AllReduce),
-      py::arg("operand"), py::arg("computation"),
-      py::arg("replica_groups") = py::list(),
-      py::arg("channel_id") = absl::nullopt,
-      py::arg("shape_with_layout") = absl::nullopt);
-  ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
-          py::arg("concat_dimension"), py::arg("split_count"),
-          py::arg("replica_groups") = py::list(),
-          py::arg("layout") = absl::nullopt);
-  ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
-          py::arg("source_target_pairs"));
-  ops.def("CreateToken", &CreateToken, py::arg("builder"));
-  ops.def("CrossReplicaSum",
-          static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
-              &CrossReplicaSum),
-          py::arg("operand"), py::arg("replica_groups") = py::list());
-  ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
-          py::arg("new_element_type"));
-  ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
-  ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
-          py::arg("shape"), py::arg("broadcast_dimensions"));
-  ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
-          py::arg("operands"));
-  ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
-  ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
-  ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
-  ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
-          py::arg("dimension"));
-  ops.def("Conditional",
-          static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
-                                absl::Span<const XlaOp>)>(&Conditional),
-          py::arg("branch_index"), py::arg("branch_computations"),
-          py::arg("branch_operands"));
-  ops.def("Conditional",
-          static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
-                                const XlaComputation&)>(&Conditional),
-          py::arg("predicate"), py::arg("true_operand"),
-          py::arg("true_computation"), py::arg("false_operand"),
-          py::arg("false_computation"));
-  ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
-  ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
-          py::arg("literal"));
-  ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
-          py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
-          py::arg("lhs_dilation"), py::arg("rhs_dilation"),
-          py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
-          py::arg("batch_group_count") = 1,
-          py::arg("precision_config") = nullptr);
-  ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
-          py::arg("new_element_type"));
-  ops.def(
-      "CustomCall",
-      [](XlaBuilder* builder, const py::bytes& call_target_name,
-         absl::Span<const XlaOp> operands, const Shape& shape,
-         const py::bytes& opaque) -> XlaOp {
-        return CustomCall(builder, call_target_name, operands, shape, opaque);
-      },
-      py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
-      py::arg("shape"), py::arg("opaque") = py::bytes(""));
-  ops.def(
-      "CustomCallWithLayout",
-      [](XlaBuilder* builder, const py::bytes& call_target_name,
-         absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
-         absl::Span<const Shape> operand_shapes_with_layout,
-         const py::bytes& opaque) -> XlaOp {
-        return CustomCallWithLayout(builder, call_target_name, operands,
-                                    shape_with_layout,
-                                    operand_shapes_with_layout, opaque);
-      },
-      py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
-      py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
-      py::arg("opaque") = py::bytes(""));
-  ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
-          py::arg("precision_config") = nullptr);
-  ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
-          py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
-  ops.def("DynamicSlice",
-          static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
-                                absl::Span<const int64>)>(&DynamicSlice),
-          py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
-  ops.def("DynamicUpdateSlice",
-          static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
-              &DynamicUpdateSlice),
-          py::arg("operand"), py::arg("update"), py::arg("start_indices"));
-
-  ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
-          py::arg("fft_length"));
-
-  ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
-          py::arg("dimension_numbers"), py::arg("slice_sizes"),
-          py::arg("indices_are_sorted") = false);
-  ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
-          py::arg("index"));
-  ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
-          py::arg("shape"), py::arg("config") = "");
-  ops.def("Iota",
-          static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota),
-          py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
-  ops.def("Iota",
-          static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota),
-          py::arg("builder"), py::arg("type"), py::arg("size"));
-  ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
-          py::arg("computation"), py::arg("dimensions"),
-          py::arg("static_operands") = py::list());
-  ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
-  ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
-          py::arg("token"), py::arg("shape_with_layout"),
-          py::arg("outfeed_config") = "");
-  ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
-          py::arg("padding_config"));
-  ops.def("Parameter",
-          static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&,
-                                const std::string&, const std::vector<bool>&)>(
-              &Parameter),
-          py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
-          py::arg("name") = "",
-          py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
-  ops.def(
-      "QR",
-      [](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
-        TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
-        return std::make_pair(qr.q, qr.r);
-      },
-      py::arg("operand"), py::arg("full_matrices"));
-  ops.def(
-      "Eigh",
-      [](XlaOp a, bool lower, int64 max_iter,
-         float epsilon) -> std::pair<XlaOp, XlaOp> {
-        auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon);
-        return std::make_pair(eigh.v, eigh.w);
-      },
-      py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100,
-      py::arg("epsilon") = 1e-6);
-  ops.def(
-      "SVD",
-      [](XlaOp a, int64 max_iter,
-         float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
-        auto svd = SVD(a, max_iter, epsilon);
-        return std::make_tuple(svd.u, svd.d, svd.v);
-      },
-      py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
-  ops.def("Reduce",
-          static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
-                                absl::Span<const XlaOp>, const XlaComputation&,
-                                absl::Span<const int64>)>(&Reduce),
-          py::arg("builder"), py::arg("operands"), py::arg("init_values"),
-          py::arg("computation"), py::arg("dimensions_to_reduce"));
-  ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
-          py::arg("exponent_bits"), py::arg("mantissa_bits"));
-  ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding,
-          py::arg("operand"), py::arg("init_value"), py::arg("computation"),
-          py::arg("window_dimensions"), py::arg("window_strides"),
-          py::arg("base_dilations"), py::arg("window_dilations"),
-          py::arg("padding"));
-  ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
-  ops.def("Reshape",
-          static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>,
-                                absl::Span<const int64>)>(&Reshape),
-          py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
-  ops.def("Reshape",
-          static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape),
-          py::arg("operand"), py::arg("new_sizes"));
-  ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
-  ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
-          py::arg("shape"));
-  ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
-          py::arg("shape"));
-  ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"),
-          py::arg("updates"), py::arg("update_computation"),
-          py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false,
-          py::arg("unique_indices") = false);
-  ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
-          py::arg("on_false"));
-  ops.def("SelectAndScatterWithGeneralPadding",
-          &SelectAndScatterWithGeneralPadding, py::arg("operand"),
-          py::arg("select"), py::arg("window_dimensions"),
-          py::arg("window_strides"), py::arg("padding"), py::arg("source"),
-          py::arg("init_value"), py::arg("scatter"));
-  ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
-          py::arg("limit_indices"), py::arg("strides"));
-  ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
-          py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
-  ops.def(
-      "Sort",
-      [](XlaBuilder* builder, absl::Span<const XlaOp> operands,
-         absl::optional<const XlaComputation*> comparator, int64 dimension,
-         bool is_stable) -> XlaOp {
-        return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-          std::vector<PrimitiveType> operand_types;
-          for (const auto& operand : operands) {
-            TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
-            operand_types.push_back(operand_shape.element_type());
-          }
-
-          if (comparator) {
-            return Sort(operands, **comparator, dimension, is_stable);
-          } else {
-            return Sort(operands,
-                        CreateScalarLtComputation(operand_types, builder),
-                        dimension, is_stable);
-          }
-        });
-      },
-      py::arg("builder"), py::arg("operands"),
-      py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1,
-      py::arg("is_stable") = false);
-  ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
-  ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
-  ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
-          py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
-          py::arg("transpose_a"));
-  ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
-  ops.def("While", &While, py::arg("condition"), py::arg("body"),
-          py::arg("init"));
-
-  ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
-  ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
-  ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
-  ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
-  ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
-          py::arg("b"), py::arg("x"));
-
-#define BINARY_OP(op)                                                 \
-  ops.def(                                                            \
-      #op,                                                            \
-      [](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \
-        return dims ? op(a, b, *dims) : op(a, b);                     \
-      },                                                              \
-      py::arg("lhs"), py::arg("rhs"),                                 \
-      py::arg("broadcast_dimensions") = absl::nullopt)
-  BINARY_OP(Eq);
-  BINARY_OP(Ne);
-  BINARY_OP(Ge);
-  BINARY_OP(Gt);
-  BINARY_OP(Lt);
-  BINARY_OP(Le);
-  BINARY_OP(Add);
-  BINARY_OP(Sub);
-  BINARY_OP(Mul);
-  BINARY_OP(Div);
-  BINARY_OP(Rem);
-  BINARY_OP(Max);
-  BINARY_OP(Min);
-  BINARY_OP(And);
-  BINARY_OP(Or);
-  BINARY_OP(Xor);
-  BINARY_OP(ShiftLeft);
-  BINARY_OP(ShiftRightArithmetic);
-  BINARY_OP(ShiftRightLogical);
-  BINARY_OP(Atan2);
-  BINARY_OP(Pow);
-  BINARY_OP(Complex);
-#undef BINARY_OP
-
-#define UNARY_OP(op) ops.def(#op, &op)
-  UNARY_OP(Not);
-  UNARY_OP(PopulationCount);
-  UNARY_OP(Clz);
-  UNARY_OP(Abs);
-  UNARY_OP(Exp);
-  UNARY_OP(Expm1);
-  UNARY_OP(Floor);
-  UNARY_OP(Ceil);
-  UNARY_OP(Round);
-  UNARY_OP(Log);
-  UNARY_OP(Log1p);
-  UNARY_OP(Sign);
-  UNARY_OP(Cos);
-  UNARY_OP(Sin);
-  UNARY_OP(Tanh);
-  UNARY_OP(IsFinite);
-  UNARY_OP(Neg);
-  UNARY_OP(Sqrt);
-  UNARY_OP(Rsqrt);
-  UNARY_OP(Square);
-  UNARY_OP(Reciprocal);
-  UNARY_OP(Erfc);
-  UNARY_OP(Erf);
-  UNARY_OP(ErfInv);
-  UNARY_OP(Lgamma);
-  UNARY_OP(Digamma);
-  UNARY_OP(BesselI0e);
-  UNARY_OP(BesselI1e);
-  UNARY_OP(Acos);
-  UNARY_OP(Asin);
-  UNARY_OP(Atan);
-  UNARY_OP(Tan);
-  UNARY_OP(Acosh);
-  UNARY_OP(Asinh);
-  UNARY_OP(Atanh);
-  UNARY_OP(Cosh);
-  UNARY_OP(Sinh);
-  UNARY_OP(Real);
-  UNARY_OP(Imag);
-  UNARY_OP(Conj);
-#undef UNARY_OP
-}
 
 void BuildProfilerSubmodule(py::module* m) {
   py::module profiler =