blob: a948c324c0f114e848f655ece9e898227a4c725c [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef HLO_OPS_BASE
#define HLO_OPS_BASE
include "mlir/IR/OpBase.td"
def HLO_Dialect : Dialect {
let name = "mhlo";
let cppNamespace = "::mlir::mhlo";
}
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td"
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
// TODO(hinsu): Use signed integers instead of signless integer which is being
// used for legacy reasons.
def HLO_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
def HLO_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;
def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
// The broadcasting dimensions correspond to a tuple that describes how a
// smaller rank shape is broadcast into a larger rank shape. For example,
// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
// matching the matrix to dimensions 1 and 2 of the cuboid.
defvar BroadcastDimAttr = I64ElementsAttr;
//===----------------------------------------------------------------------===//
// MHLO on tensors type definitions.
//===----------------------------------------------------------------------===//
// Token type.
def HLO_Token : Type<CPred<"$_self.isa<TokenType>()">, "token">;
// Any integer tensor types
def HLO_IntTensor : TensorOf<[HLO_Int]>;
// Any integer tensor type with rank 0 (i.e. representing a single integer).
def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>;
// Any floating-point tensor types
def HLO_FpTensor : TensorOf<[AnyFloat]>;
def HLO_PredTensor : TensorOf<[HLO_Pred]>;
def HLO_Tensor : TensorOf<[AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
def HLO_ComplexTensor : TensorOf<[HLO_Complex]>;
def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>;
def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>;
// Dynamic representation of a shape vector as a tensor.
def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>;
// In general, static shaped tensor constraints should be avoided unless
// it is for a legacy op which is only correct with static shapes.
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
//===----------------------------------------------------------------------===//
// MHLO on tensors combined type definitions.
//===----------------------------------------------------------------------===//
// Any integer or floating-point tensor types
def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>;
// Any integer or predicate tensor types
def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
// Any floating-point or complex tensor types
def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, HLO_Complex]>;
// Any int, floating-point or complex tensor types
def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
// Any pred, int or floating-point tensor types
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
//===----------------------------------------------------------------------===//
// MHLO nullary op definitions.
//===----------------------------------------------------------------------===//
class BASE_HLO_ConstOp {
string summary = "Constant operator";
string description = [{
Represents a constant value.
}];
}
class BASE_HLO_IotaOp {
string summary = "Iota operator";
string description = [{
Creates a rank 1 array of values starting at zero and incrementing by one.
}];
}
//===----------------------------------------------------------------------===//
// MHLO unary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class BASE_HLO_AbsOp {
string summary = "Absolute value operator";
string description = [{
Returns `abs(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_CbrtOp {
string summary = "Cubic root operator";
string description = [{
Returns element-wise cubic root of the operand.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_CeilOp {
string summary = "Ceil operator";
string description = [{
Returns `Ceil(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_ClzOp {
string summary = "Count-leading-zeros (Clz) operator";
string description = [{
Returns the number of leading zeros in each operand element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_ConvertOp {
string summary = "Convert operator";
string description = [{
Performs element-wise conversion of values from one type to another, e.g.
float to int.
See https://www.tensorflow.org/xla/operation_semantics#convertelementtype.
}];
}
class BASE_HLO_CosOp {
string summary = "Cos operator";
string description = [{
Returns `Cos(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_ExpOp {
string summary = "Exponential operator";
string description = [{
Returns `e^(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_Expm1Op {
string summary = "Exponential minus one operator";
string description = [{
Returns `e^(operand) - 1` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_FloorOp {
string summary = "Floor operator";
string description = [{
Returns `Floor(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_GetDimensionSizeOp {
string summary = "GetDimensionSize operator";
string description = [{
Returns the size of the given dimension of the operand.
See
https://www.tensorflow.org/xla/operation_semantics#getdimensionsize.
}];
}
class BASE_HLO_ImagOp {
string summary = "Imag operator";
string description = [{
Returns `Imag(operand)` element-wise.
}];
}
class BASE_HLO_IsFiniteOp {
string summary = "IsFinite operator";
string description = [{
Tests whether each element of operand is finite, i.e., is not positive or
negative infinity, and is not NaN. Returns a tensor of 1-bit integers with
the same shape as the input, where each element is nonzero (i.e. true) if
and only if the corresponding input element is finite.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_LogOp {
string summary = "Logarithm operator";
string description = [{
Returns `log(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_Log1pOp {
string summary = "Log1p operator";
string description = [{
Returns `log(operand+1)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_LogisticOp {
string summary = "Logistic operator";
string description = [{
Returns `logistic(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_NegOp {
string summary = "Negation operator";
string description = [{
Returns `-operand` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_NotOp {
string summary = "Not operator";
string description = [{
Returns `!operand` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_PopulationCountOp {
string summary = "PopulationCount operator";
string description = [{
Returns the number of bits set in each operand element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_RealOp {
string summary = "Real operator";
string description = [{
Returns `Real(operand)` element-wise.
}];
}
class BASE_HLO_RngBitGeneratorOp {
string summary = "Uniform random number generator operator";
string description = [{
Returns an output with a given shape filled with uniform random bits using
the specified algorithm (or backend default) and returns an updated state
(with the same shape as initial state) and the generated random data.
See
https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator.
}];
}
class BASE_HLO_RoundOp {
string summary = "Round operator";
string description = [{
Returns `Round(operand)` element-wise, rounding to nearest integer with
half-way cases rounding away from zero.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_RsqrtOp {
string summary = "Reciprocal Square-root operator";
string description = [{
Returns `1.0 / sqrt(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_SignOp {
string summary = "Sign operator";
string description = [{
Returns `sign(operand)` element-wise, where
```
sign(x) = -1 : x < 0
= -0 : x = -0
= NaN : x = NaN
= +0 : x = +0
= 1 : x > 0
```
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_SinOp {
string summary = "Sin operator";
string description = [{
Returns `Sin(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_SqrtOp {
string summary = "Square-root operator";
string description = [{
Returns `sqrt(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_TanhOp {
string summary = "Tanh operator";
string description = [{
Returns `tanh(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
//===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions.
//===----------------------------------------------------------------------===//
class BASE_HLO_AddOp {
string summary = "Addition operator";
string description = [{
Returns `lhs + rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_ComplexOp {
string summary = "Complex operator";
string description = [{
Performs element-wise conversion of a pair of real and imaginary values to
a complex value.
}];
}
class BASE_HLO_DivOp {
string summary = "Division operator";
string description = [{
Returns `lhs / rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_MaxOp {
string summary = "Maximum operator";
string description = [{
Returns `max(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_MinOp {
string summary = "Minimum operator";
string description = [{
Returns `min(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_MulOp {
string summary = "Multiplication operator";
string description = [{
Returns `lhs * rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_PowOp {
string summary = "Power operator";
string description = [{
Returns `lhs ^ rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_RemOp {
string summary = "Remainder operator";
string description = [{
Returns `lhs % rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_SubOp {
string summary = "Subtraction operator";
string description = [{
Returns `lhs - rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_ShiftLeftOp {
string summary = "Shift Left operator";
string description = [{
Returns `lhs << rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_ShiftRightArithmeticOp {
string summary = "Shift right arithmetic operator";
string description = [{
Returns arithmetic `lhs >> rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_ShiftRightLogicalOp {
string summary = "Shift right logical operator";
string description = [{
Returns logical `lhs >> rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_Atan2Op {
string summary = "Atan2 operator";
string description = [{
Returns `atan2(lhs/rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_AndOp {
string summary = "Logical and";
string description = [{
Returns `logical_and(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_OrOp {
string summary = "Logical or";
string description = [{
Returns `logical_or(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_XorOp {
string summary = "Logical xor";
string description = [{
Returns `logical_xor(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
//===----------------------------------------------------------------------===//
// XLA control flow related op definitions.
//===----------------------------------------------------------------------===//
class BASE_HLO_CaseOp {
string summary = "Switch-Case operator";
string description = [{
Returns the result of executing `branches[index]`. If
`index` is < 0 or >= N, then `branches[N-1] is executed as
the default branch.
Each branch `branches[b]` must take in a single argument of same type as
`branch_operands[b]` and will be invoked with `branch_operands[b]`. The type
of the returned value of each branch must be the same.
Note that only one of the branches will be executed depending on the value
of index.
See https://www.tensorflow.org/xla/operation_semantics#conditional.
}];
}
//===----------------------------------------------------------------------===//
// XLA parallelism related op definitions.
//===----------------------------------------------------------------------===//
class BASE_HLO_ReplicaIdOp {
string summary = "ReplicaId operator";
string description = [{
Returns the unique ID (int32 scalar) of the replica.
The unique ID of each replica is an unsigned integer in the interval [0, N),
where N is the number of replicas. Since all the replicas are running the
same program, a ReplicaId() call in the program will return a different
value on each replica.
See https://www.tensorflow.org/xla/operation_semantics#replicaid.
}];
}
class BASE_HLO_PartitionIdOp {
string summary = "PartitionId operator";
string description = [{
Returns the unique ID (int32 scalar) of the partition.
}];
}
class BASE_HLO_AllReduceOp {
string summary = "AllReduce operator";
string description = [{
Performs a custom reduction across replicas.
See https://www.tensorflow.org/xla/operation_semantics#allreduce.
}];
}
class BASE_HLO_ReduceOp {
string summary = "Reduce operator";
string description = [{
Returns the result of executing a reduction function on one or more arrays
in parallel.
See https://www.tensorflow.org/xla/operation_semantics#reduce.
}];
}
class BASE_HLO_ReduceWindowOp {
string summary = "ReduceWindow operator";
string description = [{
Returns the result of executing a reduction function over all elements in
each window of one or more arrays in parallel.
See https://www.tensorflow.org/xla/operation_semantics#reducewindow.
}];
}
//===----------------------------------------------------------------------===//
// XLA tuple op definitions.
//===----------------------------------------------------------------------===//
class BASE_HLO_GetTupleElementOp {
string summary = "GetTupleElement operator";
string description = [{
Returns a member of a tuple specified by an index.
See https://www.tensorflow.org/xla/operation_semantics#gettupleelement.
}];
}
class BASE_HLO_TupleOp {
string summary = "XLA's tuple op";
string description = [{
Groups a set of tensor inputs into a single tuple object.
See https://www.tensorflow.org/xla/operation_semantics#tuple.
}];
}
class BASE_HLO_CompareOp {
string summary = "Comparison operator";
string description = [{
Compares `lhs` and `rhs` elementwise according to `comparison_direction`
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
types, SIGNED for signed element types and UNSIGNED for unsigned element
types.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
}];
}
//===----------------------------------------------------------------------===//
// Quantize op definitions.
//===----------------------------------------------------------------------===//
class BASE_HLO_DequantizeOp {
string summary = "Dequantize operator";
string description = [{
Dequantize the quantized input of packed uint32 to bfloat16. Only uint8 or
uint16 is supported for the original unpacked input.
Returns a tensor of shape [d0,..., dn * unpack_size] if unpacked input shape
is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T), where T is
the unpacked input type. If transpose_output is true, will return a tensor
of shape [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster
when input's rank higher than 1. The input needs to be transposed to use
transpose_output feature.
}];
}
//===----------------------------------------------------------------------===//
// XLA Slice definitions.
//===----------------------------------------------------------------------===//
class BASE_HLO_SliceOp {
string summary = "Slice operator";
string description = [{
Slices a portion of the `operand` into a new configuration.
See https://www.tensorflow.org/xla/operation_semantics#slice.
}];
}
class BASE_HLO_DynamicSliceOp {
string summary = "Dynamic Slice operator";
string description = [{
Extracts a sub-array from the input array at dynamic start_indices.
See https://www.tensorflow.org/xla/operation_semantics#dynamicslice.
}];
}
class BASE_HLO_DynamicUpdateSliceOp {
string summary = "Dynamic Update Slice operator";
string description = [{
DynamicUpdateSlice generates a result which is the value of the input array
operand, with a slice update overwritten at start_indices.
See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice.
}];
}
//===----------------------------------------------------------------------===//
// XLA Other op definitions.
//===----------------------------------------------------------------------===//
class BASE_HLO_AllToAllOp {
string summary = "AllToAll";
string description = [{
AllToAll is a collective operation that sends data from all cores to all
cores. It has two phases:
- The scatter phase. On each core, the operand is split into `split_count`
number of blocks along the `split_dimension`, and the blocks are
scattered to all cores, e.g., the i-th block is sent to the i-th core.
- The gather phase. Each core concatenates the received blocks along the
`concat_dimension`.
The participating cores can be configured by:
- replica_groups: each ReplicaGroup contains a list of replica id
participating in the computation (replica id for the current replica can
be retrieved using ReplicaId op). AllToAll will be applied within
subgroups in the specified order. For example,
`replica_groups` = {{1,2,3}, {4,5,0}} means that an AllToAll will be applied
within replicas {1, 2, 3}, and in the gather phase, the received blocks
will be concatenated in the same order of 1, 2, 3. Then, another AllToAll
will be applied within replicas 4, 5, 0, and the concatenation order is
also 4, 5, 0. If `replica_groups` is empty, all replicas belong to one
group, and the concatenation order is the numerical order (0, 1, 2, ...).
Prerequisites:
- The dimension size of the operand on the split_dimension is divisible by
`split_count`.
- The operand's shape is not tuple.
See https://www.tensorflow.org/xla/operation_semantics#alltoall
}];
}
class BASE_HLO_BatchNormGradOp {
string summary = "Batch Normalization Gradient";
string description = [{
Calculates gradients of batch norm.
See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad
}];
}
class BASE_HLO_BatchNormInferenceOp {
string summary = "Batch Normalization for Inference";
string description = [{
Normalizes an array across batch and spatial dimensions.
See https://www.tensorflow.org/xla/operation_semantics#batchnorminference
}];
}
class BASE_HLO_BatchNormTrainingOp {
string summary = "Batch Normalization for Training";
string description = [{
Normalizes an array across batch and spatial dimensions.
See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining
}];
}
class BASE_HLO_BitcastConvertOp {
string summary = "BitcastConvert operator";
string description = [{
Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast
operation from a data shape to a target shape. The dimensions must match,
and the conversion is an element-wise one. Bitcast is implemented as a
low-level cast, so machines with different floating-point representations
will give different results.
See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype.
}];
}
class BASE_HLO_BroadcastOp {
string summary = "Broadcast a tensor to a higher rank by prepending dimensions";
string description = [{
Broadcasts the operand tensor to a higher rank by prepending
`broadcast_sizes` to the dimensions. The current values of the operand are
copied into the other dimensions.
This is a more limited form of broadcasting, that corresponds to the XLA
client Broadcast method. For a more general form of broadcasting, see the
BroadcastInDimOp.
See https://www.tensorflow.org/xla/operation_semantics#broadcast.
}];
}
class BASE_HLO_BroadcastInDimOp {
string summary = "Broadcast a tensor into the given shape by adding dimensions.";
string description = [{
Broadcasts the `operand` tensor to a higher rank. This is not the limited
form of broadcasting exposed as the XLA client broadcast op, but rather the
more powerful "InDim" broadcasting, which is closer to the HLO broadcast op
and exposed in the XLA client BroadcastInDim method.
`broadcast_dimensions` maps the operand dimension number to the target shape
dimension number. It must have the same size as the rank of the operand. The
mapped dimensions must either be the same size or the dimension being
broadcast from must be size 1 (degenerate broadcasting).
For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The
The scalar value will be broadcast to every element in the target shape.
See https://www.tensorflow.org/xla/broadcasting.
}];
}
class BASE_HLO_CholeskyOp {
string summary = "Cholesky operator";
string description = [{
Computes the Cholesky decomposition of a batch of symmetric (Hermitian)
positive definite matrices.
If lower is true, computes lower-triangular matrices l such that
`a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such
that `a=Transpose(u).u`.
Input data is read only from the lower/upper triangle of a, depending on the
value of lower. Values from the other triangle are ignored. Output data is
returned in the same triangle; the values in the other triangle are
implementation-defined and may be anything.
If the rank of a is greater than 2, a is treated as a batch of matrices, where
all except the minor 2 dimensions are batch dimensions.
If a is not symmetric (Hermitian) positive definite, the result is
implementation-defined.
See https://www.tensorflow.org/xla/operation_semantics#cholesky.
}];
}
class BASE_HLO_ClampOp {
string summary = "Clamp operator";
string description = [{
Clamps an operand to within the range between a minimum and maximum value.
Note: All three arrays must be the same shape. Alternatively, as a
restricted form of broadcasting, min and/or max can be a scalar (0D
tensor) of the element type of the tensor operand.
See https://www.tensorflow.org/xla/operation_semantics#clamp.
}];
}
class BASE_HLO_CollectivePermuteOp {
string summary = "CollectivePermute operator";
string description = [{
CollectivePermute is a collective operation that sends and receives data
cross replicas.
Note that there are the following restrictions on the source_target_pair:
- Any two pairs should not have the same target replica id, and they should
not have the same source replica id.
- If a replica id is not a target in any pair, then the output on that
replica is a tensor consists of 0(s) with the same shape as the input.
See https://www.tensorflow.org/xla/operation_semantics#collectivepermute.
}];
}
class BASE_HLO_ConcatenateOp {
string summary = "XLA's concatenate op";
string description = [{
Concatenates a set of tensors along the specified dimension.
See https://www.tensorflow.org/xla/operation_semantics#concatenate.
}];
}
//===----------------------------------------------------------------------===//
// Common convolution attributes
//===----------------------------------------------------------------------===//
// TODO(b/129153247) See if it's possible to also validate the size.
def HLO_PrecisionConfigAttr:
OptionalAttr<
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
def BoolElementsAttr :
ElementsAttrBase<
And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
"constant boolean vector/tensor attribute"> {
let storageType = [{ ::mlir::DenseElementsAttr }];
let returnType = [{ ::mlir::DenseElementsAttr }];
let convertFromStorage = "$_self";
}
def ConvolutionAttributes {
dag attributes = (ins
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<BoolElementsAttr>:$window_reversal,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
}
class BASE_HLO_ConvOp {
string summary = "Convolution operator";
string description = [{
Computes a convolution of the kind used in neural networks.
See https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
}];
code extraClassDeclaration = [{
bool hasWindowReversal() {
auto reversal = window_reversalAttr();
return reversal && llvm::any_of(reversal.getBoolValues(),
[](bool v) { return v; });
}
}];
}
class BASE_HLO_CopyOp {
string summary = "Copy operator";
string description = [{
Returns a copy of `operand`.
}];
}
class BASE_HLO_CrossReplicaSumOp {
string summary = "Sums input across replicated instances.";
string description = [{
For each of the replica groups, operands of the group devices are summed
so that each device has the sum.
For example, suppose there are 8 TPU devices: `[A, B, C, D, E, F, G, H]`.
Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0,
and `B, D, F, H` as group 1. Thus we get the outputs:
`[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`.
See https://www.tensorflow.org/xla/operation_semantics#crossreplicasum.
}];
}
class BASE_HLO_CustomCallOp {
string summary = "CustomCall operator";
string description = [{
A custom call invokes code external to XLA. The `args` are passed to the
external code, and the external code is expected to produce a result of the
given type. The exact mechanism is backend-specific. For example, in the CPU
backend, a call instruction is emitted which targets a symbol with the name
`call_target_name`.
`call_target_name` and `backend_config` can be arbitrary strings, but
`call_target_name` should be short as it may be used in labels.
`backend_config` can encode arbitrarily large amounts of information.
See https://www.tensorflow.org/xla/operation_semantics#customcall.
}];
}
class BASE_HLO_DotOp {
string summary = "Dot operator";
string description = [{
Performs dot products between vectors, vector/matrix and matrix/matrix
multiplication.
See https://www.tensorflow.org/xla/operation_semantics#dot.
}];
}
class BASE_HLO_DotGeneralOp {
string summary = "General Dot operator";
string description = [{
Performs general dot products between vectors, vector/matrix and
matrix/matrix multiplication.
See https://www.tensorflow.org/xla/operation_semantics#dotgeneral.
}];
}
class BASE_HLO_FftOp {
string summary = "Fast fourier transform operator";
string description = [{
Returns the fast-fourier-transform of the input array.
See
https://www.tensorflow.org/xla/operation_semantics#fft.
}];
}
class BASE_HLO_GatherOp{
string summary = "Gather operator";
string description = [{
Stitches together several slices of an input array.
See https://www.tensorflow.org/xla/operation_semantics#gather.
}];
}
class BASE_HLO_MapOp {
string summary = "Map operator";
string description = [{
Applies a scalar function over the given operands arrays, producing an array
of the same dimensions where each element is the result of the mapped function
applied to the corresponding elements in the input arrays.
The mapped function is an arbitrary computation with the restriction that it
has N inputs of scalar type T and a single output with type S. The output has
the same dimensions as the operands except that the element type T is replaced
with S.
See https://www.tensorflow.org/xla/operation_semantics#map.
}];
}
class BASE_HLO_ReshapeOp {
string summary = "Reshape operator";
string description = [{
Reshapes the dimensions of `operand` into a new configuration.
See https://www.tensorflow.org/xla/operation_semantics#reshape.
}];
}
class BASE_HLO_ScatterOp {
string summary = "Scatter operator";
string description = [{
Generates a result which is the value of the input array `operand`,
with several slices (at indices specified by `scatter_indices`)
updated with the values in `updates` using `update_computation`.
See https://www.tensorflow.org/xla/operation_semantics#scatter.
}];
}
class BASE_HLO_SelectOp {
string summary = "Select operator";
string description = [{
Constructs an output tensor from the elements of `on_true` and `on_false`
based on the values of `pred`.
`pred`, `on_true` and `on_false` must be broadcast compatible.
}];
}
class BASE_HLO_SelectAndScatterOp {
string summary = "SelectAndScatter operator";
string description = [{
Runs a windowed selection `select` function over `operand` with shape
`window_dimensions` and stride `window_strides`. This will produce an amount
of selected locations whose shape matches `source`. These are then scattered
to the output which is initialized with `init_value`.
Multiple scattered elements which land in the same output location are
combined using the `scatter` function.
See https://www.tensorflow.org/xla/operation_semantics#selectandscatter.
}];
}
class BASE_HLO_SetDimensionSizeOp {
string summary = "SetDimensionSize operator";
string description = [{
Sets the dynamic size of operand's given dimension. Pass through the operand
as result, with dynamic dimension tracked by the compiler. Padded values
will be ignored by downstream reduction ops.
See https://www.tensorflow.org/xla/operation_semantics#setdimensionsize.
}];
}
class BASE_HLO_SortOp {
string summary = "Sort operator";
string description = [{
Sorts the given `operands` at the given `dimension` with the given
`comparator`.
See https://www.tensorflow.org/xla/operation_semantics#sort.
}];
}
class BASE_HLO_ReverseOp {
string summary = "Reverse operator";
string description = [{
Reverses the specified dimensions of `operand` according to the given
`dimensions`.
See https://www.tensorflow.org/xla/operation_semantics#rev_reverse.
}];
}
class BASE_HLO_PadOp {
string summary = "Pad operator";
string description = [{
Pads the edges of `operand` with the `padding_value` and according to
the passed configuration.
See https://www.tensorflow.org/xla/operation_semantics#pad.
}];
}
class BASE_HLO_TraceOp {
string summary = "Trace operator";
string description = [{
Emits a logging message `tag` with the `operand`.
}];
}
class BASE_HLO_TransposeOp {
string summary = "Transpose operator";
string description = [{
Permutes the dimensions of `operand` according to the given `permutation`.
`res_dimensions[i] = operand_dimensions[permutation[i]]`
See https://www.tensorflow.org/xla/operation_semantics#transpose.
}];
}
class BASE_HLO_TriangularSolveOp {
string summary = "TriangularSolve operator";
string description = [{
Solves systems of linear equations with lower or upper triangular
coefficient matrices by forward- or back-substitution. Broadcasting along
leading dimensions, this routine solves one of the matrix systems
op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where
op(a) is either op(a) = a, or op(a) = Transpose(a), or
op(a) = Conj(Transpose(a)).
Input data is read only from the lower/upper triangle of a, depending on the
value of lower. Values from the other triangle are ignored. Output data is
returned in the same triangle; the values in the other triangle are
implementation-defined and may be anything.
If the rank of a and b are greater than 2, they are treated as batches of
matrices, where all except the minor 2 dimensions are batch dimensions. a
and b must have equal batch dimensions.
See https://www.tensorflow.org/xla/operation_semantics#triangularsolve.
}];
}
class BASE_HLO_RngUniformOp {
string summary = "RNG with uniform distribution.";
string description = [{
Constructs an output of a given shape with random numbers generated
following the uniform distribution over the interval `[a,b)`. The parameters
and output element type have to be a boolean type, an integral type or a
floating point types, and the types have to be consistent.
See https://www.tensorflow.org/xla/operation_semantics#rnguniform.
}];
}
class BASE_HLO_RngNormalOp {
string summary = "RNG with normal distribution.";
string description = [{
Constructs an output of a given shape with random numbers generated
following the normal distribution with parameters `mu` and `sigma`. The
parameters and output shape have to have a floating point elemental type.
The parameters furthermore have to be scalar valued.
See https://www.tensorflow.org/xla/operation_semantics#rngnormal.
}];
}
class BASE_HLO_ReducePrecisionOp {
string summary = "Reduce precision operator";
string description = [{
Models the effect of converting floating - point values to a lower -
precision format(such as IEEE - FP16) and back to the original
format. The number of exponent and mantissa bits in the lower -
precision format can be specified arbitrarily,
although all bit sizes may not be supported on all hardware
implementations.
See https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
}];
}
class BASE_HLO_InfeedOp {
string summary = "Infeed operator";
string description = [{
Reads a single data item from the implicit Infeed streaming interface of
the device, interpreting the data as the given shape and its layout, and
returns an LHLO op of the data. Multiple Infeed operations are allowed in a
computation, but there must be a total order among the Infeed operations.
For example, two Infeeds in the code below have a total order since there
is a dependency between the while loops.
See https://www.tensorflow.org/xla/operation_semantics#infeed
}];
}
class BASE_HLO_WhileOp {
string summary = "While operator";
string description = [{
Returns the result of executing a body function until the cond body returns
true.
See https://www.tensorflow.org/xla/operation_semantics#while.
}];
}
class BASE_HLO_BitcastOp {
string summary = "Bitcast operator";
string description = [{
This op changes the shape of the input in the way that the physical
arrangement of elements are unchanged.
However, the op needs layout information to make sense of "physical
arrangement of elements". Layout support in MHLO is currently under
exploration.
}];
}
#endif // HLO_OPS_BASE