blob: 95413792a29010fdd37cd23885b5b7f4e18bbbdc [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
// Note: Most of the operators defined in this module are used by the jax2tf
// converter (see go/jax2tf for details) and are used in SavedModel produced
// by jax2tf. Hence, we need to maintain backwards compatibility for these
// operators. Please reach out to the JAX team if you want to make changes.
namespace tensorflow {
namespace {
// Helper shape function for operators that return an output with the same rank
// as their first input.
Status UnchangedRank(shape_inference::InferenceContext* c) {
if (c->RankKnown(c->input(0))) {
c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
} else {
c->set_output(0, c->input(0));
}
return OkStatus();
}
REGISTER_OP("XlaBroadcastHelper")
.Input("lhs: T")
.Input("rhs: T")
.Input("broadcast_dims: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Output("lhs_output: T")
.Output("rhs_output: T")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Helper operator for performing XLA-style broadcasts
Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
for binary operators.
lhs: the LHS input tensor
rhs: the RHS input tensor
broadcast_dims: an XLA-style broadcast dimension specification
lhs_output: the broadcasted LHS tensor
rhs_output: the broadcasted RHS tensor
)doc");
REGISTER_OP("XlaSelfAdjointEig")
.Input("a: T")
.Attr("lower: bool")
.Attr("max_iter: int")
.Attr("epsilon: float")
.Output("w: T")
.Output("v: T")
.SetShapeFn(shape_inference::UnknownShape)
.Attr("T: numbertype")
.Doc(R"doc(
Computes the eigen decomposition of a batch of self-adjoint matrices
(Note: Only real inputs are supported).
Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
i=0...N-1.
a: the input tensor.
lower: a boolean specifies whether the calculation is done with the lower
triangular part or the upper triangular part.
max_iter: maximum number of sweep update, i.e., the whole lower triangular
part or upper triangular part based on parameter lower. Heuristically, it has
been argued that approximately logN sweeps are needed in practice (Ref: Golub &
van Loan "Matrix Computation").
epsilon: the tolerance ratio.
w: The eigenvalues in ascending order, each repeated according to its
multiplicity.
v: The column v[..., :, i] is the normalized eigenvector corresponding to the
eigenvalue w[..., i].
)doc");
REGISTER_OP("XlaSvd")
.Input("a: T")
.Attr("max_iter: int")
.Attr("epsilon: float")
.Attr("precision_config: string")
.Output("s: T")
.Output("u: T")
.Output("v: T")
.SetShapeFn(shape_inference::UnknownShape)
.Attr("T: numbertype")
.Doc(R"doc(
Computes the eigen decomposition of a batch of self-adjoint matrices
(Note: Only real inputs are supported).
Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
a: the input tensor.
max_iter: maximum number of sweep update, i.e., the whole lower triangular
part or upper triangular part based on parameter lower. Heuristically, it has
been argued that approximately log(min (M, N)) sweeps are needed in practice
(Ref: Golub & van Loan "Matrix Computation").
epsilon: the tolerance ratio.
precision_config: a serialized xla::PrecisionConfig proto.
s: Singular values. The values are sorted in reverse order of magnitude, so
s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
u: Left singular vectors.
v: Right singular vectors.
)doc");
REGISTER_OP("XlaConv")
.Input("lhs: T")
.Input("rhs: T")
.Input("window_strides: Tindices")
.Input("padding: Tindices")
.Input("lhs_dilation: Tindices")
.Input("rhs_dilation: Tindices")
.Input("feature_group_count: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("dimension_numbers: string")
.Attr("precision_config: string")
.Output("output: T")
.SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA ConvGeneralDilated operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
.
lhs: the input tensor
rhs: the kernel tensor
window_strides: the inter-window strides
padding: the padding to apply at the start and end of each input dimensions
lhs_dilation: dilation to apply between input elements
rhs_dilation: dilation to apply between kernel elements
feature_group_count: number of feature groups for grouped convolution.
dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
)doc");
REGISTER_OP("XlaConvV2")
.Input("lhs: LhsT")
.Input("rhs: RhsT")
.Input("window_strides: Tindices")
.Input("padding: Tindices")
.Input("lhs_dilation: Tindices")
.Input("rhs_dilation: Tindices")
.Input("feature_group_count: Tindices")
.Attr("LhsT: numbertype")
.Attr("RhsT: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("dimension_numbers: string")
.Attr("precision_config: string")
.Attr("preferred_element_type: numbertype")
.Attr("batch_group_count: int = 1")
.Output("output: preferred_element_type")
.SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA ConvGeneralDilated operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
.
lhs: input tensor
rhs: kernel tensor
window_strides: inter-window strides
padding: padding to apply at the start and end of each input dimensions
lhs_dilation: dilation to apply between input elements
rhs_dilation: dilation to apply between kernel elements
feature_group_count: number of feature groups for grouped convolution.
dimension_numbers: serialized xla::ConvolutionDimensionNumbers proto.
precision_config: serialized xla::PrecisionConfig proto.
preferred_element_type: type of the tensor.
batch_group_count: number of batch groups or grouped filters.
)doc");
static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
if (!c->RankKnown(lhs_shape_handle) || !c->RankKnown(rhs_shape_handle)) {
return shape_inference::UnknownShape(c);
}
string dimension_numbers_string;
TF_RETURN_IF_ERROR(
c->GetAttr("dimension_numbers", &dimension_numbers_string));
xla::DotDimensionNumbers dimension_numbers;
dimension_numbers.ParseFromString(dimension_numbers_string);
// Check that number of contracting dimensions match.
if (dimension_numbers.lhs_contracting_dimensions_size() !=
dimension_numbers.rhs_contracting_dimensions_size())
return errors::InvalidArgument(
"Must specify the same number of contracting dimensions for lhs "
"and rhs. Got: ",
dimension_numbers.lhs_contracting_dimensions_size(), " and ",
dimension_numbers.rhs_contracting_dimensions_size());
// Check that contracting dimension sizes match.
for (int64_t i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
++i) {
const int64_t lhs_contracting_dimension =
dimension_numbers.lhs_contracting_dimensions(i);
const int64_t rhs_contracting_dimension =
dimension_numbers.rhs_contracting_dimensions(i);
shape_inference::DimensionHandle unused;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
c->Merge(c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension),
c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension),
&unused),
"For contracting dimension ", i, " which is lhs dimension ",
lhs_contracting_dimension, " and rhs dimension ",
rhs_contracting_dimension);
}
// Check that number of batch dimensions match.
if (dimension_numbers.lhs_batch_dimensions_size() !=
dimension_numbers.rhs_batch_dimensions_size())
return errors::InvalidArgument(
"Must specify the same number of batch dimensions for lhs "
"and rhs. Got: ",
dimension_numbers.lhs_batch_dimensions_size(), " and ",
dimension_numbers.rhs_batch_dimensions_size());
// The ranks of lhs and rhs are decremented by the number of contractions,
// and added for the rank of the result. When an input tensor
// is a scalar, its contribution to the rank of the result is 0. Generate
// the result dimensions in order, batch dimensions, then the
// non-contracted and non-batch lhs and rhs dimensions.
std::vector<shape_inference::DimensionHandle> output_dims;
// Check that batch dimension sizes match, and add them to output_dims.
for (int64_t i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
const int64_t lhs_batch_dimension =
dimension_numbers.lhs_batch_dimensions(i);
const int64_t rhs_batch_dimension =
dimension_numbers.rhs_batch_dimensions(i);
shape_inference::DimensionHandle out;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
c->Merge(c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension),
c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension), &out),
"For batch dimension ", i, " which is lhs dimension ",
lhs_batch_dimension, " and rhs dimension ", rhs_batch_dimension);
output_dims.emplace_back(out);
}
const int32_t lhs_rank = c->Rank(lhs_shape_handle);
for (int64_t i = 0; i < lhs_rank; ++i) {
if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
i) ||
absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
continue;
}
output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
}
const int32_t rhs_rank = c->Rank(rhs_shape_handle);
for (int64_t i = 0; i < rhs_rank; ++i) {
if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
i) ||
absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
continue;
}
output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
}
c->set_output(0, c->MakeShape(output_dims));
return OkStatus();
}
REGISTER_OP("XlaDot")
.Input("lhs: T")
.Input("rhs: T")
.Attr("T: numbertype")
.Attr("dimension_numbers: string")
.Attr("precision_config: string")
.Output("output: T")
.SetShapeFn(XlaDotShapeFunction)
.Doc(R"doc(
Wraps the XLA DotGeneral operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
.
lhs: the LHS tensor
rhs: the RHS tensor
dimension_numbers: a serialized xla::DotDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
)doc");
REGISTER_OP("XlaDotV2")
.Input("lhs: LhsT")
.Input("rhs: RhsT")
.Attr("LhsT: numbertype")
.Attr("RhsT: numbertype")
.Attr("dimension_numbers: string")
.Attr("precision_config: string")
.Attr("preferred_element_type: numbertype")
.Output("output: preferred_element_type")
.SetShapeFn(XlaDotShapeFunction)
.Doc(R"doc(
Wraps the XLA DotGeneral operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
.
lhs: the LHS tensor
rhs: the RHS tensor
dimension_numbers: a serialized xla::DotDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
preferred_element_type: The type of the tensor.
)doc");
REGISTER_OP("XlaSetBound")
.Input("input: int32")
.Input("bound: int32")
.Output("output: int32")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(
R"doc(Set a bound for the given input value as a hint to Xla compiler,
returns the same value.
)doc");
REGISTER_OP("XlaSetDynamicDimensionSize")
.Input("input: T")
.Input("dim_index: int32")
.Input("size: int32")
.Output("output: T")
.Attr("T: type")
// Use unknown shape to prevent constant folding.
.SetShapeFn(shape_inference::UnknownShape)
.Doc(
R"doc(Make a static dimension into a xla bounded dynamic dimension.
The current static dimension size will become the bound and the second
operand becomes the dynamic size of the dimension.)doc");
REGISTER_OP("XlaRemoveDynamicDimensionSize")
.Input("input: T")
.Input("dim_index: int32")
.Output("output: T")
.Attr("T: type")
// Use unknown shape to prevent constant folding.
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Inverse of XlaSetDynamicDimensionSize.
Make an xla bounded dynamic dimension into a static dimension. The bound of the
size of dimension `dim_index` becomes the static dimension size.
)doc");
REGISTER_OP("XlaDynamicSlice")
.Input("input: T")
.Input("start_indices: Tindices")
.Input("size_indices: Tindices")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
shape_inference::ShapeHandle size_indices_shape = c->input(2);
if (!c->RankKnown(size_indices_shape)) {
return UnchangedRank(c);
}
if (c->Rank(size_indices_shape) != 1) {
return errors::InvalidArgument("size_indices must be a 1D tensor");
}
shape_inference::ShapeHandle size_indices_value;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &size_indices_value));
if (!c->RankKnown(size_indices_value)) {
// If we cannot tell the rank of the output from the value of
// size_indices, perhaps we can find it from the rank of first operand.
return UnchangedRank(c);
}
c->set_output(0, size_indices_value);
return OkStatus();
})
.Doc(R"doc(
Wraps the XLA DynamicSlice operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
.
DynamicSlice extracts a sub-array from the input array at dynamic
start_indices. The size of the slice in each dimension is passed in
size_indices, which specify the end point of exclusive slice intervals in each
dimension -- [start, start + size). The shape of start_indices must have rank 1,
with dimension size equal to the rank of operand.
input: A `Tensor` of type T.
start_indices: Rank 1 tensor of N integers containing the starting indices of
the slice for each dimension. Value must be greater than or equal to zero.
start_indices: List of N integers containing the slice size for each
dimension. Each value must be strictly greater than zero, and start + size
must be less than or equal to the size of the dimension to avoid
implementation defined behavior.
)doc");
REGISTER_OP("XlaDynamicUpdateSlice")
.Input("input: T")
.Input("update: T")
.Input("indices: Tindices")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Wraps the XLA DynamicUpdateSlice operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
.
XlaDynamicUpdateSlice generates a result which is the value of the `input`
operand, with a slice update overwritten at `indices`. The shape of `update`
determines the shape of the sub-array of the result which is updated. The shape
of indices must be rank == 1, with dimension size equal to the rank of `input`.
Handling of out-of-bounds slice indices is implementation-defined.
input: A `Tensor` of type T.
indices: A vector of indices into `input`. Must have length equal to the rank of
`input`.
update: A `Tensor` of type T. Same rank as `input`.
output: A `Tensor` of type T.
)doc");
// TODO(b/37549631) setting the If Op to always be stateful is too
// conservative.
REGISTER_OP("XlaIf")
.Input("cond: Tcond")
.Input("inputs: Tin")
.Output("output: Tout")
.Attr("Tcond: type")
.Attr("then_branch: func")
.Attr("else_branch: func")
.Attr("Tin: list(type) >= 0")
.Attr("Tout: list(type) >= 0")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
output = cond ? then_branch(inputs) : else_branch(inputs).
cond: A boolean scalar.
inputs: A list of input tensors.
output: A list of tensors returned by either then_branch(inputs) or
else_branch(inputs). The input shapes of the then_branch and
else_branch must match.
then_branch: A function takes 'inputs' and returns a list of tensors,
whose types are the same as what else_branch returns.
else_branch: A function takes 'inputs' and returns a list of tensors.
whose types are the same as what then_branch returns.
)doc");
REGISTER_OP("XlaPad")
.Input("input: T")
.Input("padding_value: T")
.Input("padding_low: Tindices")
.Input("padding_high: Tindices")
.Input("padding_interior: Tindices")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle input_shape_handle = c->input(0);
if (!c->RankKnown(input_shape_handle)) {
return UnchangedRank(c);
}
const int32_t op_rank = c->Rank(input_shape_handle);
shape_inference::ShapeHandle padding_shape_handle = c->input(1);
if (c->RankKnown(padding_shape_handle) &&
c->Rank(padding_shape_handle) != 0) {
return errors::InvalidArgument(
"padding_value input must be scalar, found rank ",
c->Rank(padding_shape_handle));
}
const Tensor* padding_low_tensor = c->input_tensor(2);
const Tensor* padding_high_tensor = c->input_tensor(3);
const Tensor* padding_interior_tensor = c->input_tensor(4);
if (padding_low_tensor == nullptr || padding_high_tensor == nullptr ||
padding_interior_tensor == nullptr) {
return UnchangedRank(c);
}
if (padding_low_tensor->shape().dims() != 1 ||
padding_low_tensor->shape().dim_size(0) != op_rank) {
return errors::InvalidArgument(
"padding_low must be a 1D tensor of size ", op_rank);
}
if (padding_high_tensor->shape().dims() != 1 ||
padding_high_tensor->shape().dim_size(0) != op_rank) {
return errors::InvalidArgument(
"padding_high must be a 1D tensor of size ", op_rank);
}
if (padding_interior_tensor->shape().dims() != 1 ||
padding_interior_tensor->shape().dim_size(0) != op_rank) {
return errors::InvalidArgument(
"padding_interior must be a 1D tensor of size ", op_rank);
}
std::vector<shape_inference::DimensionHandle> output_dims;
output_dims.reserve(op_rank);
for (int64_t i = 0; i < op_rank; ++i) {
int64_t low, high, interior;
TF_RETURN_IF_ERROR(c->GetScalarFromTensor(padding_low_tensor, i, &low));
TF_RETURN_IF_ERROR(
c->GetScalarFromTensor(padding_high_tensor, i, &high));
TF_RETURN_IF_ERROR(
c->GetScalarFromTensor(padding_interior_tensor, i, &interior));
if (interior < 0) {
return errors::InvalidArgument(
"padding_interior must contain only non-negative values, found ",
interior);
}
shape_inference::DimensionHandle orig_size_handle =
c->Dim(input_shape_handle, i);
if (c->ValueKnown(orig_size_handle)) {
auto orig_dim = c->Value(orig_size_handle);
int64_t new_dim = orig_dim + low + high;
if (orig_dim > 0) {
new_dim += interior * (orig_dim - 1);
}
if (new_dim < 0) {
return errors::InvalidArgument(
"resulting padded dimension has negative size ", new_dim);
}
output_dims.emplace_back(c->MakeDim(new_dim));
} else {
output_dims.emplace_back(c->UnknownDim());
}
}
c->set_output(0, c->MakeShape(output_dims));
return OkStatus();
})
.Doc(R"doc(
Wraps the XLA Pad operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#pad
.
input: A `Tensor` of type T.
padding_value: A scalar `Tensor` of type T.
padding_low: the padding to apply at the start of each input dimensions. Must
be a compile-time constant 1D tensor of length equal to rank of input.
padding_high: the padding to apply at the end of each input dimension. Must
be a compile-time constant 1D tensor of length equal to rank of input.
padding_interior: the padding to apply between each input element. Must
be a compile-time constant 1D tensor of length equal to rank of input,
containing only non-negative values.
output: A `Tensor` of type T.
)doc");
REGISTER_OP("XlaRecv")
.Output("tensor: dtype")
.Attr("dtype: type")
.Attr("tensor_name: string")
.Attr("shape: shape")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
TensorShape shape_attr;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
c->set_output(0, s);
return OkStatus();
})
.Doc(R"doc(
Receives the named tensor from another XLA computation. Wraps the XLA Recv
operator documented at
https://www.tensorflow.org/performance/xla/operation_semantics#recv .
tensor: The tensor to receive.
dtype: The type of the tensor.
tensor_name: A string key that identifies the channel.
shape: The shape of the tensor.
)doc");
REGISTER_OP("XlaReduce")
.Input("input: T")
.Input("init_value: T")
.Attr("T: {numbertype, bool}")
.Attr("dimensions_to_reduce: list(int)")
.Attr("reducer: func")
.Output("output: T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
if (c->RankKnown(c->input(0))) {
int rank = c->Rank(c->input(0));
std::vector<int64_t> dimensions_to_reduce;
TF_RETURN_IF_ERROR(
c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
std::set<int64_t> dims_set(dimensions_to_reduce.begin(),
dimensions_to_reduce.end());
auto dim_in_range = [rank](int64_t dim) {
return dim >= 0 && dim < rank;
};
const int dimensions_to_reduce_size = dimensions_to_reduce.size();
if (rank < dimensions_to_reduce_size ||
dims_set.size() != dimensions_to_reduce.size() ||
!absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
return errors::InvalidArgument(
"Invalid dimensions_to_reduce argument to XlaReduce");
}
c->set_output(
0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
} else {
c->set_output(0, c->input(0));
}
return OkStatus();
})
.Doc(R"doc(
Wraps the XLA Reduce operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
input: the input tensor
init_value: a scalar representing the initial value for the reduction
reducer: a reducer function to apply
dimensions_to_reduce: dimension numbers over which to reduce
)doc");
REGISTER_OP("XlaVariadicReduce")
.Input("input: N * T")
.Input("init_value: N * T")
.Attr("N: int >= 1")
.Attr("T: {numbertype, bool}")
.Attr("dimensions_to_reduce: list(int)")
.Attr("reducer: func")
.Output("output: N * T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
int n;
TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c->MergeInput(i, c->input(j));
}
}
if (c->RankKnown(c->input(0))) {
int rank = c->Rank(c->input(0));
std::vector<int64_t> dimensions_to_reduce;
TF_RETURN_IF_ERROR(
c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
std::set<int64_t> dims_set(dimensions_to_reduce.begin(),
dimensions_to_reduce.end());
auto dim_in_range = [rank](int64_t dim) {
return dim >= 0 && dim < rank;
};
const int dimensions_to_reduce_size = dimensions_to_reduce.size();
if (rank < dimensions_to_reduce_size ||
dims_set.size() != dimensions_to_reduce.size() ||
!absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
return errors::InvalidArgument(
"Invalid dimensions_to_reduce argument to XlaVariadicReduce");
}
for (int i = 0; i < n; i++) {
c->set_output(
i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
}
} else {
for (int i = 0; i < n; i++) {
c->set_output(i, c->input(i));
}
}
return OkStatus();
})
.Doc(R"doc(
Wraps the variadic XLA Reduce operator.
Semantics are documented at
https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
This version is limited to operands of the same dtype.
XlaVariadicReduceV2 is a version that supports heterogeneous operands.
input: the input tensor(s)
init_value: scalar initial value(s) for the reduction
reducer: a reducer function to apply
dimensions_to_reduce: dimension numbers over which to reduce
)doc");
REGISTER_OP("XlaVariadicReduceV2")
.Input("inputs: T")
.Input("init_values: T")
.Attr("T: list(type) >= 1")
.Attr("dimensions_to_reduce: list(int)")
.Attr("reducer: func")
.Output("outputs: T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<shape_inference::ShapeHandle> input_shapes;
TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
std::vector<shape_inference::ShapeHandle> init_values_shapes;
TF_RETURN_IF_ERROR(c->input("init_values", &init_values_shapes));
const int nr_inputs = input_shapes.size();
if (nr_inputs != init_values_shapes.size()) {
return errors::InvalidArgument(
"Must specify the same number of inputs and init_values. ", "Got ",
nr_inputs, " and ", init_values_shapes.size());
}
if (nr_inputs == 0) {
return errors::InvalidArgument("Must specify at least one input");
}
shape_inference::ShapeHandle input_shape = input_shapes[0];
for (int i = 1; i < nr_inputs; ++i) {
shape_inference::ShapeHandle merged;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
c->Merge(input_shape, input_shapes[i], &merged),
"All inputs must have the same shape. Input ", i,
" (zero-based) has shape ", c->DebugString(input_shapes[i]),
" incompatible with the shape ", "inferred from previous inputs ",
c->DebugString(input_shape));
input_shape = merged;
}
// All outputs have the same shape
shape_inference::ShapeHandle output_shape = c->UnknownShape();
if (c->RankKnown(input_shape)) {
int rank = c->Rank(input_shape);
std::vector<int64_t> dimensions_to_reduce;
TF_RETURN_IF_ERROR(
c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
std::set<int64_t> dims_set(dimensions_to_reduce.begin(),
dimensions_to_reduce.end());
auto dim_in_range = [rank](int64_t dim) {
return dim >= 0 && dim < rank;
};
const int dimensions_to_reduce_size = dimensions_to_reduce.size();
if (rank < dimensions_to_reduce_size ||
dims_set.size() != dimensions_to_reduce.size() ||
!absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
return errors::InvalidArgument(
"Invalid dimensions_to_reduce argument to XlaVariadicReduceV2");
}
std::vector<shape_inference::DimensionHandle> output_dims;
for (int64_t i = 0; i < rank; ++i) {
if (dims_set.find(i) == dims_set.end()) {
output_dims.emplace_back(c->Dim(input_shape, i));
}
}
output_shape = c->MakeShape(output_dims);
}
for (int i = 0; i < nr_inputs; ++i) {
c->set_output(i, output_shape);
}
return OkStatus();
})
.Doc(R"doc(
Wraps the variadic XLA Reduce operator.
Semantics are documented at
https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
This is an expanded version of XlaVariadicReduce, with support for
operands of different dtypes, and improved shape inference.
inputs: the input tensor(s)
init_values: scalar initial value(s) for the reduction
reducer: a reducer function to apply
dimensions_to_reduce: dimension numbers over which to reduce
)doc");
REGISTER_OP("XlaReduceWindow")
.Input("input: T")
.Input("init_value: T")
.Input("window_dimensions: Tindices")
.Input("window_strides: Tindices")
.Input("base_dilations: Tindices")
.Input("window_dilations: Tindices")
.Input("padding: Tindices")
.Attr("T: {numbertype, bool}")
.Attr("Tindices: {int32, int64}")
.Attr("computation: func")
.Output("output: T")
.SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA ReduceWindow operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
input: the input tensor
init_value: a scalar representing the initial value for the reduction
computation: a reducer function to apply
window_dimensions: the shape of the window
window_strides: the inter-window strides
padding: the padding to apply at the start and end of each input dimensions
)doc");
REGISTER_OP("XlaRngBitGenerator")
.Input("algorithm: int32")
.Input("initial_state: uint64")
.Input("shape: Tshape")
.Output("output_key: uint64")
.Output("output: dtype")
.Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
.Attr("Tshape: {int32, int64} = DT_INT32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle algorithm;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &algorithm));
shape_inference::ShapeHandle initial_state;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &initial_state));
c->set_output(0, initial_state);
shape_inference::ShapeHandle output;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output));
c->set_output(1, output);
return OkStatus();
})
.Doc(R"doc(
Stateless PRNG bit generator.
Wraps the XLA RngBitGenerator operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
algorithm: The PRNG algorithm to use, one of
tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}.
initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should be
a u64[2] and for PHILOX a u64[3].
shape: The output shape of the generated data.
dtype: The type of the tensor.
)doc");
REGISTER_OP("XlaSelectAndScatter")
.Input("operand: T")
.Input("window_dimensions: Tindices")
.Input("window_strides: Tindices")
.Input("padding: Tindices")
.Input("source: T")
.Input("init_value: T")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("select: func")
.Attr("scatter: func")
.Output("output: T")
.SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA SelectAndScatter operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
.
operand: the input tensor
window_dimensions: the shape of the window
window_strides: the inter-window strides
padding: the padding to apply at the start and end of each input dimensions
source: a tensor of values to scatter
init_value: a scalar representing the initial value for the output tensor
select: a selection function to apply
scatter: a scatter function to apply
)doc");
REGISTER_OP("XlaSend")
.Input("tensor: T")
.Attr("T: type")
.Attr("tensor_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Sends the named tensor to another XLA computation. Wraps the XLA Send operator
documented at
https://www.tensorflow.org/performance/xla/operation_semantics#send .
tensor: The tensor to send.
tensor_name: A string key that identifies the channel.
)doc");
REGISTER_OP("XlaSort")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Wraps the XLA Sort operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#sort
.
Sorts a tensor. Currently only sorts in ascending order are supported.
input: A `Tensor` of type T.
output: A `Tensor` of type T.
)doc");
REGISTER_OP("XlaKeyValueSort")
.Input("keys: K")
.Input("values: V")
.Output("sorted_keys: K")
.Output("sorted_values: V")
.Attr("K: realnumbertype")
.Attr("V: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
c->set_output(1, c->input(1));
return OkStatus();
})
.Doc(R"doc(
Wraps the XLA Sort operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#sort
.
Sorts a tensor. Currently only sorts in ascending order are supported.
keys: A `Tensor` of type K.
values: A `Tensor` of type V.
sorted_keys: A `Tensor` of type K.
sorted_values: A `Tensor` of type V.
)doc");
REGISTER_OP("XlaVariadicSort")
.Input("inputs: T")
.Input("dimension: int32")
.Output("outputs: T")
.Attr("T: list(type) >= 1")
.Attr("comparator: func")
.Attr("is_stable: bool")
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<shape_inference::ShapeHandle> input_shapes;
TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
TF_RETURN_IF_ERROR(c->set_output("outputs", input_shapes));
return OkStatus();
})
.Doc(R"doc(
Wraps the XLA Sort operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#sort
.
Sorts one or more tensors, with support for custom comparator, dimension, and
is_stable attributes.
inputs: A list of `Tensor` of identical shape but possibly different types.
dimension: The dimension along which to sort. Must be a compile-time constant.
is_stable: Whether to use stable sort.
comparator: A comparator function to apply to 2*N scalars and returning a
boolean. N is the number of sort inputs. If you want to sort in ascending
order then the comparator should perform a less-than comparison.
outputs: A list of `Tensor` of same shape and types as the `input`.
)doc");
// TODO(b/37549631) setting the While Op to always be stateful is too
// conservative.
REGISTER_OP("XlaWhile")
.Input("input: T")
.Output("output: T")
.Attr("T: list(type) >= 0")
.Attr("cond: func")
.Attr("body: func")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
output = input; While (Cond(output)) { output = Body(output) }
input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A function takes 'input' and returns a tensor. If the tensor is
a scalar of non-boolean, the scalar is converted to a boolean
according to the following rule: if the scalar is a numerical
value, non-zero means True and zero means False; if the scalar is
a string, non-empty means True and empty means False. If the
tensor is not a scalar, non-emptiness means True and False
otherwise.
body: A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified by T.
)doc");
REGISTER_OP("XlaDequantize")
.Input("input: uint32")
.Output("output: bfloat16")
.Attr("min_range: float")
.Attr("max_range: float")
.Attr("mode: string")
.Attr("transpose_output: bool")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Takes the packed uint32 input and unpacks the input to uint8 to do
Dequantization on device.
input: Input tensors whose types is uint32, shape is [d0, ..., dn].
output: Output tensors whose types is bloat16. If transpose_output is true,
output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output
is false, output shape is [d0,..., dn * 4].
min_range: The minimum scalar value possibly produced for the input.
max_range: The maximum scalar value possibly produced for the input.
mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}.
transpose_output: Boolean to determine if output is transposed. transpose_output
is faster when input is large and rank of input is higher than 1.
)doc");
REGISTER_OP("XlaEinsum")
.Input("a: T")
.Input("b: T")
.Output("product: T")
.Attr("equation: string")
.Attr("T: {complex64, bfloat16, float}")
.SetShapeFn([](shape_inference::InferenceContext* context) {
string equation;
TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));
// XlaEinsum supports only two-input einsum equations.
if (!absl::StrContains(equation, ",")) {
return errors::InvalidArgument("Expected one \",\" in equation. Got: ",
equation);
}
// Use EinsumShape for the rest of the inference now that we know we must
// have a two-input einsum.
return shape_inference::EinsumShape(context);
})
.Doc(R"doc(
An op which supports basic einsum op with 2 inputs and 1 output.
This op has better TPU performance since it doesn't have explicitly reshape and
transpose operations as tf.einsum does.
)doc");
REGISTER_OP("XlaSpmdFullToShardShape")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.Attr("manual_sharding: string")
.Attr("dim: int = -1")
.Attr("unspecified_dims: list(int) = []")
.SetShapeFn([](shape_inference::InferenceContext* c) {
auto input_handle = c->input(0);
if (!c->RankKnown(input_handle)) {
return shape_inference::UnknownShape(c);
}
string sharding_attr;
TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr));
int32 single_dim;
TF_RETURN_IF_ERROR(c->GetAttr("dim", &single_dim));
xla::OpSharding sharding;
sharding.ParseFromString(sharding_attr);
if (sharding.type() != xla::OpSharding::OTHER) {
return shape_inference::UnchangedShape(c);
}
std::vector<shape_inference::DimensionHandle> dims;
for (int64_t i = 0; i < c->Rank(input_handle); ++i) {
auto dim = c->Value(c->Dim(input_handle, i));
if (single_dim < 0 || single_dim == i) {
int64_t partitions_i = sharding.tile_assignment_dimensions(i);
if (dim != shape_inference::InferenceContext::kUnknownDim &&
partitions_i != 1) {
dim = (dim + partitions_i - 1) / partitions_i;
}
}
dims.push_back(c->MakeDim(dim));
}
c->set_output(0, c->MakeShape(dims));
return OkStatus();
})
.Doc(R"doc(
An op used by XLA SPMD partitioner to switch from automatic partitioning to
manual partitioning. It annotates the input (full-shape, to be automatically
partitioned) with the same sharding used by manual partitioning, and outputs a
shard-shaped tensor to be consumed by later manually-partitioned ops. If the
shape is not evenly partitionable, the padding region will be masked with 0s.
The conversion can happen partially in subgroups, by specifying the dim
attribute, where only that dim will be converted.
)doc");
REGISTER_OP("XlaSpmdShardToFullShape")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.Attr("manual_sharding: string")
.Attr("full_shape: shape")
.Attr("dim: int = -1")
.Attr("unspecified_dims: list(int) = []")
.SetShapeFn([](shape_inference::InferenceContext* c) {
TensorShape shape_attr;
TF_RETURN_IF_ERROR(c->GetAttr("full_shape", &shape_attr));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
c->set_output(0, s);
return OkStatus();
})
.Doc(R"doc(
An op used by XLA SPMD partitioner to switch from manual partitioning to
automatic partitioning. It converts the shard-shaped, manually partitioned input
into full-shaped tensor to be partitioned automatically with the same sharding
used by manual partitioning. The conversion can happen partially in subgroups,
by specifying the dim attribute, where only that dim will be converted.
)doc");
REGISTER_OP("XlaSharding")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.Attr("sharding: string = ''")
.Attr("unspecified_dims: list(int) = []")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
An op which shards the input based on the given sharding attribute. It can
selectively annotate a subset of tensor dimensions by skipping unspecified_dims,
and the sharding annotation should be replicated in those dims.
)doc");
REGISTER_OP("XlaReplicaId")
.Output("id: int32")
.SetShapeFn([](shape_inference::InferenceContext* context) {
context->set_output(0, context->MakeShape({}));
return OkStatus();
})
.Doc("Replica ID.");
REGISTER_OP("XlaGather")
.Input("operand: T")
.Input("start_indices: Tindices")
.Input("slice_sizes: Tindices")
.Attr("dimension_numbers: string")
.Attr("indices_are_sorted: bool")
.Attr("T: {numbertype, bool}")
.Attr("Tindices: {int32, int64}")
.Output("output: T")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Wraps the XLA Gather operator documented at
https://www.tensorflow.org/xla/operation_semantics#gather
operand: The array we're gathering from.
start_indices: Array containing the starting indices of the slices we gather.
dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
indices_are_sorted: Boolean indicating if the indices are sorted.
)doc");
REGISTER_OP("XlaScatter")
.Input("operand: T")
.Input("scatter_indices: Tindices")
.Input("updates: T")
.Attr("update_computation: func")
.Attr("dimension_numbers: string")
.Attr("indices_are_sorted: bool")
.Attr("T: {numbertype, bool}")
.Attr("Tindices: {int32, int64}")
.Output("output: T")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Wraps the XLA Scatter operator documented at
https://www.tensorflow.org/xla/operation_semantics#scatter.
operand: Array to be scattered into.
scatter_indices: Array containing the starting indices of the slices that must
be scattered to.
updates: Array containing the values that must be used for scattering.
update_computation: Computation to be used for combining the existing values in
the input array and the updates during scatter.
dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
indices_are_sorted: Boolean indicating if the indices are sorted.
)doc");
REGISTER_OP("XlaAllReduce")
.Input("input: T")
.Input("group_assignment: int32")
.Output("output: T")
.Attr("T: {half, bfloat16, float, int32, uint32}")
.Attr("reduce_op: {'Min', 'Max', 'Mul', 'Add', 'Mean'}")
.Attr("mode: {'CrossReplica', 'CrossReplicaAndPartition'}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Wraps the XLA AllReduce operator
documented at https://www.tensorflow.org/xla/operation_semantics#allreduce.
input: Array or a non-empty tuple of arrays to reduce across replicas.
group_assignment: Groups between which the reductions are performed.
reduce_op: Reduction computation.
mode: group mode.
CrossReplica: group_assignment contains replica_id. Each group contains the
replicas for the current partition.
CrossReplicaAndPartition: group_assignment contains replica_id. Each group
contains the replicas for all partitions.
)doc");
REGISTER_OP("XlaReduceScatter")
.Input("input: T")
.Input("group_assignment: int32")
.Input("scatter_dimension: int32")
.Output("output: T")
.Attr("T: {half, bfloat16, float, int32, uint32}")
.Attr("reduce_op: {'Min', 'Max', 'Mul', 'Add', 'Mean'}")
.SetShapeFn(shape_inference::ReduceScatterShape)
.Doc(R"doc(
Wraps the XLA ReduceScatter operator
documented at https://www.tensorflow.org/xla/operation_semantics#reducescatter.
input: Array or a non-empty tuple of arrays to reduce across replicas.
group_assignment: Groups between which the reductions are performed.
scatter_dimension: Dimension to scatter.
reduce_op: Reduction computation.
)doc");
Status OptimizationBarrierShape(shape_inference::InferenceContext* c) {
for (int i = 0; i < c->num_inputs(); ++i) {
c->set_output(i, c->input(i));
}
return OkStatus();
}
REGISTER_OP("XlaOptimizationBarrier")
.Input("input: T")
.Output("output: T")
.Attr("T: list(type) >= 0")
.SetShapeFn(OptimizationBarrierShape)
.Doc(R"doc(
Wraps the XLA OptimizationBarrier operator.
Documented at https://www.tensorflow.org/xla/operation_semantics#optimizationbarrier.
input: A Tuple of Arrays of any type.
)doc");
REGISTER_OP("XlaReducePrecision")
.Input("operand: T")
.Output("output: T")
.Attr("T: {bfloat16, half, float, double}")
.Attr("exponent_bits: int")
.Attr("mantissa_bits: int")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Wraps the XLA ReducePrecision operator
documented at https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
operand: array of floating-point type.
exponent_bits: number of exponent bits in lower-precision format
mantissa_bits: number of mantissa bits in lower-precision format
)doc");
REGISTER_OP("XlaCustomCall")
.Input("args: T")
.Output("output: dtype")
.Attr("target_name: string")
.Attr("backend_config: string")
.Attr("T: list(type) >= 0")
.Attr("dtype: type")
.Attr("shape: shape")
.SetShapeFn([](shape_inference::InferenceContext* c) {
TensorShape shape_attr;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
c->set_output(0, s);
return OkStatus();
})
.Doc(R"doc(
Wraps the XLA CustomCall operator
documented at https://www.tensorflow.org/xla/operation_semantics#customcall.
args: A list of `Tensor` with possibly different types.
target_name: Name of the function. A call instruction will be emitted which
targets this symbol name.
backend_config: String, used to encode serialized metadata to the backend.
dtype: Output tensor data type.
shape: Output tensor shape.
)doc");
REGISTER_OP("XlaCallModule")
.Input("args: Tin")
.Output("output: Tout")
.Attr("module: string")
.Attr("Sout: list(shape) >= 0")
.Attr("Tout: list(type) >= 0")
.Attr("Tin: list(type) >= 0")
.Attr("dim_args_spec: list(string) >= 0")
.SetShapeFn([](shape_inference::InferenceContext* c) {
// For debugging
VLOG(3) << "XlaCallModule.shape_inference";
std::vector<shape_inference::ShapeHandle> args_shapes;
TF_RETURN_IF_ERROR(c->input("args", &args_shapes));
for (int i = 0; i < args_shapes.size(); ++i) {
VLOG(3) << "XlaCallModule.shape_inference args[" << i
<< "] : " << c->DebugString(args_shapes[i]);
}
std::vector<PartialTensorShape> shapes_attr;
TF_RETURN_IF_ERROR(c->GetAttr("Sout", &shapes_attr));
for (int i = 0; i < shapes_attr.size(); ++i) {
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(
c->MakeShapeFromPartialTensorShape(shapes_attr[i], &s));
VLOG(3) << "XlaCallModule.shape_inference out[" << i
<< "] : " << c->DebugString(s);
c->set_output(i, s);
}
return OkStatus();
})
.Doc(R"doc(
Temporary op for experimenting with jax2tf.
DO NOT USE THIS OP. It has no backwards compatibility guarantees. It is also
very likely to change. This op will be used only in jax2tf under an
experimental flag.
This is an experimental op to allow a smooth evolution of jax2tf towards
emitting and serializing MHLO directly from JAX. At the moment this op
carries a serialized MHLO module, therefore there are no backward-compatibility
guarantees, and should not be used for serialization.
Eventually, the op will carry a MHLO object, which will have
backwards-compatibility guarantees.
The serialized module must return a tuple if and only if the Sout is an empty
list or a list with more than 1 elements. The length of Tout and Sout must
match. This op always returns a tuple of results, even if the module returns
a single result.
The handling of dynamic shapes is work-in-progress. At the moment, the
JAX lowering for dynamic shapes will prepend one dimension parameter to the
serialized module for each dimension whose value must be passed in.
The "args" correspond to the non-dimension arguments. During compilation
we compute the values of the dimension arguments based on the static shapes of
the "args". In order to do this, we encode for each dimension argument a
specification of how to compute its value, as a string, in the form
"<arg_idx>.<axis_idx>".
E.g., the specification "2.1" denotes the value args[2].shape[1].
args: A list of `Tensor` with possibly different types to be passed as arguments
to the HLO module.
module: A serialized computation, a text representation of mlir.Module.
Tout: List of output tensor data types.
Sout: List of output tensor shapes.
dim_args_spec: the specification for the dimension arguments, one for each
dimension argument. In absence of dynamic shapes this list is empty.
)doc");
} // namespace
} // namespace tensorflow