[XLA:Python] Split bindings for XLA ops into a separate file. No functional changes.
This is partially to make xla.cc shorter and partially to parallelize its build time.
PiperOrigin-RevId: 313307447
Change-Id: I4f6de5723dbef4464599813bc9284b4ac9e271d7
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 5b4182b..3dcdc46 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -186,6 +186,32 @@
],
)
+cc_library(
+ name = "ops",
+ srcs = ["ops.cc"],
+ hdrs = ["ops.h"],
+ copts = [
+ "-fexceptions",
+ "-fno-strict-aliasing",
+ ],
+ features = ["-use_header_modules"],
+ deps = [
+ ":types",
+ "//tensorflow/compiler/xla:xla_data_proto_cc",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/client/lib:comparators",
+ "//tensorflow/compiler/xla/client/lib:math",
+ "//tensorflow/compiler/xla/client/lib:qr",
+ "//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
+ "//tensorflow/compiler/xla/client/lib:sorting",
+ "//tensorflow/compiler/xla/client/lib:svd",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ "@pybind11",
+ ],
+)
+
config_setting(
name = "enable_gpu",
values = {"define": "xla_python_enable_gpu=true"},
@@ -205,6 +231,7 @@
deps = [
":bfloat16",
":dlpack",
+ ":ops",
":python_ref_manager",
":types",
"@com_google_absl//absl/base",
@@ -228,12 +255,6 @@
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/lib:comparators",
- "//tensorflow/compiler/xla/client/lib:math",
- "//tensorflow/compiler/xla/client/lib:qr",
- "//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
- "//tensorflow/compiler/xla/client/lib:sorting",
- "//tensorflow/compiler/xla/client/lib:svd",
"//tensorflow/compiler/xla/pjrt:cpu_device",
"//tensorflow/compiler/xla/pjrt:nvidia_gpu_device",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc
new file mode 100644
index 0000000..89891d3
--- /dev/null
+++ b/tensorflow/compiler/xla/python/ops.cc
@@ -0,0 +1,356 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/python/ops.h"
+
+#include <string>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "pybind11/attr.h"
+#include "pybind11/pybind11.h"
+#include "tensorflow/compiler/xla/client/lib/comparators.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/lib/qr.h"
+#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
+#include "tensorflow/compiler/xla/client/lib/sorting.h"
+#include "tensorflow/compiler/xla/client/lib/svd.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/python/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+namespace py = pybind11;
+
+void BuildOpsSubmodule(py::module* m) {
+ // ops submodule, containing free functions that add operators to an
+ // XlaBuilder.
+ py::module ops = m->def_submodule("ops", "XLA operations");
+
+ py::enum_<TriangularSolveOptions::Transpose>(
+ ops, "TriangularSolveOptions_Transpose")
+ .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
+ .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
+ .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
+ .value("ADJOINT", TriangularSolveOptions::ADJOINT);
+
+ ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
+ ops.def(
+ "AllReduce",
+ static_cast<XlaOp (*)(
+ XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
+ const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
+ &AllReduce),
+ py::arg("operand"), py::arg("computation"),
+ py::arg("replica_groups") = py::list(),
+ py::arg("channel_id") = absl::nullopt,
+ py::arg("shape_with_layout") = absl::nullopt);
+ ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
+ py::arg("concat_dimension"), py::arg("split_count"),
+ py::arg("replica_groups") = py::list(),
+ py::arg("layout") = absl::nullopt);
+ ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
+ py::arg("source_target_pairs"));
+ ops.def("CreateToken", &CreateToken, py::arg("builder"));
+ ops.def("CrossReplicaSum",
+ static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
+ &CrossReplicaSum),
+ py::arg("operand"), py::arg("replica_groups") = py::list());
+ ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
+ py::arg("new_element_type"));
+ ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
+ ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
+ py::arg("shape"), py::arg("broadcast_dimensions"));
+ ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
+ py::arg("operands"));
+ ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
+ ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
+ ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
+ ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
+ py::arg("dimension"));
+ ops.def("Conditional",
+ static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
+ absl::Span<const XlaOp>)>(&Conditional),
+ py::arg("branch_index"), py::arg("branch_computations"),
+ py::arg("branch_operands"));
+ ops.def("Conditional",
+ static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
+ const XlaComputation&)>(&Conditional),
+ py::arg("predicate"), py::arg("true_operand"),
+ py::arg("true_computation"), py::arg("false_operand"),
+ py::arg("false_computation"));
+ ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
+ ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
+ py::arg("literal"));
+ ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
+ py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
+ py::arg("lhs_dilation"), py::arg("rhs_dilation"),
+ py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
+ py::arg("batch_group_count") = 1,
+ py::arg("precision_config") = nullptr);
+ ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
+ py::arg("new_element_type"));
+ ops.def(
+ "CustomCall",
+ [](XlaBuilder* builder, const py::bytes& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const py::bytes& opaque) -> XlaOp {
+ return CustomCall(builder, call_target_name, operands, shape, opaque);
+ },
+ py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
+ py::arg("shape"), py::arg("opaque") = py::bytes(""));
+ ops.def(
+ "CustomCallWithLayout",
+ [](XlaBuilder* builder, const py::bytes& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
+ absl::Span<const Shape> operand_shapes_with_layout,
+ const py::bytes& opaque) -> XlaOp {
+ return CustomCallWithLayout(builder, call_target_name, operands,
+ shape_with_layout,
+ operand_shapes_with_layout, opaque);
+ },
+ py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
+ py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
+ py::arg("opaque") = py::bytes(""));
+ ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
+ py::arg("precision_config") = nullptr);
+ ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
+ py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
+ ops.def("DynamicSlice",
+ static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
+ absl::Span<const int64>)>(&DynamicSlice),
+ py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
+ ops.def("DynamicUpdateSlice",
+ static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
+ &DynamicUpdateSlice),
+ py::arg("operand"), py::arg("update"), py::arg("start_indices"));
+
+ ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
+ py::arg("fft_length"));
+
+ ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
+ py::arg("dimension_numbers"), py::arg("slice_sizes"),
+ py::arg("indices_are_sorted") = false);
+ ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
+ py::arg("index"));
+ ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
+ py::arg("shape"), py::arg("config") = "");
+ ops.def("Iota",
+ static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota),
+ py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
+ ops.def("Iota",
+ static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota),
+ py::arg("builder"), py::arg("type"), py::arg("size"));
+ ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
+ py::arg("computation"), py::arg("dimensions"),
+ py::arg("static_operands") = py::list());
+ ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
+ ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
+ py::arg("token"), py::arg("shape_with_layout"),
+ py::arg("outfeed_config") = "");
+ ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
+ py::arg("padding_config"));
+ ops.def("Parameter",
+ static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&,
+ const std::string&, const std::vector<bool>&)>(
+ &Parameter),
+ py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
+ py::arg("name") = "",
+ py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
+ ops.def(
+ "QR",
+ [](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
+ TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
+ return std::make_pair(qr.q, qr.r);
+ },
+ py::arg("operand"), py::arg("full_matrices"));
+ ops.def(
+ "Eigh",
+ [](XlaOp a, bool lower, int64 max_iter,
+ float epsilon) -> std::pair<XlaOp, XlaOp> {
+ auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon);
+ return std::make_pair(eigh.v, eigh.w);
+ },
+ py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100,
+ py::arg("epsilon") = 1e-6);
+ ops.def(
+ "SVD",
+ [](XlaOp a, int64 max_iter,
+ float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
+ auto svd = SVD(a, max_iter, epsilon);
+ return std::make_tuple(svd.u, svd.d, svd.v);
+ },
+ py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
+ ops.def("Reduce",
+ static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
+ absl::Span<const XlaOp>, const XlaComputation&,
+ absl::Span<const int64>)>(&Reduce),
+ py::arg("builder"), py::arg("operands"), py::arg("init_values"),
+ py::arg("computation"), py::arg("dimensions_to_reduce"));
+ ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
+ py::arg("exponent_bits"), py::arg("mantissa_bits"));
+ ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding,
+ py::arg("operand"), py::arg("init_value"), py::arg("computation"),
+ py::arg("window_dimensions"), py::arg("window_strides"),
+ py::arg("base_dilations"), py::arg("window_dilations"),
+ py::arg("padding"));
+ ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
+ ops.def("Reshape",
+ static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>,
+ absl::Span<const int64>)>(&Reshape),
+ py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
+ ops.def("Reshape",
+ static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape),
+ py::arg("operand"), py::arg("new_sizes"));
+ ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
+ ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
+ py::arg("shape"));
+ ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
+ py::arg("shape"));
+ ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"),
+ py::arg("updates"), py::arg("update_computation"),
+ py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false,
+ py::arg("unique_indices") = false);
+ ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
+ py::arg("on_false"));
+ ops.def("SelectAndScatterWithGeneralPadding",
+ &SelectAndScatterWithGeneralPadding, py::arg("operand"),
+ py::arg("select"), py::arg("window_dimensions"),
+ py::arg("window_strides"), py::arg("padding"), py::arg("source"),
+ py::arg("init_value"), py::arg("scatter"));
+ ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
+ py::arg("limit_indices"), py::arg("strides"));
+ ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
+ py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
+ ops.def(
+ "Sort",
+ [](XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ absl::optional<const XlaComputation*> comparator, int64 dimension,
+ bool is_stable) -> XlaOp {
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ std::vector<PrimitiveType> operand_types;
+ for (const auto& operand : operands) {
+ TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
+ operand_types.push_back(operand_shape.element_type());
+ }
+
+ if (comparator) {
+ return Sort(operands, **comparator, dimension, is_stable);
+ } else {
+ return Sort(operands,
+ CreateScalarLtComputation(operand_types, builder),
+ dimension, is_stable);
+ }
+ });
+ },
+ py::arg("builder"), py::arg("operands"),
+ py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1,
+ py::arg("is_stable") = false);
+ ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
+ ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
+ ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
+ py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
+ py::arg("transpose_a"));
+ ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
+ ops.def("While", &While, py::arg("condition"), py::arg("body"),
+ py::arg("init"));
+
+ ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
+ ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
+ ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
+ ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
+ ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
+ py::arg("b"), py::arg("x"));
+
+#define BINARY_OP(op) \
+ ops.def( \
+ #op, \
+ [](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \
+ return dims ? op(a, b, *dims) : op(a, b); \
+ }, \
+ py::arg("lhs"), py::arg("rhs"), \
+ py::arg("broadcast_dimensions") = absl::nullopt)
+ BINARY_OP(Eq);
+ BINARY_OP(Ne);
+ BINARY_OP(Ge);
+ BINARY_OP(Gt);
+ BINARY_OP(Lt);
+ BINARY_OP(Le);
+ BINARY_OP(Add);
+ BINARY_OP(Sub);
+ BINARY_OP(Mul);
+ BINARY_OP(Div);
+ BINARY_OP(Rem);
+ BINARY_OP(Max);
+ BINARY_OP(Min);
+ BINARY_OP(And);
+ BINARY_OP(Or);
+ BINARY_OP(Xor);
+ BINARY_OP(ShiftLeft);
+ BINARY_OP(ShiftRightArithmetic);
+ BINARY_OP(ShiftRightLogical);
+ BINARY_OP(Atan2);
+ BINARY_OP(Pow);
+ BINARY_OP(Complex);
+#undef BINARY_OP
+
+#define UNARY_OP(op) ops.def(#op, &op)
+ UNARY_OP(Not);
+ UNARY_OP(PopulationCount);
+ UNARY_OP(Clz);
+ UNARY_OP(Abs);
+ UNARY_OP(Exp);
+ UNARY_OP(Expm1);
+ UNARY_OP(Floor);
+ UNARY_OP(Ceil);
+ UNARY_OP(Round);
+ UNARY_OP(Log);
+ UNARY_OP(Log1p);
+ UNARY_OP(Sign);
+ UNARY_OP(Cos);
+ UNARY_OP(Sin);
+ UNARY_OP(Tanh);
+ UNARY_OP(IsFinite);
+ UNARY_OP(Neg);
+ UNARY_OP(Sqrt);
+ UNARY_OP(Rsqrt);
+ UNARY_OP(Square);
+ UNARY_OP(Reciprocal);
+ UNARY_OP(Erfc);
+ UNARY_OP(Erf);
+ UNARY_OP(ErfInv);
+ UNARY_OP(Lgamma);
+ UNARY_OP(Digamma);
+ UNARY_OP(BesselI0e);
+ UNARY_OP(BesselI1e);
+ UNARY_OP(Acos);
+ UNARY_OP(Asin);
+ UNARY_OP(Atan);
+ UNARY_OP(Tan);
+ UNARY_OP(Acosh);
+ UNARY_OP(Asinh);
+ UNARY_OP(Atanh);
+ UNARY_OP(Cosh);
+ UNARY_OP(Sinh);
+ UNARY_OP(Real);
+ UNARY_OP(Imag);
+ UNARY_OP(Conj);
+#undef UNARY_OP
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/python/ops.h b/tensorflow/compiler/xla/python/ops.h
new file mode 100644
index 0000000..7fe34e9
--- /dev/null
+++ b/tensorflow/compiler/xla/python/ops.h
@@ -0,0 +1,27 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_
+#define TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_
+
+#include "pybind11/pybind11.h"
+
+namespace xla {
+
+void BuildOpsSubmodule(pybind11::module* m);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_
diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index abf0937..fb7d7df 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -30,12 +30,6 @@
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/lib/comparators.h"
-#include "tensorflow/compiler/xla/client/lib/math.h"
-#include "tensorflow/compiler/xla/client/lib/qr.h"
-#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
-#include "tensorflow/compiler/xla/client/lib/sorting.h"
-#include "tensorflow/compiler/xla/client/lib/svd.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -48,6 +42,7 @@
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/bfloat16.h"
#include "tensorflow/compiler/xla/python/dlpack.h"
+#include "tensorflow/compiler/xla/python/ops.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
@@ -306,321 +301,6 @@
return result;
}
-void BuildOpsSubmodule(py::module* m) {
- // ops submodule, containing free functions that add operators to an
- // XlaBuilder.
- py::module ops = m->def_submodule("ops", "XLA operations");
-
- py::enum_<TriangularSolveOptions::Transpose>(
- ops, "TriangularSolveOptions_Transpose")
- .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
- .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
- .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
- .value("ADJOINT", TriangularSolveOptions::ADJOINT);
-
- ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
- ops.def(
- "AllReduce",
- static_cast<XlaOp (*)(
- XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
- const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
- &AllReduce),
- py::arg("operand"), py::arg("computation"),
- py::arg("replica_groups") = py::list(),
- py::arg("channel_id") = absl::nullopt,
- py::arg("shape_with_layout") = absl::nullopt);
- ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
- py::arg("concat_dimension"), py::arg("split_count"),
- py::arg("replica_groups") = py::list(),
- py::arg("layout") = absl::nullopt);
- ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
- py::arg("source_target_pairs"));
- ops.def("CreateToken", &CreateToken, py::arg("builder"));
- ops.def("CrossReplicaSum",
- static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
- &CrossReplicaSum),
- py::arg("operand"), py::arg("replica_groups") = py::list());
- ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
- py::arg("new_element_type"));
- ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
- ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
- py::arg("shape"), py::arg("broadcast_dimensions"));
- ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
- py::arg("operands"));
- ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
- ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
- ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
- ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
- py::arg("dimension"));
- ops.def("Conditional",
- static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
- absl::Span<const XlaOp>)>(&Conditional),
- py::arg("branch_index"), py::arg("branch_computations"),
- py::arg("branch_operands"));
- ops.def("Conditional",
- static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
- const XlaComputation&)>(&Conditional),
- py::arg("predicate"), py::arg("true_operand"),
- py::arg("true_computation"), py::arg("false_operand"),
- py::arg("false_computation"));
- ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
- ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
- py::arg("literal"));
- ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
- py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
- py::arg("lhs_dilation"), py::arg("rhs_dilation"),
- py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
- py::arg("batch_group_count") = 1,
- py::arg("precision_config") = nullptr);
- ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
- py::arg("new_element_type"));
- ops.def(
- "CustomCall",
- [](XlaBuilder* builder, const py::bytes& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape,
- const py::bytes& opaque) -> XlaOp {
- return CustomCall(builder, call_target_name, operands, shape, opaque);
- },
- py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
- py::arg("shape"), py::arg("opaque") = py::bytes(""));
- ops.def(
- "CustomCallWithLayout",
- [](XlaBuilder* builder, const py::bytes& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
- absl::Span<const Shape> operand_shapes_with_layout,
- const py::bytes& opaque) -> XlaOp {
- return CustomCallWithLayout(builder, call_target_name, operands,
- shape_with_layout,
- operand_shapes_with_layout, opaque);
- },
- py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
- py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
- py::arg("opaque") = py::bytes(""));
- ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
- py::arg("precision_config") = nullptr);
- ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
- py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
- ops.def("DynamicSlice",
- static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
- absl::Span<const int64>)>(&DynamicSlice),
- py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
- ops.def("DynamicUpdateSlice",
- static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
- &DynamicUpdateSlice),
- py::arg("operand"), py::arg("update"), py::arg("start_indices"));
-
- ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
- py::arg("fft_length"));
-
- ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
- py::arg("dimension_numbers"), py::arg("slice_sizes"),
- py::arg("indices_are_sorted") = false);
- ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
- py::arg("index"));
- ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
- py::arg("shape"), py::arg("config") = "");
- ops.def("Iota",
- static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota),
- py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
- ops.def("Iota",
- static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota),
- py::arg("builder"), py::arg("type"), py::arg("size"));
- ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
- py::arg("computation"), py::arg("dimensions"),
- py::arg("static_operands") = py::list());
- ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
- ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
- py::arg("token"), py::arg("shape_with_layout"),
- py::arg("outfeed_config") = "");
- ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
- py::arg("padding_config"));
- ops.def("Parameter",
- static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&,
- const std::string&, const std::vector<bool>&)>(
- &Parameter),
- py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
- py::arg("name") = "",
- py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
- ops.def(
- "QR",
- [](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
- TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
- return std::make_pair(qr.q, qr.r);
- },
- py::arg("operand"), py::arg("full_matrices"));
- ops.def(
- "Eigh",
- [](XlaOp a, bool lower, int64 max_iter,
- float epsilon) -> std::pair<XlaOp, XlaOp> {
- auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon);
- return std::make_pair(eigh.v, eigh.w);
- },
- py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100,
- py::arg("epsilon") = 1e-6);
- ops.def(
- "SVD",
- [](XlaOp a, int64 max_iter,
- float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
- auto svd = SVD(a, max_iter, epsilon);
- return std::make_tuple(svd.u, svd.d, svd.v);
- },
- py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
- ops.def("Reduce",
- static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
- absl::Span<const XlaOp>, const XlaComputation&,
- absl::Span<const int64>)>(&Reduce),
- py::arg("builder"), py::arg("operands"), py::arg("init_values"),
- py::arg("computation"), py::arg("dimensions_to_reduce"));
- ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
- py::arg("exponent_bits"), py::arg("mantissa_bits"));
- ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding,
- py::arg("operand"), py::arg("init_value"), py::arg("computation"),
- py::arg("window_dimensions"), py::arg("window_strides"),
- py::arg("base_dilations"), py::arg("window_dilations"),
- py::arg("padding"));
- ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
- ops.def("Reshape",
- static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>,
- absl::Span<const int64>)>(&Reshape),
- py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
- ops.def("Reshape",
- static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape),
- py::arg("operand"), py::arg("new_sizes"));
- ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
- ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
- py::arg("shape"));
- ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
- py::arg("shape"));
- ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"),
- py::arg("updates"), py::arg("update_computation"),
- py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false,
- py::arg("unique_indices") = false);
- ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
- py::arg("on_false"));
- ops.def("SelectAndScatterWithGeneralPadding",
- &SelectAndScatterWithGeneralPadding, py::arg("operand"),
- py::arg("select"), py::arg("window_dimensions"),
- py::arg("window_strides"), py::arg("padding"), py::arg("source"),
- py::arg("init_value"), py::arg("scatter"));
- ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
- py::arg("limit_indices"), py::arg("strides"));
- ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
- py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
- ops.def(
- "Sort",
- [](XlaBuilder* builder, absl::Span<const XlaOp> operands,
- absl::optional<const XlaComputation*> comparator, int64 dimension,
- bool is_stable) -> XlaOp {
- return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- std::vector<PrimitiveType> operand_types;
- for (const auto& operand : operands) {
- TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
- operand_types.push_back(operand_shape.element_type());
- }
-
- if (comparator) {
- return Sort(operands, **comparator, dimension, is_stable);
- } else {
- return Sort(operands,
- CreateScalarLtComputation(operand_types, builder),
- dimension, is_stable);
- }
- });
- },
- py::arg("builder"), py::arg("operands"),
- py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1,
- py::arg("is_stable") = false);
- ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
- ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
- ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
- py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
- py::arg("transpose_a"));
- ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
- ops.def("While", &While, py::arg("condition"), py::arg("body"),
- py::arg("init"));
-
- ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
- ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
- ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
- ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
- ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
- py::arg("b"), py::arg("x"));
-
-#define BINARY_OP(op) \
- ops.def( \
- #op, \
- [](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \
- return dims ? op(a, b, *dims) : op(a, b); \
- }, \
- py::arg("lhs"), py::arg("rhs"), \
- py::arg("broadcast_dimensions") = absl::nullopt)
- BINARY_OP(Eq);
- BINARY_OP(Ne);
- BINARY_OP(Ge);
- BINARY_OP(Gt);
- BINARY_OP(Lt);
- BINARY_OP(Le);
- BINARY_OP(Add);
- BINARY_OP(Sub);
- BINARY_OP(Mul);
- BINARY_OP(Div);
- BINARY_OP(Rem);
- BINARY_OP(Max);
- BINARY_OP(Min);
- BINARY_OP(And);
- BINARY_OP(Or);
- BINARY_OP(Xor);
- BINARY_OP(ShiftLeft);
- BINARY_OP(ShiftRightArithmetic);
- BINARY_OP(ShiftRightLogical);
- BINARY_OP(Atan2);
- BINARY_OP(Pow);
- BINARY_OP(Complex);
-#undef BINARY_OP
-
-#define UNARY_OP(op) ops.def(#op, &op)
- UNARY_OP(Not);
- UNARY_OP(PopulationCount);
- UNARY_OP(Clz);
- UNARY_OP(Abs);
- UNARY_OP(Exp);
- UNARY_OP(Expm1);
- UNARY_OP(Floor);
- UNARY_OP(Ceil);
- UNARY_OP(Round);
- UNARY_OP(Log);
- UNARY_OP(Log1p);
- UNARY_OP(Sign);
- UNARY_OP(Cos);
- UNARY_OP(Sin);
- UNARY_OP(Tanh);
- UNARY_OP(IsFinite);
- UNARY_OP(Neg);
- UNARY_OP(Sqrt);
- UNARY_OP(Rsqrt);
- UNARY_OP(Square);
- UNARY_OP(Reciprocal);
- UNARY_OP(Erfc);
- UNARY_OP(Erf);
- UNARY_OP(ErfInv);
- UNARY_OP(Lgamma);
- UNARY_OP(Digamma);
- UNARY_OP(BesselI0e);
- UNARY_OP(BesselI1e);
- UNARY_OP(Acos);
- UNARY_OP(Asin);
- UNARY_OP(Atan);
- UNARY_OP(Tan);
- UNARY_OP(Acosh);
- UNARY_OP(Asinh);
- UNARY_OP(Atanh);
- UNARY_OP(Cosh);
- UNARY_OP(Sinh);
- UNARY_OP(Real);
- UNARY_OP(Imag);
- UNARY_OP(Conj);
-#undef UNARY_OP
-}
void BuildProfilerSubmodule(py::module* m) {
py::module profiler =