blob: 4e8d613cd97a0cdfcb28088fe196f9c52f5559f3 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef HLO_OPS_BASE
#define HLO_OPS_BASE
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
def HLO_Dialect : Dialect {
let name = "mhlo";
let cppNamespace = "::mlir::mhlo";
let emitAccessorPrefix = kEmitAccessorPrefix_Raw;
let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;
}
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td"
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
// TODO(hinsu): Use signed integers instead of signless integer which is being
// used for legacy reasons.
def HLO_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
def HLO_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;
def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
// The broadcasting dimensions correspond to a tuple that describes how a
// smaller rank shape is broadcast into a larger rank shape. For example,
// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
// matching the matrix to dimensions 1 and 2 of the cuboid.
defvar BroadcastDimAttr = I64ElementsAttr;
//===----------------------------------------------------------------------===//
// MHLO on tensors type definitions.
//===----------------------------------------------------------------------===//
// Token type.
def HLO_Token : Type<CPred<"$_self.isa<TokenType>()">, "token">;
// Any integer tensor types
def HLO_IntTensor : TensorOf<[HLO_Int]>;
// Any integer tensor type with rank 0 (i.e. representing a single integer).
def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>;
// Any floating-point tensor types
def HLO_FpTensor : TensorOf<[AnyFloat]>;
def HLO_PredTensor : TensorOf<[HLO_Pred]>;
def HLO_Tensor : TensorOf<[AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
def HLO_ComplexTensor : TensorOf<[HLO_Complex]>;
def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
def HLO_TensorOrToken : AnyTypeOf<[HLO_Tensor, HLO_Token]>;
def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>;
def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>;
// Dynamic representation of a shape vector as a tensor.
def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>;
// In general, static shaped tensor constraints should be avoided unless
// it is for a legacy op which is only correct with static shapes.
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
//===----------------------------------------------------------------------===//
// MHLO on tensors combined type definitions.
//===----------------------------------------------------------------------===//
// Any integer or floating-point tensor types
def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>;
// Any integer or predicate tensor types
def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
// Any floating-point or complex tensor types
def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, HLO_Complex]>;
// Any int, floating-point or complex tensor types
def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
// Any pred, int or floating-point tensor types
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
// A layout attribute (1D tensor of index type)
def HLO_LayoutAttr : Attr<
And<[IndexElementsAttr.predicate,
CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank()
== 1}]>]>,
"A 1D tensor of index type (layout)"> {
let storageType = IndexElementsAttr.storageType;
let returnType = IndexElementsAttr.returnType;
let convertFromStorage = IndexElementsAttr.convertFromStorage;
}
// An array of layout (1D tensor) attributes.
def HLO_ArrayOfLayoutAttr : TypedArrayAttrBase<HLO_LayoutAttr,
"Array of layout (1D tensor of index type) attributes">;
// An array of FlatSymbolRef attributes that can be used as a default valued
// attribute.
def HLO_FlatSymbolRefArrayAttr :
TypedArrayAttrBase<FlatSymbolRefAttr, "flat symbol ref array attribute"> {
let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)";
}
//===----------------------------------------------------------------------===//
// Common convolution attributes
//===----------------------------------------------------------------------===//
// TODO(b/129153247) See if it's possible to also validate the size.
def HLO_PrecisionConfigAttr:
OptionalAttr<
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
def BoolElementsAttr :
ElementsAttrBase<
And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
"constant boolean vector/tensor attribute"> {
let storageType = [{ ::mlir::DenseElementsAttr }];
let returnType = [{ ::mlir::DenseElementsAttr }];
let convertFromStorage = "$_self";
}
def ConvolutionAttributes {
dag attributes = (ins
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<BoolElementsAttr>:$window_reversal,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
}
//===----------------------------------------------------------------------===//
// Common traits
//===----------------------------------------------------------------------===//
class HLO_NativeOpTrait<string name> : NativeOpTrait<name> {
let cppNamespace = "::mlir::mhlo::OpTrait";
}
// An operation that is essentially element-wise but may implement broadcasting
// semantics.
def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">;
// Op has pairwise operand and result type matching: the number of operands
// must be equal to the number of results and the type of ith operand must
// match the type of ith result.
// TODO(b/195086460) Promote this to be an mlir trait and remove it here.
def HLO_PairwiseSameOperandAndResultType :
HLO_NativeOpTrait<"PairwiseSameOperandAndResultType">;
#endif // HLO_OPS_BASE