| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef HLO_OPS_BASE |
| #define HLO_OPS_BASE |
| |
| include "mlir/IR/OpBase.td" |
| |
| def HLO_Dialect : Dialect { |
| let name = "mhlo"; |
| let cppNamespace = "::mlir::mhlo"; |
| } |
| |
| include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td" |
| include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td" |
| |
| def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">; |
| |
| // TODO(hinsu): Use signed integers instead of signless integer which is being |
| // used for legacy reasons. |
| def HLO_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>; |
| def HLO_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>; |
| def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>; |
| |
| def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>; |
| |
| // The broadcasting dimensions correspond to a tuple that describes how a |
| // smaller rank shape is broadcast into a larger rank shape. For example, |
| // given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means |
| // matching the matrix to dimensions 1 and 2 of the cuboid. |
| defvar BroadcastDimAttr = I64ElementsAttr; |
| |
| //===----------------------------------------------------------------------===// |
| // MHLO on tensors type definitions. |
| //===----------------------------------------------------------------------===// |
| |
| // Token type. |
| def HLO_Token : Type<CPred<"$_self.isa<TokenType>()">, "token">; |
| |
| // Any integer tensor types |
| def HLO_IntTensor : TensorOf<[HLO_Int]>; |
| |
| // Any integer tensor type with rank 0 (i.e. representing a single integer). |
| def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>; |
| |
| // Any floating-point tensor types |
| def HLO_FpTensor : TensorOf<[AnyFloat]>; |
| |
| def HLO_PredTensor : TensorOf<[HLO_Pred]>; |
| |
| def HLO_Tensor : TensorOf<[AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; |
| |
| def HLO_ComplexTensor : TensorOf<[HLO_Complex]>; |
| |
| def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; |
| |
| def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; |
| |
| def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>; |
| |
| def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>; |
| |
| // Dynamic representation of a shape vector as a tensor. |
| def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>; |
| |
| // In general, static shaped tensor constraints should be avoided unless |
| // it is for a legacy op which is only correct with static shapes. |
| def HLO_StaticShapeTensor : StaticShapeTensorOf<[ |
| AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; |
| |
| //===----------------------------------------------------------------------===// |
| // MHLO on tensors combined type definitions. |
| //===----------------------------------------------------------------------===// |
| |
| // Any integer or floating-point tensor types |
| def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>; |
| |
| // Any integer or predicate tensor types |
| def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>; |
| |
| // Any floating-point or complex tensor types |
| def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, HLO_Complex]>; |
| |
| // Any int, floating-point or complex tensor types |
| def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>; |
| |
| // Any pred, int or floating-point tensor types |
| def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; |
| |
| //===----------------------------------------------------------------------===// |
| // MHLO nullary op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| class BASE_HLO_ConstOp { |
| string summary = "Constant operator"; |
| |
| string description = [{ |
| Represents a constant value. |
| }]; |
| } |
| |
| class BASE_HLO_IotaOp { |
| string summary = "Iota operator"; |
| |
| string description = [{ |
| Creates a rank 1 array of values starting at zero and incrementing by one. |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MHLO unary elementwise op definitions. |
| //===----------------------------------------------------------------------===// |
| // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions |
| |
| class BASE_HLO_AbsOp { |
| string summary = "Absolute value operator"; |
| |
| string description = [{ |
| Returns `abs(operand)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_CbrtOp { |
| string summary = "Cubic root operator"; |
| |
| string description = [{ |
| Returns element-wise cubic root of the operand. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_CeilOp { |
| string summary = "Ceil operator"; |
| |
| string description = [{ |
| Returns `Ceil(operand)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_ClzOp { |
| string summary = "Count-leading-zeros (Clz) operator"; |
| |
| string description = [{ |
| Returns the number of leading zeros in each operand element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_ConvertOp { |
| string summary = "Convert operator"; |
| |
| string description = [{ |
| Performs element-wise conversion of values from one type to another, e.g. |
| float to int. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#convertelementtype. |
| }]; |
| } |
| |
| class BASE_HLO_CosOp { |
| string summary = "Cos operator"; |
| |
| string description = [{ |
| Returns `Cos(operand)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_ExpOp { |
| string summary = "Exponential operator"; |
| |
| string description = [{ |
| Returns `e^(operand)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_Expm1Op { |
| string summary = "Exponential minus one operator"; |
| |
| string description = [{ |
| Returns `e^(operand) - 1` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_FloorOp { |
| string summary = "Floor operator"; |
| |
| string description = [{ |
| Returns `Floor(operand)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_GetDimensionSizeOp { |
| string summary = "GetDimensionSize operator"; |
| |
| string description = [{ |
| Returns the size of the given dimension of the operand. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#getdimensionsize. |
| }]; |
| } |
| |
| class BASE_HLO_ImagOp { |
| string summary = "Imag operator"; |
| |
| string description = [{ |
| Returns `Imag(operand)` element-wise. |
| }]; |
| } |
| |
| class BASE_HLO_IsFiniteOp { |
| string summary = "IsFinite operator"; |
| |
| string description = [{ |
| Tests whether each element of operand is finite, i.e., is not positive or |
| negative infinity, and is not NaN. Returns a tensor of 1-bit integers with |
| the same shape as the input, where each element is nonzero (i.e. true) if |
| and only if the corresponding input element is finite. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_LogOp { |
| string summary = "Logarithm operator"; |
| |
| string description = [{ |
| Returns `log(operand)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_Log1pOp { |
| string summary = "Log1p operator"; |
| |
| string description = [{ |
| Returns `log(operand+1)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_LogisticOp { |
| string summary = "Logistic operator"; |
| |
| string description = [{ |
| Returns `logistic(operand)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_NegOp { |
| string summary = "Negation operator"; |
| |
| string description = [{ |
| Returns `-operand` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_NotOp { |
| string summary = "Not operator"; |
| |
| string description = [{ |
| Returns `!operand` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_PopulationCountOp { |
| string summary = "PopulationCount operator"; |
| |
| string description = [{ |
| Returns the number of bits set in each operand element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_RealOp { |
| string summary = "Real operator"; |
| |
| string description = [{ |
| Returns `Real(operand)` element-wise. |
| }]; |
| } |
| |
| class BASE_HLO_RngBitGeneratorOp { |
| string summary = "Uniform random number generator operator"; |
| |
| string description = [{ |
| Returns an output with a given shape filled with uniform random bits using |
| the specified algorithm (or backend default) and returns an updated state |
| (with the same shape as initial state) and the generated random data. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator. |
| }]; |
| } |
| |
| class BASE_HLO_RoundOp { |
| string summary = "Round operator"; |
| |
| string description = [{ |
| Returns `Round(operand)` element-wise, rounding to nearest integer with |
| half-way cases rounding away from zero. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_RsqrtOp { |
| string summary = "Reciprocal Square-root operator"; |
| |
| string description = [{ |
| Returns `1.0 / sqrt(operand)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_SignOp { |
| string summary = "Sign operator"; |
| |
| string description = [{ |
| Returns `sign(operand)` element-wise, where |
| |
| ``` |
| sign(x) = -1 : x < 0 |
| = -0 : x = -0 |
| = NaN : x = NaN |
| = +0 : x = +0 |
| = 1 : x > 0 |
| ``` |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_SinOp { |
| string summary = "Sin operator"; |
| |
| string description = [{ |
| Returns `Sin(operand)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_SqrtOp { |
| string summary = "Square-root operator"; |
| |
| string description = [{ |
| Returns `sqrt(operand)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| class BASE_HLO_TanhOp { |
| string summary = "Tanh operator"; |
| |
| string description = [{ |
| Returns `tanh(operand)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XLA binary elementwise op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| class BASE_HLO_AddOp { |
| string summary = "Addition operator"; |
| |
| string description = [{ |
| Returns `lhs + rhs` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_ComplexOp { |
| string summary = "Complex operator"; |
| |
| string description = [{ |
| Performs element-wise conversion of a pair of real and imaginary values to |
| a complex value. |
| }]; |
| } |
| |
| class BASE_HLO_DivOp { |
| string summary = "Division operator"; |
| |
| string description = [{ |
| Returns `lhs / rhs` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_MaxOp { |
| string summary = "Maximum operator"; |
| |
| string description = [{ |
| Returns `max(lhs, rhs)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_MinOp { |
| string summary = "Minimum operator"; |
| |
| string description = [{ |
| Returns `min(lhs, rhs)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_MulOp { |
| string summary = "Multiplication operator"; |
| |
| string description = [{ |
| Returns `lhs * rhs` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| class BASE_HLO_PowOp { |
| string summary = "Power operator"; |
| |
| string description = [{ |
| Returns `lhs ^ rhs` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_RemOp { |
| string summary = "Remainder operator"; |
| |
| string description = [{ |
| Returns `lhs % rhs` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_SubOp { |
| string summary = "Subtraction operator"; |
| |
| string description = [{ |
| Returns `lhs - rhs` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_ShiftLeftOp { |
| string summary = "Shift Left operator"; |
| |
| string description = [{ |
| Returns `lhs << rhs` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_ShiftRightArithmeticOp { |
| string summary = "Shift right arithmetic operator"; |
| |
| string description = [{ |
| Returns arithmetic `lhs >> rhs` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_ShiftRightLogicalOp { |
| string summary = "Shift right logical operator"; |
| |
| string description = [{ |
| Returns logical `lhs >> rhs` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_Atan2Op { |
| string summary = "Atan2 operator"; |
| |
| string description = [{ |
| Returns `atan2(lhs/rhs)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_AndOp { |
| string summary = "Logical and"; |
| |
| string description = [{ |
| Returns `logical_and(lhs, rhs)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_OrOp { |
| string summary = "Logical or"; |
| |
| string description = [{ |
| Returns `logical_or(lhs, rhs)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| class BASE_HLO_XorOp { |
| string summary = "Logical xor"; |
| |
| string description = [{ |
| Returns `logical_xor(lhs, rhs)` element-wise. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XLA control flow related op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| class BASE_HLO_CaseOp { |
| string summary = "Switch-Case operator"; |
| |
| string description = [{ |
| Returns the result of executing `branches[index]`. If |
| `index` is < 0 or >= N, then `branches[N-1] is executed as |
| the default branch. |
| |
| Each branch `branches[b]` must take in a single argument of same type as |
| `branch_operands[b]` and will be invoked with `branch_operands[b]`. The type |
| of the returned value of each branch must be the same. |
| |
| Note that only one of the branches will be executed depending on the value |
| of index. |
| See https://www.tensorflow.org/xla/operation_semantics#conditional. |
| }]; |
| |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XLA parallelism related op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| class BASE_HLO_ReplicaIdOp { |
| string summary = "ReplicaId operator"; |
| |
| string description = [{ |
| Returns the unique ID (int32 scalar) of the replica. |
| |
| The unique ID of each replica is an unsigned integer in the interval [0, N), |
| where N is the number of replicas. Since all the replicas are running the |
| same program, a ReplicaId() call in the program will return a different |
| value on each replica. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#replicaid. |
| }]; |
| } |
| |
| class BASE_HLO_PartitionIdOp { |
| string summary = "PartitionId operator"; |
| |
| string description = [{ |
| Returns the unique ID (int32 scalar) of the partition. |
| }]; |
| } |
| |
| |
| class BASE_HLO_AllReduceOp { |
| string summary = "AllReduce operator"; |
| |
| string description = [{ |
| Performs a custom reduction across replicas. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#allreduce. |
| }]; |
| } |
| |
| class BASE_HLO_ReduceOp { |
| string summary = "Reduce operator"; |
| |
| string description = [{ |
| Returns the result of executing a reduction function on one or more arrays |
| in parallel. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#reduce. |
| }]; |
| } |
| |
| class BASE_HLO_ReduceWindowOp { |
| string summary = "ReduceWindow operator"; |
| |
| string description = [{ |
| Returns the result of executing a reduction function over all elements in |
| each window of one or more arrays in parallel. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#reducewindow. |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XLA tuple op definitions. |
| //===----------------------------------------------------------------------===// |
| class BASE_HLO_GetTupleElementOp { |
| string summary = "GetTupleElement operator"; |
| |
| string description = [{ |
| Returns a member of a tuple specified by an index. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#gettupleelement. |
| }]; |
| } |
| |
| class BASE_HLO_TupleOp { |
| string summary = "XLA's tuple op"; |
| |
| string description = [{ |
| Groups a set of tensor inputs into a single tuple object. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#tuple. |
| }]; |
| } |
| |
| |
| |
| |
| class BASE_HLO_CompareOp { |
| string summary = "Comparison operator"; |
| |
| string description = [{ |
| Compares `lhs` and `rhs` elementwise according to `comparison_direction` |
| and `compare_type`. If unspecified, `compare_type` is FLOAT for float element |
| types, SIGNED for signed element types and UNSIGNED for unsigned element |
| types. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Quantize op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| class BASE_HLO_DequantizeOp { |
| string summary = "Dequantize operator"; |
| |
| string description = [{ |
| Dequantize the quantized input of packed uint32 to bfloat16. Only uint8 or |
| uint16 is supported for the original unpacked input. |
| |
| Returns a tensor of shape [d0,..., dn * unpack_size] if unpacked input shape |
| is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T), where T is |
| the unpacked input type. If transpose_output is true, will return a tensor |
| of shape [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster |
| when input's rank higher than 1. The input needs to be transposed to use |
| transpose_output feature. |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XLA Slice definitions. |
| //===----------------------------------------------------------------------===// |
| |
| class BASE_HLO_SliceOp { |
| string summary = "Slice operator"; |
| |
| string description = [{ |
| Slices a portion of the `operand` into a new configuration. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#slice. |
| }]; |
| } |
| |
| class BASE_HLO_DynamicSliceOp { |
| string summary = "Dynamic Slice operator"; |
| |
| string description = [{ |
| Extracts a sub-array from the input array at dynamic start_indices. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#dynamicslice. |
| }]; |
| } |
| |
| class BASE_HLO_DynamicUpdateSliceOp { |
| string summary = "Dynamic Update Slice operator"; |
| |
| string description = [{ |
| DynamicUpdateSlice generates a result which is the value of the input array |
| operand, with a slice update overwritten at start_indices. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice. |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XLA Other op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| class BASE_HLO_AllToAllOp { |
| string summary = "AllToAll"; |
| |
| string description = [{ |
| AllToAll is a collective operation that sends data from all cores to all |
| cores. It has two phases: |
| - The scatter phase. On each core, the operand is split into `split_count` |
| number of blocks along the `split_dimension`, and the blocks are |
| scattered to all cores, e.g., the i-th block is sent to the i-th core. |
| - The gather phase. Each core concatenates the received blocks along the |
| `concat_dimension`. |
| |
| The participating cores can be configured by: |
| - replica_groups: each ReplicaGroup contains a list of replica id |
| participating in the computation (replica id for the current replica can |
| be retrieved using ReplicaId op). AllToAll will be applied within |
| subgroups in the specified order. For example, |
| `replica_groups` = {{1,2,3}, {4,5,0}} means that an AllToAll will be applied |
| within replicas {1, 2, 3}, and in the gather phase, the received blocks |
| will be concatenated in the same order of 1, 2, 3. Then, another AllToAll |
| will be applied within replicas 4, 5, 0, and the concatenation order is |
| also 4, 5, 0. If `replica_groups` is empty, all replicas belong to one |
| group, and the concatenation order is the numerical order (0, 1, 2, ...). |
| |
| Prerequisites: |
| - The dimension size of the operand on the split_dimension is divisible by |
| `split_count`. |
| - The operand's shape is not tuple. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#alltoall |
| }]; |
| } |
| |
| class BASE_HLO_BatchNormGradOp { |
| string summary = "Batch Normalization Gradient"; |
| |
| string description = [{ |
| Calculates gradients of batch norm. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad |
| }]; |
| } |
| |
| class BASE_HLO_BatchNormInferenceOp { |
| string summary = "Batch Normalization for Inference"; |
| |
| string description = [{ |
| Normalizes an array across batch and spatial dimensions. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#batchnorminference |
| }]; |
| } |
| |
| class BASE_HLO_BatchNormTrainingOp { |
| string summary = "Batch Normalization for Training"; |
| |
| string description = [{ |
| Normalizes an array across batch and spatial dimensions. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining |
| }]; |
| } |
| |
| class BASE_HLO_BitcastConvertOp { |
| string summary = "BitcastConvert operator"; |
| |
| string description = [{ |
| Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast |
| operation from a data shape to a target shape. The dimensions must match, |
| and the conversion is an element-wise one. Bitcast is implemented as a |
| low-level cast, so machines with different floating-point representations |
| will give different results. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype. |
| }]; |
| } |
| |
| class BASE_HLO_BroadcastOp { |
| string summary = "Broadcast a tensor to a higher rank by prepending dimensions"; |
| |
| string description = [{ |
| Broadcasts the operand tensor to a higher rank by prepending |
| `broadcast_sizes` to the dimensions. The current values of the operand are |
| copied into the other dimensions. |
| |
| This is a more limited form of broadcasting, that corresponds to the XLA |
| client Broadcast method. For a more general form of broadcasting, see the |
| BroadcastInDimOp. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#broadcast. |
| }]; |
| } |
| |
| class BASE_HLO_BroadcastInDimOp { |
| string summary = "Broadcast a tensor into the given shape by adding dimensions."; |
| |
| string description = [{ |
| Broadcasts the `operand` tensor to a higher rank. This is not the limited |
| form of broadcasting exposed as the XLA client broadcast op, but rather the |
| more powerful "InDim" broadcasting, which is closer to the HLO broadcast op |
| and exposed in the XLA client BroadcastInDim method. |
| |
| `broadcast_dimensions` maps the operand dimension number to the target shape |
| dimension number. It must have the same size as the rank of the operand. The |
| mapped dimensions must either be the same size or the dimension being |
| broadcast from must be size 1 (degenerate broadcasting). |
| |
| For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The |
| The scalar value will be broadcast to every element in the target shape. |
| |
| See https://www.tensorflow.org/xla/broadcasting. |
| }]; |
| } |
| |
| class BASE_HLO_CholeskyOp { |
| string summary = "Cholesky operator"; |
| |
| string description = [{ |
| Computes the Cholesky decomposition of a batch of symmetric (Hermitian) |
| positive definite matrices. |
| |
| If lower is true, computes lower-triangular matrices l such that |
| `a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such |
| that `a=Transpose(u).u`. |
| |
| Input data is read only from the lower/upper triangle of a, depending on the |
| value of lower. Values from the other triangle are ignored. Output data is |
| returned in the same triangle; the values in the other triangle are |
| implementation-defined and may be anything. |
| |
| If the rank of a is greater than 2, a is treated as a batch of matrices, where |
| all except the minor 2 dimensions are batch dimensions. |
| |
| If a is not symmetric (Hermitian) positive definite, the result is |
| implementation-defined. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#cholesky. |
| }]; |
| } |
| |
| class BASE_HLO_ClampOp { |
| string summary = "Clamp operator"; |
| |
| string description = [{ |
| Clamps an operand to within the range between a minimum and maximum value. |
| |
| Note: All three arrays must be the same shape. Alternatively, as a |
| restricted form of broadcasting, min and/or max can be a scalar (0D |
| tensor) of the element type of the tensor operand. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#clamp. |
| }]; |
| } |
| |
| class BASE_HLO_CollectivePermuteOp { |
| string summary = "CollectivePermute operator"; |
| |
| string description = [{ |
| CollectivePermute is a collective operation that sends and receives data |
| cross replicas. |
| Note that there are the following restrictions on the source_target_pair: |
| - Any two pairs should not have the same target replica id, and they should |
| not have the same source replica id. |
| - If a replica id is not a target in any pair, then the output on that |
| replica is a tensor consists of 0(s) with the same shape as the input. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#collectivepermute. |
| |
| }]; |
| } |
| class BASE_HLO_ConcatenateOp { |
| string summary = "XLA's concatenate op"; |
| |
| string description = [{ |
| Concatenates a set of tensors along the specified dimension. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#concatenate. |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Common convolution attributes |
| //===----------------------------------------------------------------------===// |
| |
| // TODO(b/129153247) See if it's possible to also validate the size. |
| def HLO_PrecisionConfigAttr: |
| OptionalAttr< |
| TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>; |
| |
| def BoolElementsAttr : |
| ElementsAttrBase< |
| And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">, |
| CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>, |
| "constant boolean vector/tensor attribute"> { |
| let storageType = [{ ::mlir::DenseElementsAttr }]; |
| let returnType = [{ ::mlir::DenseElementsAttr }]; |
| |
| let convertFromStorage = "$_self"; |
| } |
| |
| def ConvolutionAttributes { |
| dag attributes = (ins |
| // Default value: one for each of the spatial dimension. |
| OptionalAttr<I64ElementsAttr>:$window_strides, |
| // Default value: zero for each of the spatial dimension. |
| OptionalAttr<I64ElementsAttr>:$padding, |
| // Default value: one for each of the spatial dimension. |
| OptionalAttr<I64ElementsAttr>:$lhs_dilation, |
| // Default value: one for each of the spatial dimension. |
| OptionalAttr<I64ElementsAttr>:$rhs_dilation, |
| // Default value: one for each of the spatial dimension. |
| OptionalAttr<BoolElementsAttr>:$window_reversal, |
| ConvDimensionNumbers:$dimension_numbers, |
| I64Attr:$feature_group_count, |
| I64Attr:$batch_group_count, |
| HLO_PrecisionConfigAttr:$precision_config |
| ); |
| } |
| |
| class BASE_HLO_ConvOp { |
| string summary = "Convolution operator"; |
| |
| string description = [{ |
| Computes a convolution of the kind used in neural networks. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#conv_convolution. |
| }]; |
| |
| code extraClassDeclaration = [{ |
| bool hasWindowReversal() { |
| auto reversal = window_reversalAttr(); |
| return reversal && llvm::any_of(reversal.getBoolValues(), |
| [](bool v) { return v; }); |
| } |
| }]; |
| } |
| |
| class BASE_HLO_CopyOp { |
| string summary = "Copy operator"; |
| |
| string description = [{ |
| Returns a copy of `operand`. |
| }]; |
| } |
| |
| class BASE_HLO_CrossReplicaSumOp { |
| string summary = "Sums input across replicated instances."; |
| |
| string description = [{ |
| For each of the replica groups, operands of the group devices are summed |
| so that each device has the sum. |
| |
| For example, suppose there are 8 TPU devices: `[A, B, C, D, E, F, G, H]`. |
| Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, |
| and `B, D, F, H` as group 1. Thus we get the outputs: |
| `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#crossreplicasum. |
| }]; |
| } |
| |
| |
| class BASE_HLO_CustomCallOp { |
| string summary = "CustomCall operator"; |
| |
| string description = [{ |
| A custom call invokes code external to XLA. The `args` are passed to the |
| external code, and the external code is expected to produce a result of the |
| given type. The exact mechanism is backend-specific. For example, in the CPU |
| backend, a call instruction is emitted which targets a symbol with the name |
| `call_target_name`. |
| |
| `call_target_name` and `backend_config` can be arbitrary strings, but |
| `call_target_name` should be short as it may be used in labels. |
| `backend_config` can encode arbitrarily large amounts of information. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#customcall. |
| }]; |
| } |
| |
| class BASE_HLO_DotOp { |
| string summary = "Dot operator"; |
| string description = [{ |
| Performs dot products between vectors, vector/matrix and matrix/matrix |
| multiplication. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#dot. |
| }]; |
| } |
| |
| class BASE_HLO_DotGeneralOp { |
| string summary = "General Dot operator"; |
| string description = [{ |
| Performs general dot products between vectors, vector/matrix and |
| matrix/matrix multiplication. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#dotgeneral. |
| }]; |
| } |
| |
| class BASE_HLO_FftOp { |
| string summary = "Fast fourier transform operator"; |
| |
| string description = [{ |
| Returns the fast-fourier-transform of the input array. |
| |
| See |
| https://www.tensorflow.org/xla/operation_semantics#fft. |
| }]; |
| } |
| |
| class BASE_HLO_GatherOp{ |
| string summary = "Gather operator"; |
| |
| string description = [{ |
| Stitches together several slices of an input array. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#gather. |
| }]; |
| } |
| |
| class BASE_HLO_MapOp { |
| string summary = "Map operator"; |
| |
| string description = [{ |
| Applies a scalar function over the given operands arrays, producing an array |
| of the same dimensions where each element is the result of the mapped function |
| applied to the corresponding elements in the input arrays. |
| |
| The mapped function is an arbitrary computation with the restriction that it |
| has N inputs of scalar type T and a single output with type S. The output has |
| the same dimensions as the operands except that the element type T is replaced |
| with S. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#map. |
| }]; |
| } |
| |
| class BASE_HLO_ReshapeOp { |
| string summary = "Reshape operator"; |
| |
| string description = [{ |
| Reshapes the dimensions of `operand` into a new configuration. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#reshape. |
| }]; |
| } |
| |
| class BASE_HLO_ScatterOp { |
| string summary = "Scatter operator"; |
| |
| string description = [{ |
| Generates a result which is the value of the input array `operand`, |
| with several slices (at indices specified by `scatter_indices`) |
| updated with the values in `updates` using `update_computation`. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#scatter. |
| }]; |
| } |
| |
| class BASE_HLO_SelectOp { |
| string summary = "Select operator"; |
| |
| string description = [{ |
| Constructs an output tensor from the elements of `on_true` and `on_false` |
| based on the values of `pred`. |
| |
| `pred`, `on_true` and `on_false` must be broadcast compatible. |
| }]; |
| } |
| |
| class BASE_HLO_SelectAndScatterOp { |
| string summary = "SelectAndScatter operator"; |
| |
| string description = [{ |
| Runs a windowed selection `select` function over `operand` with shape |
| `window_dimensions` and stride `window_strides`. This will produce an amount |
| of selected locations whose shape matches `source`. These are then scattered |
| to the output which is initialized with `init_value`. |
| Multiple scattered elements which land in the same output location are |
| combined using the `scatter` function. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#selectandscatter. |
| }]; |
| } |
| |
| class BASE_HLO_SetDimensionSizeOp { |
| string summary = "SetDimensionSize operator"; |
| |
| string description = [{ |
| Sets the dynamic size of operand's given dimension. Pass through the operand |
| as result, with dynamic dimension tracked by the compiler. Padded values |
| will be ignored by downstream reduction ops. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#setdimensionsize. |
| }]; |
| } |
| |
| class BASE_HLO_SortOp { |
| string summary = "Sort operator"; |
| |
| string description = [{ |
| Sorts the given `operands` at the given `dimension` with the given |
| `comparator`. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#sort. |
| }]; |
| } |
| |
| class BASE_HLO_ReverseOp { |
| string summary = "Reverse operator"; |
| |
| string description = [{ |
| Reverses the specified dimensions of `operand` according to the given |
| `dimensions`. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. |
| }]; |
| } |
| |
| class BASE_HLO_PadOp { |
| string summary = "Pad operator"; |
| |
| string description = [{ |
| Pads the edges of `operand` with the `padding_value` and according to |
| the passed configuration. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#pad. |
| }]; |
| } |
| |
| class BASE_HLO_TraceOp { |
| string summary = "Trace operator"; |
| |
| string description = [{ |
| Emits a logging message `tag` with the `operand`. |
| }]; |
| } |
| |
| class BASE_HLO_TransposeOp { |
| string summary = "Transpose operator"; |
| |
| string description = [{ |
| Permutes the dimensions of `operand` according to the given `permutation`. |
| |
| `res_dimensions[i] = operand_dimensions[permutation[i]]` |
| |
| See https://www.tensorflow.org/xla/operation_semantics#transpose. |
| }]; |
| } |
| |
| class BASE_HLO_TriangularSolveOp { |
| string summary = "TriangularSolve operator"; |
| |
| string description = [{ |
| Solves systems of linear equations with lower or upper triangular |
| coefficient matrices by forward- or back-substitution. Broadcasting along |
| leading dimensions, this routine solves one of the matrix systems |
| op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where |
| op(a) is either op(a) = a, or op(a) = Transpose(a), or |
| op(a) = Conj(Transpose(a)). |
| |
| Input data is read only from the lower/upper triangle of a, depending on the |
| value of lower. Values from the other triangle are ignored. Output data is |
| returned in the same triangle; the values in the other triangle are |
| implementation-defined and may be anything. |
| |
| If the rank of a and b are greater than 2, they are treated as batches of |
| matrices, where all except the minor 2 dimensions are batch dimensions. a |
| and b must have equal batch dimensions. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#triangularsolve. |
| }]; |
| |
| } |
| |
| class BASE_HLO_RngUniformOp { |
| string summary = "RNG with uniform distribution."; |
| |
| string description = [{ |
| Constructs an output of a given shape with random numbers generated |
| following the uniform distribution over the interval `[a,b)`. The parameters |
| and output element type have to be a boolean type, an integral type or a |
| floating point types, and the types have to be consistent. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#rnguniform. |
| }]; |
| } |
| |
| class BASE_HLO_RngNormalOp { |
| string summary = "RNG with normal distribution."; |
| |
| string description = [{ |
| Constructs an output of a given shape with random numbers generated |
| following the normal distribution with parameters `mu` and `sigma`. The |
| parameters and output shape have to have a floating point elemental type. |
| The parameters furthermore have to be scalar valued. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#rngnormal. |
| }]; |
| } |
| |
| class BASE_HLO_ReducePrecisionOp { |
| string summary = "Reduce precision operator"; |
| |
| string description = [{ |
| Models the effect of converting floating - point values to a lower - |
| precision format(such as IEEE - FP16) and back to the original |
| format. The number of exponent and mantissa bits in the lower - |
| precision format can be specified arbitrarily, |
| although all bit sizes may not be supported on all hardware |
| implementations. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#reduceprecision. |
| }]; |
| } |
| |
| class BASE_HLO_InfeedOp { |
| string summary = "Infeed operator"; |
| |
| string description = [{ |
| Reads a single data item from the implicit Infeed streaming interface of |
| the device, interpreting the data as the given shape and its layout, and |
| returns an LHLO op of the data. Multiple Infeed operations are allowed in a |
| computation, but there must be a total order among the Infeed operations. |
| For example, two Infeeds in the code below have a total order since there |
| is a dependency between the while loops. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#infeed |
| }]; |
| } |
| |
| class BASE_HLO_WhileOp { |
| string summary = "While operator"; |
| |
| string description = [{ |
| Returns the result of executing a body function until the cond body returns |
| true. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#while. |
| }]; |
| } |
| |
| class BASE_HLO_BitcastOp { |
| string summary = "Bitcast operator"; |
| |
| string description = [{ |
| This op changes the shape of the input in the way that the physical |
| arrangement of elements are unchanged. |
| |
| However, the op needs layout information to make sense of "physical |
| arrangement of elements". Layout support in MHLO is currently under |
| exploration. |
| }]; |
| } |
| |
| #endif // HLO_OPS_BASE |