blob: 6a11d63487d14c6391cbb37ddd86311d40b4c1f1 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/ops.h"
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "pybind11/attr.h"
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/lu_decomposition.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/lib/svd.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace py = pybind11;
void BuildOpsSubmodule(py::module* m) {
// ops submodule, containing free functions that add operators to an
// XlaBuilder.
py::module ops = m->def_submodule("ops", "XLA operations");
py::enum_<TriangularSolveOptions::Transpose>(
ops, "TriangularSolveOptions_Transpose")
.value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
.value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
.value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
.value("ADJOINT", TriangularSolveOptions::ADJOINT);
py::enum_<RandomAlgorithm>(ops, "RandomAlgorithm")
.value("RNG_DEFAULT", RandomAlgorithm::RNG_DEFAULT)
.value("RNG_THREE_FRY", RandomAlgorithm::RNG_THREE_FRY)
.value("RNG_PHILOX", RandomAlgorithm::RNG_PHILOX);
py::enum_<CustomCallSchedule>(ops, "CustomCallSchedule")
.value("SCHEDULE_NONE", CustomCallSchedule::SCHEDULE_NONE)
.value("SCHEDULE_LATEST", CustomCallSchedule::SCHEDULE_LATEST)
.value("SCHEDULE_EARLIEST", CustomCallSchedule::SCHEDULE_EARLIEST);
ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
ops.def("AllGather", &AllGather, py::arg("operand"),
py::arg("all_gather_dimension"), py::arg("shard_count"),
py::arg("replica_groups") = py::list(),
py::arg("channel_id") = absl::nullopt,
py::arg("shape_with_layout") = absl::nullopt,
py::arg("use_global_device_ids") = absl::nullopt);
ops.def(
"AllReduce",
static_cast<XlaOp (*)(
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
&AllReduce),
py::arg("operand"), py::arg("computation"),
py::arg("replica_groups") = py::list(),
py::arg("channel_id") = absl::nullopt,
py::arg("shape_with_layout") = absl::nullopt);
ops.def("AllReduceScatter", &AllReduceScatter, py::arg("operand"),
py::arg("computation"), py::arg("scatter_dimension"),
py::arg("shard_count"), py::arg("replica_groups") = py::list(),
py::arg("channel_id") = absl::nullopt,
py::arg("layout") = absl::nullopt,
py::arg("use_global_device_ids") = absl::nullopt);
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
py::arg("concat_dimension"), py::arg("split_count"),
py::arg("replica_groups") = py::list(),
py::arg("layout") = absl::nullopt);
ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
py::arg("new_element_type"));
ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
py::arg("shape"), py::arg("broadcast_dimensions"));
ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
py::arg("operands"));
ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
py::arg("source_target_pairs"));
ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
py::arg("dimension"));
ops.def("Conditional",
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
absl::Span<const XlaOp>)>(&Conditional),
py::arg("branch_index"), py::arg("branch_computations"),
py::arg("branch_operands"));
ops.def("Conditional",
static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
const XlaComputation&)>(&Conditional),
py::arg("predicate"), py::arg("true_operand"),
py::arg("true_computation"), py::arg("false_operand"),
py::arg("false_computation"));
ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
py::arg("literal"));
ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
py::arg("lhs_dilation"), py::arg("rhs_dilation"),
py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
py::arg("batch_group_count") = 1,
py::arg("precision_config") = nullptr,
py::arg("preferred_element_type") = absl::nullopt);
ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
py::arg("new_element_type"));
ops.def("CreateToken", &CreateToken, py::arg("builder"));
ops.def("CrossReplicaSum",
static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
&CrossReplicaSum),
py::arg("operand"), py::arg("replica_groups") = py::list());
ops.def(
"CustomCall",
[](XlaBuilder* builder, const py::bytes& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const py::bytes& opaque, bool has_side_effect,
CustomCallSchedule schedule) -> XlaOp {
return CustomCall(builder, call_target_name, operands, shape, opaque,
has_side_effect, /*output_operand_aliasing=*/{},
/*literal=*/nullptr, schedule);
},
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
py::arg("shape"), py::arg("opaque") = py::bytes(""),
py::arg("has_side_effect") = false,
py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE);
ops.def(
"CustomCallWithLayout",
[](XlaBuilder* builder, const py::bytes& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout,
const py::bytes& opaque, bool has_side_effect,
CustomCallSchedule schedule) -> XlaOp {
return CustomCallWithLayout(
builder, call_target_name, operands, shape_with_layout,
operand_shapes_with_layout, opaque, has_side_effect,
/*output_operand_aliasing=*/{},
/*literal=*/nullptr, schedule);
},
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false,
py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE);
ops.def(
"CustomCallWithAliasing",
[](XlaBuilder* builder, const py::bytes& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout,
const py::bytes& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal, CustomCallSchedule schedule) -> XlaOp {
return CustomCallWithLayout(
builder, call_target_name, operands, shape_with_layout,
operand_shapes_with_layout, opaque, has_side_effect,
output_operand_aliasing, literal, schedule);
},
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false,
py::arg("output_operand_aliasing"), py::arg("literal") = nullptr,
py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE);
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
py::arg("precision_config") = nullptr,
py::arg("preferred_element_type") = absl::nullopt);
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr,
py::arg("preferred_element_type") = absl::nullopt);
ops.def(
"DynamicReshape",
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
absl::Span<const int64>, const std::vector<bool>&)>(
&DynamicReshape),
py::arg("operand"), py::arg("dim_sizes"), py::arg("new_size_bounds"),
py::arg("dims_are_dynamic"));
ops.def("DynamicSlice",
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
absl::Span<const int64>)>(&DynamicSlice),
py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
ops.def("DynamicUpdateSlice",
static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
&DynamicUpdateSlice),
py::arg("operand"), py::arg("update"), py::arg("start_indices"));
ops.def(
"Eigh",
[](XlaOp a, bool lower, int64 max_iter, float epsilon,
bool sort_eigenvalues) -> std::pair<XlaOp, XlaOp> {
auto eigh =
SelfAdjointEig(a, lower, max_iter, epsilon, sort_eigenvalues);
return std::make_pair(eigh.v, eigh.w);
},
py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 15,
py::arg("epsilon") = 1e-5, py::arg("sort_eigenvalues") = true);
ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
py::arg("fft_length"));
ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
py::arg("dimension_numbers"), py::arg("slice_sizes"),
py::arg("indices_are_sorted") = false);
ops.def("GetDimensionSize", &GetDimensionSize, py::arg("operand"),
py::arg("dimension"));
ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
py::arg("index"));
ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
py::arg("shape"), py::arg("config") = "");
ops.def("Iota",
static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota),
py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
ops.def("Iota",
static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota),
py::arg("builder"), py::arg("type"), py::arg("size"));
ops.def(
"LU",
[](XlaOp a) -> StatusOr<std::tuple<XlaOp, XlaOp, XlaOp>> {
LuDecompositionResult lu = LuDecomposition(a);
return std::make_tuple(lu.lu, lu.pivots, lu.permutation);
},
py::arg("operand"));
ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
py::arg("computation"), py::arg("dimensions"),
py::arg("static_operands") = py::list());
ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
py::arg("token"), py::arg("shape_with_layout"),
py::arg("outfeed_config") = "");
ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
py::arg("padding_config"));
ops.def("Parameter",
static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&,
const std::string&, const std::vector<bool>&)>(
&Parameter),
py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
py::arg("name") = "",
py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
ops.def(
"QR",
[](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
XlaOp q, r;
QrExplicit(a, full_matrices, q, r);
return std::make_pair(q, r);
},
py::arg("operand"), py::arg("full_matrices"));
ops.def("Reduce",
static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
absl::Span<const XlaOp>, const XlaComputation&,
absl::Span<const int64>)>(&Reduce),
py::arg("builder"), py::arg("operands"), py::arg("init_values"),
py::arg("computation"), py::arg("dimensions_to_reduce"));
ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
py::arg("exponent_bits"), py::arg("mantissa_bits"));
ops.def(
"ReduceWindowWithGeneralPadding",
static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&,
absl::Span<const int64>, absl::Span<const int64>,
absl::Span<const int64>, absl::Span<const int64>,
absl::Span<const std::pair<int64, int64>>)>(
&ReduceWindowWithGeneralPadding),
py::arg("operand"), py::arg("init_value"), py::arg("computation"),
py::arg("window_dimensions"), py::arg("window_strides"),
py::arg("base_dilations"), py::arg("window_dilations"),
py::arg("padding"));
ops.def(
"ReduceWindowWithGeneralPadding",
static_cast<XlaOp (*)(absl::Span<const XlaOp>, absl::Span<const XlaOp>,
const XlaComputation&, absl::Span<const int64>,
absl::Span<const int64>, absl::Span<const int64>,
absl::Span<const int64>,
absl::Span<const std::pair<int64, int64>>)>(
&ReduceWindowWithGeneralPadding),
py::arg("operands"), py::arg("init_values"), py::arg("computation"),
py::arg("window_dimensions"), py::arg("window_strides"),
py::arg("base_dilations"), py::arg("window_dilations"),
py::arg("padding"));
ops.def("RemoveDynamicDimension", &RemoveDynamicDimension, py::arg("operand"),
py::arg("dimension"));
ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
ops.def("Reshape",
static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>,
absl::Span<const int64>)>(&Reshape),
py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
ops.def("Reshape",
static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape),
py::arg("operand"), py::arg("new_sizes"));
ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
ops.def("RngBitGenerator", &RngBitGenerator, py::arg("algorithm"),
py::arg("initial_state"), py::arg("shape"));
ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
py::arg("shape"));
ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
py::arg("shape"));
ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"),
py::arg("updates"), py::arg("update_computation"),
py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false,
py::arg("unique_indices") = false);
ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
py::arg("on_false"));
ops.def("SelectAndScatterWithGeneralPadding",
&SelectAndScatterWithGeneralPadding, py::arg("operand"),
py::arg("select"), py::arg("window_dimensions"),
py::arg("window_strides"), py::arg("padding"), py::arg("source"),
py::arg("init_value"), py::arg("scatter"));
ops.def("SetDimensionSize", &SetDimensionSize, py::arg("operand"),
py::arg("val"), py::arg("dimension"));
ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
py::arg("limit_indices"), py::arg("strides"));
ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
ops.def(
"Sort",
[](XlaBuilder* builder, absl::Span<const XlaOp> operands,
absl::optional<const XlaComputation*> comparator, int64 dimension,
bool is_stable) -> XlaOp {
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
std::vector<PrimitiveType> operand_types;
for (const auto& operand : operands) {
TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
operand_types.push_back(operand_shape.element_type());
}
if (comparator) {
return Sort(operands, **comparator, dimension, is_stable);
} else {
return Sort(operands,
CreateScalarLtComputation(operand_types, builder),
dimension, is_stable);
}
});
},
py::arg("builder"), py::arg("operands"),
py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1,
py::arg("is_stable") = false);
ops.def(
"SVD",
[](XlaOp a, int64 max_iter,
float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
auto svd = SVD(a, max_iter, epsilon);
return std::make_tuple(svd.u, svd.d, svd.v);
},
py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
py::arg("transpose_a"));
ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
ops.def("While", &While, py::arg("condition"), py::arg("body"),
py::arg("init"));
ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
py::arg("b"), py::arg("x"));
ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q"));
#define BINARY_OP(op) \
ops.def( \
#op, \
[](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \
return dims ? op(a, b, *dims) : op(a, b); \
}, \
py::arg("lhs"), py::arg("rhs"), \
py::arg("broadcast_dimensions") = absl::nullopt)
BINARY_OP(Eq);
BINARY_OP(Ne);
BINARY_OP(Ge);
BINARY_OP(Gt);
BINARY_OP(Lt);
BINARY_OP(Le);
BINARY_OP(Add);
BINARY_OP(Sub);
BINARY_OP(Mul);
BINARY_OP(Div);
BINARY_OP(Rem);
BINARY_OP(Max);
BINARY_OP(Min);
BINARY_OP(And);
BINARY_OP(Or);
BINARY_OP(Xor);
BINARY_OP(ShiftLeft);
BINARY_OP(ShiftRightArithmetic);
BINARY_OP(ShiftRightLogical);
BINARY_OP(Atan2);
BINARY_OP(Pow);
BINARY_OP(Complex);
#undef BINARY_OP
#define UNARY_OP(op) ops.def(#op, &op)
UNARY_OP(Not);
UNARY_OP(PopulationCount);
UNARY_OP(Clz);
UNARY_OP(Abs);
UNARY_OP(Exp);
UNARY_OP(Expm1);
UNARY_OP(Floor);
UNARY_OP(Ceil);
UNARY_OP(Round);
UNARY_OP(Log);
UNARY_OP(Log1p);
UNARY_OP(Sign);
UNARY_OP(Cos);
UNARY_OP(Sin);
UNARY_OP(Tanh);
UNARY_OP(IsFinite);
UNARY_OP(Neg);
UNARY_OP(Sqrt);
UNARY_OP(Rsqrt);
UNARY_OP(Square);
UNARY_OP(Reciprocal);
UNARY_OP(Erfc);
UNARY_OP(Erf);
UNARY_OP(ErfInv);
UNARY_OP(Lgamma);
UNARY_OP(Digamma);
UNARY_OP(BesselI0e);
UNARY_OP(BesselI1e);
UNARY_OP(Acos);
UNARY_OP(Asin);
UNARY_OP(Atan);
UNARY_OP(Tan);
UNARY_OP(Acosh);
UNARY_OP(Asinh);
UNARY_OP(Atanh);
UNARY_OP(Cosh);
UNARY_OP(Sinh);
UNARY_OP(Real);
UNARY_OP(Imag);
UNARY_OP(Conj);
#undef UNARY_OP
}
} // namespace xla