blob: 9902174813c75d3f7518f60e9ea7b72b0136cb53 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License.
// This file defines the operations used in the MHLO dialect.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <cstdint>
#include <functional>
#include <numeric>
#include <set>
#include <unordered_map>
#include <utility>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "mlir-hlo/Dialect/mhlo/IR/"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
#include "mlir-hlo/utils/convert_op_folder.h"
#include "mlir-hlo/utils/hlo_utils.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/InliningUtils.h"
namespace mlir {
#include ""
} // namespace mlir
#include "mlir-hlo/Dialect/mhlo/IR/"
namespace mlir {
namespace mhlo {
template <typename T>
static LogicalResult Verify(T op) {
return success();
namespace {
// Utilities for the canonicalize patterns
// Clamps value to the range [lower, upper]. Requires lower <= upper.
template <typename T>
static T Clamp(const T& value, const T& lower, const T& upper) {
assert(lower <= upper);
return std::max(lower, std::min(value, upper));
// Verifies that dimension attribute for the op correctly indexes in operand or
// result shape.
template <typename OpT>
static LogicalResult VerifyDimAttr(OpT op) {
int64_t rank = -1;
if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
rank = ty.getRank();
} else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
rank = ty.getRank();
} else {
return success();
int64_t dim = op.dimension();
if (dim < 0 || dim >= rank)
return op.emitOpError() << "requires dimension attribute in range [0, "
<< rank << "); found (" << dim << ")";
return success();
// Returns 1D 64-bit dense elements attribute with the given values.
DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
Builder* builder) {
RankedTensorType ty = RankedTensorType::get(
{static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
return DenseIntElementsAttr::get(ty, values);
// Given the start indices and slice sizes for a dynamic-slice that can be
// converted to a static slice, returns the limits for the static slice.
DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
DenseIntElementsAttr slice_sizes,
Builder* builder) {
SmallVector<int64_t, 4> slice_limits;
for (int64_t i = 0; i < slice_sizes.getNumElements(); ++i) {
int64_t start_index = start_indices.getValues<IntegerAttr>()[i].getInt();
int64_t slice_size = slice_sizes.getValues<IntegerAttr>()[i].getInt();
slice_limits.push_back(start_index + slice_size);
return GetI64ElementsAttr(slice_limits, builder);
/// Replaces the given op with the contents of the given single-block region,
/// using the operands of the block terminator to replace operation results.
static void ReplaceOpWithRegion(PatternRewriter& rewriter, Operation* op,
Region& region, ValueRange blockArgs = {}) {
assert(llvm::hasSingleElement(region) && "expected single-block region");
Block* block = &region.front();
Operation* terminator = block->getTerminator();
ValueRange results = terminator->getOperands();
rewriter.mergeBlockBefore(block, op, blockArgs);
rewriter.replaceOp(op, results);
#include ""
// Common shape function helper for RngNormal and RngUniform.
static LogicalResult rngInferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
if (operands.size() != 3)
return emitOptionalError(location, "expected 3 operands");
SmallVector<int64_t> shapeVector;
Value shapeOperand = operands[2];
auto shapeOperandType = shapeOperand.getType().cast<ShapedType>();
Type elementType = getElementTypeOrSelf(operands[1]);
// Match constant shape arguments.
DenseIntElementsAttr shape;
if (!matchPattern(shapeOperand, m_Constant(&shape))) {
if (!shapeOperandType.hasRank()) {
return success();
if (shapeOperandType.getRank() != 1)
return emitOptionalError(location, "shape operand required to be 1D");
int size = shapeOperandType.getDimSize(0);
if (size == ShapedType::kDynamicSize) {
return success();
shapeVector.resize(size, ShapedType::kDynamicSize);
inferredReturnShapes.emplace_back(shapeVector, elementType);
return success();
for (const APInt& fp : shape.getValues<APInt>())
inferredReturnShapes.emplace_back(shapeVector, elementType);
return success();
// Returns a new scalar integer value having type `type`. Here `type` must be
// an integer or index type.
Value MaybeCastTo(OpBuilder& b, Location loc, Value value, Type type) {
if (type == value.getType()) return value;
assert(type.isIndex() || value.getType().isIndex());
return b.create<arith::IndexCastOp>(loc, value, type);
} // namespace
// ReduceScatterOp
static LogicalResult Verify(ReduceScatterOp op) {
return mlir::hlo::VerifyReduceScatter(
// ConstOp
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
// Return the held attribute value.
return value();
// Builds a constant op with the specified attribute `value`.
void ConstOp::build(OpBuilder& builder, OperationState& result,
Attribute value) {
Type type;
if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
type = elemAttr.getType();
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
value.isa<IntegerAttr>()) {
// All XLA types must be tensor types. In the build() method, we want to
// provide more flexibility by allowing attributes of scalar types. But we
// need to wrap it up with ElementsAttr to construct valid XLA constants.
type = RankedTensorType::get(/*shape=*/{}, value.getType());
value = DenseElementsAttr::get(type.cast<TensorType>(), value);
// TODO: support other XLA specific types.
assert(type && "unsupported attribute type for building mhlo.constant");
result.addAttribute("value", value);
// CustomCallOp
static LogicalResult Verify(CustomCallOp op) {
// If both operand and result layout attributes are not specified then nothing
// to verify.
if (!op.operand_layouts().hasValue() && !op.result_layouts().hasValue())
return success();
// Layout constraints for either both operands & results or none should be
// specified.
if (op.operand_layouts().hasValue() != op.result_layouts().hasValue())
return op.emitOpError() << "Layout attributes should be specified for "
"either both operands and results or none.";
// Helper function to verify types and the corresponding layouts.
auto verify_types_and_layouts =
[&op](TypeRange types, mlir::ArrayAttr layouts,
const std::string& value_name) -> LogicalResult {
if (types.size() != layouts.size())
return op.emitOpError()
<< "Number of " << value_name << "s must match the number of "
<< value_name << " layouts, " << types.size()
<< " != " << layouts.size();
for (const auto& indexed_type_and_layout :
llvm::enumerate(llvm::zip(types, layouts))) {
// Get index for more descriptive error message.
auto index = indexed_type_and_layout.index();
auto type = std::get<0>(indexed_type_and_layout.value());
auto layout = std::get<1>(indexed_type_and_layout.value())
if (type.isa<TupleType>())
return op.emitOpError() << "Tuple types are not fully supported with "
"layout constraints yet";
auto tensor_type = type.dyn_cast<TensorType>();
// For non-tensor types such as !mhlo.token, the layout should be empty.
if (!tensor_type) {
if (layout.empty()) continue;
return op.emitOpError()
<< "Only tensor types can have non-empty layout: " << value_name
<< " #" << index << " of type " << type << " has layout "
<< layout;
// For unranked tensors, we cannot verify the compatibility with layout
// any further.
if (!tensor_type.hasRank()) continue;
// Layout must be a permutation of [0, N) where N is the rank of the
// tensor type.
std::vector<int64_t> range(tensor_type.getRank());
std::iota(range.begin(), range.end(), 0);
if (tensor_type.getRank() != layout.size() ||
!std::is_permutation(range.begin(), range.end(), layout.begin()))
return op.emitOpError()
<< "incorrect layout " << layout << " for type " << type
<< ", layout must be a permutation of [0, "
<< tensor_type.getRank() << ")";
return success();
// At this point both `operand_layouts` and `result_layouts` are defined.
ArrayAttr operand_layouts = op.operand_layouts().getValue();
ArrayAttr result_layouts = op.result_layouts().getValue();
// Full support for layouts for arbitrary nesting of tuples is not
// supported yet.
// If result does not have any tuples, then i-th element of `result_layouts`
// specifies the layout constraints on i-th result.
// For the common case of a single tuple result packing non-tuple values, the
// i-th element of `result_layouts` specifies layout for i-th element of the
// result tuple.
TypeRange result_types;
if (op->getNumResults() == 1 && op->getResult(0).getType().isa<TupleType>())
result_types = op->getResult(0).getType().cast<TupleType>().getTypes();
result_types = op->getResultTypes();
// Verify that operands and operand layouts match.
if (failed(verify_types_and_layouts(op->getOperandTypes(), operand_layouts,
return failure();
// Verify that results and result layouts match.
return verify_types_and_layouts(result_types, result_layouts, "result");
// DotOp
namespace {
bool dimCompatible(int64_t a, int64_t b) {
return ShapedType::kDynamicSize == a || ShapedType::kDynamicSize == b ||
a == b;
ShapedType inferDotReturnType(ShapedType lhs, ShapedType rhs) {
auto element_type = lhs.getElementType();
if (!lhs.hasRank() || !rhs.hasRank()) {
return UnrankedTensorType::get(element_type);
// vector dot vector
if (1 == lhs.getRank() && 1 == rhs.getRank() &&
dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) {
return RankedTensorType::get({}, element_type);
// matrix dot vector
if (2 == lhs.getRank() && 1 == rhs.getRank() &&
dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) {
return RankedTensorType::get({lhs.getDimSize(0)}, element_type);
// vector dot matrix
if (1 == lhs.getRank() && 2 == rhs.getRank() &&
dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) {
return RankedTensorType::get({rhs.getDimSize(1)}, element_type);
// matrix dot matrix
if (2 == lhs.getRank() && 2 == rhs.getRank() &&
dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) {
int64_t shape[2] = {lhs.getDimSize(0), rhs.getDimSize(1)};
return RankedTensorType::get(shape, element_type);
return {};
} // namespace
static LogicalResult Verify(DotOp op) {
auto lhs_type = op.lhs().getType().cast<ShapedType>();
auto rhs_type = op.rhs().getType().cast<ShapedType>();
auto result_type = op.getType().cast<ShapedType>();
auto expect_return_type = inferDotReturnType(lhs_type, rhs_type);
if (!expect_return_type) {
return op.emitError() << "Unexpected operands type: " << lhs_type << " and "
<< rhs_type;
if (result_type.hasRank() && expect_return_type.hasRank()) {
if (result_type.getShape() != expect_return_type.getShape()) {
return op.emitError()
<< "Unexpected result type: has " << result_type
<< " but inferred " << expect_return_type << " from operands "
<< lhs_type << " and " << rhs_type;
return success();
// DotGeneralOp
static LogicalResult Verify(DotGeneralOp op) {
auto dot_dimension_numbers = op.dot_dimension_numbers();
int64_t lhs_batching_dimensions_size =
int64_t rhs_batching_dimensions_size =
if (lhs_batching_dimensions_size != rhs_batching_dimensions_size) {
return op.emitError()
<< "lhs and rhs should have the same number of batching dimensions";
int64_t lhs_contracting_dimensions_size =
int64_t rhs_contracting_dimensions_size =
if (lhs_contracting_dimensions_size != rhs_contracting_dimensions_size) {
return op.emitError() << "lhs and rhs should have the same number of "
"contracting dimensions";
return success();
// GatherOp
// Converts gather ops to slice ops in case we have a single set of constant
// indices.
struct GatherSlice : public OpRewritePattern<GatherOp> {
using OpRewritePattern<GatherOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter& rewriter) const override {
DenseIntElementsAttr index;
if (!matchPattern(gather.start_indices(), m_Constant(&index)))
return failure();
const auto& dnums = gather.dimension_numbers();
if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1)
return failure();
// TODO(tberghammer): Remove when the verifier catches this case what is
// invalid if all previous condition holds.
if (index.getNumElements() != dnums.getStartIndexMap().size())
return failure();
RankedTensorType operand_type =
if (!operand_type || !operand_type.hasStaticShape()) return failure();
auto slice_end =
llvm::SmallVector<int64_t, 8> slice_start(slice_end.size(), 0);
for (auto it :
llvm::zip(dnums.getStartIndexMap(), index.getValues<APInt>())) {
int64_t map_index = std::get<0>(it);
// Clamp the indices within bounds to faithfully mirror gather semantics.
int64_t offset =
Clamp(std::get<1>(it).getSExtValue(), static_cast<int64_t>(0),
operand_type.getDimSize(map_index) - slice_end[map_index]);
slice_start[map_index] += offset;
slice_end[map_index] += offset;
llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
llvm::SmallVector<int64_t, 8> slice_shape(slice_end.size());
for (size_t i = 0; i < slice_end.size(); ++i) {
slice_shape[i] = slice_end[i] - slice_start[i];
Type element_type = gather.getType().cast<TensorType>().getElementType();
auto slice_type = RankedTensorType::get(slice_shape, element_type);
Value result = rewriter.create<SliceOp>(
gather.getLoc(), slice_type, gather.getOperand(0),
GetI64ElementsAttr(slice_start, &rewriter),
GetI64ElementsAttr(slice_end, &rewriter),
GetI64ElementsAttr(slice_stride, &rewriter));
auto collapsed_slice_dims = dnums.getCollapsedSliceDims();
if (!collapsed_slice_dims.empty()) {
llvm::SmallVector<int64_t, 8> reshape_shape;
for (size_t i = 0; i < slice_shape.size(); ++i) {
if (llvm::count(collapsed_slice_dims, i) == 0) {
auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
result =
rewriter.create<ReshapeOp>(gather.getLoc(), reshape_type, result);
rewriter.replaceOp(gather, result);
return success();
void GatherOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
namespace {
// following
// The bounds for the output array along dimension i is computed as follows:
// (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k)
// then we pick
// the corresponding dimension bounds out of start_indices.shape, skipping
// index_vector_dim
// (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and
// start_indices.shape.dims[k+1] otherwise).
// (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k)
// then we pick
// the corresponding bound out of slice_sizes after accounting for
// collapsed_slice_dims
// (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is
// slice_sizes with the bounds at indices collapsed_slice_dims removed).
void getSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc,
ValueRange operands,
SmallVectorImpl<Value>& slice_sizes) {
for (int64_t val : gather->slice_sizes().getValues<int64_t>()) {
slice_sizes.push_back(builder.create<arith::ConstantIndexOp>(loc, val));
void getSliceSizeValues(DynamicGatherOp* d_gather, OpBuilder& builder,
Location loc, ValueRange operands,
SmallVectorImpl<Value>& slice_size_values) {
DynamicGatherOp::Adaptor adaptor(operands);
Value slice_sizes = adaptor.slice_sizes();
auto slice_sizes_ty = slice_sizes.getType().cast<ShapedType>();
for (int64_t i = 0; i < slice_sizes_ty.getDimSize(0); ++i) {
Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
builder.create<tensor::ExtractOp>(loc, slice_sizes, idx));
static LogicalResult verifyGather(
ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
ShapeAdaptor sliceSizesShape, GatherDimensionNumbersAttr dimensionNumbers,
llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
// This should be fully expressible with type constraints, but it isn't
// obvious how to do that with the current infrastructure.
if (sliceSizesShape.hasRank() && sliceSizesShape.getRank() != 1)
return errorEmitter() << "slice_sizes.rank != 1";
int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim();
if (startIndicesShape.hasRank()) {
// index_vector_dim == start_indices.rank implies a trailing 1 on the shape
// of start_indices.
if (indexVectorDim > startIndicesShape.getRank())
return errorEmitter() << "index_vector_dim " << indexVectorDim
<< " is out of bounds for start indices with rank "
<< startIndicesShape.getRank();
bool impliedTrailingDim = indexVectorDim == startIndicesShape.getRank();
if (impliedTrailingDim || !startIndicesShape.isDynamicDim(indexVectorDim)) {
int64_t effectiveDimSize;
if (impliedTrailingDim)
effectiveDimSize = 1;
effectiveDimSize = startIndicesShape.getDimSize(indexVectorDim);
if (effectiveDimSize != dimensionNumbers.getStartIndexMap().size())
return errorEmitter() << "start_index_map size ("
<< dimensionNumbers.getStartIndexMap().size()
<< ") is not equal to size of index dimension ("
<< indexVectorDim << ") of start_indices ("
<< effectiveDimSize << ")";
int64_t impliedOperandRank = dimensionNumbers.getOffsetDims().size() +
if (operandShape.hasRank() && operandShape.getRank() != impliedOperandRank)
return errorEmitter() << "offset_dims size ("
<< dimensionNumbers.getOffsetDims().size()
<< ") plus collapse_slice_dims size ("
<< dimensionNumbers.getCollapsedSliceDims().size()
<< ") is not equal to operand rank ("
<< operandShape.getRank() << ")";
if (sliceSizesShape.hasStaticShape()) {
int64_t sliceRank = sliceSizesShape.getNumElements();
if (sliceRank != impliedOperandRank)
return errorEmitter() << "slice_sizes size (" << sliceRank
<< ") not equal to (implied) operand rank ("
<< impliedOperandRank << ")";
for (auto dim : dimensionNumbers.getCollapsedSliceDims())
if (dim >= sliceRank)
return errorEmitter()
<< "collapsed dimension " << dim
<< " is greater than slice_sizes.size (" << sliceRank << ")";
return success();
static LogicalResult verifyStaticGather(
ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
DenseIntElementsAttr sliceSizes,
GatherDimensionNumbersAttr dimensionNumbers,
llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
// For some reason the getType call is necessary here
if (failed(verifyGather(
/*sliceSizesShape=*/sliceSizes.getType(), dimensionNumbers,
return failure();
for (auto dim : dimensionNumbers.getCollapsedSliceDims()) {
int64_t sliceDimSize = sliceSizes.getValues<int64_t>()[dim];
if (sliceDimSize != 1) {
return errorEmitter() << "slice_sizes collapsed dimension " << dim
<< " != 1 (" << sliceDimSize << ")";
if (operandShape.hasRank()) {
for (const auto& it : llvm::enumerate(sliceSizes.getValues<int64_t>())) {
if (operandShape.isDynamicDim(it.index())) continue;
auto operandDimSize = operandShape.getDimSize(it.index());
auto sliceDimSize = it.value();
if (sliceDimSize > operandDimSize)
return errorEmitter() << "slice size (" << sliceDimSize
<< ") is larger than operand dimension ("
<< operandDimSize << ") at index " << it.index();
return success();
template <typename dimTy>
static void inferGatherShape(
int64_t resultRank, llvm::function_ref<dimTy(int64_t)> getStartIndicesDim,
llvm::function_ref<dimTy(int64_t)> getSliceDim,
GatherDimensionNumbersAttr dimensionNumbers,
SmallVectorImpl<dimTy>& shape) {
ArrayRef<int64_t> collapsedSliceDims =
int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim();
// We don't necessarily know the rank of sliceSizes, but we do know that it
// can't be larger than the highest collapsed dimension. So go through those
// and populate the leading dimensions of adjustedSliceSizes. The trailing
// dimensions can just be adjusted by an offset.
const auto* maxCollapsedDimIt =
std::max_element(collapsedSliceDims.begin(), collapsedSliceDims.end());
int64_t maxCollapsedDim = -1;
if (maxCollapsedDimIt != collapsedSliceDims.end())
maxCollapsedDim = *maxCollapsedDimIt;
SmallVector<dimTy> adjustedSliceSizePrefix;
for (int dimIndex = 0; dimIndex <= maxCollapsedDim; ++dimIndex) {
if (llvm::is_contained(collapsedSliceDims, dimIndex)) continue;
auto getAdjustedSliceDim = [&](int64_t index) -> dimTy {
if (index < adjustedSliceSizePrefix.size())
return adjustedSliceSizePrefix[index];
return getSliceDim(index + collapsedSliceDims.size());
ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims();
// Dimensions in the output that aren't offset dimensions are called batch
// dimensions.
SmallVector<int64_t> batchDims;
for (int dim = 0; dim < resultRank; ++dim)
if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim);
for (int i = 0; i < resultRank; ++i) {
const auto* offsetDimsIt =
std::find(offsetDims.begin(), offsetDims.end(), i);
if (offsetDimsIt != offsetDims.end()) {
auto index = std::distance(offsetDims.begin(), offsetDimsIt);
auto* batchDimsIt = std::find(batchDims.begin(), batchDims.end(), i);
assert(batchDimsIt != batchDims.end());
auto index = std::distance(batchDims.begin(), batchDimsIt);
// This can never run into the special case where start_indices gets
// implicitly expanded with a trailing 1 if
// index_vector_dim = start_indices.rank because then index would equal
// index_vector_dim, which means we'd be looking at index+1, which would be
// out of bounds anyway.
if (index >= indexVectorDim) ++index;
static LogicalResult inferGatherReturnTypeComponents(
ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
llvm::function_ref<int64_t(int64_t)> getSliceDim,
GatherDimensionNumbersAttr dimensionNumbers,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
Type elementType = operandShape.getElementType();
// We need this to determine the result rank. We could still place bounds on
// the result rank if that was something ShapedTypeComponents could express.
if (!startIndicesShape.hasRank()) {
return success();
ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims();
int64_t startIndicesRank = startIndicesShape.getRank();
// If index_vector_dim == start_indices.rank, then an implicit trailing 1 is
// appended to start_indices shape.
if (dimensionNumbers.getIndexVectorDim() == startIndicesRank)
int64_t resultRank = offsetDims.size() + startIndicesRank - 1;
auto getStartIndicesDim = [&](int64_t index) {
return startIndicesShape.getDimSize(index);
SmallVector<int64_t> shape;
inferGatherShape<int64_t>(resultRank, getStartIndicesDim, getSliceDim,
dimensionNumbers, shape);
inferredReturnShapes.emplace_back(shape, elementType);
return success();
template <typename Op>
LogicalResult reifyGatherShape(Op* op, OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
// No support for unranked gather output shape a.t.m.
auto resultTy =
op->getResult().getType().template dyn_cast<RankedTensorType>();
if (!resultTy) return failure();
typename Op::Adaptor adaptor(operands);
Value startIndices = adaptor.start_indices();
Location loc = op->getLoc();
int resultRank = resultTy.getRank();
Type shapeElTy = startIndices.getType().cast<ShapedType>().getElementType();
auto toShapeElType = [&](Value v) {
return MaybeCastTo(builder, loc, v, shapeElTy);
SmallVector<Value, 4> sliceSizes;
getSliceSizeValues(op, builder, loc, operands, sliceSizes);
llvm::transform(sliceSizes, sliceSizes.begin(),
[&](Value v) { return toShapeElType(v); });
auto getStartIndicesDim = [&](int64_t index) {
return toShapeElType(
builder.create<tensor::DimOp>(loc, startIndices, index));
SmallVector<Value, 4> shapeValues;
auto getSliceDim = [&sliceSizes](int64_t index) -> Value {
return sliceSizes[index];
inferGatherShape<Value>(resultRank, getStartIndicesDim, getSliceDim,
op->dimension_numbers(), shapeValues);
Value outputShape = builder.create<tensor::FromElementsOp>(
loc, RankedTensorType::get({resultRank}, shapeElTy), shapeValues);
return success();
} // namespace
LogicalResult GatherOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return reifyGatherShape(this, builder, operands, reifiedReturnShapes);
LogicalResult GatherOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// This can get called before other op verify methods, so we have to do a
// bunch of verification up front. With a better story for ordering and/or
// multi-phase op verification, this should hopefully all go away.
Location loc = location.getValueOr(UnknownLoc::get(context));
auto errorEmitter = [&loc]() {
return mlir::emitError(loc)
<< "'" << GatherOp::getOperationName() << "' op ";
GatherOp::Adaptor adaptor(operands, attributes, regions);
if (failed(adaptor.verify(loc))) return failure();
// We want the ShapeAdaptors, so can't route via the adaptor :-/
ShapeAdaptor operandShape = operands.getShape(0);
ShapeAdaptor startIndicesShape = operands.getShape(1);
GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers();
DenseIntElementsAttr sliceSizesAttr = adaptor.slice_sizes();
if (failed(verifyStaticGather(/*operandShape=*/operandShape,
/*sliceSizes=*/sliceSizesAttr, dimensionNumbers,
return failure();
auto getSliceDim = [&sliceSizesAttr](int64_t index) -> int64_t {
return sliceSizesAttr.getValues<int64_t>()[index];
return inferGatherReturnTypeComponents(operandShape, startIndicesShape,
getSliceDim, dimensionNumbers,
// DynamicGatherOp
LogicalResult DynamicGatherOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return reifyGatherShape(this, builder, operands, reifiedReturnShapes);
LogicalResult DynamicGatherOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// This can get called before other op verify methods, so we have to do a
// bunch of verification up front. With a better story for ordering and/or
// multi-phase op verification, this should hopefully all go away.
Location loc = location.getValueOr(UnknownLoc::get(context));
auto errorEmitter = [&loc]() {
return mlir::emitError(loc)
<< "'" << DynamicGatherOp::getOperationName() << "' op ";
DynamicGatherOp::Adaptor adaptor(operands, attributes, regions);
if (failed(adaptor.verify(loc))) return failure();
// We want the ShapeAdaptors, so can't route via the adaptor :-/
ShapeAdaptor operandShape = operands.getShape(0);
ShapeAdaptor startIndicesShape = operands.getShape(1);
ShapeAdaptor sliceSizesShape = operands.getShape(2);
GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers();
if (failed(verifyGather(/*operandShape=*/operandShape,
/*sliceSizesShape=*/sliceSizesShape, dimensionNumbers,
return failure();
auto getSliceDim = [](int64_t index) { return ShapedType::kDynamicSize; };
return inferGatherReturnTypeComponents(operandShape, startIndicesShape,
getSliceDim, dimensionNumbers,
// GetDimensionSizeOp
static LogicalResult Verify(GetDimensionSizeOp op) { return VerifyDimAttr(op); }
/// Fold get_dimension_size when the said shape dimension is a constant.
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
if (!type) return {};
int32_t dim = dimension();
if (type.isDynamicDim(dim)) return {};
// The result type is always is a 0-d i32 tensor.
return DenseIntElementsAttr::get<int32_t>(
getResult().getType().cast<RankedTensorType>(), type.getDimSize(dim));
// IotaOp
static LogicalResult Verify(IotaOp op) {
auto shape = op.getType().cast<ShapedType>();
if (!shape.hasRank()) return success();
if (shape.getRank() == 0)
return op.emitOpError() << "does not support scalars.";
auto iota_dimension = op.iota_dimension();
if (iota_dimension >= shape.getRank() || iota_dimension < 0)
return op.emitOpError() << "iota dimension cannot go beyond the output "
"rank or be negative.";
return success();
// Iota operations across multiple dimensions can be reduced to an iota and a
// ranked broadcast.
struct IotaBroadcast : public OpRewritePattern<IotaOp> {
using OpRewritePattern<IotaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IotaOp iota,
PatternRewriter& rewriter) const override {
auto result_ty = iota.getType().cast<ShapedType>();
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
return failure();
auto iota_dimension = iota.iota_dimension();
auto iota_type = RankedTensorType::get(
{result_ty.getDimSize(iota_dimension)}, result_ty.getElementType());
auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
auto broadcast_attr = DenseIntElementsAttr::get(
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, result_ty, new_iota,
return success();
void IotaOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
auto dimension = iota_dimension();
auto result_ty = getResult().getType().cast<ShapedType>();
if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
Builder builder(getContext());
return builder.getZeroAttr(result_ty);
return {};
// DynamicIotaOp
namespace {
struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicIotaOp iota,
PatternRewriter& rewriter) const override {
auto result_ty = iota.getType().cast<ShapedType>();
if (!result_ty.hasStaticShape()) {
return failure();
rewriter.replaceOpWithNewOp<IotaOp>(iota, result_ty, iota.iota_dimension());
return success();
// Dynamic Iota operations across multiple dimensions can be reduced to an iota
// and a ranked broadcast.
struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicIotaOp iota,
PatternRewriter& rewriter) const override {
auto result_ty = iota.getType().cast<ShapedType>();
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
return failure();
auto iota_dimension = iota.iota_dimension();
auto iota_dimension_int = iota_dimension;
auto converted_shape = rewriter.create<arith::IndexCastOp>(
auto sliced_shape = rewriter.create<SliceOp>(
iota.getLoc(), converted_shape,
GetI64ElementsAttr(iota_dimension_int, &rewriter),
GetI64ElementsAttr(iota_dimension_int + 1, &rewriter),
GetI64ElementsAttr(1, &rewriter));
auto converted_sliced_shape = rewriter.create<arith::IndexCastOp>(
auto iota_type = RankedTensorType::get(
{result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType());
auto new_iota = rewriter.create<DynamicIotaOp>(
iota.getLoc(), iota_type, converted_sliced_shape,
auto broadcast_attr = DenseIntElementsAttr::get(
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
iota, result_ty, new_iota, iota.output_shape(), broadcast_attr);
return success();
} // namespace
void DynamicIotaOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
static Value castToIndexTensor(OpBuilder& builder, Location loc,
Value shape_op) {
ShapedType result_ty = shape::getExtentTensorType(
if (shape_op.getType() == result_ty) return shape_op; // Nothing to do.
return builder.create<arith::IndexCastOp>(loc, shape_op, result_ty);
LogicalResult DynamicIotaOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
DynamicIotaOp::Adaptor adaptor(operands);
castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
return success();
// DynamicUpdateSliceOp
static LogicalResult Verify(DynamicUpdateSliceOp op) {
OperandRange indices = op.start_indices();
if (indices.size() <= 1) return success();
// Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
// is OK to cast indices to ShapedType here.
auto idx_tensor = indices.take_front().front().getType().cast<ShapedType>();
Type first_elem_ty = idx_tensor.getElementType();
Type elem_ty;
for (auto idx : llvm::drop_begin(indices, 1)) {
idx_tensor = idx.getType().cast<ShapedType>();
elem_ty = idx_tensor.getElementType();
if (first_elem_ty != elem_ty) {
return op.emitOpError() << "start indices must have same element type "
"(encountered mismatch: "
<< first_elem_ty << " vs " << elem_ty << ")";
return success();
OpFoldResult DynamicUpdateSliceOp::fold(ArrayRef<Attribute> operands) {
auto operand_shape = this->operand().getType().cast<RankedTensorType>();
auto update_shape = this->update().getType().cast<RankedTensorType>();
if (operand_shape != update_shape || !operand_shape.hasStaticShape()) {
return {};
// Ensure that indices are 0 constants. The 0 check mostly ensures
// correctness. For non-constants, the pattern does not fold to avoid hiding
// the behavior of incorrect user input.
for (Value index : this->start_indices()) {
DenseIntElementsAttr de_attr;
if (!matchPattern(index, m_Constant(&de_attr))) return {};
if (!de_attr.getSplatValue<IntegerAttr>().getValue().isZero()) return {};
return this->update();
// AbsOp
LogicalResult AbsOp::inferReturnTypes(
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
auto operand_ty = (*operands.begin()).getType().cast<ShapedType>();
Type element_ty = operand_ty.getElementType();
if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
element_ty = complex_ty.getElementType();
Type result_ty;
if (operand_ty.hasRank()) {
result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty);
} else {
result_ty = UnrankedTensorType::get(element_ty);
return success();
// CollectivePermuteOp
static LogicalResult Verify(CollectivePermuteOp op) {
return mlir::hlo::VerifyCollectivePermuteSourceTargetPairs(
op, op.source_target_pairs());
// ConvertOp
void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
Type result_element_ty) {
Type result_ty;
Type operand_ty = operand.getType();
if (auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>()) {
result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty);
} else {
result_ty = UnrankedTensorType::get(result_element_ty);
build(builder, result, result_ty, operand);
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
auto operand_ty = getOperand().getType().cast<TensorType>();
auto result_ty = getResult().getType().cast<TensorType>();
if (operand_ty == result_ty) return getOperand();
// If the result has non-static shape, a convert op is necessary to go from
// static shape to non-static shape.
if (!result_ty.hasStaticShape()) return {};
// TODO(hinsu): Handle unsigned types.
if (operand_ty.getElementType().isUnsignedInteger() ||
result_ty.getElementType().isUnsignedInteger()) {
return {};
// If the operand is constant, we can do the conversion now.
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
return hlo::ConvertElementsAttr(elementsAttr,
return {};
void ConvertOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
// DequantizeOp
static LogicalResult Verify(DequantizeOp op) {
auto input_type = op.input().getType().dyn_cast<ShapedType>();
auto output_type = op.output().getType().dyn_cast<ShapedType>();
if (!input_type || !output_type) {
return op.emitError() << "ranked input and output.";
auto input_shape = input_type.getShape();
auto output_shape = output_type.getShape().vec();
if (op.transpose_output()) {
std::reverse(output_shape.begin(), output_shape.end());
// Check the input rank and output rank are same, and also the lower
// dimensions are same.
if (input_shape.size() != output_shape.size() ||
std::next(input_shape.begin(), input_shape.size() - 1),
output_shape.begin())) {
return op.emitError() << "mismatched dimensions.";
// Check that the last dimension of the output is 2x or 4x of that of the
// input depending on the unpacked input is 16 or 8 bits.
int input_last_dim = *input_shape.rbegin();
int output_last_dim = *output_shape.rbegin();
int scale_factor = op.is_16bits() ? 2 : 4;
if (output_last_dim != scale_factor * input_last_dim) {
return op.emitError() << "last dimension of output should be "
<< scale_factor << "x of the input.";
return success();
// GetTupleElementOp
static LogicalResult Verify(GetTupleElementOp op) {
auto indexVal = op.index();
auto operandType = op.getOperand().getType().cast<TupleType>();
if (indexVal >= operandType.size()) {
return op.emitOpError(
llvm::formatv("index {0} is out of bounds of operand with size {1}",
indexVal, operandType.size()));
auto expectedType = operandType.getType(indexVal);
if (op.getType() != expectedType) {
return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
op.getType(), expectedType));
return success();
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
if (auto tuple_op = getOperand().getDefiningOp<mhlo::TupleOp>()) {
return tuple_op.getOperand(index());
return {};
// TupleOp
static LogicalResult Verify(TupleOp op) {
auto opType = op.getType().dyn_cast<TupleType>();
if (!opType) return op.emitOpError("tuple op with non-tuple result");
if (op.getNumOperands() != opType.size())
return op.emitOpError(
"number of operands to tuple expected to match number of types in "
"resultant tuple type");
for (const auto& it : llvm::enumerate(
llvm::zip_first(op.getOperandTypes(), opType.getTypes()))) {
if (std::get<0>(it.value()) != std::get<1>(it.value()))
return op.emitOpError("has return type mismatch at ")
<< it.index() << "th value (" << std::get<0>(it.value())
<< " != " << std::get<1>(it.value()) << ")";
return success();
namespace {
// Pattern for unpacking and repacking the same tuple.
struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
using OpRewritePattern<TupleOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TupleOp op,
PatternRewriter& rewriter) const override {
if (op.val().empty()) return failure();
Value first_element = op.val().front();
auto first_element_op = first_element.getDefiningOp<GetTupleElementOp>();
if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
return failure();
Value tuple_predecessor = first_element_op.getOperand();
if (tuple_predecessor.getType() != op.getType()) return failure();
for (const auto& element_and_idx :
llvm::enumerate(op.val().drop_front(1))) {
auto element_op =
if (!element_op ||
element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
element_op.getOperand() != tuple_predecessor)
return failure();
rewriter.replaceOp(op, tuple_predecessor);
return success();
} // namespace
void TupleOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
// AllToAllOp
static LogicalResult Verify(AllToAllOp op) {
// If operand is ranked, size of split dimension should be a multiple of split
// count.
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
if (!type) return success();
auto split_dim_size = type.getDimSize(op.split_dimension());
auto split_count = op.split_count();
if (split_dim_size % split_count != 0) {
return op.emitError() << "split dimension has size " << split_dim_size
<< ", expected to be a multiple of split_count "
<< split_count;
return success();
// AllGatherOp
static LogicalResult Verify(AllGatherOp op) {
// If operand and result are both ranked, then the size of the gather
// dimension in the result should be a multiple of the size of the gather
// dimension in the operand.
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
auto resultType = op.getType().dyn_cast<RankedTensorType>();
uint64_t allGatherDimIndex = op.all_gather_dim();
if (!operandType || !resultType ||
operandType.isDynamicDim(allGatherDimIndex) ||
return success();
if (operandType.getDimSize(allGatherDimIndex) == 0)
return op.emitOpError() << "operand gather dimension cannot be zero.";
if ((resultType.getDimSize(allGatherDimIndex) %
operandType.getDimSize(allGatherDimIndex)) != 0)
return op.emitOpError()
<< "result gather dimension has size "
<< resultType.getDimSize(allGatherDimIndex)
<< ", expected to be a multiple of operand gather dimension size "
<< operandType.getDimSize(allGatherDimIndex);
return success();
// BitcastConvertOp
LogicalResult BitcastConvertOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
auto operand_type = operands[0].getType().dyn_cast<RankedTensorType>();
auto result_type = getType().dyn_cast<RankedTensorType>();
// Only ranked tensors are supported.
if (!operand_type || !result_type) return failure();
// Shape-changing bitcast convert is not implemented.
// TODO(kramerb): This could be done by adjusting the last dimension.
DataLayout data_layout = DataLayout::closest(*this);
unsigned operand_element_size =
unsigned result_element_size =
if (operand_element_size != result_element_size) return failure();
return ::mlir::mhlo::deriveShapeFromOperand(
&builder, getOperation(), operands.front(), &reifiedReturnShapes);
// BroadcastOp
// TODO(b/129012527) These should be expressed as type constraints.
static LogicalResult Verify(BroadcastOp op) {
auto sizes = op.broadcast_sizes();
auto sizesType = sizes.getType();
auto sizesRank = sizesType.getRank();
if (sizesRank != 1) {
return op.emitOpError(llvm::formatv(
"broadcast_sizes has rank {0} instead of rank 1", sizesRank));
auto resultType = op.getResult().getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank();
auto operandType = op.operand().getType().cast<RankedTensorType>();
auto operandRank = operandType.getRank();
auto sizesSize = sizesType.getNumElements();
auto expectedRank = operandRank + sizesSize;
if (resultRank != expectedRank) {
return op.emitOpError(
llvm::formatv("result rank ({0}) does not match operand rank "
"({1}) plus size of broadcast_sizes ({2})",
resultRank, operandRank, sizesSize));
llvm::SmallVector<int64_t, 10> expectedShape(sizes.getValues<int64_t>());
auto operandShape = operandType.getShape();
expectedShape.insert(expectedShape.end(), operandShape.begin(),
auto resultShape = resultType.getShape();
if (resultShape != llvm::makeArrayRef(expectedShape)) {
return op.emitOpError(llvm::formatv(
"result has shape [{0}] instead of [{1}]",
llvm::make_range(resultShape.begin(), resultShape.end()),
llvm::make_range(expectedShape.begin(), expectedShape.end())));
return success();
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> attrs) {
auto type = getType().cast<RankedTensorType>();
auto sizesType = broadcast_sizes().getType();
if (sizesType.getNumElements() == 0) {
return getOperand();
// Constant fold when an operand is a splat tensor attribute.
if (!attrs[0] || !type.hasStaticShape()) return {};
auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
if (!splatOperandAttr) return {};
// Handle complex type
if (type.getElementType().isa<ComplexType>()) {
ComplexType complex = type.getElementType().cast<ComplexType>();
if (complex.getElementType().isa<FloatType>()) {
return DenseElementsAttr::get(
type, {splatOperandAttr.getSplatValue<std::complex<APFloat>>()});
if (complex.getElementType().isa<IntegerType>()) {
return DenseElementsAttr::get(
type, {splatOperandAttr.getSplatValue<std::complex<APInt>>()});
return {};
return SplatElementsAttr::get(
type, splatOperandAttr.getSplatValue<mlir::Attribute>());
LogicalResult BroadcastOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
BroadcastOp::Adaptor adaptor(operands);
Value operand = adaptor.operand();
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
// Unranked tensors are not supported.
if (!operand_type) return failure();
Location loc = getLoc();
SmallVector<Value, 4> shape_values;
// Collect the broadcast sizes.
for (const auto& size : broadcast_sizes()) {
builder.create<arith::ConstantIndexOp>(loc, size.getZExtValue()));
// Collect the operand sizes.
for (auto index : llvm::seq<int64_t>(0, operand_type.getRank())) {
builder.createOrFold<tensor::DimOp>(loc, operand, index));
return success();
// BroadcastInDimOp
static LogicalResult Verify(BroadcastInDimOp op) {
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
if (!operandType) {
// The following verification checks all depend on knowing the rank of
// the operand. Bail out now if we don't know the rank of the operand.
return success();
auto operandRank = operandType.getRank();
if (!op.broadcast_dimensions()) {
if (operandRank == 0) {
return success();
return op.emitOpError(
llvm::formatv("broadcast_dimensions is absent, but required because "
"operand has non-zero rank ({0})",
auto dimensions = op.broadcast_dimensions();
auto dimensionsType = op.broadcast_dimensions().getType();
auto dimensionsRank = dimensionsType.getRank();
if (dimensionsRank != 1) {
return op.emitOpError(llvm::formatv(
"broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank));
auto dimensionsSize = dimensionsType.getNumElements();
if (dimensionsSize != operandRank) {
return op.emitOpError(llvm::formatv(
"broadcast_dimensions size ({0}) does not match operand rank ({1})",
dimensionsSize, operandRank));
auto resultType = op.getResult().getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank();
if (resultRank < operandRank) {
return op.emitOpError(
llvm::formatv("result rank ({0}) is less than operand rank ({1})",
resultRank, operandRank));
for (int i = 0; i != dimensionsSize; ++i) {
auto dimIndex = dimensions.getValues<int64_t>()[i];
if (dimIndex >= resultRank) {
return op.emitOpError(
llvm::formatv("broadcast_dimensions contains invalid value {0} for "
"result with rank {1}",
dimIndex, resultRank));
if (!operandType.isDynamicDim(i)) {
auto dimSize = operandType.getDimSize(i);
auto resultDimSize = resultType.getDimSize(dimIndex);
if (dimSize != 1 && dimSize != resultDimSize) {
return op.emitOpError(
llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
"1 or size of result dimension {2} ({3})",
i, dimSize, dimIndex, resultDimSize));
return success();
OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) {
auto type = getType().cast<RankedTensorType>();
if (type == getOperand().getType()) {
auto broadcast_values = broadcast_dimensions().getValues<int64_t>();
if (!std::equal(broadcast_values.begin(), broadcast_values.end(),
llvm::seq<int64_t>(0, type.getRank()).begin())) {
return {};
return getOperand();
// Constant fold when an operand is a splat tensor attribute.
if (!attrs[0] || !type.hasStaticShape()) return {};
auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
if (!splatOperandAttr) return {};
// Handle complex type
if (type.getElementType().isa<ComplexType>()) {
ComplexType complex = type.getElementType().cast<ComplexType>();
if (complex.getElementType().isa<FloatType>()) {
return DenseElementsAttr::get(
type, {splatOperandAttr.getSplatValue<std::complex<APFloat>>()});
if (complex.getElementType().isa<IntegerType>()) {
return DenseElementsAttr::get(
type, {splatOperandAttr.getSplatValue<std::complex<APInt>>()});
return {};
return SplatElementsAttr::get(
type, splatOperandAttr.getSplatValue<mlir::Attribute>());
// Simplify BroadcastInDim has the following behaviors: replace BroadcastInDim
// with Reshape or Transpose if they are equivalent or replace
// BroadcastInDim(BroadcastInDim(X)) with BroadcastInDim(X)
class BroadcastInDimSimplifier : public OpRewritePattern<BroadcastInDimOp> {
using OpRewritePattern<BroadcastInDimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BroadcastInDimOp op,
PatternRewriter& rewriter) const override {
auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!operand_type || !result_type) {
return failure();
auto bs_dim_indices = op.broadcast_dimensions().getValues<int64_t>();
if (operand_type.hasStaticShape() && result_type.hasStaticShape()) {
bool same_total_elements =
operand_type.getNumElements() == result_type.getNumElements();
// BroadcastInDim equivalent to reshape
if (llvm::is_sorted(bs_dim_indices) && same_total_elements) {
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
return success();
// BroadcastInDim equivalent to transpose
if (operand_type.getRank() == result_type.getRank() &&
same_total_elements) {
rewriter.replaceOpWithNewOp<TransposeOp>(op, op.getType(), op.operand(),
return success();
// eliminate redundant BroadcastInDim
if (auto broadcast_in_dim_op = llvm::dyn_cast_or_null<BroadcastInDimOp>(
op.operand().getDefiningOp())) {
auto new_indices =
[&bs_dim_indices](const APInt& dim) -> APInt {
return APInt(dim.getBitWidth(),
op, op.getType(), broadcast_in_dim_op.operand(), new_indices);
return success();
return failure();
void BroadcastInDimOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
// DynamicBroadcastInDimOp
static LogicalResult Verify(DynamicBroadcastInDimOp op) {
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
// If either the operand or result are unranked, there is very little
// to verify statically.
if (!operandType || !resultType) {
return success();
auto outputDimensionsType =
auto outputDimensionsSize = outputDimensionsType.getDimSize(0);
auto operandRank = operandType.getRank();
auto resultRank = resultType.getRank();
// Verify broadcast_dimensions.
auto bcastDimensions = op.broadcast_dimensions();
auto bcastDimensionsType = op.broadcast_dimensions().getType();
auto bcastDimensionsRank = bcastDimensionsType.getRank();
// TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1.
if (bcastDimensionsRank != 1) {
return op.emitOpError(
llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1",
auto bcastDimensionsSize = bcastDimensionsType.getNumElements();
if (bcastDimensionsSize != operandRank) {
return op.emitOpError(llvm::formatv(
"broadcast_dimensions size ({0}) does not match operand rank ({1})",
bcastDimensionsSize, operandRank));
if (resultRank < operandRank) {
return op.emitOpError(
llvm::formatv("result rank ({0}) is less than operand rank ({1})",
resultRank, operandRank));
for (int i = 0; i != bcastDimensionsSize; ++i) {
auto dimIndex = bcastDimensions.getValues<int64_t>()[i];
if (dimIndex >= resultRank) {
return op.emitOpError(
llvm::formatv("broadcast_dimensions contains invalid value {0} for "
"result with rank {1}",
dimIndex, resultRank));
auto dimSize = operandType.getDimSize(i);
auto resultDimSize = resultType.getDimSize(dimIndex);
// Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
// add a manual check for this.
if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
return op.emitOpError(
llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
"with size of result dimension {2} ({3})",
i, dimSize, dimIndex, resultDimSize));
if (outputDimensionsSize != resultRank) {
return op.emitOpError(
llvm::formatv("result rank ({0}) is not equal to number of output "
"dimensions ({1})",
resultRank, outputDimensionsSize));
return success();
namespace {
// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
// BroadcastInDimOp.
class DynamicBroadcastInDimOpNotActuallyDynamic
: public OpRewritePattern<DynamicBroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
PatternRewriter& rewriter) const override {
auto type = op.getType().dyn_cast<RankedTensorType>();
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
if (!type || !type.hasStaticShape() || !operandType ||
!operandType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op, "requires static shape");
op, op.getType(), op.operand(), op.broadcast_dimensions());
return success();
class ChainedDynamicBroadcastInDimCanonicalization
: public OpRewritePattern<DynamicBroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast,
PatternRewriter& rewriter) const override {
auto preceding_bcast =
if (!preceding_bcast) return failure();
// Compose broadcast dimensions.
DenseIntElementsAttr preceding_bcast_dims =
DenseIntElementsAttr bcast_dims = bcast.broadcast_dimensions();
SmallVector<APInt, 4> composition;
for (APInt preceding_dim : preceding_bcast_dims) {
auto composed_bcast_dims =
DenseIntElementsAttr::get(preceding_bcast_dims.getType(), composition);
bcast, bcast.getType(), preceding_bcast.operand(),
bcast.output_dimensions(), composed_bcast_dims);
return success();
} // namespace
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
RewritePatternSet& results, MLIRContext* context) {
DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2,
DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>(
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
DynamicBroadcastInDimOp::Adaptor adaptor(operands);
castToIndexTensor(builder, getLoc(), adaptor.output_dimensions()));
return success();
// ClampOp
static LogicalResult Verify(ClampOp op) {
auto operandType = op.operand().getType().cast<RankedTensorType>();
auto operandShape = operandType.getShape();
auto minType = op.min().getType().cast<RankedTensorType>();
auto minShape = minType.getShape();
if (minShape != operandShape && minType.getRank() != 0) {
return op.emitOpError(llvm::formatv(
"min shape [{0}] is not scalar and does not match operand shape [{1}]",
llvm::make_range(minShape.begin(), minShape.end()),
llvm::make_range(operandShape.begin(), operandShape.end())));
auto maxType = op.max().getType().cast<RankedTensorType>();
auto maxShape = maxType.getShape();
if (maxShape != operandShape && maxType.getRank() != 0) {
return op.emitOpError(llvm::formatv(
"max shape [{0}] is not scalar and does not match operand shape [{1}]",
llvm::make_range(maxShape.begin(), maxShape.end()),
llvm::make_range(operandShape.begin(), operandShape.end())));
return success();
LogicalResult ClampOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
// For `mhlo.clamp`, the first operand may be a scalar.
return deriveShapeFromOperand(&builder, getOperation(), operands[1],
// ComplexOp
LogicalResult ComplexOp::inferReturnTypes(
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
auto type = operands[0].getType();
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
Type result_ty;
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty);
} else if (type.isa<UnrankedTensorType>()) {
result_ty = UnrankedTensorType::get(element_ty);
} else {
result_ty = element_ty;
return success();
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
auto real_op = getOperand(0).getDefiningOp<mhlo::RealOp>();
auto imag_op = getOperand(1).getDefiningOp<mhlo::ImagOp>();
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
return real_op.getOperand();
return {};
// ImagOp
namespace {
Type CreateRealType(Type type) {
auto element_ty = getElementTypeOrSelf(type);
if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
element_ty = complex_ty.getElementType();
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
return RankedTensorType::get(ranked_type.getShape(), element_ty);
if (type.dyn_cast<UnrankedTensorType>()) {
return UnrankedTensorType::get(element_ty);
return element_ty;
} // namespace
LogicalResult ImagOp::inferReturnTypes(
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
return success();
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
return complex_op.getOperand(1);
return {};
// IsFiniteOp
TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type) {
if (auto ranked_tensor_ty = tensor_type.dyn_cast<RankedTensorType>()) {
return RankedTensorType::get(ranked_tensor_ty.getShape(), element_type);
if (auto unranked_tensor_ty = tensor_type.dyn_cast<UnrankedTensorType>()) {
return UnrankedTensorType::get(element_type);
llvm_unreachable("unhandled type");
LogicalResult IsFiniteOp::inferReturnTypes(
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
auto arg_ty = operands.front().getType().cast<TensorType>();
Builder b(ctx);
inferredReturnTypes.push_back(getSameShapeTensorType(arg_ty, b.getI1Type()));
return success();
// RealOp
LogicalResult RealOp::inferReturnTypes(
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
return success();
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
return complex_op.getOperand(0);
return {};
// ConcatenateOp
namespace {
class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter& rewriter) const override {
auto axis = op.dimension();
llvm::SmallVector<Value, 6> new_operands;
for (auto operand : op.getOperands()) {
auto ty = operand.getType().cast<ShapedType>();
if (ty.getDimSize(axis) != 0) {
if (!new_operands.empty() && new_operands.size() < op.getNumOperands()) {
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
new_operands, op.dimension());
return success();
return failure();
class ConcatenateForwarding : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter& rewriter) const override {
auto getFlattenedOperands = [&](const Value& val) -> ValueRange {
auto definingOp = dyn_cast_or_null<ConcatenateOp>(val.getDefiningOp());
// To avoid inflate the memory footprint, only flatten the ConcatenateOp
// when it has only one use.
if (definingOp && definingOp->hasOneUse() &&
definingOp.dimension() == op.dimension())
return definingOp.val();
return val;
bool needToFlatten = false;
int operandCount = 0;
llvm::for_each(op.val(), [&](Value val) {
auto result = getFlattenedOperands(val);
if (result.size() != 1 || result[0] != val) needToFlatten = true;
operandCount += result.size();
if (!needToFlatten) return failure();
llvm::SmallVector<Value, 6> newOperands;
for (auto operand : op.val()) {
auto flattenedOperands = getFlattenedOperands(operand);
newOperands.append(flattenedOperands.begin(), flattenedOperands.end());
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
newOperands, op.dimension());
return success();
} // namespace
LogicalResult ConcatenateOp::inferReturnTypes(
MLIRContext*, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type>& inferredReturnTypes) {
if (operands.empty()) {
return failure();
auto dimension_attr = attributes.get("dimension").cast<IntegerAttr>();
auto dimension = dimension_attr.getInt();
auto first_type = (*operands.begin()).getType().cast<ShapedType>();
auto out_element = first_type.getElementType();
for (auto operand : operands.getTypes()) {
auto element_type = getElementTypeOrSelf(operand);
if (element_type != out_element) {
return failure();
// Find the first ranked input to determine the output rank.
for (auto type : operands.getTypes()) {
auto shaped_type = type.cast<ShapedType>();
if (shaped_type.hasRank()) {
first_type = shaped_type;
// If all inputs are unranked, the result must be unranked.
if (!first_type.hasRank()) {
return success();
if (first_type.getRank() == 0)
return emitOptionalError(location, "rank-0 values cannot be concatenated");
auto out_shape = llvm::to_vector<6>(first_type.getShape());
// Determine what the non-concatenate dimensions should be.
for (auto type : operands.getTypes()) {
auto shaped_ty = type.cast<ShapedType>();
if (!shaped_ty.hasRank()) {
for (const auto& it : llvm::enumerate(shaped_ty.getShape())) {
// If a dimension is not dynamic, the output shape should match.
if (ShapedType::isDynamic(out_shape[it.index()])) {
out_shape[it.index()] = it.value();
out_shape[dimension] = 0;
for (auto operand : operands.getTypes()) {
auto type = operand.cast<ShapedType>();
if (!type.hasRank()) {
return success();
// If the dimension is dynamic we know the output dimension is dynamic.
auto dim = type.getShape()[dimension];
if (dim == -1) {
out_shape[dimension] = -1;
out_shape[dimension] += dim;
inferredReturnTypes.push_back(RankedTensorType::get(out_shape, out_element));
return success();
void ConcatenateOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
results.insert<ConcatenateOperandRemoval, ConcatenateForwarding>(context);
template <typename T>
static Attribute foldConcatenateHelper(ConcatenateOp* op,
ArrayRef<Attribute> operands) {
auto axis = op->dimension();
auto type = op->getType().cast<ShapedType>();
SmallVector<T, 6> values;
auto shape = type.getShape();
size_t top_size = 1;
for (int i = 0, e = axis; i < e; i++) {
top_size = top_size * shape[i];
for (size_t i = 0; i < top_size; i++) {
for (auto operand : operands) {
DenseElementsAttr attr = operand.cast<DenseElementsAttr>();
size_t bottom_size = attr.getNumElements() / top_size;
auto iter = attr.getValues<T>().begin() + i * bottom_size;
values.append(iter, iter + bottom_size);
return DenseElementsAttr::get(type, values);
static Attribute foldConcatenate(ConcatenateOp* op,
ArrayRef<Attribute> operands) {
for (auto operand : operands) {
if (!operand) return {};
auto type = op->getResult().getType().cast<ShapedType>();
auto etype = type.getElementType();
if (etype.isa<IntegerType>()) {
return foldConcatenateHelper<APInt>(op, operands);
if (etype.isa<FloatType>()) {
return foldConcatenateHelper<APFloat>(op, operands);
return {};
OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
if (getNumOperands() == 1) return getOperand(0);
ShapedType type = getResult().getType().cast<ShapedType>();
if (!type.hasStaticShape()) return {};
auto axis = dimension();
if (auto attr = foldConcatenate(this, operands)) {
return attr;
llvm::SmallVector<Value, 6> new_operands;
for (auto operand : getOperands()) {
auto ty = operand.getType().cast<ShapedType>();
if (ty.getDimSize(axis) != 0) {
return {};
return DenseElementsAttr::get(type, ArrayRef<Attribute>());
static LogicalResult Verify(ConcatenateOp op) {
Type element_type = getElementTypeOrSelf(op.getOperand(0).getType());
RankedTensorType first_ranked_type;
int num_operands = op.getNumOperands();
for (int i = 0; i < num_operands; i++) {
auto second_type = op.getOperand(i).getType().dyn_cast<ShapedType>();
if (second_type.getElementType() != element_type) {
return op.emitOpError(
llvm::formatv("operands (0) and ({0}) do not match element type", i));
if (!second_type.hasRank()) {
if (!first_ranked_type) {
first_ranked_type = second_type.cast<RankedTensorType>();
if (first_ranked_type.getRank() != second_type.getRank()) {
return op.emitOpError(
llvm::formatv("operands (0) and ({0}) do not match rank", i));
auto first_shape = second_type.getShape();
auto second_shape = second_type.getShape();
for (int d = 0; d < first_ranked_type.getRank(); ++d) {
if (first_shape[d] != second_shape[d] && d != op.dimension()) {
return op.emitOpError(llvm::formatv(
"operands (0) and ({0}) non-concat dimensions do not match "
"({1}) != ({2})",
i, llvm::make_range(first_shape.begin(), first_shape.end()),
llvm::make_range(second_shape.begin(), second_shape.end())));
return success();
LogicalResult ConcatenateOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
ConcatenateOp::Adaptor adaptor(operands);
auto inputs = adaptor.val();
auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>();
// Not support unranked type a.t.m.
if (!operand_type) return failure();
Location loc = this->getLoc();
Type shape_scalar_type = builder.getIndexType();
auto to_shape_scalar_type = [&](Value v) {
return MaybeCastTo(builder, loc, v, shape_scalar_type);
SmallVector<SmallVector<Value, 4>, 4> all_shape_values;
for (size_t input_id = 0; input_id < inputs.size(); ++input_id) {
Value operand = inputs[input_id];
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
if (!operand_type) return failure();
SmallVector<Value, 4> shape_vals;
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
Value value_dim = to_shape_scalar_type(
builder.create<tensor::DimOp>(loc, operand, element.index()));
int axis = this->dimension();
auto& shape_values = all_shape_values[0];
for (size_t vec_id = 1; vec_id < all_shape_values.size(); ++vec_id) {
auto& other_shape_values = all_shape_values[vec_id];
if (other_shape_values.size() != shape_values.size()) {
<< "Concatenate expects all operands must be of the same rank";
return failure();
shape_values[axis] = builder.create<arith::AddIOp>(
loc, shape_values[axis], other_shape_values[axis]);
Value output_shape = builder.create<tensor::FromElementsOp>(
return success();
// DynamicReshapeOp
static LogicalResult Verify(DynamicReshapeOp op) {
auto result_type = op.result().getType().dyn_cast<RankedTensorType>();
auto output_shape_type =
if (result_type && output_shape_type && output_shape_type.hasStaticShape() &&
output_shape_type.getDimSize(0) != result_type.getRank()) {
return op.emitError() << "output should have a rank equal to the number of "
"elements in output_shape";
return success();
LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
DynamicReshapeOp::Adaptor adaptor(operands);
castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
return success();
namespace {
class DynamicReshapeOpNotActuallyDynamic
: public OpRewritePattern<DynamicReshapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicReshapeOp op,
PatternRewriter& rewriter) const override {
auto type = op.result().getType().dyn_cast<RankedTensorType>();
if (!type || !type.hasStaticShape()) {
return rewriter.notifyMatchFailure(op, "requires static shape tensor");
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
return success();
// Canonicalizes
// %0 = some_op(%tensor)
// %1 = "mhlo.dynamic_reshape"(%0, %shape)
// (tensor<?xT>, tensor<1xindex>) -> tensor<?xT>
// ... uses of %1.
// into
// ... uses of %0.
// This canonicalization is only correct if the input is correct!
// TODO(b/178779691): Use a more sophisticated canonicalization that preserves
// errors in input, and still allows us to get rid of redundant reshapes.
class RemoveRedundantRank1DynamicReshape
: public OpRewritePattern<DynamicReshapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicReshapeOp op,
PatternRewriter& rewriter) const override {
auto type = op.result().getType().dyn_cast<RankedTensorType>();
if (!type || type.getRank() != 1 || type.hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "requires rank 1 shape tensor with dynamic dimension");
auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
if (!operand_type || operand_type.getRank() != 1 ||
operand_type.hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "requires rank 1 shape tensor with dynamic dimension");
rewriter.replaceOp(op, {op.operand()});
return success();
// Canonicalizes
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
// %1 = same_operands_and_result_shape_op(%tensor)
// %2 = "mhlo.dynamic_reshape"(%1, %shape)
// ... uses of %2.
// into
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
// %1 = same_operands_and_result_shape_op(%tensor)
// ... uses of %1.
class DynamicReshapeOpSameShapeOpResult
: public OpRewritePattern<DynamicReshapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicReshapeOp op,
PatternRewriter& rewriter) const override {
Operation* def_op = op.operand().getDefiningOp();
if (!def_op ||
!def_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
return failure();
Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
if (!input_def_op) {
return failure();
auto reshape = dyn_cast<DynamicReshapeOp>(*input_def_op);
if (reshape && reshape.output_shape() == op.output_shape()) {
rewriter.replaceOp(op, {def_op->getResult(0)});
return success();
return failure();
} // namespace
void DynamicReshapeOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
// clang-format off
// clang-format on
// DynamicSliceOp
namespace {
// Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops.
// This canonicalization is applied the case when the `begin` input values are
// compile time constants and thus can be made into a tensor.
struct DynamicSliceToSlice : public OpRewritePattern<DynamicSliceOp> {
using OpRewritePattern<DynamicSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicSliceOp dynamic_slice,
PatternRewriter& rewriter) const override {
Value input = dynamic_slice.operand();
auto input_tensor = input.getType().dyn_cast<RankedTensorType>();
if (!input_tensor || !input_tensor.hasStaticShape()) return failure();
auto slice_sizes = dynamic_slice.slice_sizes().getValues<int64_t>();
SmallVector<int64_t, 4> temp_start_indices;
for (const auto& index_and_slice_start :
llvm::enumerate(dynamic_slice.start_indices())) {
APInt val;
Value start = index_and_slice_start.value();
int64_t index = index_and_slice_start.index();
if (!matchPattern(start, m_ConstantInt(&val))) {
return failure();
// Clamp the indices within bounds to faithfully mirror dynamic slice
// semantics.
int64_t clamped_start =
Clamp(val.getSExtValue(), static_cast<int64_t>(0),
input_tensor.getDimSize(index) - slice_sizes[index]);
// At this point we've determined that the start indices are all constants;
// pack them into a single tensor.
auto loc = dynamic_slice.getLoc();
int64_t input_rank = input_tensor.getRank();
auto slice_start_indices =
GetI64ElementsAttr(temp_start_indices, &rewriter);
DenseIntElementsAttr slice_limits = BuildSliceLimits(
slice_start_indices, dynamic_slice.slice_sizes(), &rewriter);
DenseIntElementsAttr slice_strides =
GetI64ElementsAttr(SmallVector<int64_t, 4>(input_rank, 1), &rewriter);
auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
slice_limits, slice_strides);
rewriter.replaceOp(dynamic_slice, {result});
return success();
} // namespace
void DynamicSliceOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
// Verifies that the number of slice sizes and the number of start indices match
static LogicalResult Verify(DynamicSliceOp op) {
int num_slice_sizes = op.slice_sizes().getNumElements();
int num_start_indices = op.start_indices().size();
if (num_start_indices != num_slice_sizes) {
return op.emitOpError()
<< "has mismatched number of slice sizes (" << num_slice_sizes
<< ") and number of start indices (" << num_start_indices << ")";
return success();
// RealDynamicSliceOp
// Verifies that operand rank matches start_indices/limit_indices/strides size
static LogicalResult Verify(RealDynamicSliceOp op) {
auto input_type = op.operand().getType().dyn_cast<RankedTensorType>();
// If operand is unranked, there is very little to verify statically.
if (!input_type) return success();
int input_rank = input_type.getRank();
auto start_type = op.start_indices().getType().cast<RankedTensorType>();
auto limit_type = op.limit_indices().getType().cast<RankedTensorType>();
auto strides_type = op.strides().getType().cast<RankedTensorType>();
if (input_rank != start_type.getNumElements()) {
return op.emitOpError() << "has mismatched number of operand rank ("
<< input_rank << ") and start_indices size ("
<< start_type.getNumElements() << ")";
if (input_rank != limit_type.getNumElements()) {
return op.emitOpError() << "has mismatched number of operand rank ("
<< input_rank << ") and limit_indices size ("
<< limit_type.getNumElements() << ")";
if (input_rank != strides_type.getNumElements()) {
return op.emitOpError()
<< "has mismatched number of operand rank (" << input_rank
<< ") and strides size (" << strides_type.getNumElements() << ")";
return success();
namespace {
// Canonicalizes RealDynamicSlice ops that can be replaced instead with Slice
// ops. This canonicalization is applied the case when the `begin` input values
// are compile time constants and thus can be made into a tensor.
struct RealDynamicSliceIsStatic : public OpRewritePattern<RealDynamicSliceOp> {
using OpRewritePattern<RealDynamicSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(RealDynamicSliceOp real_dynamic_slice,
PatternRewriter& rewriter) const override {
Location loc = real_dynamic_slice.getLoc();
Value input = real_dynamic_slice.operand();
Value output = real_dynamic_slice.result();
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
auto output_ty = output.getType().dyn_cast<RankedTensorType>();
if (!input_ty || !output_ty || !input_ty.hasStaticShape() ||
!output_ty.hasStaticShape()) {
return failure();
int64_t input_rank = input_ty.getRank();
auto start_val = real_dynamic_slice.start_indices();
auto limit_val = real_dynamic_slice.limit_indices();
auto stride_val = real_dynamic_slice.strides();
auto start_op = start_val.getDefiningOp<mlir::arith::ConstantOp>();
auto limit_op = limit_val.getDefiningOp<mlir::arith::ConstantOp>();
auto stride_op = stride_val.getDefiningOp<mlir::arith::ConstantOp>();
if (!start_op || !limit_op || !stride_op) return failure();
auto start_attr =
auto limit_attr =
auto stride_attr =
if (!start_attr || !limit_attr || !stride_attr) return failure();
SmallVector<int64_t, 4> temp_start_indices;
SmallVector<int64_t, 4> temp_limit_indices;
SmallVector<int64_t, 4> temp_stride;
for (int64_t dim_idx = 0; dim_idx < input_rank; dim_idx++) {
int64_t start = start_attr.getValues<IntegerAttr>()[dim_idx].getInt();
int64_t limit = limit_attr.getValues<IntegerAttr>()[dim_idx].getInt();
int64_t end = stride_attr.getValues<IntegerAttr>()[dim_idx].getInt();
DenseIntElementsAttr slice_start_indices =
GetI64ElementsAttr(temp_start_indices, &rewriter);
DenseIntElementsAttr slice_limit_indices =
GetI64ElementsAttr(temp_limit_indices, &rewriter);
DenseIntElementsAttr slice_strides =
GetI64ElementsAttr(temp_stride, &rewriter);
auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
slice_limit_indices, slice_strides);
rewriter.replaceOp(real_dynamic_slice, {result});
return success();
} // namespace
void RealDynamicSliceOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
results.insert<RealDynamicSliceIsStatic, RealDSliceToSlice>(context);
LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
RealDynamicSliceOp::Adaptor adaptor(operands);
Value operand = adaptor.operand();
Value start_indices = adaptor.start_indices();
Value limit_indices = adaptor.limit_indices();
Value strides = adaptor.strides();
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
// Not support unranked type a.t.m.
if (!operand_type) return failure();
Location loc = this->getLoc();
SmallVector<Value, 4> shape_values;
Type shape_scalar_type =
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
one = MaybeCastTo(builder, loc, one, shape_scalar_type);
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
Value offset = builder.create<arith::ConstantIndexOp>(loc, element.index());
Value value_start =
builder.create<tensor::ExtractOp>(loc, start_indices, offset);
Value value_limit =
builder.create<tensor::ExtractOp>(loc, limit_indices, offset);
Value value_stride =
builder.create<tensor::ExtractOp>(loc, strides, offset);
// size = (limit - start + stride - 1) / stride
loc, value_stride,
builder.create<arith::SubIOp>(loc, value_limit, value_start)),
return success();
// InfeedOp
// Checks that the result type is of the form `tuple< any_type, token >`.
static LogicalResult Verify(InfeedOp op) {
auto result_ty = op.getResult().getType().cast<TupleType>();
auto subtypes = result_ty.getTypes();
if (subtypes.size() != 2)
return op.emitOpError()
<< "result is expected to be a tuple of size 2, but got "
<< subtypes.size();
if (!subtypes[1].isa<TokenType>())
return op.emitOpError() << "second element of result tuple is expected to "
"be of token type, but got "
<< subtypes[1];
return success();
// Logical Ops
OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs()) return lhs();
auto rType = getType().cast<ShapedType>();
auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (lhsVal && lhsVal.isSplat()) {
if (lhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
return rhs();
if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
return lhsVal;
if (rhsVal && rhsVal.isSplat()) {
if (rhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
return lhs();
if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
return rhsVal;
if (!rhsVal || !lhsVal) return {};
llvm::SmallVector<APInt, 4> values;
for (auto it :
llvm::zip(rhsVal.getValues<APInt>(), lhsVal.getValues<APInt>())) {
values.push_back(std::get<0>(it) & std::get<1>(it));
return DenseIntElementsAttr::get(rType, values);
OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs()) return lhs();
auto rType = getType().cast<ShapedType>();
auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (lhsVal && lhsVal.isSplat()) {
if (lhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
return lhsVal;
if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
return rhs();
if (rhsVal && rhsVal.isSplat()) {
if (rhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
return rhsVal;
if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
return lhs();
if (!rhsVal || !lhsVal) return {};
llvm::SmallVector<APInt, 4> values;
for (auto it :
llvm::zip(rhsVal.getValues<APInt>(), lhsVal.getValues<APInt>())) {
values.push_back(std::get<0>(it) | std::get<1>(it));
return DenseIntElementsAttr::get(rType, values);
OpFoldResult XorOp::fold(ArrayRef<Attribute> operands) {
// Fold x^x to 0. Attributes only support static shapes.
auto rType = getType().cast<ShapedType>();
if (lhs() == rhs() && rType.hasStaticShape()) {
Builder builder(getContext());
return builder.getZeroAttr(rType);
auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (lhsVal && lhsVal.isSplat()) {
if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
return rhs();
if (rhsVal && rhsVal.isSplat()) {
if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
return lhs();
if (!rhsVal || !lhsVal) return {};
llvm::SmallVector<APInt, 4> values;
for (auto it :
llvm::zip(rhsVal.getValues<APInt>(), lhsVal.getValues<APInt>())) {
values.push_back(std::get<0>(it) ^ std::get<1>(it));
return DenseIntElementsAttr::get(rType, values);
// MapOp
static LogicalResult Verify(MapOp op) {
// Checks if the number of `operands` match the arity of the map `computation`
// region.
auto& computation_block = op.computation().front();
auto computation_args = computation_block.getArguments();
if (op.operands().size() != computation_args.size())
return op.emitOpError()
<< "expects number of operands to match the arity "
"of map computation, but got: "
<< op.operands().size() << " and " << computation_args.size();
// The parameters of computation should all be scalars and match the element
// type of operands.
auto operand_type = op.operands()[0].getType().cast<TensorType>();
auto operand_elem_ty = operand_type.getElementType();
for (const auto& indexed_arg : llvm::enumerate(computation_args)) {
auto arg_type = indexed_arg.value().getType().dyn_cast<TensorType>();
if (!arg_type || arg_type.getRank() != 0)
return op.emitOpError()
<< "computation arguments must be 0-rank tensor, but got: arg #"
<< indexed_arg.index() << " of type "
<< indexed_arg.value().getType();
if (arg_type.getElementType() != operand_elem_ty) {
return op.emitOpError()
<< "element type of operands and computation arguments must "
"match, but got: "
<< operand_elem_ty << " and " << arg_type.getElementType();
// Mapped computation must return single output
auto computation_outputs = computation_block.getTerminator()->getOperands();
if (computation_outputs.size() != 1)
return op.emitOpError()
<< "computation must return single output, but got: "
<< computation_outputs.size();
// The output of computation must be scalar and have the same element type
// as op result.
auto computation_output_type =
if (!computation_output_type || computation_output_type.getRank() != 0)
return op.emitOpError()
<< "computation must return 0-rank tensor, but got: "
<< computation_outputs[0].getType();
auto result_type = op.getType().cast<TensorType>();
if (computation_output_type.getElementType() != result_type.getElementType())
return op.emitOpError() << "element type of result and computation output "
"must match, but got: "
<< result_type.getElementType() << " and "
<< computation_output_type.getElementType();
// Checks that the requested map dimension numbers are monotonically
// increasing.
DenseIntElementsAttr dimensions = op.dimensions();
for (const auto& indexedValue :
llvm::enumerate(dimensions.getValues<int64_t>())) {
if (indexedValue.value() != indexedValue.index())
return op.emitOpError() << "requires monotonically increasing dimension "
"numbers, but got: "
<< op.dimensions();
// Checks that number of dimensions of operands matches the size of
// `dimensions` since we currently only support mapping across all
// dimensions: i.e., scalar map functions.
if (operand_type.hasRank()) {
if (dimensions.size() != operand_type.getShape().size())
return op.emitOpError()
<< "applied to a subset of dimensions currently not supported: "
"operand dimensions = "
<< operand_type.getShape().size()
<< ", requested map dimensions size = " << dimensions.size();
return success();
OpFoldResult MapOp::fold(ArrayRef<Attribute> operands) {
mlir::Block& bb = computation().front();
mlir::Operation& front_op = bb.front();
auto ret_op = mlir::dyn_cast<ReturnOp>(front_op);
if (!ret_op) return nullptr;
if (ret_op.results().size() != 1) return nullptr;
for (mlir::BlockArgument barg : bb.getArguments()) {
if (barg == ret_op.results()[0]) return getOperands()[barg.getArgNumber()];
return nullptr;
// RecvOp
// Checks that the result type is of the form `tuple<any_type, mhlo::token>`
static LogicalResult Verify(RecvOp op) {
auto result_ty = op.getResult().getType().cast<TupleType>();
auto subtypes = result_ty.getTypes();
if (subtypes.size() != 2)
return op.emitOpError()
<< "result is expected to be a tuple of size 2, but got "
<< subtypes.size();
if (!subtypes[1].isa<TokenType>())
return op.emitOpError() << "second element of result tuple is expected to "
"be of token type, but got "
<< subtypes[1];
return success();
// CopyOp
OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); }
// ReduceWindowOp
// For reduce-window, all `inputs` need to have compatible shapes.
static LogicalResult Verify(ReduceWindowOp op) {
if (failed(verifyCompatibleShapes(op.inputs().getTypes())))
return op.emitOpError() << "requires same shape for all inputs";
return success();
// Get the operation used for reduction applied to `result_index`th result. Its
// expected to be a binary operation that consumes `result_index`th and
// `result_index + operands().size`th arguments of the body.
Operation* ReduceWindowOp::getReductionOp(int result_index) {
auto return_op = cast<ReturnOp>(body().front().getTerminator());
Operation* compute_op = return_op.results()[result_index].getDefiningOp();
if (compute_op->getNumOperands() != 2) return nullptr;
auto arg0 = compute_op->getOperand(0).dyn_cast<BlockArgument>();
auto arg1 = compute_op->getOperand(1).dyn_cast<BlockArgument>();
if (!arg0 || !arg1) return nullptr;
int arg0_num = arg0.getArgNumber();
int arg1_num = arg1.getArgNumber();
int other_arg_index = result_index + inputs().size();
if (arg0_num == result_index && arg1_num == other_arg_index)
return compute_op;
if (arg0_num == other_arg_index && arg1_num == result_index &&
return compute_op;
return nullptr;
// ReverseOp
OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
auto input = operand();
// No dimensions to reverse.
if (dimensions().getNumElements() == 0) return input;
llvm::SmallVector<APInt, 5> new_dims;
auto shaped_type = input.getType().cast<ShapedType>();
for (auto dim : dimensions().getValues<APInt>()) {
if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) {
return nullptr;
return input;
// ReduceOp
// Returns the result type after reducing operand of the given type across the
// specified dimensions.
static TensorType GetReduceResultType(Type operand_ty,
DenseIntElementsAttr dimensions,
Builder* builder) {
Type element_ty = getElementTypeOrSelf(operand_ty);
auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) return UnrankedTensorType::get(element_ty);
int64_t rank = ranked_ty.getRank();
llvm::SmallVector<bool, 4> dims_mask(rank, false);
for (int64_t dim : dimensions.getValues<int64_t>()) dims_mask[dim] = true;
SmallVector<int64_t, 4> shape;
for (int64_t i = 0; i < rank; ++i) {
if (!dims_mask[i]) shape.push_back(ranked_ty.getDimSize(i));
return RankedTensorType::get(shape, element_ty);
void ReduceOp::build(OpBuilder& builder, OperationState& state,
ValueRange inputs, ValueRange init_values,
DenseIntElementsAttr dimensions) {
SmallVector<Type, 1> result_ty;
for (Value input : inputs) {
GetReduceResultType(input.getType(), dimensions, &builder));
build(builder, state, result_ty, inputs, init_values, dimensions);
LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult>& results) {
// No dimensions to reduce.
if (dimensions().getNumElements() == 0) {
for (Value input : this->inputs()) {
return success();
// If all returned values in the ReduceOp region exists outside
// the region replace the ReduceOp with those values.
mlir::Block& bb = this->body().front();
SmallVector<Value> replaced_results;
if (auto ret_op = mlir::dyn_cast<ReturnOp>(bb.back())) {
for (Value result : ret_op.results()) {
if (result.getParentRegion() == ret_op->getParentRegion())
return failure();
results.insert(results.end(), replaced_results.begin(),
return success();
return failure();
// Checks the following eligibility criteria for compact printing of
// mhlo.reduce:
// E1. The reduce-op wraps a single inner-op in the associated region.
// E2. The single operation is a commutative binary-op from mhlo dialect, zero
// region, producing single result such that the operands and result all
// have the same type.
// E3. The reduce-op consist of at least one input-operand; The operand-types of
// inner-op should be derived trivially from the element-type of reduce-op's
// first input-operand.
// E4. The arguments of the region's only basic block are forwarded perfectly
// to inner-op's operands.
// E5. The reduce-op, inner-op, blocks arguments, and the return-op all have the
// same location.
// E6. The single operation result is perfectly forwarded to the reduce op
// return.
static bool isEligibleForCompactPrint(ReduceOp op) {
// Check E1.
auto& block = op.body().front();
if (!hasSingleElement(block.without_terminator())) return false;
Operation& innerOp = *block.begin();
// Check E2.
if (innerOp.getDialect() != op->getDialect()) return false;
if (innerOp.getNumOperands() != 2 ||
!innerOp.hasTrait<mlir::OpTrait::OneResult>() ||
!innerOp.hasTrait<mlir::OpTrait::SameOperandsAndResultType>() ||
!innerOp.hasTrait<mlir::OpTrait::IsCommutative>() ||
return false;
// Check E3.
if (op.inputs().empty()) return false;
auto elemType = op.inputs()[0].getType().cast<TensorType>().getElementType();
auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType);
if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false;
// Check E4.
if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false;
// Check E5.
auto retOp = dyn_cast<ReturnOp>(block.getTerminator());
if (!retOp) return false;
auto blockArgLoc = block.getArgument(0).getLoc();
if (blockArgLoc != block.getArgument(1).getLoc()) return false;
if (innerOp.getLoc() != op.getLoc() || retOp.getLoc() != op.getLoc() ||
blockArgLoc != op.getLoc())
return false;
// Check E6.
return llvm::equal(innerOp.getResults(), retOp.getOperands());
void printReduceOp(ReduceOp op, OpAsmPrinter& p) {
// Print the pairs of operands under the form:
// (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5)
StringRef comma = "";
int numOperandPairs = op.getNumOperands() / 2;
for (int opId : llvm::seq<int>(0, numOperandPairs)) {
p << comma << "(" << op.getOperand(opId)
<< " init: " << op.getOperand(opId + numOperandPairs) << ")";
comma = ", ";
// If the reduce-op is eligible for compact printing, we emit the one-liner:
// mhlo.reduce applies <inner-op> across dimensions = [...] : <func-type>
// Note: We are not printing the function type of reduction operation. We
// have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3)
// to derive the type from that of reduce-op.
if (isEligibleForCompactPrint(op)) {
Operation& innerOp = op.body().front().front();
p << " applies ";
printEscapedString(innerOp.getName().getStringRef(), p.getStream());
p << " across dimensions = [";
llvm::interleaveComma(op.dimensions().getValues<int64_t>(), p);
p << "]";
p << " : ";
} else {
p << " across dimensions = [";
llvm::interleaveComma(op.dimensions().getValues<int64_t>(), p);
p << "]";
p.printOptionalAttrDict(op->getAttrs(), {"dimensions"});
p << " : ";
p << " reducer";
// Print the pairs of block operands under the form:
// (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc):
Block& reducer = op.body().front();
int numOperandPairs = op.getNumOperands() / 2;
for (int opId : llvm::seq<int>(0, numOperandPairs)) {
p << "(";
p << ", ";
p.printRegionArgument(reducer.getArgument(opId + numOperandPairs));
p << ") ";
p << ' ';
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
ParseResult parseReduceOp(OpAsmParser& parser, OperationState& result) {
llvm::SMLoc loc = parser.getCurrentLocation();
Location currLocation = parser.getEncodedSourceLoc(loc);
// Parse the operands of reduce-op, this is a list of pair under the form:
// (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5)
// Each input to reduce is paired with its init value, even though in memory
// they are stored with the input first and the init values after.
SmallVector<OpAsmParser::OperandType, 2> operands;
SmallVector<OpAsmParser::OperandType, 2> initOperands;
do {
if (parser.parseOptionalLParen()) break;
OpAsmParser::OperandType operand, initOperand;
if (parser.parseOperand(operand) || parser.parseKeyword("init") ||
parser.parseColon() || parser.parseOperand(initOperand) ||
return failure();
} while (true);
// Check if we are parsing the compact version of reduce-op:
// mhlo.reduce applies <inner-op> across dimensions = [...] : <func-type>
// else parse the "region-based" variant.
if (failed(parser.parseOptionalKeyword("applies"))) {
// Parse the inner-op dimensions, reduce-op's function-type and
// optional location.
SmallVector<int64_t> dimensions;
auto parseDim = [&]() -> ParseResult {
if (parser.parseInteger(dimensions.emplace_back())) return failure();
return success();
FunctionType reduceOpFntype;
if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
parser.parseEqual() ||
parseDim) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColon() || parser.parseType(reduceOpFntype) ||
return failure();
OpBuilder builder(parser.getBuilder().getContext());
result.addAttribute("dimensions", GetI64ElementsAttr(dimensions, &builder));
// Parse the "reducer" region now.
SmallVector<OpAsmParser::OperandType, 2> reducerOperands;
SmallVector<OpAsmParser::OperandType, 2> reducerInitOperands;
SmallVector<Type, 2> reducerTypes;
SmallVector<Type, 2> reducerInitTypes;
SmallVector<Optional<Location>, 2> reducerLocs;
SmallVector<Optional<Location>, 2> reducerInitLocs;
auto parseBlockOperand =
[&](SmallVectorImpl<OpAsmParser::OperandType>& operands,
SmallVectorImpl<Type>& types,
SmallVectorImpl<Optional<Location>>& locs) -> ParseResult {
OpAsmParser::OperandType operand;
Type type;
Optional<Location> loc;
if (parser.parseRegionArgument(operand) || parser.parseColon() ||
parser.parseType(type) || parser.parseOptionalLocationSpecifier(loc))
return failure();
return success();
do {
if (failed(parser.parseOptionalLParen())) break;
if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) ||
parser.parseComma() ||
parseBlockOperand(reducerInitOperands, reducerInitTypes,
reducerInitLocs) ||
return failure();
} while (true);
// Derive the SSA-values for reduce-op's operands and parse the region, and
// the optional trailing location.
Optional<Location> trailingLoc;
if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
result.operands) ||
parser.parseRegion(*result.addRegion(), reducerOperands, reducerTypes))
return failure();
// Set the individual block arguments.
for (auto argAndLoc :
llvm::zip(result.regions.front()->front().getArguments(), reducerLocs))
if (std::get<1>(argAndLoc))
result.location = trailingLoc.getValueOr(currLocation);
return success();
// Parse the inner-op name and check if the contract on inner-op
// mentioned in "isEligibleForCompactPrint::E2" for pretty-priting is met.
FailureOr<OperationName> innerOpNameInfo = parser.parseCustomOperationName();
if (failed(innerOpNameInfo)) return failure();
StringRef innerOpName = innerOpNameInfo->getStringRef();
Dialect* innerOpDialect = innerOpNameInfo->getDialect();
if (!innerOpDialect || !innerOpDialect->getNamespace().equals("mhlo") ||
!innerOpNameInfo->hasTrait<mlir::OpTrait::NOperands<2>::Impl>() ||
!innerOpNameInfo->hasTrait<mlir::OpTrait::OneResult>() ||
!innerOpNameInfo->hasTrait<mlir::OpTrait::SameOperandsAndResultType>() ||
!innerOpNameInfo->hasTrait<mlir::OpTrait::IsCommutative>() ||
!innerOpNameInfo->hasTrait<mlir::OpTrait::ZeroRegion>()) {
"expected the inner-op to be a commutative binary-op from "
"mhlo dialect, zero region, producing single result such "
"that the operands and result all have the same type");
return failure();
// Parse the inner-op dimensions, reduce-op's function-type and
// optional location.
SmallVector<int64_t> dimensions;
auto parseDim = [&]() -> ParseResult {
if (parser.parseInteger(dimensions.emplace_back())) return failure();
return success();
Optional<Location> explicitLoc;
FunctionType reduceOpFntype;
if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
parser.parseEqual() ||
parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) ||
parser.parseColon() || parser.parseType(reduceOpFntype) ||
return failure();
if (!reduceOpFntype || reduceOpFntype.getInputs().empty()) {
if (!reduceOpFntype) return parser.emitError(loc, "expected function type");
return parser.emitError(loc,
"input types missing in reduce-op function type");
// If location of reduce-op is explicitly provided, then use it; Else use
// the parser's current location.
Location reduceOpLoc = explicitLoc.getValueOr(currLocation);
// Derive the SSA-values for reduce-op's operands.
if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
return failure();
// Derive the type of inner-op from that of reduce-op's input operand.
auto innerOpType = RankedTensorType::get(
/*shape=*/{}, getElementTypeOrSelf(reduceOpFntype.getInput(0)));
// Add a region for reduce-op.
Region& region = *result.addRegion();
// Create a basic-block inside reduce-op's region.
Block& block = region.emplaceBlock();
auto lhs = block.addArgument(innerOpType, reduceOpLoc);
auto rhs = block.addArgument(innerOpType, reduceOpLoc);
// Create and insert an "inner-op" operation in the block.
OpBuilder builder(parser.getBuilder().getContext());
OperationState innerOpState(reduceOpLoc, innerOpName);
Operation* innerOp = builder.createOperation(innerOpState);
// Insert a return statement in the block returning the inner-op's result.
builder.create<ReturnOp>(innerOp->getLoc(), innerOp->getResults());
// Populate the reduce-op operation-state with result-type, location, and
// dimension attribute.
result.location = innerOp->getLoc();
result.addAttribute("dimensions", GetI64ElementsAttr(dimensions, &builder));
return success();
// Return true if type1 and type2 are tensors and have the same
// element-type, else return false. With float element-types, ignore comparing
// floating-point precision if ignoreFpPrecision is True.
static bool tensorsHaveSameElType(Type type1, Type type2,
bool ignoreFpPrecision) {
auto tensorTy1 = type1.dyn_cast<TensorType>();
auto tensorTy2 = type2.dyn_cast<TensorType>();
if (!tensorTy1 || !tensorTy2) return false;
if (ignoreFpPrecision && tensorTy1.getElementType().isa<FloatType>() &&
return true;
return tensorTy1.getElementType() == tensorTy2.getElementType();
// Return true if type1 and type2 are shape-compatible and have same element
// type. If 'ignoreFpPrecision' is True, then allow floats with different
// precisions while checking element-types.
static bool compatibleShapeAndElementType(Type type1, Type type2,
bool ignoreFpPrecision = false) {
if (failed(verifyCompatibleShape(type1, type2))) return false;
return tensorsHaveSameElType(type1.cast<ShapedType>(),
type2.cast<ShapedType>(), ignoreFpPrecision);
static LogicalResult verifyReducerShape(
ReduceOp op, Block& block, ArrayRef<TensorType> inputArgTypes,
ArrayRef<TensorType> initValueTypes, int64_t numInputs,
ArrayRef<int64_t> outputShape, bool allInputsUnranked,
SmallVectorImpl<TensorType>& accumulatorSubShapes) {
// Check that the number of reduction-region arguments matches with that of
// reduce-op's arguments.
if (block.getArguments().size() != numInputs * 2)
return op.emitError() << "Reduction-region must take " << numInputs * 2
<< " parameters, but takes "
<< block.getArguments().size() << " parameter(s)";
// Check if the reduction-region produces non-zero outputs.
if (block.getTerminator()->getOperands().empty())
return op.emitError()
<< "The reduction-region expected to return some value(s)";
// Check that the reduction-region returns a tuple- OR list- of tensors.
// The number of result-tensors must match the `numInputs`.
// TODO(b/171261845): Remove tuples from MHLO dialect.
auto tupleT =
if (tupleT && block.getTerminator()->getOperands().size() == 1) {
if (tupleT.size() != numInputs)
return op.emitError()
<< "Reduction-region here must produce a tuple with " << numInputs
<< " tensors, but produces " << tupleT.size() << " instead";
for (Type elementType : tupleT.getTypes()) {
auto tensorTy = elementType.dyn_cast<TensorType>();
if (!tensorTy)
return op.emitError() << "Reduction-region here must produce tuple "
"of tensor-typed results, but "
"produces "
<< elementType << " instead";
} else {
if (block.getTerminator()->getOperands().size() != numInputs)
return op.emitError()
<< "Reduction-region here must produce " << numInputs
<< " tensors, but produces "
<< block.getTerminator()->getOperands().size() << " instead";
for (Value retOperand : block.getTerminator()->getOperands()) {
auto tensorTy = retOperand.getType().dyn_cast<TensorType>();
if (!tensorTy)
return op.emitError() << "Reduction-region here must produce "
"tensor-typed result(s), but "
"produces "
<< retOperand.getType() << " instead";
// Consider typical reduce-op syntax:
// reduce(I(i), V(j)):
// block(BI(i), BV(j)):
// ... some computation ...
// return(R(i))
// where
// I(i) : i-th input of reduce-op
// V(j) : j-th init-value of reduce-op
// BI(i) : i-th input of reducer-function
// BV(j) : j-th init-value of reducer-function
// R(i) : i-th return-type
// Note that: |I(i)| == V(j)| == |BI(i)| == |BV(j)| == |R(i)|
// Here are the type-constraints among V(j), BI(i), BV(j), and R(i).
// C1 : Check that BI(i) and R(i) have same shape and element-type.
// C2 : Check that BV(j) and R(i) have same shape and element-type.
// C3 : Check that V(j) and R(i) have same shape and element-type.
// From C1, C2, and C3, we can infer that V(j), BI(i), BV(j), and R(i) all
// have compatible shapes and element-types.
// The next check, C4, adds constraints on how the type if I(i) is related
// to any_of(V(j), BI(i), BV(j), and R(i)), say BV(j);
// C4.1 : Check that I(i) and BV(j) have same element-type.
// C4.2 : Check that shape of BV(j) is a 'sub-sequence' of the shape of
// output-array. The shape of output-array is determined from that
// of I(i) after removing the "dimensions-to-reduce" (as specified by
// the dimensions attribute of reduce-op).
for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
// Check C1.
if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
return op.emitError()
<< "The type of reduction-region's parameter at index " << inputIdx
<< " is different than the corresponding result type: "
<< block.getArgument(inputIdx).getType() << " vs "
<< accumulatorSubShapes[inputIdx];
// Check C2.
if (!compatibleShapeAndElementType(
block.getArgument(numInputs + inputIdx).getType(),
return op.emitError()
<< "The type of reduction-region's parameter at index "
<< numInputs + inputIdx
<< " is different than the corresponding result type: "
<< block.getArgument(numInputs + inputIdx).getType() << " vs "
<< accumulatorSubShapes[inputIdx];
// Check C3.
if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
return op.emitError()
<< "The type of reduction-region's result type at index "
<< inputIdx
<< " differs from the reduce-op's corresponding init-value type: "
<< accumulatorSubShapes[inputIdx] << " vs "
<< initValueTypes[inputIdx];
// Check C4.1.
if (!tensorsHaveSameElType(
block.getArgument(numInputs + inputIdx).getType(), true))
return op.emitError()
<< "The element-type of reduce-op's input-parameter at index "
<< inputIdx
<< " differs from that of reduction-region's argument at index "
<< numInputs + inputIdx << ": " << inputArgTypes[inputIdx]
<< " vs " << block.getArgument(numInputs + inputIdx).getType();
// Check C4.2.
Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
auto blockArgTensorTy = blockArgType.cast<TensorType>();
if (allInputsUnranked || !blockArgTensorTy.hasRank()) return success();
auto argShape = blockArgTensorTy.getShape();
if (argShape.size() > outputShape.size())
return op.emitError()
<< "The rank of reduction-region's argument at index "
<< numInputs + inputIdx
<< " is not compatible with that of reduce-op's result: "
<< argShape.size() << " vs " << outputShape.size()
<< " (expected)";
int64_t argShapeIdx = 0;
for (int64_t outputShapeIdx = 0;
outputShapeIdx < outputShape.size() && argShapeIdx < argShape.size();
if (outputShape[outputShapeIdx] == argShape[argShapeIdx]) argShapeIdx++;
if (argShapeIdx != argShape.size())
return op.emitError()
<< "The shape of reduction-region's argument at index "
<< numInputs + inputIdx
<< " is not compatible with that of reduce-op's input-parameter "
"at index "
<< inputIdx;
return success();
static LogicalResult Verify(ReduceOp op) {
// Check that there are even number of operands and >= 2.
if (op.getNumOperands() % 2 != 0 || op.getOperands().empty())
return op.emitOpError()
<< "expects the size of operands to be even and >= 2";
// Collect the input and init-value operands. Note that the operand-type is
// enforced as "TensorType" by ODS.
int64_t numInputs = op.getNumOperands() / 2;
auto operandTensorTypes = llvm::to_vector<4>(llvm::map_range(
[](Type t) -> TensorType { return t.cast<TensorType>(); }));
ArrayRef<TensorType> inputArgTypes(operandTensorTypes.begin(),
operandTensorTypes.begin() + numInputs);
ArrayRef<TensorType> initValueTypes(operandTensorTypes.begin() + numInputs,
// Check for unranked tensors in input operands.
int64_t rankedInputIdx = -1;
for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
if (inputArgTypes[inputIdx].hasRank()) {
rankedInputIdx = inputIdx;
bool allInputsUnranked = (rankedInputIdx == -1);
// Check that all input operands have compatible shapes. The element types may
// be different.
if (!allInputsUnranked) {
for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
if (failed(mlir::verifyCompatibleShape(inputArgTypes[rankedInputIdx],
inputArgTypes[inputIdx]))) {
return op.emitOpError()
<< "expects all inputs to have compatible shapes. Shape at"
<< " input-index " << inputIdx
<< " is not compatible with shape at input-index "
<< rankedInputIdx;
// Check that
// 1. the dimensions of reduce-op are in-bounds for the given shape.
// 2. the dimension-attribute have no duplicate entries.
DenseSet<int64_t> dimensionsToReduceSet;
for (int64_t dimension : op.dimensions().getValues<int64_t>()) {
if ((!allInputsUnranked &&
dimension >= inputArgTypes[rankedInputIdx].getRank()) ||
dimension < 0) {
return op.emitError() << "Out-of-bounds dimension " << dimension
<< " for input-tensor rank: "
<< inputArgTypes[rankedInputIdx].getRank();
if (!dimensionsToReduceSet.insert(dimension).second) {
return op.emitError() << "Duplicate reduction dimension: " << dimension;
// Verify the inner block defining the reducer function.
SmallVector<int64_t> newDimensions;
if (!allInputsUnranked) {
for (int inputIdx = 0; inputIdx < inputArgTypes[rankedInputIdx].getRank();
++inputIdx) {
if (!dimensionsToReduceSet.count(inputIdx)) {
Block& block = op.body().front();
SmallVector<TensorType> accumulatorSubShapes;
if (failed(verifyReducerShape(op, block, inputArgTypes, initValueTypes,
numInputs, newDimensions, allInputsUnranked,
return failure();
// Check if the reduce-op's result-type matches with the one derived from
// the reducer-block and dimensions attribute.
if (op.getResults().size() != accumulatorSubShapes.size())
return op.emitError()
<< "Unexpected number of reduce-op's returned values: "
<< op.getResults().size() << " vs " << accumulatorSubShapes.size()
<< " (expected)";
for (int64_t shapeIdx = 0; shapeIdx < accumulatorSubShapes.size();
shapeIdx++) {
// The result-type is enforced as "TensorType" by ODS.
auto opResultType = op.getResult(shapeIdx).getType().cast<TensorType>();
// Check element-type.
if (accumulatorSubShapes[shapeIdx].getElementType() !=
opResultType.getElementType()) {
return op.emitError()
<< "Unexpected element-type for reduce-op's return value at index "
<< shapeIdx << ": " << opResultType.getElementType() << " vs "
<< accumulatorSubShapes[shapeIdx].getElementType()
<< " (expected)";
// Check shape.
if (!allInputsUnranked && opResultType.hasRank() &&
(newDimensions != opResultType.getShape())) {
Type expectedResultType = RankedTensorType::get(
newDimensions, accumulatorSubShapes[shapeIdx].getElementType());
return op.emitError()
<< "Unexpected type for reduce-op's return value at index "
<< shapeIdx << ": " << opResultType << " vs " << expectedResultType
<< " (expected)";
return success();
// Enable constant folding to occur within the region of the ReduceOp
// by replacing block argument uses with constants if:
// 1. All the ReduceOp operands are splat constants.
// 2. The ReduceOp region consists of a single logical AND or logical OR.
// The pattern leverages the idempotent property of the AND and OR operators
// to determine the value of a reduction on splat constants. Other boolean
// operators do not have this property, and need separate patterns to resolve
// reductions of their splat constants.
struct LowerBoolSplatConstantsIntoRegion : public OpRewritePattern<ReduceOp> {
using OpRewritePattern<ReduceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReduceOp op,
PatternRewriter& rewriter) const override {
mlir::Block& bb = op.body().front();
// Ensure only a compute op and return op exist and the
// compute op is an AND or OR op.
if (bb.getOperations().size() != 2) return failure();
if (!mlir::isa<mhlo::AndOp, mhlo::OrOp>(bb.front())) return failure();
// Ensure all operands are splat constants.
SmallVector<DenseElementsAttr, 4> barg_cst_attrs;
for (auto inp_and_barg : llvm::zip(op.getOperands(), bb.getArguments())) {
Value inp = std::get<0>(inp_and_barg);
BlockArgument barg = std::get<1>(inp_and_barg);
ConstOp cst = inp.getDefiningOp<ConstOp>();
if (!cst) return failure();
auto cst_attr = cst.value().dyn_cast_or_null<DenseElementsAttr>();
if (!cst_attr.isSplat()) {
return rewriter.notifyMatchFailure(op, "Must be splat constant.");
auto barg_shaped_type = barg.getType().dyn_cast<ShapedType>();
if (!barg_shaped_type) return failure();
auto barg_cst_attr = DenseElementsAttr::get(
barg_shaped_type, cst_attr.getSplatValue<mlir::Attribute>());
// Create new splat constants to replace block arguments.
for (BlockArgument barg : bb.getArguments()) {
int arg_idx = barg.getArgNumber();
mhlo::ConstOp new_cst = rewriter.create<mhlo::ConstOp>(
bb.front().getLoc(), barg.getType(), barg_cst_attrs[arg_idx]);
return success();
void ReduceOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
LogicalResult ReduceOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
ReduceOp::Adaptor adaptor(operands);
auto inputs = adaptor.inputs();
auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>();
// Not support unranked type a.t.m.
if (!operand_type) return failure();
Location loc = this->getLoc();
SmallVector<Value, 4> shape_values;
SmallVector<int64_t, 4> dimensions(this->dimensions().getValues<int64_t>());
Type shape_scalar_type = builder.getIndexType();
auto to_shape_scalar_type = [&](Value v) {
return MaybeCastTo(builder, loc, v, shape_scalar_type);
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
int64_t idx = element.index();
auto* it = std::find(dimensions.begin(), dimensions.end(), idx);
if (it != dimensions.end()) {
Value value_dim = to_shape_scalar_type(
builder.create<tensor::DimOp>(loc, inputs[0], element.index()));
Value output_shape = builder.create<tensor::FromElementsOp>(
for (size_t i = 0; i < inputs.size(); ++i) {
return success();
// RngNormalOp
LogicalResult RngNormalOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
return rngInferReturnTypeComponents(context, location, operands, attributes,
regions, inferredReturnShapes);
LogicalResult RngNormalOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
RngNormalOp::Adaptor adaptor(operands);
castToIndexTensor(builder, getLoc(), adaptor.shape()));
return success();
// RngUniformOp
LogicalResult RngUniformOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
return rngInferReturnTypeComponents(context, location, operands, attributes,
regions, inferredReturnShapes);
LogicalResult RngUniformOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
RngUniformOp::Adaptor adaptor(operands);
castToIndexTensor(builder, getLoc(), adaptor.shape()));
return success();
// SelectOp
static LogicalResult Verify(SelectOp op) {
// Either, all operands could be the same shape ...
if (succeeded(verifyCompatibleShapes(op.getOperandTypes()))) return success();
// ... or the predicate could be a scalar and the remaining two operands could
// be of the same shape.
auto predTy = op.pred().getType().dyn_cast<RankedTensorType>();
bool predMayBeScalar = !predTy || predTy.getRank() == 0;
if (!predMayBeScalar ||
{op.on_true().getType(), op.on_false().getType()}))) {
return op.emitOpError()
<< "requires the same type for all operands and results";
return success();
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
if (on_true() == on_false()) {
return on_true();
auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!predicate) {
return {};
auto predicateTy = predicate.getType().cast<ShapedType>();
if (!predicateTy.getElementType().isInteger(1)) {
return {};
if (predicate.isSplat()) {
return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
: on_false();
return {};
// Makes it such that a SelectOp that is a non-root operation in a DRR infers
// the return type based on operand type.
LogicalResult SelectOp::inferReturnTypes(
MLIRContext*, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type>& inferredReturnTypes) {
SelectOp::Adaptor op(operands);
auto true_type = op.on_true().getType().cast<TensorType>();
auto false_type = op.on_true().getType().cast<TensorType>();
// Check for type compatibility in the select op. This requires that the two
// non-predicate operands:
// (a) have the same element type
// (b) have compatible shapes (i.e. the same shape and/or at least one
// dynamic shape)
if (true_type.getElementType() != false_type.getElementType() ||
failed(mlir::verifyCompatibleShape(true_type, false_type))) {
return emitOptionalError(location,
"incompatible operand types: ", true_type, " and ",
// The output shape should be the most general of the operand shapes at each
// dimension.
Type output_type;
if (true_type == false_type || !true_type.hasRank()) {
output_type = true_type;
} else if (!false_type.hasRank()) {
output_type = false_type;
} else {
assert(true_type.getRank() == false_type.getRank());
llvm::SmallVector<int64_t, 4> dims;
for (auto dim : llvm::zip(true_type.getShape(), false_type.getShape())) {
dims.push_back(std::get<0>(dim) == std::get<1>(dim)
? std::get<0>(dim)
: ShapedType::kDynamicSize);
output_type = RankedTensorType::get(dims, true_type.getElementType());
return success();
LogicalResult SelectOp::inferReturnTypeComponents(
mlir::MLIRContext* ctx, llvm::Optional<mlir::Location> loc,
ValueShapeRange operands, mlir::DictionaryAttr attributes,
mlir::RegionRange regions,
inferredShapedTypeComponents) {
llvm::SmallVector<Type, 4> inferredReturnTypes;
const LogicalResult infer_types_status = inferReturnTypes(
ctx, loc, operands, attributes, regions, inferredReturnTypes);
if (infer_types_status.failed()) return infer_types_status;
if (inferredReturnTypes.size() != 1) return failure();
auto result_tensor_type =
if (!result_tensor_type) return failure();
mlir::Type element_type =
{result_tensor_type.getShape(), element_type});
return success();
LogicalResult SelectOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
// For ``, the first operand may be a scalar.
return deriveShapeFromOperand(&builder, getOperation(), operands[1],
// SetDimensionSizeOp
static LogicalResult Verify(SetDimensionSizeOp op) {
if (auto size = op.size().getType().dyn_cast<RankedTensorType>()) {
if (size.getRank() != 0)
return op.emitOpError() << "size operand should be of rank-0";
return VerifyDimAttr(op);
OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
if (input) return input;
DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (!size || !size.isSplat()) return {};
auto ty = getType().dyn_cast<RankedTensorType>();
if (!ty) return {};
int64_t dim_size = ty.getDimSize(dimension());
if (dim_size == size.getSplatValue<IntegerAttr>().getInt()) return operand();
return {};
// PadOp
static LogicalResult Verify(PadOp op) {
auto input_type = op.operand().getType().cast<RankedTensorType>();
auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
if (pad_type.getRank() != 0) {
return op.emitOpError(
llvm::formatv("padding value type should be a rank-0 "
"tensor, is rank {0}",
const auto& padding_low = op.edge_padding_low();
if (padding_low.getType().getNumElements() != input_type.getRank()) {
return op.emitOpError(llvm::formatv(
"edge_padding_low length ({0}) must match operand rank ({1})",
padding_low.getType().getNumElements(), input_type.getRank()));
const auto& padding_high = op.edge_padding_high();
if (padding_high.getType().getNumElements() != input_type.getRank()) {
return op.emitOpError(llvm::formatv(
"edge_padding_high length ({0}) must match operand rank ({1})",
padding_high.getType().getNumElements(), input_type.getRank()));
const auto& padding_interior = op.interior_padding();
if (padding_interior.getType().getNumElements() != input_type.getRank()) {
return op.emitOpError(llvm::formatv(
"interior_padding length ({0}) must match operand rank ({1})",
padding_interior.getType().getNumElements(), input_type.getRank()));
auto input_shape = input_type.getShape();
auto output_shape =
if (input_shape.size() != output_shape.size()) {
return op.emitOpError(
llvm::formatv("operand rank ({0}) and result rank({0}) should match",
input_shape.size(), output_shape.size()));
for (int i = 0, e = input_shape.size(); i < e; i++) {
int64_t padding_low_val = padding_low.getValues<APInt>()[i].getSExtValue();
int64_t padding_high_val =
int64_t padding_interior_val =
int64_t expected_output =
input_shape[i] + padding_low_val + padding_high_val +
std::max<int64_t>(input_shape[i] - 1, 0LL) * padding_interior_val;
if (expected_output != output_shape[i]) {
return op.emitOpError(llvm::formatv(
"expected output shape's dimension #{0} to be {1} but found {2}", i,
expected_output, output_shape[i]));
return success();
OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
// If all padding is zero then it is an identity pad.
auto is_zero = [](const APInt& i) { return i == 0; };
if (llvm::all_of(edge_padding_low().getValues<APInt>(), is_zero) &&
llvm::all_of(edge_padding_high().getValues<APInt>(), is_zero) &&
llvm::all_of(interior_padding().getValues<APInt>(), is_zero))
return operand();
// If any padding is negative then it isn't supported by the folder (yet).
auto is_negative = [](const APInt& i) { return i.slt(0); };
if (llvm::all_of(edge_padding_low().getValues<APInt>(), is_negative) &&
llvm::all_of(edge_padding_high().getValues<APInt>(), is_negative) &&
llvm::all_of(interior_padding().getValues<APInt>(), is_negative))
return {};
DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
RankedTensorType return_type = getType().dyn_cast_or_null<RankedTensorType>();
if (!input || !input.getType().hasRank() || !padding || !return_type ||
return {};
// Fill the full result tensor with the padding value.
llvm::SmallVector<Attribute, 4> result(return_type.getNumElements(),
auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
llvm::ArrayRef<int64_t> shape) {
for (int64_t i = index.size() - 1; i >= 0; --i) {
if (index[i] < shape[i]) return;
index[i] = 0;
// Iterate over all elements of the input tensor and copy it to the correct
// location in the output tensor.
llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
uint64_t num_elements = input.getNumElements();
for (uint64_t operand_idx = 0; operand_idx < num_elements; operand_idx++) {
uint64_t result_idx = 0;
uint64_t idx_multiplyer = 1;
for (int64_t i = index.size() - 1; i >= 0; --i) {
result_idx +=
(edge_padding_low().getValues<int64_t>()[i] +
index[i] * (interior_padding().getValues<int64_t>()[i] + 1)) *
idx_multiplyer *= return_type.getDimSize(i);
result[result_idx] = input.getValues<Attribute>()[index];
next_index(index, input.getType().getShape());
return DenseElementsAttr::get(return_type, result);
// DynamicPadOp
void DynamicPadOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
static LogicalResult Verify(DynamicPadOp op) {
auto input_type = op.operand().getType().dyn_cast<RankedTensorType>();
// If operand is unranked, there is very little to verify statically.
if (!input_type) return success();
int input_rank = input_type.getRank();
auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
if (pad_type.getRank() != 0) {
return op.emitOpError() << "padding value type should be a rank-0";
auto padding_low_type =
if (padding_low_type.getNumElements() != input_rank) {
return op.emitOpError()
<< "edge_padding_low length(" << padding_low_type.getNumElements()
<< ") must match operand rank(" << input_rank << ").";
auto padding_high_type =
if (padding_high_type.getNumElements() != input_rank) {
return op.emitOpError()
<< "edge_padding_high length(" << padding_high_type.getNumElements()
<< ") must match operand rank(" << input_rank << ").";
auto interior_padding_type =
if (interior_padding_type.getNumElements() != input_rank) {
return op.emitOpError()
<< "edge_padding_interior length("
<< interior_padding_type.getNumElements()
<< ") must match operand rank(" << input_rank << ").";
auto output_type = op.getResult().getType().dyn_cast<RankedTensorType>();
// If result is unranked, there is very little to verify statically.
if (!output_type) return success();
int output_rank = output_type.getRank();
if (input_rank != output_rank) {
return op.emitOpError() << "operand rank(" << input_rank
<< ") must match result(" << output_rank << ").";
return success();
LogicalResult DynamicPadOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
DynamicPadOp::Adaptor adaptor(operands);
Value operand = adaptor.operand();
Value edge_padding_low = adaptor.edge_padding_low();
Value edge_padding_high = adaptor.edge_padding_high();
Value interior_padding = adaptor.interior_padding();
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
// Not support unranked pad a.t.m.
if (!operand_type) return failure();
auto loc = this->getLoc();
SmallVector<Value, 4> shape_values;
Type shape_scalar_type =
auto to_shape_scalar_type = [&](Value v) {
return MaybeCastTo(builder, loc, v, shape_scalar_type);
Value zero =
to_shape_scalar_type(builder.create<arith::ConstantIndexOp>(loc, 0));
Value one =
to_shape_scalar_type(builder.create<arith::ConstantIndexOp>(loc, 1));
for (int idx : llvm::seq<int>(0, operand_type.getShape().size())) {
Value value_dim =
to_shape_scalar_type(builder.create<tensor::DimOp>(loc, operand, idx));
Value offset = builder.create<arith::ConstantIndexOp>(loc, idx);
Value value_low =
builder.create<tensor::ExtractOp>(loc, edge_padding_low, offset);
Value value_high =
builder.create<tensor::ExtractOp>(loc, edge_padding_high, offset);
Value value_interior =
builder.create<tensor::ExtractOp>(loc, interior_padding, offset);
// output_size = input_size + padding_low + padding_high + interior *
// max(input_size - 1, 0)
Value value_dim_less_than_one = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, value_dim, one);
Value interior_size = builder.create<arith::MulIOp>(
loc, value_interior,
loc, value_dim_less_than_one, zero,
builder.create<arith::SubIOp>(loc, value_dim, one)));
loc, builder.create<arith::AddIOp>(loc, interior_size, value_dim),
return success();
// ReshapeOp
static LogicalResult Verify(ReshapeOp op) {
// If the operand type is dynamically shaped there is nothing to verify.
auto operand_ty = op.operand().getType().dyn_cast<RankedTensorType>();
if (!operand_ty || !operand_ty.hasStaticShape()) return success();
// If the operand type is statically shaped (not required) the number of
// elements must match that of the result type.
auto result_ty = op.getType().cast<RankedTensorType>();
assert(result_ty && result_ty.hasStaticShape() &&
"result type must be statically shaped");
int64_t num_result_elements = result_ty.getNumElements();
int64_t num_operand_elements = operand_ty.getNumElements();
if (num_result_elements != num_operand_elements)
return op.emitOpError()
<< "number of output elements (" << num_result_elements
<< ") doesn't match expected number of elements ("
<< num_operand_elements << ")";
return success();
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
if (getOperand().getType() == getType()) {
return getOperand();
if (auto prev_op = getOperand().getDefiningOp<ReshapeOp>()) {
return getResult();
if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
return elements.reshape(getResult().getType().cast<ShapedType>());
return {};
void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape,
EliminateRedundantReshape, EliminateIdentityReshape>(context);
// ReplicaId Op
LogicalResult ReplicaIdOp::inferReturnTypes(
MLIRContext* context, Optional<Location>, ValueRange operands,
DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
/*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
return success();
// If Op
static LogicalResult VerifyConditionalBranch(Operation* op, Region& region,
llvm::Twine branchName) {
if (region.getNumArguments() != 0)
return op->emitOpError()
<< branchName << " must have 0 arguments, but found "
<< region.getNumArguments();
TypeRange branchReturnTypes =
if (branchReturnTypes != op->getResultTypes())
return op->emitOpError()
<< branchName << " returned types (" << branchReturnTypes
<< ") do not match op result types (" << op->getResultTypes() << ")";
return success();
static LogicalResult Verify(IfOp op) {
if (failed(VerifyConditionalBranch(op, op.true_branch(),
/*branchName=*/"true_branch"))) {
return failure();
if (failed(VerifyConditionalBranch(op, op.false_branch(),
/*branchName=*/"false_branch"))) {
return failure();
return success();
static LogicalResult InlineIfConstantCondition(IfOp ifOp,
PatternRewriter& rewriter) {
DenseIntElementsAttr pred_attr;
if (!matchPattern(ifOp.pred(), m_Constant(&pred_attr))) return failure();
if (pred_attr.getSplatValue<BoolAttr>().getValue()) {
ReplaceOpWithRegion(rewriter, ifOp, ifOp.true_branch());
} else {
ReplaceOpWithRegion(rewriter, ifOp, ifOp.false_branch());
return success();
void IfOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
// Case Op
static LogicalResult Verify(CaseOp op) {
auto num_branches = op.branches().size();
for (unsigned i = 0; i < num_branches; ++i)
if (failed(VerifyConditionalBranch(op, op.branches()[i],
/*branchName=*/"branch " + Twine(i))))
return failure();
return success();
static LogicalResult InlineCaseConstantCondition(CaseOp caseOp,
PatternRewriter& rewriter) {
DenseIntElementsAttr index_attr;
if (!matchPattern(caseOp.index(), m_Constant(&index_attr))) {
return failure();
int64_t index =
// For an OOB index, the last branch is executed as the default branch:
if (index < 0 || index >= caseOp.getNumRegions())
index = caseOp.getNumRegions() - 1;
Region& region = caseOp.getRegion(index);
if (!llvm::hasSingleElement(region)) return failure();
ReplaceOpWithRegion(rewriter, caseOp, region);
return success();
void CaseOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
// SqrtOp
OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
auto val = operands[0].dyn_cast_or_null<DenseElementsAttr>();
if (!val) return {};
auto type = getElementTypeOrSelf(getType());
if (!type.isF32() && !type.isF64()) return {};
auto shaped_type = getType().cast<ShapedType>();
if (!shaped_type.hasStaticShape()) return {};
int bit_width = type.getIntOrFloatBitWidth();
llvm::SmallVector<APFloat, 4> values;
for (auto it : val.getValues<APFloat>()) {
double value = bit_width == 32 ? it.convertToFloat() : it.convertToDouble();
if (value < 0) return {};
value = std::sqrt(value);
if (bit_width == 32)
return DenseFPElementsAttr::get(shaped_type, values);
// UnaryOps
template <typename Op, typename ElementType = Type, typename ValType,
typename Convert>
static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
if (!attrs[0]) return {};
DenseElementsAttr val = attrs[0].dyn_cast<DenseElementsAttr>();
if (!val) return {};
ShapedType type = op->getType().template cast<ShapedType>();
if (!type.hasStaticShape()) {
return {};
Type etype = type.getElementType();
// Evaluate for integer values.
if (!etype.isa<ElementType>()) {
return {};
SmallVector<ValType, 6> values;
for (const auto v : val.getValues<ValType>()) {
return DenseElementsAttr::get(type, values);
struct round {
APFloat operator()(const APFloat& f) {
APFloat r = f;
return r;
struct logical_not {
APInt operator()(const APInt& i) {
return APInt(i.getBitWidth(), static_cast<uint64_t>(!i));
template <typename FloatOrInt>
struct sign {
APFloat compute(const APFloat& f) {
if (f.isZero() || f.isNaN()) return f;
double value = f.isNegative() ? -1.0 : 1.0;
APFloat val(value);
bool unused;
val.convert(f.getSemantics(), APFloat::rmNearestTiesToEven, &unused);
return val;
APInt compute(const APInt& i) {
APInt r = i;
if (r == 0) return r;
if (r.isNegative()) {
return APInt(r.getBitWidth(), -1, /*isSigned=*/true);
return APInt(r.getBitWidth(), 1, /*isSigned=*/true);
FloatOrInt operator()(const FloatOrInt& fi) { return compute(fi); }
#define UNARY_FOLDER(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
return UnaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
return UnaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs); \
return {}; \
#define UNARY_FOLDER_INT(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
return UnaryFolder<Op, IntegerType, APInt, Func>(this, attrs); \
return {}; \
#define UNARY_FOLDER_FLOAT(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
return {}; \
UNARY_FOLDER(NegOp, std::negate);
UNARY_FOLDER(SignOp, sign);
UNARY_FOLDER_INT(NotOp, logical_not);
// BinaryOps
namespace {
// Updates the element type of a (presumed) tensor type 'x', returning either
// a permuted UnrankedTensorType or RankedTensorType.
static Type UpdateResultElementType(Builder* builder, Type x,
Type element_type) {
auto x_ranked = x.dyn_cast<RankedTensorType>();
if (!x_ranked) {
return UnrankedTensorType::get(element_type);
auto shape_x = x_ranked.getShape();
return RankedTensorType::get(shape_x, element_type);
} // namespace
template <typename Op, typename ElementType = Type, typename ValType,
typename Convert>
static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
if (!attrs[0] || !attrs[1]) return {};
DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
if (!lhs || !rhs) return {};
ShapedType type = op->getType().template cast<ShapedType>();
if (!type.hasStaticShape()) {
return {};
Type etype = type.getElementType();
// Evaluate for integer values.
if (!etype.isa<ElementType>()) {
return {};
SmallVector<ValType, 6> values;
for (const auto zip :
llvm::zip(lhs.getValues<ValType>(), rhs.getValues<ValType>())) {
values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
return DenseElementsAttr::get(type, values);
template <typename T>
struct divide : std::divides<T> {};
template <>
struct divide<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
template <typename T>
struct remainder : std::modulus<T> {};
template <>
struct remainder<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
template <>
struct remainder<APFloat> {
APFloat operator()(const APFloat& a, const APFloat& b) const {
APFloat result(a);
return result;
template <typename T>
struct max {
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
template <>
struct max<APInt> {
APInt operator()(const APInt& a, const APInt& b) const {
return llvm::APIntOps::smax(a, b);
template <typename T>
struct min {
T operator()(const T& a, const T& b) const { return std::min<T>(a, b); }
template <>
struct min<APInt> {
APInt operator()(const APInt& a, const APInt& b) const {
return llvm::APIntOps::smin(a, b);
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
return BinaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
return BinaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs); \
return {};
#define BINARY_FOLDER(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
// Addition, subtraction and multiplication use the std:: versions of the ops.
// Due to the other ops behaving differently in signed vs unsigned integers,
// APInts need a special implementation. Currently, it replicates signed int
// op behavior.
BINARY_FOLDER(SubOp, std::minus);
BINARY_FOLDER(DivOp, divide);
BINARY_FOLDER(RemOp, remainder);
OpFoldResult AddOp::fold(ArrayRef<Attribute> attrs) {
if (attrs[0] && attrs[1]) {
// Handle special case where one operand is 0: x + 0 => x
if (attrs[0] || attrs[1]) {
SplatElementsAttr attr = attrs[0] ? attrs[0].dyn_cast<SplatElementsAttr>()
: attrs[1].dyn_cast<SplatElementsAttr>();
if (!attr) return {};
Value result = attrs[0] ? rhs() : lhs();
if (attr.getElementType().isa<FloatType>()) {
if (attr.getSplatValue<APFloat>().isZero()) return result;
} else if (attr.getElementType().isa<IntegerType>()) {
if (attr.getSplatValue<APInt>().isZero()) return result;
return {};
OpFoldResult MulOp::fold(ArrayRef<Attribute> attrs) {
if (attrs[0] && attrs[1]) {
BINARY_FOLDER_INTERNAL(MulOp, std::multiplies);
// Handle special case where one operand is 1: x * 1 => x
if (attrs[0] || attrs[1]) {
SplatElementsAttr attr = attrs[0] ? attrs[0].dyn_cast<SplatElementsAttr>()
: attrs[1].dyn_cast<SplatElementsAttr>();
if (!attr) return {};
Value result = attrs[0] ? rhs() : lhs();
if (attr.getElementType().isa<FloatType>()) {
if (attr.getSplatValue<APFloat>().convertToDouble() == 1.0) return result;
} else if (attr.getElementType().isa<IntegerType>()) {
if (attr.getSplatValue<APInt>().getSExtValue() == 1) return result;
return {};
// SliceOp
// Returns output dimension size for slice result for the given arguments.
// Returns -1 if arguments are illegal.
static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
int64_t stride) {
if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
stride == 0)
return -1;
return llvm::divideCeil(end - start, stride);
LogicalResult SliceOp::inferReturnTypes(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type>& inferredReturnTypes) {
SliceOpAdaptor slice(operands, attributes);
// TODO(jpienaar): Update this code after refactoring verify.
if (failed(slice.verify(location.getValueOr(UnknownLoc::get(context))))) {
return failure();
Type ty = slice.operand().getType();
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) {
// The operand type is unranked, so the best we can infer for the result
// type is an unranked tensor with the same element type as the operand
// type.
return success();
ShapedType attr_ty = slice.start_indices().getType();
if (attr_ty.getRank() != 1) {
return emitOptionalError(location, "start_indices has rank ",
attr_ty.getRank(), " instead of required rank 1");
int64_t rank = ranked_ty.getRank();
if (attr_ty.getNumElements() != rank) {
return emitOptionalError(
location, "the number of elements in start_indices (",
attr_ty.getNumElements(), ") does not match the rank of the operand (",
rank, ")");
if (!attr_ty.getElementType().isSignlessInteger(64) ||
slice.limit_indices().getType() != attr_ty ||
slice.strides().getType() != attr_ty) {
// Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp
// having been verified at this point. Emit an error message that matches
// the one that would be reported by AllTypesMatch for a more consistent
// user experience.
// TODO(b/171567182): Clean this up after AllTypesMatch has been refactored.
return emitOptionalError(location,
"failed to verify that all of {start_indices, "
"limit_indices, strides} have same type");
SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
SmallVector<int64_t, 4> stride_vals(slice.strides().getValues<int64_t>());
SmallVector<int64_t, 4> shape;
for (int64_t i = 0, e = rank; i != e; i++) {
shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
{RankedTensorType::get(shape, ranked_ty.getElementType())});
return success();
template <typename I, typename E>
static void SliceElements(I values, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> starts, ArrayRef<int64_t> limits,
ArrayRef<int64_t> strides,
llvm::SmallVectorImpl<E>* out_values) {
assert(starts.size() == limits.size());
assert(starts.size() == strides.size());
if (starts.empty()) return;
int64_t start = starts.front();
int64_t limit = limits.front();
int64_t stride = strides.front();
if (starts.size() == 1) {
for (int i = start; i < limit; i += stride) {
out_values->push_back(*(values + i));
for (; start < limit; start += stride) {
auto begin = values + start * sizes.front();
SliceElements<I, E>(begin, sizes.drop_front(), starts.drop_front(),
limits.drop_front(), strides.drop_front(), out_values);
template <typename I, typename E>
static Attribute FoldSlice(SliceOp* op, I values) {
auto start = llvm::to_vector<6>(op->start_indices().getValues<int64_t>());
auto limit = llvm::to_vector<6>(op->limit_indices().getValues<int64_t>());
auto stride = llvm::to_vector<6>(op->strides().getValues<int64_t>());
auto result_type = op->operand().getType().cast<ShapedType>();
if (!result_type.hasStaticShape()) return {};
auto shape = result_type.getShape();
int64_t count = result_type.getNumElements();
if (count == 0) {
return DenseElementsAttr::get<E>(
// Compute the striding for each dimension.
llvm::SmallVector<int64_t, 6> sizes;
for (auto v : shape) {
count = count / v;
llvm::SmallVector<E, 6> out_values;
SliceElements<I, E>(values, sizes, start, limit, stride, &out_values);
return DenseElementsAttr::get(op->getResult().getType().cast<ShapedType>(),
OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
// Check if the SliceOp is a NoOp operation.
auto operand_type = getOperand().getType().cast<ShapedType>();
auto result_type = getResult().getType().cast<ShapedType>();
if (operand_type.hasStaticShape() && result_type.hasStaticShape() &&
(operand_type.getShape() == result_type.getShape())) {
return getOperand();
if (operands.empty() || !operands.front()) return {};
// Evaluate for statically valued inputs.
DenseElementsAttr elements = operands.front().dyn_cast<DenseElementsAttr>();
if (!elements) return {};
auto etype = elements.getType().getElementType();
if (etype.isa<IntegerType>()) {
return FoldSlice<DenseElementsAttr::IntElementIterator, APInt>(
this, elements.value_begin<APInt>());
if (etype.isa<FloatType>()) {
return FoldSlice<DenseElementsAttr::FloatElementIterator, APFloat>(
this, elements.value_begin<APFloat>());
return {};
namespace {
// In cases where a concat is fed into a slice, it is possible the concat
// can be simplified or bypassed. This checks which inputs to the concat are
// used by the slice, either reducing the number of concatenated values or
// entirely removes the concat.
struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
using OpRewritePattern<SliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SliceOp slice,
PatternRewriter& rewriter) const override {
auto result_ty = slice.getType().cast<ShapedType>();
if (!result_ty.hasStaticShape()) {
return failure();
auto slice_input = slice.operand();
auto slice_input_ty = slice_input.getType().cast<ShapedType>();
auto concat = slice_input.getDefiningOp<ConcatenateOp>();
if (!concat) {
return failure();
auto dimension = concat.dimension();
auto start = slice.start_indices().getValues<APInt>();
auto limit = slice.limit_indices().getValues<APInt>();
auto slice_start = (*(start.begin() + dimension)).getSExtValue();
auto slice_limit = (*(limit.begin() + dimension)).getSExtValue();
// We need to determine what inputs from the concat affect the slice, and
// how the bounds of the slice need to be updated for the minimally required
// inputs.
int64_t running_size = 0;
int64_t front_offset = slice_input_ty.getShape()[dimension];
auto subset_start = concat.operand_end();
auto subset_end = concat.operand_end();
for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) {
auto input = *it;
ShapedType input_ty = input.getType().cast<ShapedType>();
if (input_ty.isDynamicDim(dimension)) {
return failure();
auto dim_size = input_ty.getShape()[dimension];
// If this position is in the slice its the start of the subset and we
// need to update the start and limit values.
if (running_size + dim_size > slice_start &&
subset_start == concat.operand_end()) {
subset_start = it;
front_offset = running_size;
// Determine the last required offset.
if (running_size < slice_limit) {
subset_end = it + 1;
running_size += dim_size;
auto subset_size = subset_end - subset_start;
// We need all inputs so no optimization.
if (subset_size == concat.getNumOperands()) {
return failure();
// If there's nothing to slice that means the output is an empty tensor and
// there is dead code. We do nothing here and rely on other passes to clean
// this up.
if (subset_size == 0) {
return failure();
if (subset_size > 1 && !concat.getResult().hasOneUse()) {
return failure();
auto concat_range = OperandRange(subset_start, subset_end);
auto new_concat = rewriter.create<ConcatenateOp>(
concat.getLoc(), concat_range, concat.dimension());
llvm::SmallVector<APInt, 6> new_start(start);
llvm::SmallVector<APInt, 6> new_limit(limit);
new_start[dimension] -= front_offset;
new_limit[dimension] -= front_offset;
auto attr_type = slice.start_indices().getType().cast<ShapedType>();
auto create = rewriter.create<SliceOp>(
slice.getLoc(), new_concat,
DenseIntElementsAttr::get(attr_type, new_start),
DenseIntElementsAttr::get(attr_type, new_limit), slice.strides());
rewriter.replaceOp(slice, create.getResult());
return success();
} // namespace
void SliceOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
// SortOp
void SortOp::build(OpBuilder& builder, OperationState& state,
ValueRange operands, int64_t dimension, bool is_stable) {
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
state.addAttribute("is_stable", builder.getBoolAttr(is_stable));
for (Value operand : operands) state.addTypes(operand.getType());
static LogicalResult Verify(SortOp op) {
Operation::operand_range operands = op.operands();
if (operands.empty()) return op.emitOpError("requires at least one input");
// TODO(antiagainst): verify partionally dynamic shapes
if (llvm::all_of(operands, [](Value operand) {
return operand.getType().cast<ShapedType>().hasRank();
})) {
ArrayRef<int64_t> input_shape =
if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
return operand.getType().cast<ShapedType>().getShape() != input_shape;
return op.emitOpError("requires all inputs to have the same dimensions");
int64_t rank = input_shape.size();
int64_t cmp_dim = op.dimension();
if (cmp_dim < -rank || cmp_dim >= rank)
return op.emitOpError("dimension attribute value must be in range [-")
<< rank << ", " << rank << "), but found " << cmp_dim;
Block& block = op.comparator().front();
size_t num_operands = op.getOperation()->getNumOperands();
if (block.getNumArguments() != 2 * num_operands)
return op.emitOpError("comparator block should have ")
<< 2 * num_operands << " arguments";
for (const auto& indexed_operand : llvm::enumerate(operands)) {
int index = indexed_operand.index();
Type element_type =
Type tensor_type = RankedTensorType::get({}, element_type);
for (int i : {2 * index, 2 * index + 1}) {
Type arg_type = block.getArgument(i).getType();
if (arg_type != tensor_type)
return op.emitOpError("comparator block argument #")
<< i << " should be of type " << tensor_type << " but got "
<< arg_type;
return success();
/// Drops the operands if the results are not used and they are not used in
/// op.comparator().
static LogicalResult SortDropEmptyUseArgs(SortOp op,
PatternRewriter& rewriter) {
DenseSet<unsigned> erased_args;
unsigned num_operands = op.getNumOperands();
for (unsigned i = 0; i < num_operands; ++i) {
if (!op.getResult(i).use_empty()) continue;
Block& block = op.comparator().front();
if (!block.getArgument(i * 2).use_empty()) continue;
if (!block.getArgument(i * 2 + 1).use_empty()) continue;
if (erased_args.empty()) return failure();
SmallVector<Value> new_operands;
SmallVector<unsigned> erased_block_args;
for (const auto& en : llvm::enumerate(op.operands())) {
if (erased_args.contains(en.index())) {
erased_block_args.push_back(en.index() * 2);
erased_block_args.push_back(en.index() * 2 + 1);
} else {
auto new_op = rewriter.create<SortOp>(op.getLoc(), new_operands,
op.dimension(), op.is_stable());
Region& region = new_op.comparator();
rewriter.inlineRegionBefore(op.comparator(), region, region.end());
SmallVector<Value> results;
for (unsigned i = 0, j = 0; i < num_operands; ++i) {
if (erased_args.contains(i)) {
} else {
rewriter.replaceOp(op, results);
return success();
/// Set the sorting dimension to the last dimension if it's not set and the rank
/// is known.
static LogicalResult SortOpInferDefaultDimension(SortOp op,
PatternRewriter& rewriter) {
auto ty = op.getResultTypes()[0].dyn_cast<ShapedType>();
if (!ty) {
return failure();
if (op.dimension() != -1) {
return failure();
IntegerAttr dim = rewriter.getI64IntegerAttr(ty.getRank() - 1);
auto new_op = rewriter.create<SortOp>(op.getLoc(), op.getResultTypes(),
op.operands(), dim, op.is_stableAttr());
Region& region = new_op.comparator();
rewriter.inlineRegionBefore(op.comparator(), region, region.end());
rewriter.replaceOp(op, new_op.getResults());
return success();
void SortOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* /*context*/) {
// TransposeOp
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
for (const auto& it : llvm::enumerate(permutation().getValues<APInt>())) {
if (it.index() != it.value()) {
return {};
return getOperand();
// transpose(transpose(X)) => transpose(X)
static LogicalResult EliminateRedundantTranspse(TransposeOp op,
PatternRewriter& rewriter) {
auto tranpose_operand = op.operand().getDefiningOp<TransposeOp>();
if (!tranpose_operand) {
return failure();
auto operand_permutation = tranpose_operand.permutation().getValues<APInt>();
auto new_permutation =
[&operand_permutation](const APInt& index) -> APInt {
return operand_permutation[index.getSExtValue()];
rewriter.replaceOpWithNewOp<TransposeOp>(op, op.getResult().getType(),
return success();
// transpose(broadcast_in_dim(X)) => broadcast_in_dim(X)
static LogicalResult EliminateBroadcastInDimTranspose(
TransposeOp op, PatternRewriter& rewriter) {
auto broadcast_in_dim_op = op.operand().getDefiningOp<BroadcastInDimOp>();
if (!broadcast_in_dim_op) {
return failure();
DenseIntElementsAttr broadcast_dimensions =
DenseIntElementsAttr permutation = op.permutation();
SmallVector<int64_t> new_broadcast_dimensions;
for (auto dimension : broadcast_dimensions.getValues<int64_t>()) {
int64_t index = 0;
for (auto p : permutation.getValues<int64_t>()) {
if (p == dimension) {
op, op->getResultTypes(), broadcast_in_dim_op.operand(),
return success();
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* /*context*/) {
LogicalResult TransposeOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
TransposeOp::Adaptor adaptor(operands);
Value operand = adaptor.operand();
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
// Not support unranked type a.t.m.
if (!operand_type) return failure();
Location loc = this->getLoc();
SmallVector<int64_t, 4> permutation(this->permutation().getValues<int64_t>());
SmallVector<Value, 4> shape_values(permutation.size());
Type shape_scalar_type = builder.getIndexType();
auto to_shape_scalar_type = [&](Value v) {
return MaybeCastTo(builder, loc, v, shape_scalar_type);
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
int64_t idx = element.index();
auto* it = std::find(permutation.begin(), permutation.end(), idx);
Value value_dim = to_shape_scalar_type(
builder.createOrFold<tensor::DimOp>(loc, operand, element.index()));
shape_values[std::distance(permutation.begin(), it)] = value_dim;
Value output_shape = builder.create<tensor::FromElementsOp>(
return success();
// Method for InferTypeOpInterface: infer the return type from the operand type
// and the permutation.
LogicalResult TransposeOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> loc, ValueShapeRange operands,
DictionaryAttr attributes, RegionRange,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnTypes) {
auto type = operands[0].getType();
auto rankedTy = type.dyn_cast<RankedTensorType>();
if (!rankedTy) {
auto shapedTy = type.dyn_cast<ShapedType>();
if (!shapedTy)
return emitOptionalError(loc,
"expected shaped type operand, got: ", type);
return success();
auto permutation = attributes.getAs<DenseIntElementsAttr>("permutation");
int64_t rank = rankedTy.getRank();
if (!permutation)
return emitOptionalError(loc,
"missing permutation attribute on TransposeOp");
if (permutation.getType().getRank() != 1)
return emitOptionalError(loc, "TransposeOp permutation has rank ",
" instead of rank 1");
if (permutation.size() != rank)
return emitOptionalError(loc, "TransposeOp operand rank ", rank,
" does not match permutation size ",
SmallVector<int64_t> resultShape;
ArrayRef<int64_t> inputShape = rankedTy.getShape();
for (int64_t dim : permutation.getValues<int64_t>()) {
if (dim >= rank) return failure();
inferredReturnTypes.emplace_back(resultShape, rankedTy.getElementType());
return success();
// TriangularSolveOp
static LogicalResult Verify(TriangularSolveOp op) {
auto a_type = op.a().getType().dyn_cast<RankedTensorType>();
// Skip verifier if a is unranked tensor.
if (!a_type) return success();
// Check that a should have rank >= 2
auto a_rank = a_type.getRank();
if (a_rank < 2)
return op.emitOpError()
<< "operand 'a' must have rank >= 2, but got " << a_type;
// The two minor dimensions of a must have same size.
if (a_type.getDimSize(a_rank - 2) != a_type.getDimSize(a_rank - 1))
return op.emitOpError() << "two minor dimensions of operand 'a' must have "
"equal size, but got "
<< a_type;
auto b_type = op.b().getType().dyn_cast<RankedTensorType>();
// If b is unranked skip remaining checks.
if (!b_type) return success();
// Check that a and b have same rank.
auto b_rank = b_type.getRank();
if (a_rank != b_rank)
return op.emitOpError() << "operands must have equal rank, but got "
<< a_type << " and " << b_type;
// The shared dimension of a and b should match.
if (a_type.getDimSize(a_rank - 1) !=
b_type.getDimSize(b_rank - (op.left_side() ? 2 : 1)))
return op.emitOpError() << "shared dimension of operands 'a' and 'b' does "
"not match, but got "
<< a_type << " and " << b_type;
// The leading batch dimensions of a and b must be equal.
auto a_batch_dims = a_type.getShape().drop_back(2);
auto b_batch_dims = b_type.getShape().drop_back(2);
if (a_batch_dims != b_batch_dims)
return op.emitOpError()
<< "leading batch dimensions of the operands must be same, but got "
<< a_type << " and " << b_type;
// Result and argument b must have same shape.
auto result_type = op.getType().dyn_cast<RankedTensorType>();
if (!result_type) return success();
if (result_type != b_type)
return op.emitOpError()
<< "result and operand 'b' must have same shape, but got "
<< result_type << " and " << b_type;
return success();
// GetTupleElementOp
void GetTupleElementOp::build(OpBuilder& builder, OperationState& result,
Value tuple, int32_t index) {
if (auto tuple_type = tuple.getType().dyn_cast<TupleType>()) {
auto element_type = tuple_type.getType(index);
build(builder, result, element_type, tuple,
build(builder, result, tuple.getType(), tuple,
// TupleOp
void TupleOp::build(OpBuilder& builder, OperationState& result,
ValueRange values) {
SmallVector<Type, 4> types;
for (auto val : values) {
build(builder, result, builder.getTupleType(types), values);
// UnaryEinsumOp
void UnaryEinsumOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
// CompareOp
void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
Value rhs, StringAttr comparison_direction,
StringAttr compare_type) {
auto new_type =
UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
build(builder, result, new_type, lhs, rhs, comparison_direction,
LogicalResult CompareOp::inferReturnTypeComponents(
mlir::MLIRContext* ctx, llvm::Optional<mlir::Location>,
ValueShapeRange operands, mlir::DictionaryAttr, mlir::RegionRange,
llvm::SmallVectorImpl<mlir::ShapedTypeComponents>& inferredReturnTypes) {
OpBuilder builder(ctx);
auto arg_ty = operands.front().getType().cast<TensorType>();
inferredReturnTypes.push_back({arg_ty.getShape(), builder.getI1Type()});
return success();
LogicalResult CompareOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return deriveShapeFromOperand(&builder, getOperation(), operands.front(),
template <typename T>
struct less : std::less<T> {};
template <>
struct less<APInt> {
bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); }
template <typename T>
struct less_equal : std::less_equal<T> {};
template <>
struct less_equal<APInt> {
bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); }
template <typename T>
struct greater : std::greater<T> {};
template <>
struct greater<APInt> {
bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); }
template <typename T>
struct greater_equal : std::greater_equal<T> {};
template <>
struct greater_equal<APInt> {
bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); }
template <typename Op, typename ElementType, typename SrcType, typename Convert>
static Attribute CompareFolder(CompareOp op, ArrayRef<Attribute> attrs) {
if (!attrs[0] || !attrs[1]) return {};
DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
if (!lhs || !rhs) return {};
ShapedType operand_type =
op.getOperand(0).getType().template cast<ShapedType>();
if (!operand_type.hasStaticShape()) {
return {};
if (!operand_type.getElementType().isa<ElementType>()) {
return {};
SmallVector<bool, 6> values;
for (const auto zip :
llvm::zip(lhs.getValues<SrcType>(), rhs.getValues<SrcType>())) {
values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
auto result_ty = op.getType().cast<ShapedType>();
return DenseElementsAttr::get(result_ty, values);
OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
auto result_ty = getType().cast<ShapedType>();
if (!result_ty.hasStaticShape()) return {};
auto direction = comparison_direction();
auto lhs_ty = getElementTypeOrSelf(lhs());
if (lhs() == rhs() && !lhs_ty.isa<FloatType>() &&
(!lhs_ty.isa<ComplexType>() ||
!lhs_ty.cast<ComplexType>().getElementType().isa<FloatType>())) {
if (direction == "LE" || direction == "EQ" || direction == "GE") {
return DenseIntElementsAttr::get(result_ty, {true});
return DenseIntElementsAttr::get(result_ty, {false});
auto op_el_type = lhs().getType().cast<ShapedType>().getElementType();
// Fold tensor<*xi1> != false to just return tensor<*xi1>
if (direction == "NE" && op_el_type.isInteger(1)) {
DenseIntElementsAttr cst_attr;
if (matchPattern(lhs(), m_Constant(&cst_attr))) {
if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
return rhs();
if (matchPattern(rhs(), m_Constant(&cst_attr))) {
if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
return lhs();
// Fold tensor<*xi1> == True to just return tensor<*xi1>
if (direction == "EQ" && op_el_type.isInteger(1)) {
DenseIntElementsAttr cst_attr;
if (matchPattern(lhs(), m_Constant(&cst_attr))) {
if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
return rhs();
if (matchPattern(rhs(), m_Constant(&cst_attr))) {
if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
return lhs();
if (!operands[0] || !operands[1]) {
return {};
#define COMPARE_FOLDER(Op, comparison, Func) \
if (direction == comparison) { \
if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \
*this, operands)) \
return folded; \
if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>( \
*this, operands)) \
return folded; \
COMPARE_FOLDER(CompareOp, "EQ", std::equal_to);
COMPARE_FOLDER(CompareOp, "NE", std::not_equal_to);
COMPARE_FOLDER(CompareOp, "LT", less);
COMPARE_FOLDER(CompareOp, "LE", less_equal);
COMPARE_FOLDER(CompareOp, "GT", greater);
COMPARE_FOLDER(CompareOp, "GE", greater_equal);
return {};
// ScatterOp
llvm::SmallVector<Attribute, 4> evaluateMhloRegion(Region& region,
ArrayRef<Attribute> inputs) {
if (region.getNumArguments() != inputs.size()) return {};
llvm::DenseMap<Value, Attribute> values;
for (auto it : llvm::zip(region.getArguments(), inputs)) {
values.try_emplace(std::get<0>(it), std::get<1>(it));
for (auto& op : region.getOps()) {
llvm::SmallVector<Attribute, 4> inputs;
for (auto& operand : op.getOpOperands()) {
if (isa<ReturnOp>(op)) return inputs;
llvm::SmallVector<OpFoldResult, 4> results;
if (failed(op.fold(inputs, results))) return {};
for (auto it : llvm::zip(op.getResults(), results)) {
if (!std::get<1>(it).is<Attribute>()) return {};
values.insert({std::get<0>(it), std::get<1>(it).get<Attribute>()});
return {};
OpFoldResult ScatterOp::fold(ArrayRef<Attribute> operands) {
auto base = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto index = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
auto update = operands[2].dyn_cast_or_null<DenseElementsAttr>();
if (!base || !index || !update) return {};
auto base_type = base.getType().dyn_cast<RankedTensorType>();
auto index_type = index.getType().dyn_cast<RankedTensorType>();
auto update_type = update.getType().dyn_cast<RankedTensorType>();
if (!base_type || !index_type || !update_type) return {};
// Add the virtual trailing dimension of size 1 if index_vector_dim equals to
// index_type.rank.
const int64_t index_vector_dim =
if (index_vector_dim == index_type.getRank()) {
auto index_shape = index_type.getShape().vec();
index_type =
RankedTensorType::get(index_shape, index_type.getElementType());
index = index.reshape(index_type).cast<DenseIntElementsAttr>();
// Increment the multi-dimensional index vector based on the limits for each
// dimension specified by shape and returns false if the index rolled around
// with true otherwise.
auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
llvm::ArrayRef<int64_t> shape) {
for (int64_t i = index.size() - 1; i >= 0; --i) {
if (index[i] < shape[i]) return true;
index[i] = 0;
return false;
// Iterate over all elements of the update tensor, then find the corresponding
// value in the indices tensor to determine which location we have to update
// in the base/result tensor.
llvm::SmallVector<Attribute, 8> results(base.getValues<Attribute>());
llvm::SmallVector<uint64_t, 8> update_index(update_type.getRank(), 0);
llvm::SmallVector<uint64_t, 8> index_index;
llvm::SmallVector<uint64_t, 8> base_index;
do {
// Compute the index for the slice of the indices tensor for this update
// value.
if (index_vector_dim == 0) index_index.push_back(0);
for (int64_t i = 0; i < update_index.size(); ++i) {
if (llvm::count(scatter_dimension_numbers().getUpdateWindowDims(), i) ==
if (index_index.size() == index_vector_dim) index_index.push_back(0);
// Compute the index for the given update value in the base tensor.
base_index.assign(base_type.getRank(), 0);
uint64_t index_count = index_type.getShape()[index_vector_dim];
for (uint64_t i = 0; i < index_count; ++i) {
uint64_t operand_dim =
index_index[index_vector_dim] = i;
base_index[operand_dim] +=
uint64_t update_window_dim_index = 0;
auto inserted_window_dims =
auto update_window_dims = scatter_dimension_numbers().getUpdateWindowDims();
for (uint64_t i = 0; i < base_index.size(); ++i) {
if (llvm::count(inserted_window_dims, i)) continue;
base_index[i] +=
// Compute the linear index for the index into the base tensor.
int64_t linear_base_index = 0;
int64_t linear_base_index_multiplyer = 1;
for (int64_t i = base_index.size() - 1; i >= 0; --i) {
// Out of bound index have backend specific behaviour so avoid folding it.
if (base_index[i] < 0 || base_index[i] >= base_type.getShape()[i])
return {};
linear_base_index += base_index[i] * linear_base_index_multiplyer;
linear_base_index_multiplyer *= base_type.getShape()[i];
// Evaluate update computation and update the value with the newly computed
// attribute in the base tensor.
auto lhs = DenseElementsAttr::get(
RankedTensorType::get({}, base_type.getElementType()),
auto rhs = DenseElementsAttr::get(
RankedTensorType::get({}, base_type.getElementType()),
auto new_value = evaluateMhloRegion(update_computation(), {lhs, rhs});
if (new_value.size() != 1 || !new_value[0]) return {};
results[linear_base_index] =
} while (next_index(update_index, update_type.getShape()));
return DenseElementsAttr::get(base_type, results);
// WhileOp
static LogicalResult verify(WhileOp whileOp) {
if (whileOp->getNumOperands() != whileOp.cond().front().getNumArguments())
return whileOp.emitOpError()
<< "mismatch in operand count (" << whileOp->getNumOperands()
<< ") vs the condition block argument count ("
<< whileOp.cond().front().getNumArguments() << ")";
if (whileOp->getNumOperands() != whileOp.body().front().getNumArguments())
return whileOp.emitOpError()
<< "mismatch in operand count (" << whileOp->getNumOperands()
<< ") vs the body block argument count ("
<< whileOp.body().front().getNumArguments() << ")";
for (const auto& enumeratedOperands :
whileOp.body().front().getArgumentTypes()))) {
int argCount = enumeratedOperands.index();
const auto& operands = enumeratedOperands.value();
Type operandType = std::get<0>(operands);
Type condType = std::get<1>(operands);
Type bodyType = std::get<2>(operands);
if (operandType != condType)
return whileOp.emitOpError()
<< "type mismatch between operand #" << argCount
<< " and the matching condition block argument: " << operandType
<< " vs " << condType;
if (operandType != bodyType)
return whileOp.emitOpError()
<< "type mismatch between operand #" << argCount
<< " and the matching body block argument: " << operandType
<< " vs " << bodyType;
// Check the return type for the condition block.
auto condReturnOp = cast<ReturnOp>(whileOp.cond().front().back());
if (condReturnOp->getNumOperands() != 1)
return condReturnOp.emitOpError()
<< "expects a single operand for while condition body return, got "
<< condReturnOp->getNumOperands();
auto operandType =
if (!operandType || // TODO(b/210930774): operandType.getRank() != 0 ||
!operandType.getElementType().isa<IntegerType>() ||
operandType.getElementType().cast<IntegerType>().getWidth() != 1)
return condReturnOp.emitOpError()
<< "expects a zero-ranked tensor of i1, got "
<< condReturnOp->getOperand(0).getType();
// Check the return type for the body block.
auto bodyReturnOp = cast<ReturnOp>(whileOp.body().front().back());
if (bodyReturnOp->getNumOperands() != whileOp->getNumOperands())
return bodyReturnOp.emitOpError()
<< "expects body to return a many value as the operands ("
<< whileOp->getNumOperands() << "), got "
<< bodyReturnOp->getNumOperands();
for (const auto& enumeratedOperandTypes : llvm::enumerate(llvm::zip(
bodyReturnOp->getOperandTypes(), whileOp->getOperandTypes()))) {
Type operandType = std::get<0>(enumeratedOperandTypes.value());
Type returnType = std::get<1>(enumeratedOperandTypes.value());
if (operandType != returnType)
return bodyReturnOp.emitOpError()
<< "type mismatch between operand #"
<< enumeratedOperandTypes.index()
<< " and the enclosing WhileOp returned value: " << operandType
<< " vs " << returnType;
return success();
/// Print a `while` op.
/// op ::= `mhlo.while` `(` assignment-list `)` `:` types attribute-dict
/// `cond` region
/// `do` region
/// assignment-list ::= assignment | assignment `,` assignment-list
/// assignment ::= ssa-value `=` ssa-value
void printWhileOp(WhileOp op, OpAsmPrinter& p) {
p << '(';
llvm::zip(op.getBody()->getArguments(), op->getOperands()), p,
[&](auto zip) {
p << " = ";
p << ")";
if (op->getNumOperands()) {
p << " : ";
llvm::interleaveComma(op->getOperandTypes(), p);
p << " cond ";
p.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false);
p << " do ";
p.printRegion(op->getRegion(1), /*printEntryBlockArgs=*/false);
ParseResult parseWhileOp(OpAsmParser& parser, OperationState& result) {
llvm::SMLoc loc = parser.getCurrentLocation();
// Parse the operands of the while: these are of the form:
// %iter_arg = %init_val
// where %iter_arg is the name of the block argument in the cond/body blocks
// and %init_val is the actual operand.
SmallVector<OpAsmParser::OperandType> operands;
SmallVector<OpAsmParser::OperandType> iterArgs;
if (parser.parseLParen()) return failure();
do {
if (succeeded(parser.parseOptionalRParen())) break;
OpAsmParser::OperandType operand, iterArg;
if (parser.parseOperand(iterArg) || parser.parseEqual() ||
return failure();
if (succeeded(parser.parseOptionalRParen())) break;
} while (true);
if (!operands.empty()) {
if (parser.parseColon() || parser.parseTypeList(result.types))
return failure();
if (parser.resolveOperands(operands, result.types, loc, result.operands) ||
parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
parser.parseKeyword("cond") ||
parser.parseRegion(*result.addRegion(), iterArgs, result.types) ||
parser.parseKeyword("do") ||
parser.parseRegion(*result.addRegion(), iterArgs, result.types))
return failure();
return success();
static LogicalResult whileCanonicalization(WhileOp whileOp,
PatternRewriter& rewriter) {
// Turn loop invariant values into implicit capture.
// Check if there is at least one value is forwarded from one iteration to the
// next, or one of the yielded value is an implicit capture already. Otherwise
// there is nothing to do here.
Block* cond = whileOp.getBody(0);
Block* body = whileOp.getBody(1);
auto bodyReturnOp = cast<ReturnOp>(body->getTerminator());
if (!llvm::any_of(llvm::zip(whileOp->getOperands(), body->getArguments(),
[&](auto zip) {
return (std::get<0>(zip) == std::get<2>(zip) ||
std::get<1>(zip) == std::get<2>(zip));
return rewriter.notifyMatchFailure(whileOp, "no loop invariant found");
SmallVector<Value> newOperands, resultsToReplace;
SmallVector<unsigned> invariantArgIdxs;
for (const auto& enumeratedOperands : llvm::enumerate(llvm::zip(
whileOp.getOperands(), cond->getArguments(), body->getArguments(),
bodyReturnOp->getOperands(), whileOp->getResults()))) {
const auto& operands = enumeratedOperands.value();
Value whileOperand = std::get<0>(operands);
BlockArgument condBlockArg = std::get<1>(operands);
BlockArgument bodyBlockArg = std::get<2>(operands);
Value bodyReturnOperand = std::get<3>(operands);
Value whileResult = std::get<4>(operands);
bool forwarded = (whileOperand == bodyReturnOperand ||
bodyBlockArg == bodyReturnOperand);
if (forwarded) {
for (int idx : llvm::reverse(invariantArgIdxs))
WhileOp newWhileOp = rewriter.create<WhileOp>(
whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands);
for (auto results : llvm::zip(resultsToReplace, newWhileOp->getResults()))
return success();
void WhileOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
using mlir::hlo::parseWindowAttributes;
using mlir::hlo::printWindowAttributes;
} // namespace mhlo
} // namespace mlir
#include "mlir-hlo/Dialect/mhlo/IR/"
namespace mlir {
namespace mhlo {
// mhlo Dialect Interfaces
namespace {
struct HLOInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
// Allow all call operations to be inlined.
bool isLegalToInline(Operation* call, Operation* callable,
bool wouldBeCloned) const final {
return true;
// We don't have any special restrictions on what can be inlined into
// destination regions (e.g. while/conditional bodies). Always allow it.
bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
BlockAndValueMapping& valueMapping) const final {
return true;
// Operations in mhlo dialect are always legal to inline since they are
// pure.
bool isLegalToInline(Operation*, Region*, bool,
BlockAndValueMapping&) const final {
return true;
} // end anonymous namespace
// mhlo Dialect Constructor
MhloDialect::MhloDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context, TypeID::get<MhloDialect>()) {
#define GET_OP_LIST
#include "mlir-hlo/Dialect/mhlo/IR/"
#include "mlir-hlo/Dialect/mhlo/IR/"
Type MhloDialect::parseType(DialectAsmParser& parser) const {
StringRef data_type;
if (parser.parseKeyword(&data_type)) return Type();
if (data_type == "token") return TokenType::get(getContext());
parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
return nullptr;
void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
if (type.isa<TokenType>()) {
os << "token";
os << "<unknown mhlo type>";
// Entry point for Attribute parsing, TableGen generated code will handle the
// dispatch to the individual classes.
Attribute MhloDialect::parseAttribute(DialectAsmParser& parser,
Type type) const {
StringRef attr_tag;
if (failed(parser.parseKeyword(&attr_tag))) return Attribute();
Attribute attr;
auto parse_result = generatedAttributeParser(parser, attr_tag, type, attr);
if (parse_result.hasValue()) return attr;
parser.emitError(parser.getNameLoc(), "unknown mhlo attribute");
return Attribute();
// Entry point for Attribute printing, TableGen generated code will handle the
// dispatch to the individual classes.
void MhloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const {
LogicalResult result = generatedAttributePrinter(attr, os);
/// Helpers for attributes parsing.
static ParseResult parseDims(AsmParser& parser, SmallVector<int64_t>& dims) {
if (parser.parseLSquare()) return failure();
while (failed(parser.parseOptionalRSquare())) {
if (parser.parseInteger(dims.back())) return failure();
return success();
static ParseResult parseDimsWithMinimumElements(AsmParser& parser,
SmallVector<int64_t>& dims,
int min_elements) {
if (failed(parseDims(parser, dims))) return failure();
if (dims.size() < min_elements)
return parser.emitError(parser.getCurrentLocation())
<< "expected at least " << min_elements << " element(s), found "
<< dims.size();
return success();
/// Parse a custom attribute that resembles a struct of the form
/// <
/// foo = something_parsed_by_custom_parser,
/// bar = something_parsed_by_different_custom_parser,
/// baz something_parsed_by_another_custom_parser
/// >
/// The optional argument `parse_equal` array can be used to denote if
/// '=' follows the keyword (see baz in the example above) for a field. If
/// not provided, all fields must be followed by a '='.
static ParseResult parseStruct(
AsmParser& parser, ArrayRef<StringRef> keywords,
ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs,
ArrayRef<bool> parse_equal = {}) {
assert(keywords.size() == parseFuncs.size());
assert(parse_equal.empty() || parse_equal.size() == keywords.size());
SmallVector<bool> seen(keywords.size(), false);
while (failed(parser.parseOptionalGreater())) {
bool foundOne = false;
for (const auto& it : llvm::enumerate(keywords)) {
size_t index = it.index();
StringRef keyword = it.value();
if (succeeded(parser.parseOptionalKeyword(keyword))) {
if (seen[index]) {
return parser.emitError(parser.getCurrentLocation())
<< "duplicated `" << keyword << "` entry";
if (parse_equal.empty() || parse_equal[index]) {
if (failed(parser.parseEqual())) return failure();
if (failed(parseFuncs[index]())) return failure();
if (failed(parser.parseOptionalComma())) return parser.parseGreater();
seen[index] = true;
foundOne = true;
if (!foundOne) {
auto parseError = parser.emitError(parser.getCurrentLocation())
<< "expected one of: ";
llvm::interleaveComma(keywords, parseError, [&](StringRef kw) {
parseError << '`' << kw << '`';
return parseError;
return success();
// Helpers to print an optional array or integer field, to simplify writing
// attribute printers.
template <typename T>
static void printField(AsmPrinter& printer, StringRef name, T field,
StringRef& separator) {
if (field != 0) {
printer << separator << name << " = " << field;
separator = ", ";
template <typename T>
static void printField(AsmPrinter& printer, StringRef name, ArrayRef<T> field,
StringRef& separator) {
if (!field.empty()) {
printer << separator << name << " = [";
llvm::interleaveComma(field, printer);
printer << "]";
separator = ", ";
template <typename... Ts>
static void printStruct(AsmPrinter& printer, StringRef name,
Ts... printFields) {
printer << "<";
StringRef separator = "";
// Fold expression to print each entry in the parameter pack.
// TODO(mhlo-team): this can be simplified when TF moves to C++17.
using unused = int[];
(void)unused{0, (printField(printer, std::get<0>(printFields),
std::get<1>(printFields), separator),
printer << ">";
// Custom printer and parser for ScatterDimensionNumbersAttr.
void ScatterDimensionNumbersAttr::print(AsmPrinter& printer) const {
printStruct(printer, "scatter",
std::make_pair("update_window_dims", getUpdateWindowDims()),
std::make_pair("inserted_window_dims", getInsertedWindowDims()),
std::make_pair("index_vector_dim", getIndexVectorDim()));
Attribute ScatterDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
if (failed(parser.parseLess())) return {};
SmallVector<int64_t> update_window_dims;
SmallVector<int64_t> inserted_window_dims;
SmallVector<int64_t> scatter_dims_to_operand_dims;
int64_t index_vector_dim = 0;
if (failed(parseStruct(
{"update_window_dims", "inserted_window_dims",
"scatter_dims_to_operand_dims", "index_vector_dim"},
{[&]() { return parseDims(parser, update_window_dims); },
[&]() { return parseDims(parser, inserted_window_dims); },
[&]() { return parseDims(parser, scatter_dims_to_operand_dims); },
[&]() { return parser.parseInteger(index_vector_dim); }}))) {
<< "failed parsing scatter dimension numbers attribute";
return {};
return ScatterDimensionNumbersAttr::get(
parser.getContext(), update_window_dims, inserted_window_dims,
scatter_dims_to_operand_dims, index_vector_dim);
// Custom printer and parser for GatherDimensionNumbersAttr.
void GatherDimensionNumbersAttr::print(AsmPrinter& printer) const {
printStruct(printer, "gather", std::make_pair("offset_dims", getOffsetDims()),
std::make_pair("collapsed_slice_dims", getCollapsedSliceDims()),
std::make_pair("start_index_map", getStartIndexMap()),
std::make_pair("index_vector_dim", getIndexVectorDim()));
Attribute GatherDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
if (failed(parser.parseLess())) return {};
SmallVector<int64_t> offset_dims;
SmallVector<int64_t> collapsed_slice_dims;
SmallVector<int64_t> start_index_map;
int64_t index_vector_dim = 0;
if (failed(parseStruct(
{"offset_dims", "collapsed_slice_dims", "start_index_map",
{[&]() { return parseDims(parser, offset_dims); },
[&]() { return parseDims(parser, collapsed_slice_dims); },
[&]() { return parseDims(parser, start_index_map); },
[&]() { return parser.parseInteger(index_vector_dim); }}))) {
<< "failed parsing gather dimension numbers attribute";
return {};
return GatherDimensionNumbersAttr::get(parser.getContext(), offset_dims,
collapsed_slice_dims, start_index_map,
// Custom printer and parser for DotDimensionNumbersAttr.
void DotDimensionNumbersAttr::print(AsmPrinter& printer) const {
printer, "dot",
std::make_pair("lhs_batching_dimensions", getLhsBatchingDimensions()),
std::make_pair("rhs_batching_dimensions", getRhsBatchingDimensions()),
Attribute DotDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
if (failed(parser.parseLess())) return {};
SmallVector<int64_t> lhs_batching_dimensions;
SmallVector<int64_t> rhs_batching_dimensions;
SmallVector<int64_t> lhs_contracting_dimensions;
SmallVector<int64_t> rhs_contracting_dimensions;
if (failed(parseStruct(
{"lhs_batching_dimensions", "rhs_batching_dimensions",
"lhs_contracting_dimensions", "rhs_contracting_dimensions"},
{[&]() { return parseDims(parser, lhs_batching_dimensions); },
[&]() { return parseDims(parser, rhs_batching_dimensions); },
[&]() { return parseDims(parser, lhs_contracting_dimensions); },
[&]() { return parseDims(parser, rhs_contracting_dimensions); }}))) {
<< "failed parsing dot dimension numbers attribute";
return {};
return DotDimensionNumbersAttr::get(
parser.getContext(), lhs_batching_dimensions, rhs_batching_dimensions,
lhs_contracting_dimensions, rhs_contracting_dimensions);
namespace {
enum NonSpatialDim : int64_t {
IOBatch = -1, // Input or output batch dimension
IOFeature = -2, // Input or output feature dimension
KIFeature = -3, // Kernel input feature dimension
KOFeature = -4, // Kernel output feature dimensions.
struct DenseMapInfoNonSpatialDim {
static inline NonSpatialDim getEmptyKey() {
return NonSpatialDim(DenseMapInfo<int64_t>::getEmptyKey());
static inline NonSpatialDim getTombstoneKey() {
return NonSpatialDim(DenseMapInfo<int64_t>::getTombstoneKey());
static unsigned getHashValue(const NonSpatialDim& Key) {
return DenseMapInfo<int64_t>::getHashValue(Key);
static bool isEqual(const NonSpatialDim& LHS, const NonSpatialDim& RHS) {
return LHS == RHS;
char NonSpatialDimToString(NonSpatialDim dim) {
switch (dim) {
case IOBatch:
return 'b';
case IOFeature:
return 'f';
case KIFeature:
return 'i';
case KOFeature:
return 'o';
llvm_unreachable("Unknown NonSpatialDim");
} // namespace
// Custom printer and parser for convolution attribute.
void printConvolutionDimensions(AsmPrinter& p, ConvDimensionNumbersAttr dnums) {
// TODO(b/202040055): we should check the attribute invariant and print the
// "raw" form if they are violated, otherwise we'll crash here.
constexpr int64_t kUnknownDim = std::numeric_limits<int64_t>::min();
auto print_dim =
[&](ArrayRef<int64_t> spatial_dims,
ArrayRef<std::pair<int64_t, NonSpatialDim>> non_spatial_dims) {
int64_t num_dims = 0;
if (!spatial_dims.empty()) {
num_dims =
*std::max_element(spatial_dims.begin(), spatial_dims.end()) + 1;
for (const auto& dim : non_spatial_dims) {
num_dims = std::max(num_dims, dim.first + 1);
llvm::SmallVector<int64_t> dims(num_dims, kUnknownDim);
// Fill each element of dims with a (< 0) NonSpatialDim enum or a (>=0)
// spatial dimension index.
for (const std::pair<int64_t, NonSpatialDim>& non_spatial_dim :
non_spatial_dims) {
dims[non_spatial_dim.first] = non_spatial_dim.second;
for (const auto& spatial_dim : llvm::enumerate(spatial_dims)) {
dims[spatial_dim.value()] = static_cast<int64_t>(spatial_dim.index());
// Each dimension numbers will be printed as a comma separated list
// surrounded by square brackets, e.g., [b, 0, 1, 2, f]
p << '[';
llvm::interleaveComma(dims, p, [&](int64_t dim) {
if (dim == kUnknownDim) {
p << "?";
} else if (dim >= 0) {
p << dim;
} else {
p << NonSpatialDimToString(static_cast<NonSpatialDim>(dim));
p << ']';
{{dnums.getInputBatchDimension(), IOBatch},
{dnums.getInputFeatureDimension(), IOFeature}});
p << "x";
{{dnums.getKernelInputFeatureDimension(), KIFeature},
{dnums.getKernelOutputFeatureDimension(), KOFeature}});
p << "->";
{{dnums.getOutputBatchDimension(), IOBatch},
{dnums.getOutputFeatureDimension(), IOFeature}});
// Custom printer and parser for ConvDimensionNumbersAttr.
void ConvDimensionNumbersAttr::print(AsmPrinter& printer) const {
printer << "<";
printConvolutionDimensions(printer, *this);
printer << ">";
// If the attribute is written with `#mhlo.conv raw<`, we parse it as a struct
// instead of the compressed format. This enables writing tests covering
// impossible/invalid internal representation for the attribute.
static ParseResult parseConvolutionDimensionsRaw(
AsmParser& parser, ConvDimensionNumbersAttr& dnums) {
int64_t input_batch_dimension = 0;
int64_t input_feature_dimension = 0;
SmallVector<int64_t> input_spatial_dimensions;
int64_t kernel_input_feature_dimension = 0;
int64_t kernel_output_feature_dimension = 0;
SmallVector<int64_t> kernel_spatial_dimensions;
int64_t output_batch_dimension = 0;
int64_t output_feature_dimension = 0;
SmallVector<int64_t> output_spatial_dimensions;
if (failed(parseStruct(
{"input_batch_dimension", "input_feature_dimension",
"input_spatial_dimensions", "kernel_input_feature_dimension",
"kernel_output_feature_dimension", "kernel_spatial_dimensions",
"output_batch_dimension", "output_feature_dimension",
[&]() { return parser.parseInteger(input_batch_dimension); },
[&]() { return parser.parseInteger(input_feature_dimension); },
[&]() { return parseDims(parser, input_spatial_dimensions); },
[&]() {
return parser.parseInteger(kernel_input_feature_dimension);
[&]() {
return parser.parseInteger(kernel_output_feature_dimension);
[&]() { return parseDims(parser, kernel_spatial_dimensions); },
[&]() { return parser.parseInteger(output_batch_dimension); },
[&]() { return parser.parseInteger(output_feature_dimension); },
[&]() { return parseDims(parser, output_spatial_dimensions); },
}))) {
<< "failed parsing dot dimension numbers attribute";
return failure();
dnums = ConvDimensionNumbersAttr::get(
parser.getBuilder().getContext(), input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions);
return success();
ParseResult parseConvolutionDimensions(AsmParser& parser,
ConvDimensionNumbersAttr& dnums) {
// Parsing a single set of dim numbers gives the spatial dimensions as a
// single ArrayRef<int64_t> and a list of non-spatial dimensions as
// IntegerAttrs (indexed by the NonSpatialDim enum).
using parse_dim_result_t =
llvm::SmallDenseMap<NonSpatialDim, int64_t, 4,
// Note that the allowed_non_spatial_dims is a set (as opposed to unordered
// set) because its used to print a list of allowed non spatial dims in the
// error messages, so making it a set keeps the error messages deterministic.
auto parse_dims =
[&](std::set<NonSpatialDim, std::greater<>> allowed_non_spatial_dims,
parse_dim_result_t& parsed_dims) -> ParseResult {
auto& spatial_dims = std::get<0>(parsed_dims);
auto& non_spatial_dims = std::get<1>(parsed_dims);
// Parse the starting [
if (parser.parseLSquare()) {
return failure();
llvm::SmallDenseMap<int64_t, int64_t> spatial_dims_map;
constexpr int64_t kInvalidDimension = -1;
// Keep track of the maximum spatial dimension parsed as we expect to see
// all the dimensions from 0 to maximum dimension parsed.
int64_t max_parsed_spatial_dim = kInvalidDimension;
int64_t index = 0;
do {
int64_t spatial_dim;
auto dim_location = parser.getCurrentLocation();
OptionalParseResult parseResult =
if (parseResult.hasValue()) {
if (parseResult.getValue().failed()) {
return failure();
// We were successful in parsing an integer. Check if it is a valid
// dimension (non-negative and no duplicate) and add its index to the
// spatial dims map.
if (spatial_dim < 0)
return parser.emitError(dim_location)
<< "Unexpected dimension " << spatial_dim;
if (!spatial_dims_map
.insert(std::pair<int64_t, int64_t>(spatial_dim, index))
return parser.emitError(dim_location)
<< "Duplicate entries for spatial dimension " << spatial_dim;
max_parsed_spatial_dim = std::max(spatial_dim, max_parsed_spatial_dim);
} else if (!parser.parseOptionalQuestion()) {
// Do nothing other than increment `index` at the bottom of the loop;
// '?' means "unknown dimension", and it's not represented in the
// return value of this function.
} else {
// We did not parse an integer or question mark. We expect a keyword
// token.
StringRef keyword;
if (parser.parseKeyword(&keyword)) {
return failure();
if (keyword.size() != 1 || allowed_non_spatial_dims.empty()) {
return parser.emitError(dim_location, "Unexpected keyword ")
<< keyword;
// Check if the keyword matches one of the allowed non-spatial dims.
// If so, add it to the non_spatial dims and remove it from the
// allowed set so that it won't be allowed again.
bool is_allowed = false;
for (NonSpatialDim allowed : allowed_non_spatial_dims) {
if (keyword[0] == NonSpatialDimToString(allowed)) {
non_spatial_dims.insert({allowed, index});
is_allowed = true;
if (!is_allowed) {
mlir::InFlightDiagnostic diag =
parser.emitError(dim_location, "Unexpected dimension ");
diag << keyword << ", expecting ";
allowed_non_spatial_dims, diag,
[&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
return diag;
} while (parser.parseOptionalComma().succeeded());
// Make sure all expected non-spatial dimensions are parsed.
if (!allowed_non_spatial_dims.empty()) {
mlir::InFlightDiagnostic diag =
parser.emitError(parser.getCurrentLocation(), "Expected dimensions ");
allowed_non_spatial_dims, diag,
[&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
diag << " not specified";
return diag;
// parse ending ]
if (parser.parseRSquare()) {
return failure();
// Number of expected spatial dimensions is one more than the maximum parsed
// spatial dimension. For example, if we parse [0, 3, 2, b, i, 1], then the
// maximum parsed spatial dimension is 3 and the number of expected spatial
// dimensions is 4.
int64_t num_spatial_dimensions = max_parsed_spatial_dim + 1;
// Store spatial dimensions in a vector which maps spatial dim (vector
// index) -> index in the tensor dimensions. For example, for parsed
// dimension numbers [0, 3, 2, b, i, 1] the spatial dimension vector would
// be [0, 5, 2, 1].
// Get all the unspecified spatial dimensions to throw a more descriptive
// error later.
llvm::SmallVector<int64_t> unspecified_spatial_dims;
constexpr int kPrintUnspecifiedDimsMax = 10;
for (int dim = 0; dim < num_spatial_dimensions; ++dim) {
auto it = spatial_dims_map.find(dim);
if (it == spatial_dims_map.end()) {
// Have an upper bound on the number of unspecified dimensions to print
// in the error message.
if (unspecified_spatial_dims.size() < kPrintUnspecifiedDimsMax)
spatial_dims[dim] = it->second;
// Verify that we got all spatial dimensions between 0 and maximum parsed
// spatial dimension.
if (!unspecified_spatial_dims.empty()) {
mlir::InFlightDiagnostic diag = parser.emitError(
parser.getCurrentLocation(), "Expected spatial dimensions ");
llvm::interleaveComma(unspecified_spatial_dims, diag);
diag << " not specified";
return diag;
return success();
parse_dim_result_t parsed_dims;
if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
return failure();
llvm::SmallVector<int64_t> input_spatial_dimensions = parsed_dims.first;
int64_t input_batch_dimension = parsed_dims.second[IOBatch];
int64_t input_feature_dimension = parsed_dims.second[IOFeature];
if (parser.parseKeyword("x")) return failure();
if (parse_dims({KIFeature, KOFeature}, parsed_dims)) {
return failure();
llvm::SmallVector<int64_t> kernel_spatial_dimensions = parsed_dims.first;
int64_t kernel_input_feature_dimension = parsed_dims.second[KIFeature];
int64_t kernel_output_feature_dimension = parsed_dims.second[KOFeature];
if (parser.parseArrow()) {
return failure();
if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
return failure();
llvm::SmallVector<int64_t> output_spatial_dimensions = parsed_dims.first;
int64_t output_batch_dimension = parsed_dims.second[IOBatch];
int64_t output_feature_dimension = parsed_dims.second[IOFeature];
dnums = ConvDimensionNumbersAttr::get(
parser.getBuilder().getContext(), input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions);
return success();
Attribute ConvDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
if (failed(parser.parseLess())) return {};
ConvDimensionNumbersAttr dnums;
if (succeeded(parser.parseOptionalKeyword("raw"))) {
if (failed(parseConvolutionDimensionsRaw(parser, dnums))) return {};
return dnums;
if (failed(parseConvolutionDimensions(parser, dnums))) return {};
if (failed(parser.parseGreater())) return {};
return dnums;
// Custom printer and parser for ArgResultAliasAttr.
constexpr char kMustAlias[] = "must_alias";
constexpr char kResult[] = "result_index";
constexpr char kArgTupleIndices[] = "tuple_indices";
void ArgResultAliasAttr::print(AsmPrinter& printer) const {
printer << "<";
// The attribute can have empty tuple indices. Only print argument tuple
// indices if they are non-empty.
if (!getArgTupleIndices().empty())
printer << kArgTupleIndices << " = [" << getArgTupleIndices() << "], ";
// Print the result index followed by any result tuple indices if present.
printer << kResult << " = [";
printer << getResultIndex();
if (!getResultTupleIndices().empty()) {
printer << ", " << getResultTupleIndices();
printer << "]";
// Print the "must_alias" keyword if this is a must alias, otherwise skip.
if (getIsMustAlias()) printer << ", " << kMustAlias;
printer << ">";
Attribute ArgResultAliasAttr::parse(AsmParser& parser, Type type) {
if (failed(parser.parseLess())) return {};
llvm::SmallVector<int64_t> arg_tuple_indices;
// The first element of result indices holds the aliased result index and the
// remaining elements are the result tuple indices.
llvm::SmallVector<int64_t> result_indices;
bool is_must_alias = false;
// This conveys to parseStruct that keyword "must_alias" (3rd field) is not
// followed by a "=", but other fields are.
llvm::SmallVector<bool, 3> parse_equal = {true, true, false};
if (failed(
parseStruct(parser, {kArgTupleIndices, kResult, kMustAlias},
{[&]() { return parseDims(parser, arg_tuple_indices); },
[&]() {
// Since the first element is the index of result, at
// least one element is expected.
return parseDimsWithMinimumElements(
parser, result_indices, /*min_elements=*/1);
[&]() {
// always succeeds if the keyword "must_alias" was
// parsed
is_must_alias = true;
return success();
parse_equal))) {
<< "failed parsing argument-result alias attribute";
return {};
int64_t result_index = result_indices[0];
auto result_tuple_indices =
ArrayRef<int64_t>{result_indices.begin() + 1, result_indices.end()};
return ArgResultAliasAttr::get(parser.getContext(), arg_tuple_indices,
result_index, result_tuple_indices,
// Returns the element type pointed to by `indices` in type `t`. If the indices
// are invalid, returns nullptr.
static Type GetTypeFromTupleIndices(Type type, ArrayRef<int64_t> indices) {
Type current = type;
for (auto index : indices) {
TupleType tuple_type = current.dyn_cast<TupleType>();
if (!tuple_type || index >= tuple_type.size()) return {};
current = tuple_type.getType(index);
return current;
static LogicalResult VerifyArgResultAliasAttr(StringAttr attr_name,
ArgResultAliasAttr alias_attr,
unsigned arg_index,
Operation* op) {
// The attribute can only be applied to function-like operations.
if (!isa<mlir::FunctionOpInterface>(op))
return op->emitOpError() << "attribute " << attr_name
<< " can only be used on function-like operations";
// Verify there are no negative indices.
auto tuple_indices = llvm::concat<const int64_t>(
alias_attr.getArgTupleIndices(), alias_attr.getResultTupleIndices());
if (llvm::any_of(tuple_indices, [](const int64_t val) { return val < 0; }) ||
alias_attr.getResultIndex() < 0)
return op->emitOpError()
<< "attribute " << attr_name
<< " expects all argument and result indices to be >= 0";
// Verify that the result index is not out of range. Since the attribute is a
// function argument attribute, the argument index is always correct when this
// verifier is called.
FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
FunctionType ftype = funcOp.getType().cast<FunctionType>();
if (alias_attr.getResultIndex() >= ftype.getNumResults())
return op->emitOpError() << "attribute " << attr_name
<< " result index is out of range, must be <"
<< ftype.getNumResults();
// Verify that argument and result types pointed to by the indices are valid
// and compatible.
Type arg_type = GetTypeFromTupleIndices(ftype.getInput(arg_index),
if (!arg_type)
return op->emitOpError() << "attribute " << attr_name
<< " argument tuple indices are invalid";
Type result_type =
if (!result_type)
return op->emitOpError()
<< "attribute " << attr_name << " result tuple indices are invalid";
if (failed(mlir::verifyCompatibleShape(arg_type, result_type)) ||
getElementTypeOrSelf(arg_type) != getElementTypeOrSelf(result_type))
return op->emitOpError() << "attribute " << attr_name
<< " aliases do not have compatible types, "
<< arg_type << " vs. " << result_type;
return success();
// Shape inference
LogicalResult deriveShapeFromOperand(
OpBuilder* builder, Operation* op, Value operand,
SmallVectorImpl<Value>* reifiedReturnShapes) {
auto shaped_ty = operand.getType().dyn_cast<ShapedType>();
if (!shaped_ty) {
op->emitOpError() << "operand is not a shaped type";
return failure();
{builder->create<shape::ShapeOfOp>(op->getLoc(), operand)});
return success();
// MHLO Dialect Hooks
Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
Type type, Location loc) {
// HLO dialect constants only support ElementsAttr unlike standard dialect
// constant which supports all attributes.
if (value.isa<ElementsAttr>())
return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
return nullptr;
LogicalResult MhloDialect::verifyRegionArgAttribute(Operation* op,
unsigned region_index,
unsigned arg_index,
NamedAttribute attr) {
if (auto alias_attr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) {
if (failed(VerifyArgResultAliasAttr(attr.getName(), alias_attr, arg_index,
return failure();
return success();
LogicalResult MhloDialect::verifyOperationAttribute(Operation* op,
NamedAttribute attr) {
if (auto alias_attr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) {
if (!isa<mlir::FunctionOpInterface>(op))
return op->emitOpError()
<< "attribute " << attr.getName()
<< " can only be used on function-like operations";
return success();
} // namespace mhlo
} // namespace mlir