blob: c633bd2cd81d1f351268c6e3e9e509942060bd38 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Defines "client" aligned HLO ops.
// These ops are not necessarily orthogonal or optimized for transformation but
// for ease of expression in certain cases deemed important for client
// libraries (i.e. implicit broadcasting, helper ops, etc).
// This dialect is considered to exist in addition to augment the mhlo
// dialect for ergonomic needs, not duplicate/replace it.
//
// The typical use of this dialect is for client libraries to be able to emit
// less constrained ops and rely on the conversion framework to lower any
// chlo ops to canonical mhlo ops.
//
// See: https://www.tensorflow.org/xla/operation_semantics
#ifndef CHLO_OPS
#define CHLO_OPS
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
def HLOClient_Dialect : Dialect {
let name = "chlo";
let cppNamespace = "::mlir::chlo";
let summary = [{
Client HLO Ops
}];
let description = [{
This dialect contains ops that align closely with the API surface area
of the XlaBuilder C++ API, where such ops have semantics that go beyond
what exists in the lower level dialects (such as `mhlo`). Essentially,
whenever the client library uses syntactic sugar or composition
of multiple ops for an API call, this dialect tries to model the API call
and provide conversion patterns to fully materialize into lower level
dialects.
}];
}
class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
Op<HLOClient_Dialect, mnemonic, traits> {
// TODO(b/129012527) Much of this custom verification should be expressed as
// type constraints.
let verifier = [{ return Verify(*this); }];
}
//===----------------------------------------------------------------------===//
// CHLO binary elementwise op definitions.
// From the client perspective, each of these support both explicit rank
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
// shape broadcasting.
//
// These correspond to operations in the mhlo dialect without the
// "broadcast_" prefix, except that those ops require same-shaped operands and
// results.
//
// See:
// https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
// https://www.tensorflow.org/xla/broadcasting
//===----------------------------------------------------------------------===//
class HLOClient_BroadcastBinaryElementwiseOp<
string mnemonic, list<OpTrait> traits> :
HLOClient_Op<mnemonic,
!listconcat(traits, [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes"]>])> {
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
// padded rank-broadcast semantics if omitted.
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
let builders = [
OpBuilderDAG<(ins "Value":$left, "Value":$right,
"DenseIntElementsAttr":$broadcast_dimensions)>];
let results = (outs HLO_Tensor);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:`
`(` type($lhs) `,` type($rhs) `)` `->` type(results)
}];
}
def HLOClient_BroadcastAddOp : HLOClient_BroadcastBinaryElementwiseOp<"broadcast_add",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Addition operator (with optional broadcasting)";
string description = [{
Returns `lhs + rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastAtan2Op : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_atan2",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Atan2 operator (with optional broadcasting)";
string description = [{
Returns `atan2(lhs/rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastDivOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_divide",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Division operator (with optional broadcasting)";
string description = [{
Returns `lhs / rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastMaxOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_maximum",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Maximum operator (with optional broadcasting)";
string description = [{
Returns `max(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastMinOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_minimum",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Minimum operator (with optional broadcasting)";
string description = [{
Returns `min(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastMulOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_multiply",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Multiplication operator (with optional broadcasting)";
string description = [{
Returns `lhs * rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_power",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Power operator (with optional broadcasting)";
string description = [{
Returns `lhs ^ rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastRemOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_remainder",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Remainder operator (with optional broadcasting)";
string description = [{
Returns `lhs % rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastShiftLeftOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_shift_left",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Shift left operator (with optional broadcasting)";
string description = [{
Returns `lhs << rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastShiftRightArithmeticOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_shift_right_arithmetic",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Shift right arithmetic operator (with optional broadcasting)";
string description = [{
Returns `lhs >> rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastShiftRightLogicalOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_shift_right_logical",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Shift right logical operator (with optional broadcasting)";
string description = [{
Returns `lhs >> rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_subtract",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Subtraction operator (with optional broadcasting)";
string description = [{
Returns `lhs - rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
//===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions.
// The same description as the arithmetic binary elementwise ops applies.
//===----------------------------------------------------------------------===//
class HLOClient_BroadcastBinaryLogicalElementwiseOp<string mnemonic> :
HLOClient_BroadcastBinaryElementwiseOp<
mnemonic, [Commutative, NoSideEffect]> {
let arguments = (ins
HLO_PredOrIntTensor:$lhs,
HLO_PredOrIntTensor:$rhs,
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
// padded rank-broadcast semantics if omitted.
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
}
def HLOClient_BroadcastAndOp: HLOClient_BroadcastBinaryLogicalElementwiseOp<
"broadcast_and"> {
string summary = "Logical and operator (with optional broadcasting)";
string description = [{
Returns `logical_and(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastOrOp: HLOClient_BroadcastBinaryLogicalElementwiseOp<
"broadcast_or"> {
string summary = "Logical or operator (with optional broadcasting)";
string description = [{
Returns `logical_or(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp<
"broadcast_xor"> {
string summary = "Logical xor operator (with optional broadcasting)";
string description = [{
Returns `logical_xor(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
//===----------------------------------------------------------------------===//
// Broadcasting complex op
//===----------------------------------------------------------------------===//
def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_complex", [NoSideEffect]> {
string summary = "Complex operator (with optional broadcasting)";
string description = [{
Performs element-wise conversion of a pair of real and imaginary values to
a complex value.
}];
let arguments = (ins
HLO_FpTensor:$lhs,
HLO_FpTensor:$rhs,
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
// padded rank-broadcast semantics if omitted.
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
let results = (outs HLO_ComplexTensor);
}
//===----------------------------------------------------------------------===//
// Unary op
//===----------------------------------------------------------------------===//
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
Type TensorType> : HLOClient_Op<mnemonic, !listconcat(traits, [
InferFusibilityOpInterface, NoSideEffect, SameOperandsAndResultType])> {
let arguments = (ins TensorType:$operand);
let results = (outs TensorType:$result);
let assemblyFormat = "$operand attr-dict `:` type($operand)";
}
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
HLO_FpOrComplexTensor> {
let summary = "Acos operator";
let description = [{
Returns `Acos(operand)` element-wise.
$$
\acos(x) = 2 * \atan(\sqrt(1 - x^2) / (1 + x)) if x != -1
= pi if x == -1
$$
}];
}
def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
HLO_FpOrComplexTensor> {
let summary = "Asin operator";
let description = [{
Returns `Asin(operand)` element-wise.
$$
\asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
$$
}];
}
def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [],
HLO_FpOrComplexTensor> {
let summary = "Asinh operation";
let description = [{
Returns `Asinh(operand)` element-wise.
$$
\asinh(x) = log(x + sqrt(x^2 + 1))
$$
}];
}
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
HLO_FpOrComplexTensor> {
let summary = "Atan operator";
let description = [{
Returns `Atan(operand)` element-wise.
$$
\atan(x) = \atan2(x, 1)
$$
}];
}
def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [],
HLO_FpOrComplexTensor> {
let summary = "Atanh operator";
let description = [{
Returns `Atanh(operand)` element-wise.
$$
\atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
= nan otherwise
$$
}];
}
def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
HLO_FpOrComplexTensor> {
let summary = "Conj operator";
let description = [{
Returns `Conj(operand)` element-wise.
$$
\conj(x) = (\real(x), \neg(\imag(x)))
$$
}];
}
def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh", [],
HLO_FpOrComplexTensor> {
let summary = "Cosh operator";
let description = [{
Returns `Cosh(operand)` element-wise.
$$
\cosh(x) = (e^x + e^-x) / 2
$$
}];
}
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
HLO_FpOrComplexTensor> {
let summary = "Sinh operation";
let description = [{
Returns `Sinh(operand)` element-wise.
$$
\sinh(x) = (e^x - e^-x) / 2 if |x| < 1
= e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
$$
}];
}
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [],
HLO_FpOrComplexTensor> {
let summary = "Tan operation";
let description = [{
Returns `Tan(operand)` element-wise.
$$
\tan(x) = \sin(x) / \cos(x)
$$
}];
}
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
[NoSideEffect, SameOperandsAndResultShape,
InferTypeOpInterface,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
NativeOpTrait<"InferTensorType">]> {
let summary = "Constant like operator";
let description = [{
Returns a splat constant of the same shape as the operand.
}];
// TODO(jpienaar): value's type could be tightened.
let arguments = (ins AnyAttr:$value, HLO_Tensor:$operand);
let results = (outs HLO_Tensor);
let hasCanonicalizer = 1;
}
def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf",
[NoSideEffect, SameOperandsAndResultShape],
HLO_FpTensor> {
let summary = "Erfc operator";
let description = [{
Computes the Gauss error function of `x` element-wise.
erf(x) = erf_impl(x) if |x| < 1
= 1 - erfc_impl(x) otherwise
}];
}
def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
[NoSideEffect, SameOperandsAndResultShape],
HLO_FpTensor> {
let summary = "Erfc operator";
let description = [{
Computes an approximation of the error function complement (1 - erf(x)).
erfc(x) = erfc_impl(x) if |x| > 1
= 1 - erf_impl(x) otherwise
}];
}
//===----------------------------------------------------------------------===//
// Broadcasting compare op
//===----------------------------------------------------------------------===//
def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_compare", [NoSideEffect]> {
string summary = "Compare operator (with optional broadcasting)";
string description = [{
Compares `lhs` and `rhs` elementwise according to `comparison_direction`
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
types, SIGNED for signed element types and UNSIGNED for unsigned element
types.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
}];
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
);
let results = (outs HLO_PredTensor);
let builders = [
OpBuilderDAG<(ins "Value":$lhs, "Value":$rhs,
"DenseIntElementsAttr":$broadcast_dimensions,
"StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>];
}
#endif // CHLO_OPS