| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This file defines the operations used in the MHLO dialect. |
| |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| |
| #include <assert.h> |
| #include <stddef.h> |
| #include <stdint.h> |
| |
| #include <algorithm> |
| #include <cstdint> |
| #include <functional> |
| #include <numeric> |
| #include <set> |
| #include <unordered_map> |
| #include <utility> |
| |
| #include "llvm/ADT/APFloat.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ADT/Twine.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/ADT/iterator_range.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/MathExtras.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" |
| #include "mlir-hlo/utils/convert_op_folder.h" |
| #include "mlir-hlo/utils/hlo_utils.h" |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
| #include "mlir/Dialect/Shape/IR/Shape.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Diagnostics.h" |
| #include "mlir/IR/Dialect.h" |
| #include "mlir/IR/FunctionInterfaces.h" |
| #include "mlir/IR/Location.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/IR/Types.h" |
| #include "mlir/IR/Value.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/InliningUtils.h" |
| |
| namespace mlir { |
| #include "hlo_patterns.cc.inc" |
| } // namespace mlir |
| |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.cc.inc" |
| #define GET_ATTRDEF_CLASSES |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.cc.inc" |
| |
| namespace mlir { |
| namespace mhlo { |
| namespace { |
| void createArgs(ArrayRef<OpAsmParser::UnresolvedOperand> operands, |
| ArrayRef<Type> types, |
| SmallVector<OpAsmParser::Argument>& args) { |
| for (auto argAndType : llvm::zip(operands, types)) { |
| auto& arg = args.emplace_back(); |
| arg.ssaName = std::get<0>(argAndType); |
| arg.type = std::get<1>(argAndType); |
| } |
| } |
| |
| const auto hasDuplicates = [](SmallVector<int64_t>& nums) { |
| if (!llvm::is_sorted(nums)) std::sort(nums.begin(), nums.end()); |
| auto* last = std::unique(nums.begin(), nums.end()); |
| return last != nums.end(); |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Utilities for the canonicalize patterns |
| //===----------------------------------------------------------------------===// |
| |
| // This is an upper limit on how many elements can be folded by an op folder. |
| // This limit doesn't apply to some special cases like adding a zero, |
| // multiplying by one, doing many operations with splats. |
| constexpr int64_t kFoldOpEltLimit = 65536; |
| |
| // Clamps value to the range [lower, upper]. Requires lower <= upper. |
| template <typename T> |
| static T clamp(const T& value, const T& lower, const T& upper) { |
| assert(lower <= upper); |
| return std::max(lower, std::min(value, upper)); |
| } |
| |
| // Verifies that dimension attribute for the op correctly indexes in operand or |
| // result shape. |
| template <typename OpT> |
| static LogicalResult verifyDimAttr(OpT op) { |
| int64_t rank = -1; |
| if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) { |
| rank = ty.getRank(); |
| } else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) { |
| rank = ty.getRank(); |
| } else { |
| return success(); |
| } |
| |
| int64_t dim = op.dimension(); |
| if (dim < 0 || dim >= rank) |
| return op.emitOpError() << "requires dimension attribute in range [0, " |
| << rank << "); found (" << dim << ")"; |
| return success(); |
| } |
| |
| // Given the start indices and slice sizes for a dynamic-slice that can be |
| // converted to a static slice, returns the limits for the static slice. |
| DenseIntElementsAttr buildSliceLimits(DenseIntElementsAttr startIndices, |
| DenseIntElementsAttr sliceSizes, |
| Builder* builder) { |
| SmallVector<int64_t, 4> sliceLimits; |
| for (int64_t i = 0; i < sliceSizes.getNumElements(); ++i) { |
| int64_t startIndex = startIndices.getValues<IntegerAttr>()[i].getInt(); |
| int64_t sliceSize = sliceSizes.getValues<IntegerAttr>()[i].getInt(); |
| sliceLimits.push_back(startIndex + sliceSize); |
| } |
| return builder->getI64TensorAttr(sliceLimits); |
| } |
| |
| /// Replaces the given op with the contents of the given single-block region, |
| /// using the operands of the block terminator to replace operation results. |
| static void replaceOpWithRegion(PatternRewriter& rewriter, Operation* op, |
| Region& region, ValueRange blockArgs = {}) { |
| assert(llvm::hasSingleElement(region) && "expected single-block region"); |
| Block* block = ®ion.front(); |
| Operation* terminator = block->getTerminator(); |
| ValueRange results = terminator->getOperands(); |
| rewriter.mergeBlockBefore(block, op, blockArgs); |
| rewriter.replaceOp(op, results); |
| rewriter.eraseOp(terminator); |
| } |
| |
| #include "mhlo_canonicalize.inc" |
| |
| // Check if the dimension size is dynamic. |
| inline static bool isDynamicDimSize(int64_t val) { |
| return val == ShapedType::kDynamicSize; |
| } |
| |
| // Common shape function helper for RngNormal and RngUniform. |
| static LogicalResult rngInferReturnTypeComponents( |
| MLIRContext* context, Optional<Location> location, ValueRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| if (operands.size() != 3) |
| return emitOptionalError(location, "expected 3 operands"); |
| |
| SmallVector<int64_t> shapeVector; |
| Value shapeOperand = operands[2]; |
| auto shapeOperandType = shapeOperand.getType().cast<ShapedType>(); |
| Type elementType = getElementTypeOrSelf(operands[1]); |
| |
| // Operand `shape` (1D by ODS) may be a constant or not, if `shape` is: |
| // 1, not constant and have dynimic dim (tensor<?x>): infer tensor<*x>. |
| // 2. not constant nor dynimic (e.g. tensor<3xi64>): infer tensor<?x?x?x>. |
| // 3. constant (e.g. dense<[2, 3, 5]>): infer tensor<2x3x5x>. |
| |
| // Match to check whether the `shape` operand is a constant. |
| DenseIntElementsAttr shape; |
| if (!matchPattern(shapeOperand, m_Constant(&shape))) { |
| int size = shapeOperandType.getDimSize(0); |
| if (isDynamicDimSize(size)) { |
| inferredReturnShapes.emplace_back(elementType); |
| return success(); |
| } |
| shapeVector.resize(size, ShapedType::kDynamicSize); |
| inferredReturnShapes.emplace_back(shapeVector, elementType); |
| return success(); |
| } |
| |
| // `shape` operand is a constant. |
| shapeVector.reserve(shape.size()); |
| for (const APInt& fp : shape.getValues<APInt>()) |
| shapeVector.push_back(fp.getSExtValue()); |
| inferredReturnShapes.emplace_back(shapeVector, elementType); |
| return success(); |
| } |
| |
| // Returns a new scalar integer value having type `type`. Here `type` must be |
| // an integer or index type. |
| Value maybeCastTo(OpBuilder& b, Location loc, Value value, Type type) { |
| if (type == value.getType()) return value; |
| assert(type.isIndex() || value.getType().isIndex()); |
| return b.create<arith::IndexCastOp>(loc, type, value); |
| } |
| |
| DenseElementsAttr reshape(DenseElementsAttr attr, ShapedType newType) { |
| // TODO(b/232866626): DenseElementsAttr::reshape is broken for bool splats. |
| // Once that ticket is fixed, we can remove this conditional. |
| if (attr.isSplat() && newType.getElementType().isInteger(/*width=*/1)) { |
| auto splatValue = attr.getValues<bool>()[0]; |
| return DenseElementsAttr::get(newType, {splatValue}); |
| } |
| return attr.reshape(newType); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Utilities for verifiers |
| //===----------------------------------------------------------------------===// |
| |
| // Convert a 1D dense int64 attribute to a list of values. |
| SmallVector<int64_t> convertDenseIntAttr( |
| llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr) { |
| if (!optionalAttr.has_value()) return SmallVector<int64_t>{}; |
| |
| mlir::DenseIntElementsAttr attr = *optionalAttr; |
| auto values = attr.getValues<int64_t>(); |
| return {values.begin(), values.end()}; |
| } |
| |
| // Convert a 1D or Nx2 dense int64 attribute to a list of tuples. |
| FailureOr<SmallVector<std::pair<int64_t, int64_t>>> convertNx2Attribute( |
| llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr, Location loc) { |
| if (!optionalAttr.has_value()) |
| return SmallVector<std::pair<int64_t, int64_t>>{}; |
| mlir::DenseIntElementsAttr attr = *optionalAttr; |
| |
| auto attrType = attr.getType().cast<RankedTensorType>(); // ensured by ODS. |
| if (attrType.getRank() > 1) { |
| if (attrType.getRank() != 2 || attrType.getShape()[1] != 2) |
| return (mlir::emitError(loc) << "expects the shape of padding-attribute " |
| "to be {N, 2}, but got {" |
| << attrType.getShape() << "}.", |
| failure()); |
| } else { |
| // Padding values can be provided as a 1D vector as well. |
| if (attr.getValues<int64_t>().size() % 2 != 0) |
| return (mlir::emitError(loc) |
| << "expects the padding-entries to have even number of " |
| "elements, but got " |
| << attr.getValues<int64_t>().size() << " elements.", |
| failure()); |
| } |
| |
| auto it = attr.getValues<int64_t>().begin(); |
| SmallVector<std::pair<int64_t, int64_t>> out(attr.getNumElements() / 2); |
| for (auto& item : out) { |
| int64_t first = *it; |
| ++it; |
| int64_t second = *it; |
| ++it; |
| item = {first, second}; |
| } |
| return out; |
| } |
| |
| // If a window with the given bound in some dimension is dilated with the given |
| // dilation factor in that dimension, then the value returned is the bound for |
| // the array in that dimension after dilation. |
| // |
| // For a 1D array with 3 entries 1, 2, 3, a dilation factor of 2 yields a new |
| // window with values 1, x, 2, x, 3, where x indicates holes left by the |
| // dilation. So DilatedBound(3, 2) == 5. |
| int64_t dilatedBound(int64_t bound, int64_t dilation) { |
| assert(bound >= 0 && "The dimension to dialate must be >= 0"); |
| if (bound == 0) return 0; |
| |
| // Suppose the array has three entries 123 and the dilation factor is 4. Then |
| // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except |
| // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we |
| // add 1 to account for the final input element. |
| return (bound - 1) * dilation + 1; |
| } |
| |
| // Returns the number of valid positions of a window with the given size and |
| // stride within an array with the given bound. This is the bound of an output |
| // array with one element per valid position of the window. |
| // |
| // For example, for arguments of (bound=5, window_size=2, stride=2), the |
| // returned value is 2. There are valid positions at offset 0 and offset 2, |
| // while offset 4 is not valid since the window's last entry would be at 5, |
| // which is beyond the bound of 5. |
| int64_t stridedBound(int64_t bound, int64_t windowSize, int64_t stride) { |
| assert(windowSize >= 0 && "Expected window size to be >= 0"); |
| assert(bound >= 0 && "Expected bound to be >= 0"); |
| |
| if (bound == 0 || windowSize > bound) return 0; |
| |
| // Without considering stride, the maximum valid offset is bound - |
| // window_size. Taking stride into account, the valid offsets then have the |
| // form q * stride for q = 0, ..., Q such that q * stride <= bound - |
| // window_size. This implies that Q equals floor(bound - window_size / |
| // stride). There are Q + 1 valid values of q, yielding the formula below. |
| return (bound - windowSize) / stride + 1; |
| } |
| |
| // WindowDimension described how the kernel window moves across the base area |
| // in a particular dimension. |
| // Describes the windowing in an operation such as convolution. |
| // The window is moved across a base area and for each position of the |
| // window a computation is performed. The field below describes the |
| // window and the movement of the window across a base area. |
| struct WindowDimension { |
| int64_t size = 0; |
| int64_t stride = 1; |
| int64_t paddingLow = 0; |
| int64_t paddingHigh = 0; |
| int64_t windowDilation = 1; |
| int64_t baseDilation = 1; |
| bool windowReversal = false; |
| }; |
| |
| // Verifies various properties of window-attributes (viz., stride, padding, |
| // lhs_dilation and rhs_dilation) and collects all the window-attributes for |
| // each kernel spatial dimensions. |
| FailureOr<SmallVector<WindowDimension>> |
| verifyWindowAttributesAndInferWindowDimensions( |
| ArrayRef<int64_t> windowDimensions, ArrayRef<int64_t> windowStrides, |
| ArrayRef<std::pair<int64_t, int64_t>> padding, |
| ArrayRef<int64_t> lhsDilation, ArrayRef<int64_t> rhsDilation, |
| Location loc) { |
| const auto verifySize = [&](const size_t attrSize, |
| StringRef attrName) -> LogicalResult { |
| if (attrSize == 0 || attrSize == windowDimensions.size()) return success(); |
| return mlir::emitError(loc) |
| << "expects " << attrName |
| << " to have same dimension-size as size of " |
| "window dimensions " |
| "(" |
| << windowDimensions.size() << "), but got: " << attrSize << "."; |
| }; |
| |
| if (failed(verifySize(windowStrides.size(), "window-strides"))) |
| return failure(); |
| if (failed(verifySize(lhsDilation.size(), "base-dilation factors"))) |
| return failure(); |
| if (failed(verifySize(rhsDilation.size(), "window-dilation factors"))) |
| return failure(); |
| if (failed(verifySize(padding.size(), "padding-entries"))) return failure(); |
| |
| SmallVector<WindowDimension> window(windowDimensions.size()); |
| for (size_t i = 0; i < windowDimensions.size(); i++) { |
| WindowDimension& dim = window[i]; |
| |
| dim.size = windowDimensions[i]; |
| if (!isDynamicDimSize(dim.size) && dim.size <= 0) |
| return (mlir::emitError(loc) |
| << "expects window to have positive value for " << i |
| << "-th window dimension, but got " << dim.size << ".", |
| failure()); |
| |
| if (!windowStrides.empty()) dim.stride = windowStrides[i]; |
| if (dim.stride <= 0) |
| return (mlir::emitError(loc) |
| << "expects window to have positive stride for " << i |
| << "-th window dimension, but got " << dim.stride << ".", |
| failure()); |
| |
| if (!lhsDilation.empty()) dim.baseDilation = lhsDilation[i]; |
| if (dim.baseDilation <= 0) |
| return (mlir::emitError(loc) << "expects window to have positive base " |
| "dilation factor for " |
| << i << "-th window dimension, but got " |
| << dim.baseDilation << ".", |
| failure()); |
| |
| if (!rhsDilation.empty()) dim.windowDilation = rhsDilation[i]; |
| if (dim.windowDilation <= 0) |
| return (mlir::emitError(loc) << "expects window to have positive window " |
| "dilation factor for " |
| << i << "-th window dimension, but got " |
| << dim.windowDilation << ".", |
| failure()); |
| |
| if (!padding.empty()) { |
| dim.paddingLow = padding[i].first; |
| dim.paddingHigh = padding[i].second; |
| } |
| } |
| |
| return window; |
| } |
| |
| // Infer the shape of the output window. |
| // Foreach dimension d, |
| // output-window-shape[d] = |
| // stridedBound(padding_low + dilatedBound(base_shape[d]) + |
| // padding_high, |
| // dilatedBound(window_shape[d])) |
| // where (padding_low, padding_high) is the padding-pair for d. |
| SmallVector<int64_t> inferWindowOutputShape( |
| const ArrayRef<int64_t> baseShape, const ArrayRef<WindowDimension> window) { |
| assert(baseShape.size() == window.size() && |
| "Size of window dimensions must match the size of base shape."); |
| |
| SmallVector<int64_t> outputDimensions(window.size()); |
| for (int64_t i = 0; i < static_cast<int64_t>(window.size()); ++i) { |
| if (isDynamicDimSize(baseShape[i]) || isDynamicDimSize(window[i].size)) { |
| outputDimensions[i] = ShapedType::kDynamicSize; |
| } else { |
| const auto& dim = window[i]; |
| |
| const int64_t dilatedBase = dilatedBound(baseShape[i], dim.baseDilation); |
| const int64_t paddedDilatedBase = |
| dim.paddingLow + dilatedBase + dim.paddingHigh; |
| const int64_t dilatedWindow = dilatedBound(dim.size, dim.windowDilation); |
| |
| outputDimensions[i] = |
| stridedBound(paddedDilatedBase, dilatedWindow, dim.stride); |
| } |
| } |
| |
| return outputDimensions; |
| } |
| |
| // Return true if type1 and type2 are tensors and have the same |
| // element-type, else return false. With float element-types, ignore comparing |
| // floating-point precision if ignoreFpPrecision is True. |
| bool tensorsHaveSameElType(Type type1, Type type2, bool ignoreFpPrecision) { |
| auto tensorTy1 = type1.dyn_cast<TensorType>(); |
| auto tensorTy2 = type2.dyn_cast<TensorType>(); |
| |
| if (!tensorTy1 || !tensorTy2) return false; |
| |
| if (ignoreFpPrecision && tensorTy1.getElementType().isa<FloatType>() && |
| tensorTy2.getElementType().isa<FloatType>()) |
| return true; |
| |
| return tensorTy1.getElementType() == tensorTy2.getElementType(); |
| } |
| |
| // Return true if type1 and type2 are shape-compatible and have same element |
| // type. If 'ignoreFpPrecision' is True, then allow floats with different |
| // precisions while checking element-types. |
| bool compatibleShapeAndElementType(Type type1, Type type2, |
| bool ignoreFpPrecision = false) { |
| if (failed(verifyCompatibleShape(type1, type2))) return false; |
| return tensorsHaveSameElType(type1.cast<ShapedType>(), |
| type2.cast<ShapedType>(), ignoreFpPrecision); |
| } |
| |
| LogicalResult verifyReducerShape( |
| Location loc, Block& block, ArrayRef<TensorType> inputArgTypes, |
| ArrayRef<TensorType> initValueTypes, int64_t numInputs, |
| ArrayRef<int64_t> allowedDimensions, bool allInputsUnranked, |
| SmallVectorImpl<TensorType>& accumulatorSubShapes) { |
| // Check that the number of reduction-region arguments matches with that of |
| // reduce-op's arguments. |
| if (static_cast<int64_t>(block.getArguments().size()) != numInputs * 2) |
| return mlir::emitError(loc) |
| << "Reduction-region must take " << numInputs * 2 |
| << " parameters, but takes " << block.getArguments().size() |
| << " parameter(s)"; |
| |
| // Check if the reduction-region produces non-zero outputs. |
| if (block.getTerminator()->getOperands().empty()) |
| return mlir::emitError(loc) |
| << "The reduction-region expected to return some value(s)"; |
| |
| // Check that the reduction-region returns list- of tensors. |
| // The number of result-tensors must match the `numInputs`. |
| if (static_cast<int64_t>(block.getTerminator()->getOperands().size()) != |
| numInputs) |
| return mlir::emitError(loc) |
| << "Reduction-region here must produce " << numInputs |
| << " tensors, but produces " |
| << block.getTerminator()->getOperands().size() << " instead"; |
| |
| for (Value retOperand : block.getTerminator()->getOperands()) { |
| auto tensorTy = retOperand.getType().dyn_cast<TensorType>(); |
| if (!tensorTy) |
| return mlir::emitError(loc) << "Reduction-region here must produce " |
| "tensor-typed result(s), but " |
| "produces " |
| << retOperand.getType() << " instead"; |
| |
| accumulatorSubShapes.push_back(tensorTy); |
| } |
| |
| // Consider typical reduce-* op syntax: |
| // |
| // op(I(i), V(j)): |
| // block(BI(i), BV(j)): |
| // ... some computation ... |
| // return(R(i)) |
| // |
| // where |
| // I(i) : i-th input of op |
| // V(j) : j-th init-value of op |
| // BI(i) : i-th input of reducer-function |
| // BV(j) : j-th init-value of reducer-function |
| // R(i) : i-th return-type |
| // |
| // Note that: |I(i)| == V(j)| == |BI(i)| == |BV(j)| == |R(i)| |
| // |
| // Here are the type-constraints among V(j), BI(i), BV(j), and R(i). |
| // C1 : Check that BI(i) and R(i) have same shape and element-type. |
| // C2 : Check that BV(j) and R(i) have same shape and element-type. |
| // C3 : Check that V(j) and R(i) have same shape and element-type. |
| // |
| // From C1, C2, and C3, we can infer that V(j), BI(i), BV(j), and R(i) all |
| // have compatible shapes and element-types. |
| // The next check, C4, adds constraints on how the type if I(i) is related |
| // to any_of(V(j), BI(i), BV(j), and R(i)), say BV(j); |
| // |
| // C4.1 : Check that I(i) and BV(j) have same element-type. |
| // C4.2 : Check that shape of BV(j) is a 'sub-sequence' of |
| // 'allowedDimensions'. 'allowedDimensions' is a list of dimensions |
| // which any of BI(i), BV(j), and R(i) is allowed to have. |
| for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) { |
| // Check C1. |
| if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx], |
| block.getArgument(inputIdx).getType())) |
| return mlir::emitError(loc) |
| << "The type of reduction-region's parameter at index " << inputIdx |
| << " is different than the corresponding result type: " |
| << block.getArgument(inputIdx).getType() << " vs " |
| << accumulatorSubShapes[inputIdx]; |
| |
| // Check C2. |
| if (!compatibleShapeAndElementType( |
| accumulatorSubShapes[inputIdx], |
| block.getArgument(numInputs + inputIdx).getType(), |
| /*ignoreFpPrecision=*/true)) |
| return mlir::emitError(loc) |
| << "The type of reduction-region's parameter at index " |
| << numInputs + inputIdx |
| << " is different than the corresponding result type: " |
| << block.getArgument(numInputs + inputIdx).getType() << " vs " |
| << accumulatorSubShapes[inputIdx]; |
| |
| // Check C3. |
| if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx], |
| initValueTypes[inputIdx], |
| /*ignoreFpPrecision=*/true)) |
| return mlir::emitError(loc) |
| << "The type of reduction-region's result type at index " |
| << inputIdx |
| << " differs from the op's corresponding init-value type: " |
| << accumulatorSubShapes[inputIdx] << " vs " |
| << initValueTypes[inputIdx]; |
| |
| // Check C4.1. |
| if (!tensorsHaveSameElType( |
| inputArgTypes[inputIdx], |
| block.getArgument(numInputs + inputIdx).getType(), true)) |
| return mlir::emitError(loc) |
| << "The element-type of reduction-region's argument at index " |
| << numInputs + inputIdx << " is expected to be " |
| << inputArgTypes[inputIdx].getElementType() << ", but got " |
| << block.getArgument(numInputs + inputIdx).getType() |
| << " as its type."; |
| |
| // Check C4.2. |
| Type blockArgType = block.getArgument(numInputs + inputIdx).getType(); |
| auto blockArgTensorTy = blockArgType.cast<TensorType>(); |
| |
| if (allInputsUnranked || !blockArgTensorTy.hasRank()) return success(); |
| |
| auto argShape = blockArgTensorTy.getShape(); |
| if (argShape.size() > allowedDimensions.size()) |
| return mlir::emitError(loc) |
| << "The rank of reduction-region's argument at index " |
| << numInputs + inputIdx |
| << " is expected to be <= " << allowedDimensions.size() << ", got " |
| << argShape.size(); |
| |
| int64_t argShapeIdx = 0; |
| for (int64_t outputShapeIdx = 0; |
| outputShapeIdx < static_cast<int64_t>(allowedDimensions.size()) && |
| argShapeIdx < static_cast<int64_t>(argShape.size()); |
| outputShapeIdx++) |
| if (allowedDimensions[outputShapeIdx] == argShape[argShapeIdx]) |
| argShapeIdx++; |
| |
| if (argShapeIdx != static_cast<int64_t>(argShape.size())) |
| return mlir::emitError(loc) |
| << "The shape of reduction-region's argument at index " |
| << numInputs + inputIdx |
| << " is not compatible with that of reduce-op's input-parameter " |
| "at index " |
| << inputIdx; |
| } |
| |
| return success(); |
| } |
| |
| unsigned potentiallyComplexBitwidth(Type type) { |
| auto complexTy = type.dyn_cast<ComplexType>(); |
| return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth() |
| : type.getIntOrFloatBitWidth(); |
| } |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // ReduceScatterOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ReduceScatterOp::verify() { |
| if (failed(mlir::hlo::verifyReplicaGroups(*this, /*is_uniform_sized=*/true))) |
| return failure(); |
| auto operandType = operand().getType().cast<TensorType>(); |
| bool operandTypeRanked = operandType.isa<RankedTensorType>(); |
| Block& block = computation().front(); |
| SmallVector<TensorType> accumulatorSubshapes; |
| if (failed(verifyReducerShape( |
| this->getLoc(), block, {operandType}, |
| {RankedTensorType::get({}, operandType.getElementType())}, |
| /*numInputs=*/1, /*allowedDimensions=*/{}, |
| /*allInputsUnranked=*/!operandTypeRanked, accumulatorSubshapes))) |
| return failure(); |
| |
| return mlir::hlo::verifyReduceScatter( |
| *this, |
| /*operand_types=*/{operand().getType()}, |
| /*result_types=*/{getType()}, |
| /*scatter_dimension=*/scatter_dimension()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CompatibleOperandsAndResultType |
| //===----------------------------------------------------------------------===// |
| |
| // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that |
| // support quantization or sparsity. |
| #define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op) \ |
| LogicalResult Op::inferReturnTypeComponents( \ |
| MLIRContext* context, Optional<Location> location, \ |
| ValueShapeRange operands, DictionaryAttr attributes, \ |
| RegionRange regions, \ |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { \ |
| return inferReturnTypeComponentsFromOperands(context, location, operands, \ |
| attributes, regions, \ |
| inferredReturnShapes); \ |
| } |
| |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AddOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AndOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ClzOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectivePermuteOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CopyOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DivOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DomainOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ExpOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Expm1Op) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(FloorOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LogOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Log1pOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LogisticOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MaxOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MinOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MulOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NegOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NotOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(OrOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PopulationCountOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PowOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ReducePrecisionOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RemOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ReverseOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundNearestEvenOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RsqrtOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftLeftOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightArithmeticOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightLogicalOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SignOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SineOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SqrtOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SubtractOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanhOp) |
| INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(XorOp) |
| |
| //===----------------------------------------------------------------------===// |
| // ConstantOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.empty() && "constant has no operands"); |
| |
| // Return the held attribute value. |
| return value(); |
| } |
| |
| // Builds a constant op with the specified attribute `value`. |
| void ConstantOp::build(OpBuilder& /*builder*/, OperationState& result, |
| Attribute value) { |
| Type type; |
| if (auto elemAttr = value.dyn_cast<ElementsAttr>()) { |
| type = elemAttr.getType(); |
| } else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() || |
| value.isa<IntegerAttr>()) { |
| // All XLA types must be tensor types. In the build() method, we want to |
| // provide more flexibility by allowing attributes of scalar types. But we |
| // need to wrap it up with ElementsAttr to construct valid XLA constants. |
| type = |
| RankedTensorType::get(/*shape=*/{}, value.cast<TypedAttr>().getType()); |
| value = DenseElementsAttr::get(type.cast<TensorType>(), value); |
| } |
| |
| // TODO: support other XLA specific types. |
| assert(type && "unsupported attribute type for building mhlo.constant"); |
| result.types.push_back(type); |
| result.addAttribute("value", value); |
| } |
| |
| LogicalResult ConstantOp::inferReturnTypes( |
| MLIRContext*, Optional<Location>, ValueRange operands, |
| DictionaryAttr attributes, RegionRange, |
| SmallVectorImpl<Type>& inferredReturnTypes) { |
| ConstantOpAdaptor adaptor(operands, attributes); |
| Type type = adaptor.value().getType(); |
| inferredReturnTypes.push_back(type); |
| return success(); |
| } |
| |
| bool ConstantOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
| if (l.size() != r.size() || l.size() != 1) return false; |
| auto lhsTy = l.front().cast<TensorType>(); |
| auto rhsTy = r.front().cast<TensorType>(); |
| // For comparisons of the uniform quantized element based tensor type, use the |
| // storage type since the constant value will be stored through the underlying |
| // storage type. |
| if (auto rhsElemTy = |
| rhsTy.getElementType().dyn_cast<quant::QuantizedType>()) { |
| rhsTy = getSameShapeTensorType(rhsTy, rhsElemTy.getStorageType()); |
| } |
| return lhsTy == rhsTy; |
| } |
| |
| ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) { |
| // Parse the generic form. |
| if (succeeded(parser.parseOptionalLParen())) { |
| if (parser.parseRParen()) return failure(); |
| if (parser.parseOptionalAttrDict(result.attributes)) return failure(); |
| if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() || |
| parser.parseArrow()) |
| return failure(); |
| Type resultTy; |
| if (parser.parseType(resultTy)) { |
| return failure(); |
| } |
| result.addTypes(resultTy); |
| return success(); |
| } |
| |
| ElementsAttr valueAttr; |
| if (parser.parseOptionalAttrDict(result.attributes)) return failure(); |
| |
| if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value", |
| result.attributes)) { |
| return failure(); |
| } |
| result.addTypes(valueAttr.getType()); |
| return success(); |
| } |
| |
| /// Print a `constant` op. |
| /// |
| /// op ::= attr-dict $value |
| /// |
| /// When the `value` and `output` have different type, it just uses the default |
| /// operator assembly format as a fallback. |
| void ConstantOp::print(::mlir::OpAsmPrinter& p) { |
| // If not all types are the same, use generic form. |
| if (value().getType() != getType()) { |
| p.printGenericOp(getOperation(), /*printOpName=*/false); |
| return; |
| } |
| |
| p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); |
| p << ' '; |
| p.printStrippedAttrOrType(valueAttr()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CustomCallOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult CustomCallOp::verify() { |
| // If both operand and result layout attributes are not specified then nothing |
| // to verify. |
| if (!operand_layouts().has_value() && !result_layouts().has_value()) |
| return success(); |
| |
| // Layout constraints for either both operands & results or none should be |
| // specified. |
| if (operand_layouts().has_value() != result_layouts().has_value()) |
| return emitOpError() << "Layout attributes should be specified for " |
| "either both operands and results or none."; |
| |
| // Helper function to verify types and the corresponding layouts. |
| auto verifyTypesAndLayouts = |
| [this](TypeRange types, mlir::ArrayAttr layouts, |
| const std::string& valueName) -> LogicalResult { |
| if (types.size() != layouts.size()) |
| return emitOpError() << "Number of " << valueName |
| << "s must match the number of " << valueName |
| << " layouts, " << types.size() |
| << " != " << layouts.size(); |
| |
| for (const auto& indexedTypeAndLayout : |
| llvm::enumerate(llvm::zip(types, layouts))) { |
| // Get index for more descriptive error message. |
| auto index = indexedTypeAndLayout.index(); |
| |
| auto type = std::get<0>(indexedTypeAndLayout.value()); |
| auto layout = std::get<1>(indexedTypeAndLayout.value()) |
| .cast<DenseIntElementsAttr>(); |
| |
| if (type.isa<TupleType>()) |
| return emitOpError() << "Tuple types are not fully supported with " |
| "layout constraints yet"; |
| auto tensorType = type.dyn_cast<TensorType>(); |
| |
| // For non-tensor types such as !mhlo.token, the layout should be empty. |
| if (!tensorType) { |
| if (layout.empty()) continue; |
| return emitOpError() |
| << "Only tensor types can have non-empty layout: " << valueName |
| << " #" << index << " of type " << type << " has layout " |
| << layout; |
| } |
| |
| // For unranked tensors, we cannot verify the compatibility with layout |
| // any further. |
| if (!tensorType.hasRank()) continue; |
| |
| // Layout must be a permutation of [0, N) where N is the rank of the |
| // tensor type. |
| std::vector<int64_t> range(tensorType.getRank()); |
| std::iota(range.begin(), range.end(), 0); |
| if (tensorType.getRank() != layout.size() || |
| !std::is_permutation(range.begin(), range.end(), layout.begin())) |
| return emitOpError() << "incorrect layout " << layout << " for type " |
| << type << ", layout must be a permutation of [0, " |
| << tensorType.getRank() << ")"; |
| } |
| return success(); |
| }; |
| |
| // At this point both `operand_layouts` and `result_layouts` are defined. |
| ArrayAttr operandLayouts = this->operand_layouts().getValue(); |
| ArrayAttr resultLayouts = this->result_layouts().getValue(); |
| |
| // Full support for layouts for arbitrary nesting of tuples is not |
| // supported yet. |
| // |
| // If result does not have any tuples, then i-th element of `result_layouts` |
| // specifies the layout constraints on i-th result. |
| // |
| // For the common case of a single tuple result packing non-tuple values, the |
| // i-th element of `result_layouts` specifies layout for i-th element of the |
| // result tuple. |
| TypeRange resultTypes; |
| if (getNumResults() == 1 && getResult(0).getType().isa<TupleType>()) |
| resultTypes = getResult(0).getType().cast<TupleType>().getTypes(); |
| else |
| resultTypes = getResultTypes(); |
| |
| // Verify that operands and operand layouts match. |
| if (failed( |
| verifyTypesAndLayouts(getOperandTypes(), operandLayouts, "operand"))) |
| return failure(); |
| |
| // Verify that results and result layouts match. |
| return verifyTypesAndLayouts(resultTypes, resultLayouts, "result"); |
| } |
| |
| void CustomCallOp::getEffects( |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>& |
| effects) { |
| // CustomCall has "all possible effects" unless the has_side_effect is present |
| // and set to false. |
| auto hasSideEffect = (*this)->getAttrOfType<BoolAttr>("has_side_effect"); |
| if (hasSideEffect && !hasSideEffect.getValue()) return; |
| effects.emplace_back(MemoryEffects::Allocate::get()); |
| effects.emplace_back(MemoryEffects::Free::get()); |
| effects.emplace_back(MemoryEffects::Write::get()); |
| effects.emplace_back(MemoryEffects::Read::get()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CholeskyOp |
| //===----------------------------------------------------------------------===// |
| |
| // The following properties are already enforced by the ODS: |
| // P0. a.element_type is floating or complex |
| // We intend to verify the following properties |
| // P1. The 'a' argument to Cholesky must have rank >= 2, got shape %s |
| // P2. The two minor dimensions of 'a' must have equal size, got %s. |
| LogicalResult CholeskyOp::inferReturnTypeComponents( |
| MLIRContext*, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| CholeskyOp::Adaptor adaptor(operands, attributes, regions); |
| Type aType = adaptor.a().getType(); |
| RankedTensorType aRankedType = aType.dyn_cast<RankedTensorType>(); |
| if (!aRankedType) { |
| inferredReturnShapes.emplace_back( |
| aType.cast<TensorType>().getElementType()); |
| return success(); |
| } |
| |
| ArrayRef<int64_t> aShape = aRankedType.getShape(); |
| if (aShape.size() < 2) { |
| return emitOptionalError( |
| location, "argument 'a' must have rank >= 2, got shape ", aShape, "."); |
| } |
| |
| int64_t lastDim = aShape[aShape.size() - 1]; |
| int64_t penultimateDim = aShape[aShape.size() - 2]; |
| if (!isDynamicDimSize(lastDim) && !isDynamicDimSize(penultimateDim) && |
| lastDim != penultimateDim) { |
| return emitOptionalError( |
| location, "minor dimensions of 'a' must have equal size, got shape ", |
| aShape, "."); |
| } |
| inferredReturnShapes.emplace_back(aRankedType.getShape(), |
| aRankedType.getElementType()); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DotOp |
| //===----------------------------------------------------------------------===// |
| namespace { |
| bool dimCompatible(int64_t a, int64_t b) { |
| return isDynamicDimSize(a) || isDynamicDimSize(b) || a == b; |
| } |
| |
| ShapedType inferDotReturnType(ShapedType lhs, ShapedType rhs) { |
| auto elementType = lhs.getElementType(); |
| if (!lhs.hasRank() || !rhs.hasRank()) { |
| return UnrankedTensorType::get(elementType); |
| } |
| |
| // vector dot vector |
| if (1 == lhs.getRank() && 1 == rhs.getRank() && |
| dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) { |
| return RankedTensorType::get({}, elementType); |
| } |
| // matrix dot vector |
| if (2 == lhs.getRank() && 1 == rhs.getRank() && |
| dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) { |
| return RankedTensorType::get({lhs.getDimSize(0)}, elementType); |
| } |
| // vector dot matrix |
| if (1 == lhs.getRank() && 2 == rhs.getRank() && |
| dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) { |
| return RankedTensorType::get({rhs.getDimSize(1)}, elementType); |
| } |
| // matrix dot matrix |
| if (2 == lhs.getRank() && 2 == rhs.getRank() && |
| dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) { |
| int64_t shape[2] = {lhs.getDimSize(0), rhs.getDimSize(1)}; |
| return RankedTensorType::get(shape, elementType); |
| } |
| return {}; |
| } |
| } // namespace |
| |
| LogicalResult DotOp::inferReturnTypes( |
| MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr, |
| RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) { |
| DotOp::Adaptor op(operands); |
| auto lhsType = op.lhs().getType().cast<ShapedType>(); |
| auto rhsType = op.rhs().getType().cast<ShapedType>(); |
| inferredReturnTypes.push_back(inferDotReturnType(lhsType, rhsType)); |
| return success(); |
| } |
| |
| LogicalResult DotOp::verify() { |
| auto lhsType = lhs().getType().cast<ShapedType>(); |
| auto rhsType = rhs().getType().cast<ShapedType>(); |
| auto resultType = getType().cast<ShapedType>(); |
| auto expectReturnType = inferDotReturnType(lhsType, rhsType); |
| if (!expectReturnType) { |
| return emitError() << "Unexpected operands type: " << lhsType << " and " |
| << rhsType; |
| } |
| if (resultType.hasRank() && expectReturnType.hasRank()) { |
| if (resultType.getShape() != expectReturnType.getShape()) { |
| return emitError() << "Unexpected result type: has " << resultType |
| << " but inferred " << expectReturnType |
| << " from operands " << lhsType << " and " << rhsType; |
| } |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DotGeneralOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult DotGeneralOp::verify() { |
| auto dimNumbers = this->dot_dimension_numbers(); |
| |
| ArrayRef<int64_t> lhsBatchingDims = dimNumbers.getLhsBatchingDimensions(); |
| ArrayRef<int64_t> rhsBatchingDims = dimNumbers.getRhsBatchingDimensions(); |
| ArrayRef<int64_t> lhsContractingDims = |
| dimNumbers.getLhsContractingDimensions(); |
| ArrayRef<int64_t> rhsContractingDims = |
| dimNumbers.getRhsContractingDimensions(); |
| |
| if (lhsBatchingDims.size() != rhsBatchingDims.size()) { |
| return emitOpError() << "lhs and rhs should have the same number of " |
| "batching dimensions"; |
| } |
| if (lhsContractingDims.size() != rhsContractingDims.size()) { |
| return emitOpError() << "lhs and rhs should have the same number of " |
| "contracting dimensions"; |
| } |
| |
| llvm::SmallDenseSet<int64_t> dimSet; |
| |
| auto checkDimsDistinct = |
| [this](ArrayRef<int64_t> batchingDims, ArrayRef<int64_t> contractingDims, |
| llvm::SmallDenseSet<int64_t>& dimSet, llvm::StringRef lhs, |
| llvm::StringRef rhs) -> LogicalResult { |
| auto dims = llvm::concat<const int64_t>(batchingDims, contractingDims); |
| for (auto dim : dims) { |
| auto [_, wasInserted] = dimSet.insert(dim); |
| if (!wasInserted) { |
| return emitOpError() << "has duplicated dimension from " << lhs |
| << " and " << rhs << ": " << dim; |
| } |
| } |
| return success(); |
| }; |
| |
| if (failed(checkDimsDistinct(lhsBatchingDims, lhsContractingDims, dimSet, |
| "lhs_batching_dimensions", |
| "lhs_contracting_dimensions"))) { |
| return failure(); |
| } |
| dimSet.clear(); |
| if (failed(checkDimsDistinct(rhsBatchingDims, rhsContractingDims, dimSet, |
| "rhs_batching_dimensions", |
| "rhs_contracting_dimensions"))) { |
| return failure(); |
| } |
| |
| auto checkDimsInRange = [this](int64_t rank, ArrayRef<int64_t> dims, |
| llvm::StringRef dimName) -> LogicalResult { |
| auto inRange = [&](int64_t i) -> bool { return 0 <= i && i < rank; }; |
| const auto* dimsNotInRange = |
| std::find_if_not(dims.begin(), dims.end(), inRange); |
| if (dimsNotInRange != dims.end()) { |
| return emitOpError() << dimName << " value: " << *dimsNotInRange |
| << " is out of range: " |
| << "[0, " << rank << ")"; |
| } |
| return success(); |
| }; |
| |
| auto lhsType = this->lhs().getType().dyn_cast<RankedTensorType>(); |
| auto rhsType = this->rhs().getType().dyn_cast<RankedTensorType>(); |
| |
| if (lhsType) { |
| if (failed(checkDimsInRange(lhsType.getRank(), lhsBatchingDims, |
| "lhs_batching_dimensions")) || |
| failed(checkDimsInRange(lhsType.getRank(), lhsContractingDims, |
| "lhs_contracting_dimensions"))) { |
| return failure(); |
| } |
| } |
| if (rhsType) { |
| if (failed(checkDimsInRange(rhsType.getRank(), rhsBatchingDims, |
| "rhs_batching_dimensions")) || |
| failed(checkDimsInRange(rhsType.getRank(), rhsContractingDims, |
| "rhs_contracting_dimensions"))) { |
| return failure(); |
| } |
| } |
| |
| if (lhsType && rhsType) { |
| // Dimension sizes must be compatible for lhs/rhs. |
| auto lhsShape = lhsType.getShape(); |
| auto rhsShape = rhsType.getShape(); |
| |
| for (auto [lhs, rhs] : llvm::zip(lhsBatchingDims, rhsBatchingDims)) { |
| if (lhsShape[lhs] != rhsShape[rhs]) { |
| return emitOpError() << "batching dimension sizes must match for " |
| "lhs/rhs"; |
| } |
| } |
| for (auto [lhs, rhs] : llvm::zip(lhsContractingDims, rhsContractingDims)) { |
| if (lhsShape[lhs] != rhsShape[rhs]) { |
| return emitOpError() << "contracting dimension sizes must match for " |
| "lhs/rhs"; |
| } |
| } |
| } |
| return success(); |
| } |
| |
| namespace { |
| // Handle the generic case of DotGeneral and convert to a regulat DotOp. |
| struct DotGeneralToDot : public OpRewritePattern<DotGeneralOp> { |
| using OpRewritePattern<DotGeneralOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(DotGeneralOp dot, |
| PatternRewriter& rewriter) const override { |
| auto lhs = dot.lhs(); |
| auto rhs = dot.rhs(); |
| auto lhsTy = lhs.getType().cast<ShapedType>(); |
| auto rhsTy = rhs.getType().cast<ShapedType>(); |
| |
| if (lhsTy.getRank() != 2) return failure(); |
| if (rhsTy.getRank() != 2) return failure(); |
| |
| auto nums = dot.dot_dimension_numbers(); |
| if (!nums.getLhsBatchingDimensions().empty()) return failure(); |
| if (!nums.getRhsBatchingDimensions().empty()) return failure(); |
| |
| auto lhsContract = nums.getLhsContractingDimensions(); |
| auto rhsContract = nums.getRhsContractingDimensions(); |
| if (lhsContract.size() != 1 || rhsContract.size() != 1) return failure(); |
| |
| if (lhsContract.front() != 1) return failure(); |
| if (rhsContract.front() != 0) return failure(); |
| |
| rewriter.replaceOpWithNewOp<mhlo::DotOp>( |
| dot, dot.getType(), lhs, rhs, |
| dot.precision_config().getValueOr(nullptr)); |
| |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void DotGeneralOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<DotGeneralToDot>(context); |
| } |
| |
| LogicalResult DotGeneralOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| auto lhsType = lhs().getType().dyn_cast<ShapedType>(); |
| auto rhsType = rhs().getType().dyn_cast<ShapedType>(); |
| if (!lhsType || !rhsType) { |
| return failure(); |
| } |
| |
| Adaptor adaptor(operands); |
| auto dimNumbers = dot_dimension_numbers(); |
| SmallVector<Value> dimensions; |
| for (const int64_t lhsDim : dimNumbers.getLhsBatchingDimensions()) { |
| dimensions.push_back( |
| builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), lhsDim)); |
| } |
| |
| for (int64_t i = 0; i < lhsType.getRank(); i++) { |
| if (!llvm::is_contained(dimNumbers.getLhsContractingDimensions(), i) && |
| !llvm::is_contained(dimNumbers.getLhsBatchingDimensions(), i)) { |
| dimensions.push_back( |
| builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), i)); |
| } |
| } |
| for (int64_t i = 0; i < rhsType.getRank(); i++) { |
| if (!llvm::is_contained(dimNumbers.getRhsContractingDimensions(), i) && |
| !llvm::is_contained(dimNumbers.getRhsBatchingDimensions(), i)) { |
| dimensions.push_back( |
| builder.create<tensor::DimOp>(getLoc(), adaptor.rhs(), i)); |
| } |
| } |
| |
| reifiedReturnShapes.push_back( |
| builder.create<tensor::FromElementsOp>(getLoc(), dimensions)); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FftOp |
| //===----------------------------------------------------------------------===// |
| |
| // We intend to verify the following properties |
| // P1. 1 <= rank <= 3 |
| // P2. Element types agree with fft_type |
| // P3. Operand shape dimensions agree with fft_length for the given fft_type |
| LogicalResult FftOp::inferReturnTypeComponents( |
| MLIRContext*, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| FftOp::Adaptor adaptor(operands, attributes, regions); |
| auto fftLength = adaptor.fft_length().getValues<int64_t>(); |
| int64_t fftRank = fftLength.size(); |
| |
| // P1. |
| if (fftRank > 3 || fftRank < 1) { |
| return emitOptionalError(location, "rank must be between 1 and 3, but got ", |
| fftRank, "."); |
| } |
| |
| // P2. Element type agreement |
| // FFT : C -> C |
| // IFFT : C -> C |
| // RFFT : R -> C |
| // IRFFT : C -> R |
| auto fftType = adaptor.fft_type(); |
| auto operandType = adaptor.operand().getType().cast<TensorType>(); |
| Type operandElementType = operandType.getElementType(); |
| // Check the input element type and infer return element type |
| if (fftType == FftType::RFFT) { |
| if (!operandElementType.isF32() && !operandElementType.isF64()) { |
| return emitOptionalError( |
| location, "RFFT requires f32 or f64 input type, but is given ", |
| operandElementType, "."); |
| } |
| } else { |
| if (!operandElementType.isa<ComplexType>()) { |
| return emitOptionalError( |
| location, stringifyFftType(fftType), |
| " takes a complex tensor as input, but is given ", operandType, "."); |
| } |
| } |
| // Generate the output element type |
| Type resultElementType = operandElementType; |
| if (fftType == FftType::RFFT) { // RFFT : R -> C |
| resultElementType = ComplexType::get(resultElementType); |
| } else if (fftType == FftType::IRFFT) { // IRFFT : C -> R |
| resultElementType = operandElementType.cast<ComplexType>().getElementType(); |
| } |
| |
| // P3. Check input shape and infer return shape |
| operandType = operandType.dyn_cast<RankedTensorType>(); |
| if (!operandType) { |
| inferredReturnShapes.emplace_back(resultElementType); |
| return success(); |
| } |
| auto operandShape = operandType.getShape(); |
| if (static_cast<int64_t>(operandShape.size()) < fftRank) { |
| return emitOptionalError( |
| location, "operand rank must not be less than fft rank of ", fftRank, |
| " for operand of type ", operandType, "."); |
| } |
| |
| SmallVector<int64_t> resultShape = to_vector(operandShape); |
| |
| if (fftType == FftType::RFFT) { |
| auto shapeBack = operandShape.take_back(fftRank); |
| for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) { |
| if (operandDim != fftDim) { |
| return emitOptionalError( |
| location, |
| "RFFT requires innermost dimensions match fft_length. Got: ", |
| operandShape, " but wanted ", fftLength, "."); |
| } |
| } |
| if (fftLength[fftRank - 1] != 0) { |
| resultShape[resultShape.size() - 1] = fftLength[fftRank - 1] / 2 + 1; |
| } |
| } |
| if (fftType == FftType::IRFFT) { |
| auto shapeBack = operandShape.take_back(fftRank).drop_back(); |
| for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) { |
| if (operandDim != fftDim) { |
| return emitOptionalError(location, |
| "IRFFT requires non-final dimensions " |
| "match fft_length. Got: ", |
| operandShape, " but wanted ", fftLength, |
| ", and ", operandDim, " != ", fftDim, "."); |
| } |
| } |
| if ((operandShape[operandShape.size() - 1] != 0 || |
| fftLength[fftRank - 1] != 0) && |
| operandShape[operandShape.size() - 1] != fftLength[fftRank - 1] / 2 + 1) |
| return emitOptionalError(location, |
| "IRFFT requires innermost dimension match " |
| "fft_length[-1]/2+1. Got: ", |
| operandShape, " but fft_length is ", fftLength, |
| "."); |
| resultShape[resultShape.size() - 1] = fftLength[fftRank - 1]; |
| } |
| |
| inferredReturnShapes.emplace_back(resultShape, resultElementType); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // GatherOp |
| //===----------------------------------------------------------------------===// |
| |
| // Converts gather ops to slice ops in case we have a single set of constant |
| // indices. |
| struct GatherSlice : public OpRewritePattern<GatherOp> { |
| using OpRewritePattern<GatherOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(GatherOp gather, |
| PatternRewriter& rewriter) const override { |
| DenseIntElementsAttr index; |
| if (!matchPattern(gather.start_indices(), m_Constant(&index))) |
| return failure(); |
| |
| const auto& dnums = gather.dimension_numbers(); |
| if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1) |
| return failure(); |
| |
| // TODO(tberghammer): Remove when the verifier catches this case what is |
| // invalid if all previous condition holds. |
| if (index.getNumElements() != |
| static_cast<int64_t>(dnums.getStartIndexMap().size())) |
| return failure(); |
| |
| RankedTensorType operandType = |
| gather->getOperand(0).getType().dyn_cast<RankedTensorType>(); |
| if (!operandType || !operandType.hasStaticShape()) return failure(); |
| |
| auto sliceEnd = |
| llvm::to_vector<8>(gather.slice_sizes().getValues<int64_t>()); |
| llvm::SmallVector<int64_t, 8> sliceStart(sliceEnd.size(), 0); |
| for (auto it : |
| llvm::zip(dnums.getStartIndexMap(), index.getValues<APInt>())) { |
| int64_t mapIndex = std::get<0>(it); |
| // Clamp the indices within bounds to faithfully mirror gather semantics. |
| int64_t offset = |
| clamp(std::get<1>(it).getSExtValue(), static_cast<int64_t>(0), |
| operandType.getDimSize(mapIndex) - sliceEnd[mapIndex]); |
| sliceStart[mapIndex] += offset; |
| sliceEnd[mapIndex] += offset; |
| } |
| |
| llvm::SmallVector<int64_t, 8> sliceStride(sliceEnd.size(), 1); |
| llvm::SmallVector<int64_t, 8> sliceShape(sliceEnd.size()); |
| for (size_t i = 0; i < sliceEnd.size(); ++i) { |
| sliceShape[i] = sliceEnd[i] - sliceStart[i]; |
| } |
| Type elementType = gather.getType().cast<TensorType>().getElementType(); |
| auto sliceType = RankedTensorType::get(sliceShape, elementType); |
| Value result = rewriter.create<SliceOp>( |
| gather.getLoc(), sliceType, gather.getOperand(0), |
| rewriter.getI64TensorAttr(sliceStart), |
| rewriter.getI64TensorAttr(sliceEnd), |
| rewriter.getI64TensorAttr(sliceStride)); |
| |
| auto collapsedSliceDims = dnums.getCollapsedSliceDims(); |
| if (!collapsedSliceDims.empty()) { |
| llvm::SmallVector<int64_t, 8> reshapeShape; |
| for (size_t i = 0; i < sliceShape.size(); ++i) { |
| if (llvm::count(collapsedSliceDims, i) == 0) { |
| reshapeShape.push_back(sliceShape[i]); |
| } |
| } |
| auto reshapeType = RankedTensorType::get(reshapeShape, elementType); |
| result = rewriter.create<ReshapeOp>(gather.getLoc(), reshapeType, result); |
| } |
| |
| result.setType(gather.getType()); |
| rewriter.replaceOp(gather, result); |
| return success(); |
| } |
| }; |
| |
| void GatherOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<GatherSlice>(context); |
| } |
| |
| namespace { |
| |
| // following https://www.tensorflow.org/xla/operation_semantics#gather |
| // The bounds for the output array along dimension i is computed as follows: |
| // (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k) |
| // then we pick |
| // the corresponding dimension bounds out of start_indices.shape, skipping |
| // index_vector_dim |
| // (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and |
| // start_indices.shape.dims[k+1] otherwise). |
| // (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k) |
| // then we pick |
| // the corresponding bound out of slice_sizes after accounting for |
| // collapsed_slice_dims |
| // (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is |
| // slice_sizes with the bounds at indices collapsed_slice_dims removed). |
| |
| void getSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc, |
| ValueRange operands, |
| SmallVectorImpl<Value>& sliceSizes) { |
| for (int64_t val : gather->slice_sizes().getValues<int64_t>()) { |
| sliceSizes.push_back(builder.create<arith::ConstantIndexOp>(loc, val)); |
| } |
| } |
| |
| void getSliceSizeValues(DynamicGatherOp* /*dGather*/, OpBuilder& builder, |
| Location loc, ValueRange operands, |
| SmallVectorImpl<Value>& sliceSizeValues) { |
| DynamicGatherOp::Adaptor adaptor(operands); |
| Value sliceSizes = adaptor.slice_sizes(); |
| auto sliceSizesTy = sliceSizes.getType().cast<ShapedType>(); |
| for (int64_t i = 0; i < sliceSizesTy.getDimSize(0); ++i) { |
| Value idx = builder.create<arith::ConstantIndexOp>(loc, i); |
| sliceSizeValues.push_back( |
| builder.create<tensor::ExtractOp>(loc, sliceSizes, idx)); |
| } |
| } |
| |
| // Verify the following properties: |
| // P1. Verify no repeat in start_index_map. |
| // P2. Verify 0 <= start_index_map[i] < rank(operand), for every i. |
| // P3. Verify 0 <= index_vector_dim <= rank(start_indices). |
| // P4. Verify size(start_index_map) == shape(start_indices)[index_vector_dim]. |
| // P5. Verify offset_dims is_sorted and no repeated. |
| // P6. Verify collapsed_slice_dims is_sorted and no repeated. |
| // P7. Verify rank(operand) == size(offset_dims) + size(collapsed_slice_dims). |
| // P8. Verify slice_sizes has rank of 1. |
| // P9. Verify size(slice_sizes) == rank(operand). |
| // P10. Verify 0 <= collapsed_slice_dims[i] < size(slice_sizes) for all items. |
| static LogicalResult verifyGather( |
| ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape, |
| ShapeAdaptor sliceSizesShape, GatherDimensionNumbersAttr dimensionNumbers, |
| llvm::function_ref<InFlightDiagnostic()> errorEmitter) { |
| int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim(); |
| |
| // Check startIndexMap |
| auto startIndexMap = to_vector(dimensionNumbers.getStartIndexMap()); |
| // P1. |
| if (hasDuplicates(startIndexMap)) |
| return errorEmitter() << "expects start_index_map to not repeat, got: [" |
| << startIndexMap << "]"; |
| |
| // P2. |
| for (int i = 0; i < startIndexMap.size(); ++i) |
| if (startIndexMap[i] < 0 || |
| (operandShape.hasRank() && startIndexMap[i] >= operandShape.getRank())) |
| return errorEmitter() |
| << "start_index_map[" << i << "]: " << startIndexMap[i] |
| << " is out of bounds for " |
| << "operand rank " << operandShape.getRank(); |
| |
| if (startIndicesShape.hasRank()) { |
| // P3. |
| // index_vector_dim == start_indices.rank implies a trailing 1 on the shape |
| // of start_indices. |
| if (indexVectorDim > startIndicesShape.getRank() || indexVectorDim < 0) |
| return errorEmitter() << "index_vector_dim " << indexVectorDim |
| << " is out of bounds for start indices with rank " |
| << startIndicesShape.getRank(); |
| |
| bool impliedTrailingDim = indexVectorDim == startIndicesShape.getRank(); |
| if (impliedTrailingDim || !startIndicesShape.isDynamicDim(indexVectorDim)) { |
| int64_t effectiveDimSize; |
| if (impliedTrailingDim) |
| effectiveDimSize = 1; |
| else |
| effectiveDimSize = startIndicesShape.getDimSize(indexVectorDim); |
| // P4. |
| if (effectiveDimSize != |
| static_cast<int64_t>(dimensionNumbers.getStartIndexMap().size())) |
| return errorEmitter() << "start_index_map size (" |
| << dimensionNumbers.getStartIndexMap().size() |
| << ") is not equal to size of index dimension (" |
| << indexVectorDim << ") of start_indices (" |
| << effectiveDimSize << ")"; |
| } |
| } |
| |
| // P5. |
| auto offsetDims = to_vector(dimensionNumbers.getOffsetDims()); |
| if (!llvm::is_sorted(offsetDims)) |
| return errorEmitter() << "expects offset_dims to be sorted, got: [" |
| << offsetDims << "]"; |
| if (hasDuplicates(offsetDims)) |
| return errorEmitter() << "expects offset_dims to not repeat, got: [" |
| << offsetDims << "]"; |
| |
| // P6. |
| auto collapsedSliceDims = to_vector(dimensionNumbers.getCollapsedSliceDims()); |
| if (!llvm::is_sorted(collapsedSliceDims)) |
| return errorEmitter() << "expects collapsed_slice_dims to be sorted, got: [" |
| << collapsedSliceDims << "]"; |
| if (hasDuplicates(collapsedSliceDims)) |
| return errorEmitter() |
| << "expects collapsed_slice_dims to not repeat, got: [" |
| << collapsedSliceDims << "]"; |
| |
| // P7. |
| int64_t impliedOperandRank = dimensionNumbers.getOffsetDims().size() + |
| dimensionNumbers.getCollapsedSliceDims().size(); |
| if (operandShape.hasRank() && operandShape.getRank() != impliedOperandRank) |
| return errorEmitter() << "offset_dims size (" |
| << dimensionNumbers.getOffsetDims().size() |
| << ") plus collapse_slice_dims size (" |
| << dimensionNumbers.getCollapsedSliceDims().size() |
| << ") is not equal to operand rank (" |
| << operandShape.getRank() << ")"; |
| |
| // P8. |
| // This should be fully expressible with type constraints, but it isn't |
| // obvious how to do that with the current infrastructure. |
| if (sliceSizesShape.hasRank() && sliceSizesShape.getRank() != 1) |
| return errorEmitter() << "slice_sizes.rank != 1"; |
| if (sliceSizesShape.hasStaticShape()) { |
| int64_t sliceSize = sliceSizesShape.getNumElements(); |
| |
| // P9. |
| if (sliceSize != impliedOperandRank) |
| return errorEmitter() << "slice_sizes size (" << sliceSize |
| << ") not equal to (implied) operand rank (" |
| << impliedOperandRank << ")"; |
| |
| // P10. |
| for (auto dim : dimensionNumbers.getCollapsedSliceDims()) |
| if (dim < 0 || dim >= sliceSize) |
| return errorEmitter() << "collapsed dimension " << dim |
| << " is out of bounds for slice_sizes.size (" |
| << sliceSize << ")"; |
| } |
| |
| return success(); |
| } |
| |
| // Verify the following properties: |
| // P1. Verifications by verifyGather(). |
| // P2. Verify slice_sizes[i] <= 1 for i in collapsed_slice_dims. |
| // P3. Verify 0 <= slice_sizes[i] < shape(operand)[i], for every i. |
| static LogicalResult verifyStaticGather( |
| ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape, |
| DenseIntElementsAttr sliceSizes, |
| GatherDimensionNumbersAttr dimensionNumbers, |
| llvm::function_ref<InFlightDiagnostic()> errorEmitter) { |
| // P1. |
| // For some reason the getType call is necessary here |
| if (failed(verifyGather( |
| /*operandShape=*/operandShape, |
| /*startIndicesShape=*/startIndicesShape, |
| /*sliceSizesShape=*/sliceSizes.getType(), dimensionNumbers, |
| errorEmitter))) |
| return failure(); |
| |
| // P2. |
| for (auto dim : dimensionNumbers.getCollapsedSliceDims()) { |
| int64_t sliceDimSize = sliceSizes.getValues<int64_t>()[dim]; |
| if (sliceDimSize > 1) { |
| return errorEmitter() << "slice_sizes collapsed dimension " << dim |
| << " should <= 1 but got " << sliceDimSize; |
| } |
| } |
| |
| // P3. |
| if (operandShape.hasRank()) { |
| for (const auto& it : llvm::enumerate(sliceSizes.getValues<int64_t>())) { |
| if (operandShape.isDynamicDim(it.index())) continue; |
| auto operandDimSize = operandShape.getDimSize(it.index()); |
| auto sliceDimSize = it.value(); |
| if (sliceDimSize < 0 || sliceDimSize > operandDimSize) |
| return errorEmitter() << "slice size (" << sliceDimSize |
| << ") is out of bounds for operand dimension (" |
| << operandDimSize << ") at index " << it.index(); |
| } |
| } |
| return success(); |
| } |
| |
| template <typename dimTy> |
| static void inferGatherShape( |
| int64_t resultRank, llvm::function_ref<dimTy(int64_t)> getStartIndicesDim, |
| llvm::function_ref<dimTy(int64_t)> getSliceDim, |
| GatherDimensionNumbersAttr dimensionNumbers, |
| SmallVectorImpl<dimTy>& shape) { |
| ArrayRef<int64_t> collapsedSliceDims = |
| dimensionNumbers.getCollapsedSliceDims(); |
| int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim(); |
| |
| // We don't necessarily know the rank of sliceSizes, but we do know that it |
| // can't be larger than the highest collapsed dimension. So go through those |
| // and populate the leading dimensions of adjustedSliceSizes. The trailing |
| // dimensions can just be adjusted by an offset. |
| const auto* maxCollapsedDimIt = |
| std::max_element(collapsedSliceDims.begin(), collapsedSliceDims.end()); |
| int64_t maxCollapsedDim = -1; |
| if (maxCollapsedDimIt != collapsedSliceDims.end()) |
| maxCollapsedDim = *maxCollapsedDimIt; |
| |
| SmallVector<dimTy> adjustedSliceSizePrefix; |
| for (int dimIndex = 0; dimIndex <= maxCollapsedDim; ++dimIndex) { |
| if (llvm::is_contained(collapsedSliceDims, dimIndex)) continue; |
| adjustedSliceSizePrefix.push_back(getSliceDim(dimIndex)); |
| } |
| auto getAdjustedSliceDim = [&](int64_t index) -> dimTy { |
| if (index < static_cast<int64_t>(adjustedSliceSizePrefix.size())) |
| return adjustedSliceSizePrefix[index]; |
| return getSliceDim(index + collapsedSliceDims.size()); |
| }; |
| |
| ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims(); |
| |
| // Dimensions in the output that aren't offset dimensions are called batch |
| // dimensions. |
| SmallVector<int64_t> batchDims; |
| for (int dim = 0; dim < resultRank; ++dim) |
| if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim); |
| |
| for (int i = 0; i < resultRank; ++i) { |
| const auto* offsetDimsIt = |
| std::find(offsetDims.begin(), offsetDims.end(), i); |
| if (offsetDimsIt != offsetDims.end()) { |
| auto index = std::distance(offsetDims.begin(), offsetDimsIt); |
| shape.push_back(getAdjustedSliceDim(index)); |
| continue; |
| } |
| auto* batchDimsIt = std::find(batchDims.begin(), batchDims.end(), i); |
| assert(batchDimsIt != batchDims.end()); |
| auto index = std::distance(batchDims.begin(), batchDimsIt); |
| // This can never run into the special case where start_indices gets |
| // implicitly expanded with a trailing 1 if |
| // index_vector_dim = start_indices.rank because then index would equal |
| // index_vector_dim, which means we'd be looking at index+1, which would be |
| // out of bounds anyway. |
| if (index >= indexVectorDim) ++index; |
| shape.push_back(getStartIndicesDim(index)); |
| } |
| } |
| |
| // Verify the following properties: |
| // P1. Verify 0 <= offset_dims[i] < output_shape_rank, for every i. |
| // (output_shape_rank = size(offset_dims) + rank(start_indices) -1) |
| static LogicalResult inferGatherReturnTypeComponents( |
| ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape, |
| llvm::function_ref<int64_t(int64_t)> getSliceDim, |
| GatherDimensionNumbersAttr dimensionNumbers, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes, |
| llvm::function_ref<InFlightDiagnostic()> errorEmitter) { |
| Type elementType = operandShape.getElementType(); |
| |
| // We need this to determine the result rank. We could still place bounds on |
| // the result rank if that was something ShapedTypeComponents could express. |
| if (!startIndicesShape.hasRank()) { |
| inferredReturnShapes.push_back(elementType); |
| return success(); |
| } |
| |
| ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims(); |
| int64_t startIndicesRank = startIndicesShape.getRank(); |
| // If index_vector_dim == start_indices.rank, then an implicit trailing 1 is |
| // appended to start_indices shape. |
| if (dimensionNumbers.getIndexVectorDim() == startIndicesRank) |
| ++startIndicesRank; |
| int64_t resultRank = offsetDims.size() + startIndicesRank - 1; |
| // P1. |
| for (int i = 0; i < offsetDims.size(); ++i) |
| if (offsetDims[i] < 0 || offsetDims[i] >= resultRank) |
| return errorEmitter() << "offset_dims[" << i << "]: " << offsetDims[i] |
| << " is out of bounds for " |
| << "implied result rank " << resultRank; |
| |
| auto getStartIndicesDim = [&](int64_t index) { |
| return startIndicesShape.getDimSize(index); |
| }; |
| |
| SmallVector<int64_t> shape; |
| inferGatherShape<int64_t>(resultRank, getStartIndicesDim, getSliceDim, |
| dimensionNumbers, shape); |
| |
| inferredReturnShapes.emplace_back(shape, elementType); |
| return success(); |
| } |
| |
| template <typename Op> |
| LogicalResult reifyGatherShape(Op* op, OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| // No support for unranked gather output shape a.t.m. |
| auto resultTy = |
| op->getResult().getType().template dyn_cast<RankedTensorType>(); |
| if (!resultTy) return failure(); |
| |
| typename Op::Adaptor adaptor(operands); |
| Value startIndices = adaptor.start_indices(); |
| |
| Location loc = op->getLoc(); |
| int resultRank = resultTy.getRank(); |
| Type shapeElTy = startIndices.getType().cast<ShapedType>().getElementType(); |
| auto toShapeElType = [&](Value v) { |
| return maybeCastTo(builder, loc, v, shapeElTy); |
| }; |
| |
| SmallVector<Value, 4> sliceSizes; |
| getSliceSizeValues(op, builder, loc, operands, sliceSizes); |
| llvm::transform(sliceSizes, sliceSizes.begin(), |
| [&](Value v) { return toShapeElType(v); }); |
| |
| auto getStartIndicesDim = [&](int64_t index) { |
| return toShapeElType( |
| builder.create<tensor::DimOp>(loc, startIndices, index)); |
| }; |
| SmallVector<Value, 4> shapeValues; |
| auto getSliceDim = [&sliceSizes](int64_t index) -> Value { |
| return sliceSizes[index]; |
| }; |
| inferGatherShape<Value>(resultRank, getStartIndicesDim, getSliceDim, |
| op->dimension_numbers(), shapeValues); |
| |
| Value outputShape = builder.create<tensor::FromElementsOp>( |
| loc, RankedTensorType::get({resultRank}, shapeElTy), shapeValues); |
| reifiedReturnShapes.push_back(outputShape); |
| |
| return success(); |
| } |
| |
| } // namespace |
| |
| LogicalResult GatherOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| return reifyGatherShape(this, builder, operands, reifiedReturnShapes); |
| } |
| |
| // The following properties are already enforced by the ODS: |
| // P0. Verify the start_indices has element type of integer. |
| // Verify the following properties: |
| // Verifications by verifyStaticGather() and verifyGather() inside it. |
| // Verifications by inferGatherReturnTypeComponents. |
| LogicalResult GatherOp::inferReturnTypeComponents( |
| MLIRContext* context, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| // TODO(zhouxin) remove this comment after the ordering issue is clear. |
| // This can get called before other op verify methods, so we have to do a |
| // bunch of verification up front. With a better story for ordering and/or |
| // multi-phase op verification, this should hopefully all go away. |
| Location loc = location.getValueOr(UnknownLoc::get(context)); |
| auto errorEmitter = [&loc]() { |
| return mlir::emitError(loc) |
| << "'" << GatherOp::getOperationName() << "' op "; |
| }; |
| GatherOp::Adaptor adaptor(operands, attributes, regions); |
| if (failed(adaptor.verify(loc))) return failure(); |
| |
| // We want the ShapeAdaptors, so can't route via the adaptor :-/ |
| ShapeAdaptor operandShape = operands.getShape(0); |
| ShapeAdaptor startIndicesShape = operands.getShape(1); |
| GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers(); |
| DenseIntElementsAttr sliceSizesAttr = adaptor.slice_sizes(); |
| |
| if (failed(verifyStaticGather(/*operandShape=*/operandShape, |
| /*startIndicesShape=*/startIndicesShape, |
| /*sliceSizes=*/sliceSizesAttr, dimensionNumbers, |
| errorEmitter))) |
| return failure(); |
| |
| auto getSliceDim = [&sliceSizesAttr](int64_t index) -> int64_t { |
| return sliceSizesAttr.getValues<int64_t>()[index]; |
| }; |
| |
| return inferGatherReturnTypeComponents(operandShape, startIndicesShape, |
| getSliceDim, dimensionNumbers, |
| inferredReturnShapes, errorEmitter); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DynamicGatherOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult DynamicGatherOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| return reifyGatherShape(this, builder, operands, reifiedReturnShapes); |
| } |
| |
| LogicalResult DynamicGatherOp::inferReturnTypeComponents( |
| MLIRContext* context, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| // This can get called before other op verify methods, so we have to do a |
| // bunch of verification up front. With a better story for ordering and/or |
| // multi-phase op verification, this should hopefully all go away. |
| Location loc = location.getValueOr(UnknownLoc::get(context)); |
| auto errorEmitter = [&loc]() { |
| return mlir::emitError(loc) |
| << "'" << DynamicGatherOp::getOperationName() << "' op "; |
| }; |
| DynamicGatherOp::Adaptor adaptor(operands, attributes, regions); |
| if (failed(adaptor.verify(loc))) return failure(); |
| |
| // We want the ShapeAdaptors, so can't route via the adaptor :-/ |
| ShapeAdaptor operandShape = operands.getShape(0); |
| ShapeAdaptor startIndicesShape = operands.getShape(1); |
| ShapeAdaptor sliceSizesShape = operands.getShape(2); |
| GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers(); |
| |
| if (failed(verifyGather(/*operandShape=*/operandShape, |
| /*startIndicesShape=*/startIndicesShape, |
| /*sliceSizesShape=*/sliceSizesShape, dimensionNumbers, |
| errorEmitter))) |
| return failure(); |
| |
| auto getSliceDim = [](int64_t index) { return ShapedType::kDynamicSize; }; |
| return inferGatherReturnTypeComponents(operandShape, startIndicesShape, |
| getSliceDim, dimensionNumbers, |
| inferredReturnShapes, errorEmitter); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // GetDimensionSizeOp |
| //===----------------------------------------------------------------------===// |
| // |
| LogicalResult GetDimensionSizeOp::verify() { return verifyDimAttr(*this); } |
| |
| /// Fold get_dimension_size when the said shape dimension is a constant. |
| OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) { |
| RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>(); |
| if (!type) return {}; |
| |
| int32_t dim = dimension(); |
| if (type.isDynamicDim(dim)) return {}; |
| // The result type is always is a 0-d i32 tensor. |
| return DenseIntElementsAttr::get<int32_t>( |
| getResult().getType().cast<RankedTensorType>(), type.getDimSize(dim)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // IotaOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult IotaOp::verify() { |
| auto shape = getType().cast<ShapedType>(); |
| if (!shape.hasRank()) return success(); |
| |
| if (shape.getRank() == 0) return emitOpError() << "does not support scalars."; |
| |
| auto iotaDimension = this->iota_dimension(); |
| if (static_cast<int64_t>(iotaDimension) >= shape.getRank() || |
| iotaDimension < 0) |
| return emitOpError() |
| << "iota dimension cannot go beyond the output rank or be negative."; |
| return success(); |
| } |
| |
| // Iota operations across multiple dimensions can be reduced to an iota and a |
| // ranked broadcast. |
| struct IotaBroadcast : public OpRewritePattern<IotaOp> { |
| using OpRewritePattern<IotaOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(IotaOp iota, |
| PatternRewriter& rewriter) const override { |
| auto resultTy = iota.getType().cast<ShapedType>(); |
| if (!resultTy.hasRank() || resultTy.getRank() < 2) { |
| return failure(); |
| } |
| |
| auto iotaDimension = iota.iota_dimension(); |
| |
| auto iotaType = RankedTensorType::get({resultTy.getDimSize(iotaDimension)}, |
| resultTy.getElementType()); |
| |
| auto newIota = rewriter.create<IotaOp>(iota.getLoc(), iotaType, |
| rewriter.getI64IntegerAttr(0)); |
| |
| auto broadcastAttr = DenseIntElementsAttr::get( |
| RankedTensorType::get({1}, rewriter.getIntegerType(64)), |
| {iotaDimension}); |
| rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, resultTy, newIota, |
| broadcastAttr); |
| return success(); |
| } |
| }; |
| |
| void IotaOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<IotaBroadcast>(context); |
| } |
| |
| OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) { |
| auto dimension = iota_dimension(); |
| auto resultTy = getResult().getType().cast<ShapedType>(); |
| if (resultTy.hasRank() && resultTy.getDimSize(dimension) == 1) { |
| Builder builder(getContext()); |
| return builder.getZeroAttr(resultTy); |
| } |
| |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DynamicIotaOp |
| //===----------------------------------------------------------------------===// |
| |
| // Does the same as PatternRewriter::replaceOpWithNewOp, but with a twist. |
| // |
| // Sometimes, we want to replace an op with a new op and simultaneously refine |
| // the result type from a dynamically-shaped type to a statically-shaped type. |
| // (Search for usages of this function for examples). |
| // |
| // Oftentimes, this works just fine because MHLO is designed to accommodate |
| // this kind of type refinements. But sometimes, this doesn't work - when |
| // the op is used outside of the MHLO dialect (e.g. in func.return). In these |
| // cases, we insert a tensor.cast to smooth things out. |
| template <typename OpTy, typename... Args> |
| OpTy refineOpWithNewOp(PatternRewriter& rewriter, Operation* op, |
| Args&&... args) { |
| auto newOp = rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...); |
| |
| llvm::SmallVector<Value> replacementResults; |
| assert(op->getNumResults() == newOp->getNumResults() && |
| "replacement op doesn't match results of original op"); |
| for (auto [opResult, newOpResult] : |
| llvm::zip(op->getResults(), newOp->getResults())) { |
| Value replacementResult = newOpResult; |
| if (llvm::any_of(opResult.getUsers(), [&](Operation* user) { |
| return user->getDialect() != op->getDialect(); |
| })) { |
| replacementResult = rewriter.create<tensor::CastOp>( |
| op->getLoc(), opResult.getType(), newOpResult); |
| } |
| replacementResults.push_back(replacementResult); |
| } |
| |
| rewriter.replaceOp(op, replacementResults); |
| return newOp; |
| } |
| |
| namespace { |
| |
| struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> { |
| using OpRewritePattern<DynamicIotaOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(DynamicIotaOp iota, |
| PatternRewriter& rewriter) const override { |
| // Result type has static shape, replace with iota. |
| auto resultTy = iota.getType().cast<ShapedType>(); |
| if (resultTy.hasStaticShape()) { |
| rewriter.replaceOpWithNewOp<IotaOp>(iota, resultTy, |
| iota.iota_dimension()); |
| return success(); |
| } |
| |
| // Output shape is constant, compute result type with static shape, then |
| // replace with iota. |
| DenseIntElementsAttr outputShapeAttr; |
| if (matchPattern(iota.output_shape(), m_Constant(&outputShapeAttr))) { |
| SmallVector<int64_t> outputShape; |
| for (APInt dim : outputShapeAttr.getValues<APInt>()) { |
| outputShape.push_back(dim.getSExtValue()); |
| } |
| resultTy = RankedTensorType::get(outputShape, resultTy.getElementType()); |
| refineOpWithNewOp<IotaOp>(rewriter, iota, resultTy, |
| iota.iota_dimension()); |
| return success(); |
| } |
| |
| return rewriter.notifyMatchFailure( |
| iota, "requires static shape or constant output shape"); |
| } |
| }; |
| |
| // Dynamic Iota operations across multiple dimensions can be reduced to an iota |
| // and a ranked broadcast. |
| struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> { |
| using OpRewritePattern<DynamicIotaOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(DynamicIotaOp iota, |
| PatternRewriter& rewriter) const override { |
| auto resultTy = iota.getType().cast<ShapedType>(); |
| if (!resultTy.hasRank() || resultTy.getRank() < 2) { |
| return failure(); |
| } |
| |
| auto iotaDimension = iota.iota_dimension(); |
| auto iotaDimensionInt = iotaDimension; |
| |
| auto convertedShape = rewriter.create<arith::IndexCastOp>( |
| iota.getLoc(), |
| RankedTensorType::get( |
| iota.output_shape().getType().cast<ShapedType>().getShape(), |
| rewriter.getI64Type()), |
| iota.output_shape()); |
| |
| auto slicedShape = rewriter.create<SliceOp>( |
| iota.getLoc(), convertedShape, |
| rewriter.getI64TensorAttr(iotaDimensionInt), |
| rewriter.getI64TensorAttr(iotaDimensionInt + 1), |
| rewriter.getI64TensorAttr(1)); |
| |
| auto convertedSlicedShape = rewriter.create<arith::IndexCastOp>( |
| iota.getLoc(), |
| RankedTensorType::get( |
| {1}, |
| iota.output_shape().getType().cast<ShapedType>().getElementType()), |
| slicedShape); |
| |
| auto iotaType = RankedTensorType::get( |
| {resultTy.getDimSize(iotaDimensionInt)}, resultTy.getElementType()); |
| |
| auto newIota = rewriter.create<DynamicIotaOp>( |
| iota.getLoc(), iotaType, convertedSlicedShape, |
| rewriter.getI64IntegerAttr(0)); |
| |
| auto broadcastAttr = DenseIntElementsAttr::get( |
| RankedTensorType::get({1}, rewriter.getIntegerType(64)), |
| {iotaDimension}); |
| rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>( |
| iota, resultTy, newIota, iota.output_shape(), broadcastAttr); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void DynamicIotaOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<DynamicIotaIsStatic>(context); |
| results.add<DynamicIotaBroadcast>(context); |
| } |
| |
| static Value castToIndexTensor(OpBuilder& builder, Location loc, |
| Value shapeOp) { |
| ShapedType resultTy = shape::getExtentTensorType( |
| builder.getContext(), shapeOp.getType().cast<ShapedType>().getDimSize(0)); |
| if (shapeOp.getType() == resultTy) return shapeOp; // Nothing to do. |
| return builder.create<arith::IndexCastOp>(loc, resultTy, shapeOp); |
| } |
| |
| LogicalResult DynamicIotaOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| DynamicIotaOp::Adaptor adaptor(operands); |
| reifiedReturnShapes.push_back( |
| castToIndexTensor(builder, getLoc(), adaptor.output_shape())); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DynamicUpdateSliceOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult DynamicUpdateSliceOp::verify() { |
| OperandRange indices = start_indices(); |
| if (indices.size() <= 1) return success(); |
| |
| // Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it |
| // is OK to cast indices to ShapedType here. |
| auto idxTensor = indices.take_front().front().getType().cast<ShapedType>(); |
| Type firstElemTy = idxTensor.getElementType(); |
| Type elemTy; |
| |
| for (auto idx : llvm::drop_begin(indices, 1)) { |
| idxTensor = idx.getType().cast<ShapedType>(); |
| elemTy = idxTensor.getElementType(); |
| |
| if (firstElemTy != elemTy) { |
| return emitOpError() << "start indices must have same element type " |
| "(encountered mismatch: " |
| << firstElemTy << " vs " << elemTy << ")"; |
| } |
| } |
| return success(); |
| } |
| |
| OpFoldResult DynamicUpdateSliceOp::fold(ArrayRef<Attribute> operands) { |
| auto operandShape = this->operand().getType().cast<RankedTensorType>(); |
| auto updateShape = this->update().getType().cast<RankedTensorType>(); |
| |
| // If any of the dimensions are length-0, the update does nothing. |
| for (auto dim : updateShape.getShape()) { |
| if (dim == 0) { |
| return this->operand(); |
| } |
| } |
| |
| if (operandShape != updateShape || !operandShape.hasStaticShape()) { |
| return {}; |
| } |
| |
| // Ensure that indices are 0 constants. The 0 check mostly ensures |
| // correctness. For non-constants, the pattern does not fold to avoid hiding |
| // the behavior of incorrect user input. |
| for (Value index : this->start_indices()) { |
| DenseIntElementsAttr deAttr; |
| if (!matchPattern(index, m_Constant(&deAttr))) return {}; |
| if (!deAttr.getSplatValue<IntegerAttr>().getValue().isZero()) return {}; |
| } |
| return this->update(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AbsOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult AbsOp::inferReturnTypes( |
| MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr, |
| RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) { |
| auto operandTy = (*operands.begin()).getType().cast<ShapedType>(); |
| Type elementTy = operandTy.getElementType(); |
| if (auto complexTy = elementTy.dyn_cast<ComplexType>()) { |
| elementTy = complexTy.getElementType(); |
| } |
| |
| Type resultTy; |
| if (auto rankedOperandTy = operandTy.dyn_cast<RankedTensorType>()) { |
| resultTy = RankedTensorType::get(operandTy.getShape(), elementTy, |
| rankedOperandTy.getEncoding()); |
| } else if (operandTy.hasRank()) { |
| resultTy = RankedTensorType::get(operandTy.getShape(), elementTy); |
| } else { |
| resultTy = UnrankedTensorType::get(elementTy); |
| } |
| inferredReturnTypes.push_back(resultTy); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CollectivePermuteOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult CollectivePermuteOp::verify() { |
| return mlir::hlo::verifyCollectivePermuteSourceTargetPairs( |
| *this, source_target_pairs()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ConvolutionOp |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| // Checks: |
| // P1. Same sizes for input, kernel and output spatial_dims. |
| // P2. Spatial and non-spatial dimentions (for input,kernel, &output) should |
| // be unique and in range [0, num_dims), where num_dims = rank of input |
| // (lhs/rhs) tensors. |
| // |
| // Note that the spatial + non-spatial dimensions may not cover all the |
| // dimensions in the range [0,num) because of the presence of 'unknown' |
| // dimensions (ref. cl/415132294). |
| LogicalResult isSpatialDimensionsValid(ConvolutionOp op) { |
| auto inputSpatialDimensions = |
| op.dimension_numbers().getInputSpatialDimensions(); |
| auto kernelSpatialDimensions = |
| op.dimension_numbers().getKernelSpatialDimensions(); |
| auto outputSpatialDimensions = |
| op.dimension_numbers().getOutputSpatialDimensions(); |
| |
| // P1. |
| if ((inputSpatialDimensions.size() != kernelSpatialDimensions.size()) || |
| (inputSpatialDimensions.size() != outputSpatialDimensions.size())) |
| return op.emitOpError() << "expects the same size for input, kernel and " |
| "output spatial-dimensions, but got " |
| << inputSpatialDimensions.size() << ", " |
| << kernelSpatialDimensions.size() << ", and " |
| << outputSpatialDimensions.size() << " resp."; |
| |
| // P2. |
| SmallVector<int64_t> inputDnums(inputSpatialDimensions.size() + 2); |
| inputDnums[0] = op.dimension_numbers().getInputBatchDimension(); |
| inputDnums[1] = op.dimension_numbers().getInputFeatureDimension(); |
| std::copy(inputSpatialDimensions.begin(), inputSpatialDimensions.end(), |
| inputDnums.begin() + 2); |
| |
| SmallVector<int64_t> windowDnums(kernelSpatialDimensions.size() + 2); |
| windowDnums[0] = op.dimension_numbers().getKernelInputFeatureDimension(); |
| windowDnums[1] = op.dimension_numbers().getKernelOutputFeatureDimension(); |
| std::copy(kernelSpatialDimensions.begin(), kernelSpatialDimensions.end(), |
| windowDnums.begin() + 2); |
| |
| SmallVector<int64_t> outputDnums(outputSpatialDimensions.size() + 2); |
| outputDnums[0] = op.dimension_numbers().getOutputBatchDimension(); |
| outputDnums[1] = op.dimension_numbers().getOutputFeatureDimension(); |
| std::copy(outputSpatialDimensions.begin(), outputSpatialDimensions.end(), |
| outputDnums.begin() + 2); |
| |
| auto numDims = op.lhs().getType().cast<RankedTensorType>().getRank(); |
| const auto inRange = [numDims](int64_t i) { return 0 <= i && i < numDims; }; |
| |
| if (!llvm::all_of(inputDnums, inRange) || |
| !llvm::all_of(windowDnums, inRange) || |
| !llvm::all_of(outputDnums, inRange)) |
| return op.emitOpError() << "expects input, kernel, and output " |
| "dimension-numbers to be in-range [0, " |
| << numDims << ")."; |
| |
| if (hasDuplicates(inputDnums)) |
| return op.emitOpError() |
| << "expects input dimension-numbers to be unique, got {" |
| << inputDnums << "}."; |
| |
| if (hasDuplicates(windowDnums)) |
| return op.emitOpError() |
| << "expects kernel dimension-numbers to be unique, got {" |
| << windowDnums << "}."; |
| |
| if (hasDuplicates(outputDnums)) |
| return op.emitOpError() |
| << "expects output dimension-numbers to be unique, got {" |
| << outputDnums << "}."; |
| |
| return success(); |
| } |
| |
| // Verifies the following properties: |
| // P1. The input, kernel, and output spatial-dimentions are valid. |
| // P2. Given, |
| // input-dimensions: b * input-spatial-dims * f |
| // kernel-dimensions: kernel-spatial-dims * i * o |
| // output-dimensions: b' * out-spatial-dims * f' |
| // where b = input-batch-dims |
| // where f = input-feature-dims |
| // where i = kernel-input-feature-dims |
| // where o = kernel-output-feature-dims |
| // where b' = output-batch-dims |
| // where f' = output-feature-dims |
| // Check the following properties w.r.t feature_group_count (fgc) and |
| // batch_group_count (bgc). |
| // fgc > 0, bgc > 1 and !(fgc > 1 && bgc > 1) |
| // b % bgc == 0 |
| // f % fgc == 0 and i = f / fgc |
| // o (or f') % bgc == 0 and o (or f') % fgc == 0 |
| LogicalResult verifyConvolutionAttributes(ConvolutionOp op) { |
| // P1. |
| if (failed(isSpatialDimensionsValid(op))) return failure(); |
| |
| // P2. |
| const int64_t featureGroupCount = op.feature_group_count(); |
| const int64_t batchGroupCount = op.batch_group_count(); |
| |
| if (featureGroupCount <= 0) |
| return op.emitOpError() |
| << "expects feature_group_count to be a positive number, got " |
| << featureGroupCount << "."; |
| |
| if (batchGroupCount <= 0) |
| return op.emitOpError() |
| << "expects batch_group_count to be a positive number, got " |
| << batchGroupCount << "."; |
| |
| if (batchGroupCount > 1 && featureGroupCount > 1) |
| return op.emitOpError() |
| << "expects batch_group_count and feature_group_count not to be " |
| "both greater than 1. Got " |
| << batchGroupCount << " and " << featureGroupCount << " resp."; |
| |
| auto lhsType = op.lhs().getType().cast<RankedTensorType>(); |
| const int64_t inputFeatures = |
| lhsType.getShape()[op.dimension_numbers().getInputFeatureDimension()]; |
| const int64_t inputBatch = |
| lhsType.getShape()[op.dimension_numbers().getInputBatchDimension()]; |
| |
| auto rhsType = op.rhs().getType().cast<RankedTensorType>(); |
| const int64_t kernelInputFeatures = |
| rhsType |
| .getShape()[op.dimension_numbers().getKernelInputFeatureDimension()]; |
| const int64_t kernelOutputFeatures = |
| rhsType |
| .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()]; |
| |
| if (!isDynamicDimSize(kernelOutputFeatures)) { |
| if (kernelOutputFeatures % batchGroupCount != 0) |
| return op.emitOpError() << "expects output feature dimension size (" |
| << kernelOutputFeatures |
| << ") to be a multiple of " |
| "batch_group_count. Got batch_group_count = " |
| << batchGroupCount << "."; |
| |
| if (kernelOutputFeatures % featureGroupCount != 0) |
| return op.emitOpError() |
| << "expects kernel output feature dimension (" |
| << kernelOutputFeatures |
| << ") to be divisible by " |
| "feature_group_count. For feature_group_count = " |
| << featureGroupCount << "."; |
| } |
| |
| if (!isDynamicDimSize(inputFeatures)) { |
| if (inputFeatures % featureGroupCount != 0) |
| return op.emitOpError() |
| << "expects input feature dimension (" << inputFeatures |
| << ") to be a multiple of " |
| "feature_group_count. Got feature_group_count = " |
| << featureGroupCount << "."; |
| |
| if (!isDynamicDimSize(kernelInputFeatures) && |
| inputFeatures / featureGroupCount != kernelInputFeatures) |
| return op.emitOpError() |
| << "expects input feature dimension (" << inputFeatures |
| << ") / " |
| "feature_group_count = kernel input feature dimension (" |
| << kernelInputFeatures |
| << "). Got feature_group_count = " << featureGroupCount << "."; |
| } |
| |
| if (!isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0) |
| return op.emitOpError() << "expects input batch dimension (" << inputBatch |
| << ") to be divisible by " |
| "batch_group_count. Got batch_group_count = " |
| << batchGroupCount << "."; |
| |
| return success(); |
| } |
| |
| // Infer the return-shape of ConvolutionOp. |
| // Precondition: |
| // 1. Input args to ConvolutionOp 'op' are RankedTypes. |
| // 2. rank-of(input-type) == rank-of(output-type) |
| SmallVector<int64_t> inferConvolutionOpReturnShape( |
| ConvolutionOp op, const ArrayRef<WindowDimension> window) { |
| // We keep the 'unknown' dimensions (cl/415132294) as it is in the |
| // output-shape. To do that we initilize the output dimensions with the shape |
| // of the return-type and updates only the spatial + non-spatial dimensions. |
| // Precondition 2 ensures that size of output-shape == size of input-shape. |
| SmallVector<int64_t> outputDimensions = |
| to_vector(op.getResult().getType().cast<ShapedType>().getShape()); |
| |
| // Infer the output spatial dimensions. |
| auto lhsType = op.lhs().getType().cast<RankedTensorType>(); |
| auto inputSpatialDims = op.dimension_numbers().getInputSpatialDimensions(); |
| auto numSpatialDims = inputSpatialDims.size(); |
| SmallVector<int64_t> inputSpatialDimVals(numSpatialDims); |
| for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i) |
| inputSpatialDimVals[i] = lhsType.getShape()[inputSpatialDims[i]]; |
| |
| auto windowOutputShape = inferWindowOutputShape(inputSpatialDimVals, window); |
| |
| for (int64_t i = 0; i < static_cast<int64_t>(window.size()); ++i) |
| outputDimensions[op.dimension_numbers().getOutputSpatialDimensions()[i]] = |
| windowOutputShape[i]; |
| |
| // Infer the output-batch-dimension and output-feature-dimension. |
| auto rhsType = op.rhs().getType().cast<RankedTensorType>(); |
| const int64_t inputBatch = |
| lhsType.getShape()[op.dimension_numbers().getInputBatchDimension()]; |
| const int64_t kernelOutputFeatures = |
| rhsType |
| .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()]; |
| |
| outputDimensions[op.dimension_numbers().getOutputBatchDimension()] = |
| isDynamicDimSize(inputBatch) ? ShapedType::kDynamicSize |
| : inputBatch / op.batch_group_count(); |
| outputDimensions[op.dimension_numbers().getOutputFeatureDimension()] = |
| kernelOutputFeatures; |
| |
| return outputDimensions; |
| } |
| |
| // Some mhlo.convolutions are dot products, specifically when there is no |
| // padding and no spatial dimensions. DotGeneralOp is general enough that it |
| // can sufficiently describe it. |
| struct ConvolutionIsDot : public OpRewritePattern<mhlo::ConvolutionOp> { |
| using OpRewritePattern<mhlo::ConvolutionOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(mhlo::ConvolutionOp op, |
| PatternRewriter& rewriter) const override { |
| auto lhs = op.lhs(); |
| auto rhs = op.rhs(); |
| auto lhsTy = lhs.getType().cast<RankedTensorType>(); |
| auto rhsTy = rhs.getType().cast<RankedTensorType>(); |
| auto resultTy = op.getType().cast<RankedTensorType>(); |
| |
| if (lhsTy.getRank() != 2) return failure(); |
| if (rhsTy.getRank() != 2) return failure(); |
| |
| if (op.batch_group_count() != 1) return failure(); |
| |
| // There should not be any padding if this is a matmul. |
| auto dNums = op.dimension_numbers(); |
| assert(!op.padding() || op.padding()->empty()); |
| assert(dNums.getKernelSpatialDimensions().empty()); |
| |
| auto lhsBatchDim = dNums.getInputBatchDimension(); |
| auto rhsBatchDim = dNums.getKernelOutputFeatureDimension(); |
| auto lhsContractDim = dNums.getInputFeatureDimension(); |
| auto rhsContractDim = dNums.getKernelInputFeatureDimension(); |
| auto outBatchDim = dNums.getOutputBatchDimension(); |
| auto outFeatureDim = dNums.getOutputFeatureDimension(); |
| |
| // If the input features are not grouped then we can directly convert to an |
| // mhlo.dot_general. |
| if (op.feature_group_count() == 1) { |
| // We can swap the lhs and rhs sides to avoid a transpose. |
| if (outBatchDim == 1 && outFeatureDim == 0) { |
| std::swap(lhs, rhs); |
| std::swap(outBatchDim, outFeatureDim); |
| std::swap(lhsContractDim, rhsContractDim); |
| } |
| |
| auto dotNums = DotDimensionNumbersAttr::get( |
| op.getContext(), {}, {}, {lhsContractDim}, {rhsContractDim}); |
| auto dotOp = rewriter.create<mhlo::DotGeneralOp>( |
| op.getLoc(), op.getType(), lhs, rhs, dotNums, |
| op.precision_config().getValueOr(nullptr)); |
| |
| rewriter.replaceOp(op, dotOp.getResult()); |
| return success(); |
| } |
| |
| int64_t featureGroupCount = op.feature_group_count(); |
| int64_t lhsBatchSize = lhsTy.getDimSize(lhsBatchDim); |
| int64_t lhsContractSize = lhsTy.getDimSize(lhsContractDim); |
| int64_t rhsBatchSize = rhsTy.getDimSize(rhsBatchDim); |
| int64_t rhsContractSize = rhsTy.getDimSize(rhsContractDim); |
| |
| llvm::SmallVector<int64_t> lhsShape; |
| llvm::SmallVector<int64_t> rhsShape; |
| lhsShape.resize(3, lhsBatchSize); |
| rhsShape.resize(3, rhsContractSize); |
| lhsShape[lhsContractDim] = featureGroupCount; |
| lhsShape[lhsContractDim + 1] = lhsContractSize / featureGroupCount; |
| rhsShape[rhsContractDim] = featureGroupCount; |
| rhsShape[rhsContractDim + 1] = rhsBatchSize / featureGroupCount; |
| |
| lhsTy = RankedTensorType::get(lhsShape, lhsTy.getElementType()); |
| rhsTy = RankedTensorType::get(rhsShape, rhsTy.getElementType()); |
| |
| lhs = rewriter.create<mhlo::ReshapeOp>(op.getLoc(), lhsTy, lhs); |
| rhs = rewriter.create<mhlo::ReshapeOp>(op.getLoc(), rhsTy, rhs); |
| |
| auto dotTy = RankedTensorType::get( |
| {featureGroupCount, lhsBatchSize, rhsBatchSize / featureGroupCount}, |
| resultTy.getElementType()); |
| |
| auto dotNums = DotDimensionNumbersAttr::get( |
| op.getContext(), {lhsContractDim}, {rhsContractDim}, |
| {lhsContractDim + 1}, {rhsContractDim == 0 ? 2 : 0}); |
| auto dotOp = rewriter.create<mhlo::DotGeneralOp>( |
| op.getLoc(), dotTy, lhs, rhs, dotNums, |
| op.precision_config().getValueOr(nullptr)); |
| |
| llvm::SmallVector<int64_t> perms; |
| perms.resize(3, dNums.getOutputBatchDimension() == 0 ? 0 : 2); |
| perms[0] = dNums.getOutputFeatureDimension(); |
| perms[2] = dNums.getOutputFeatureDimension() + 1; |
| |
| auto transposeTy = RankedTensorType::get( |
| {dotTy.getDimSize(perms[0]), dotTy.getDimSize(perms[1]), |
| dotTy.getDimSize(perms[2])}, |
| dotTy.getElementType()); |
| auto transposeOp = rewriter.create<mhlo::TransposeOp>( |
| op.getLoc(), transposeTy, dotOp, rewriter.getI64TensorAttr(perms)); |
| |
| rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(op, resultTy, transposeOp); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void ConvolutionOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<ConvolutionIsDot>(context); |
| } |
| |
| /* |
| * We intend to verify the following properties |
| * P1. Verify the input, kernel types. |
| * P2. Verify the convolution atributes. |
| * P3. Verify and collect the window atributes. |
| * P4. Verify the return shape. |
| * TODO(b/232574102): Verify the element-type of return-value. |
| */ |
| LogicalResult ConvolutionOp::verify() { |
| auto lhsType = lhs().getType().dyn_cast<RankedTensorType>(); |
| auto rhsType = rhs().getType().dyn_cast<RankedTensorType>(); |
| |
| if (!lhsType || !rhsType) return success(); |
| |
| // P1. |
| int numDims = lhsType.getRank(); |
| if (numDims != rhsType.getRank()) |
| return emitOpError() |
| << "expects convolution arguments to have same number of " |
| "dimensions. Got: " |
| << lhsType << " and " << rhsType << "."; |
| |
| if (numDims < 2) |
| return emitOpError() |
| << "expects convolution arguments to have >= 2 dimensions. " |
| "Got: " |
| << lhsType << " and " << rhsType << "."; |
| |
| // P2. |
| if (failed(verifyConvolutionAttributes(*this))) return failure(); |
| |
| // P3. |
| auto kernelSpatialDimensions = |
| dimension_numbers().getKernelSpatialDimensions(); |
| SmallVector<int64_t> windowDimensions(kernelSpatialDimensions.size()); |
| for (size_t i = 0; i < windowDimensions.size(); i++) |
| windowDimensions[i] = rhsType.getShape()[kernelSpatialDimensions[i]]; |
| |
| auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc()); |
| if (failed(paddingOrErr)) return failure(); |
| SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr; |
| |
| auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions( |
| windowDimensions, convertDenseIntAttr(window_strides()), padding, |
| convertDenseIntAttr(lhs_dilation()), convertDenseIntAttr(rhs_dilation()), |
| getLoc()); |
| if (failed(windowOrErr)) return failure(); |
| |
| // P4. |
| auto actualReturnType = getResult().getType().cast<TensorType>(); |
| auto actualReturnElementType = actualReturnType.getElementType(); |
| if (!actualReturnType.hasRank()) return success(); |
| |
| auto actualReturnRankedType = actualReturnType.cast<RankedTensorType>(); |
| if (numDims != actualReturnRankedType.getRank()) |
| return emitOpError() << "expects rank of convolution return-type to be " |
| "equal to input-ranks (" |
| << numDims << "), but got " |
| << actualReturnRankedType.getRank() << "."; |
| |
| auto expectedReturnShape = inferConvolutionOpReturnShape(*this, *windowOrErr); |
| auto expectedReturnType = |
| RankedTensorType::get(expectedReturnShape, actualReturnElementType); |
| if (failed(verifyCompatibleShape(expectedReturnType, actualReturnRankedType))) |
| return emitOpError() |
| << "has shape mismatch between the expected return-type (" |
| << expectedReturnType << ") and actual return-type (" |
| << actualReturnRankedType << ")."; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ConvertOp |
| //===----------------------------------------------------------------------===// |
| |
| void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand, |
| Type resultElementTy) { |
| Type resultTy; |
| Type operandTy = operand.getType(); |
| if (auto rankedTy = operandTy.dyn_cast<RankedTensorType>()) { |
| resultTy = RankedTensorType::get(rankedTy.getShape(), resultElementTy); |
| } else { |
| resultTy = UnrankedTensorType::get(resultElementTy); |
| } |
| build(builder, result, resultTy, operand); |
| } |
| |
| OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { |
| auto operandTy = getOperand().getType().cast<TensorType>(); |
| auto resultTy = getResult().getType().cast<TensorType>(); |
| if (operandTy == resultTy) return getOperand(); |
| |
| // If the result has non-static shape, a convert op is necessary to go from |
| // static shape to non-static shape. |
| if (!resultTy.hasStaticShape()) return {}; |
| |
| // If the operand is constant, we can do the conversion now. |
| auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>(); |
| if (!elementsAttr) return {}; |
| |
| // Prevent folding if the result is too large. |
| if (elementsAttr.getNumElements() > kFoldOpEltLimit) return {}; |
| return hlo::convertElementsAttr(elementsAttr, |
| getElementTypeOrSelf(getResult())); |
| } |
| |
| namespace { |
| |
| struct EliminateRedundantConvert : public OpRewritePattern<ConvertOp> { |
| using OpRewritePattern<ConvertOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(ConvertOp op, |
| PatternRewriter& rewriter) const override { |
| auto convertOp = op.operand().getDefiningOp<ConvertOp>(); |
| if (!convertOp) { |
| return failure(); |
| } |
| auto firstType = |
| convertOp.operand().getType().cast<TensorType>().getElementType(); |
| auto secondType = |
| op.operand().getType().cast<TensorType>().getElementType(); |
| auto thirdType = |
| op.getResult().getType().cast<TensorType>().getElementType(); |
| auto loc = rewriter.getFusedLoc({convertOp->getLoc(), op->getLoc()}); |
| if (firstType.isa<FloatType>() && secondType.isa<FloatType>() && |
| thirdType.isa<FloatType>()) { |
| // fold when the second float type's width is longer than first, |
| // like fp16 -> fp32 -> fp64, bf16 -> fp32 -> fp16 |
| if (secondType.cast<FloatType>().getWidth() > |
| firstType.cast<FloatType>().getWidth()) { |
| Value result = rewriter.create<ConvertOp>(loc, op.getResult().getType(), |
| convertOp.operand()); |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| } else if (firstType.isa<IntegerType>() && secondType.isa<IntegerType>() && |
| thirdType.isa<IntegerType>()) { |
| // fold when the second integer type's width is longer than first, |
| // like i16 -> i32 -> i64, u16 -> i32 -> u32 |
| if (secondType.cast<IntegerType>().getWidth() > |
| firstType.cast<IntegerType>().getWidth()) { |
| Value result = rewriter.create<ConvertOp>(loc, op.getResult().getType(), |
| convertOp.operand()); |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| } |
| return failure(); |
| } |
| }; |
| |
| } // namespace |
| |
| void ConvertOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<EliminateIdentityConvert>(context); |
| results.add<EliminateRedundantConvert>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // GetTupleElementOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult GetTupleElementOp::verify() { |
| auto indexVal = index(); |
| auto operandType = getOperand().getType().cast<TupleType>(); |
| if (indexVal >= operandType.size()) { |
| return emitOpError( |
| llvm::formatv("index {0} is out of bounds of operand with size {1}", |
| indexVal, operandType.size())); |
| } |
| |
| auto expectedType = operandType.getType(indexVal); |
| if (getType() != expectedType) { |
| return emitOpError(llvm::formatv("has return type {0}, but expected {1}", |
| getType(), expectedType)); |
| } |
| return success(); |
| } |
| |
| OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) { |
| if (auto tupleOp = getOperand().getDefiningOp<mhlo::TupleOp>()) { |
| return tupleOp.getOperand(index()); |
| } |
| |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TupleOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult TupleOp::verify() { |
| auto opType = getType().dyn_cast<TupleType>(); |
| if (!opType) return emitOpError("tuple op with non-tuple result"); |
| if (getNumOperands() != opType.size()) |
| return emitOpError( |
| "number of operands to tuple expected to match number of types in " |
| "resultant tuple type"); |
| for (const auto& it : |
| llvm::enumerate(llvm::zip_first(getOperandTypes(), opType.getTypes()))) { |
| if (std::get<0>(it.value()) != std::get<1>(it.value())) |
| return emitOpError("has return type mismatch at ") |
| << it.index() << "th value (" << std::get<0>(it.value()) |
| << " != " << std::get<1>(it.value()) << ")"; |
| } |
| return success(); |
| } |
| |
| namespace { |
| |
| // Pattern for unpacking and repacking the same tuple. |
| struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> { |
| using OpRewritePattern<TupleOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(TupleOp op, |
| PatternRewriter& rewriter) const override { |
| if (op.val().empty()) return failure(); |
| |
| Value firstElement = op.val().front(); |
| auto firstElementOp = firstElement.getDefiningOp<GetTupleElementOp>(); |
| if (!firstElementOp || firstElementOp.indexAttr().getInt() != 0) |
| return failure(); |
| |
| Value tuplePredecessor = firstElementOp.getOperand(); |
| if (tuplePredecessor.getType() != op.getType()) return failure(); |
| |
| for (const auto& elementAndIdx : llvm::enumerate(op.val().drop_front(1))) { |
| auto elementOp = elementAndIdx.value().getDefiningOp<GetTupleElementOp>(); |
| if (!elementOp || |
| elementOp.indexAttr().getInt() != |
| static_cast<int64_t>(elementAndIdx.index() + 1) || |
| elementOp.getOperand() != tuplePredecessor) |
| return failure(); |
| } |
| |
| rewriter.replaceOp(op, tuplePredecessor); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void TupleOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<UnpackRepackSameTuple>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AllToAllOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult AllToAllOp::inferReturnTypeComponents( |
| MLIRContext*, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| AllToAllOp::Adaptor adaptor(operands, attributes, regions); |
| Type operandType = adaptor.operand().getType(); |
| RankedTensorType operandRankedType = operandType.dyn_cast<RankedTensorType>(); |
| if (!operandRankedType) { |
| inferredReturnShapes.emplace_back( |
| operandType.cast<TensorType>().getElementType()); |
| return success(); |
| } |
| |
| int64_t inputRank = operandRankedType.getRank(); |
| int64_t splitDimension = static_cast<int64_t>(adaptor.split_dimension()); |
| int64_t concatDimension = static_cast<int64_t>(adaptor.concat_dimension()); |
| if (splitDimension >= inputRank || splitDimension < 0) { |
| return emitOptionalError(location, "AllToAll split_dimension ", |
| splitDimension, |
| " is out-of-bounds for input rank ", inputRank); |
| } |
| if (concatDimension >= inputRank || concatDimension < 0) { |
| return emitOptionalError(location, "AllToAll concat_dimension ", |
| concatDimension, |
| " is out-of-bounds for input rank ", inputRank); |
| } |
| |
| // If operand is ranked, size of split dimension should be a multiple of split |
| // count. |
| int64_t splitCount = adaptor.split_count(); |
| auto splitDimSize = operandRankedType.getDimSize(splitDimension); |
| if (splitDimSize % splitCount != 0) { |
| return emitOptionalError( |
| location, "split dimension has size ", splitDimSize, |
| ", expected to be a multiple of split_count ", splitCount); |
| } |
| SmallVector<int64_t> resultShape(operandRankedType.getShape().begin(), |
| operandRankedType.getShape().end()); |
| resultShape[splitDimension] /= splitCount; |
| resultShape[concatDimension] *= splitCount; |
| inferredReturnShapes.emplace_back(resultShape, |
| operandRankedType.getElementType()); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AllGatherOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult AllGatherOp::verify() { |
| // If operand and result are both ranked, then the size of the gather |
| // dimension in the result should be a multiple of the size of the gather |
| // dimension in the operand. |
| auto operandType = operand().getType().dyn_cast<RankedTensorType>(); |
| auto resultType = getType().dyn_cast<RankedTensorType>(); |
| uint64_t allGatherDimIndex = all_gather_dim(); |
| if (!operandType || !resultType || |
| operandType.isDynamicDim(allGatherDimIndex) || |
| resultType.isDynamicDim(allGatherDimIndex)) |
| return success(); |
| if (operandType.getDimSize(allGatherDimIndex) == 0) |
| return emitOpError() << "operand gather dimension cannot be zero."; |
| if ((resultType.getDimSize(allGatherDimIndex) % |
| operandType.getDimSize(allGatherDimIndex)) != 0) |
| return emitOpError() |
| << "result gather dimension has size " |
| << resultType.getDimSize(allGatherDimIndex) |
| << ", expected to be a multiple of operand gather dimension size " |
| << operandType.getDimSize(allGatherDimIndex); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BatchNormGradOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult BatchNormGradOp::verify() { |
| // The following properties are already enforced by the ODS: |
| // 1. Inputs 'operand' & 'grad_output' and outputs 'grad_operand', |
| // are ranked-tensors with floating-point (fp) type. |
| // 2. The shapes of inputs 'operand' & 'grad_output' match. |
| // 3. Inputs 'scale', 'mean', 'variance' and Outputs 'grad_scale', |
| // 'grad_offset' are all 1D fp tensors with same shape. |
| // 4. The element-types of input 'operand' and outputs 'grad_scale', |
| // 'grad_offset' match. |
| // 5. The type of input 'operand' and output 'grad_operand' match. |
| // |
| // We intend to verify the following properties |
| // P1. Inputs 'operand' & 'grad_output' has the same shape with fp |
| // element-types, ignoring fp-precision : Inferred from (1) & (2). |
| // P2. The feature dimension 'feature_index' is a valid index in 'operand': |
| // Inferred from check C2 below. |
| // P3. Inputs 'scale', 'mean', 'variance' must be 1D tensors with same shape |
| // and fp element-type (ignoring precision) and the number of elements |
| // in its sole-dimension == number of features in the 'operand's |
| // feature-dimension 'feature_index': Inferred from (3) and check C3 |
| // below. |
| // P4. Outputs 'grad_scale' & 'grad_offset' are 1D tensors with |
| // element-type == element-type of(operand) and same shape as any of |
| // the inputs 'scale', 'mean', or 'variance': Inferred from (3), (4) and |
| // check C3 below. |
| // P5. The type (shape + element-type) of input 'operand' and |
| // output 'grad_operand' must match: Inferred from (5). |
| |
| // C2. |
| auto operandType = operand().getType().cast<RankedTensorType>(); |
| if (static_cast<int64_t>(feature_index()) >= operandType.getRank()) |
| return emitOpError() << "expects feature_index to be smaller " |
| "than the rank of operand type; got feature_index " |
| << feature_index() << ", and rank " |
| << operandType.getRank() << "."; |
| |
| if (static_cast<int64_t>(feature_index()) < 0) |
| return emitOpError() << "expects feature_index to be a " |
| << "non-negative number, got " |
| << static_cast<int64_t>(feature_index()) << "."; |
| |
| auto gradOutputType = grad_output().getType().cast<RankedTensorType>(); |
| if (operandType.getRank() != gradOutputType.getRank()) |
| return emitOpError() << "expects 'operand' and 'grad_output' to have the " |
| "same rank. but got rank(oprand) " |
| << operandType.getRank() << " and rank(grad_output) " |
| << gradOutputType.getRank() << "."; |
| |
| // C3. |
| const int64_t featureCount = operandType.getShape()[feature_index()]; |
| const int64_t scaleShape = |
| scale().getType().cast<RankedTensorType>().getShape()[0]; |
| if (scaleShape != featureCount) |
| return emitOpError() << "expects the size of scale factor to be " |
| "same as the feature count," |
| " but the size of scale factor is " |
| << scaleShape << " and the feature count is " |
| << featureCount << "."; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BatchNormTrainingOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult BatchNormTrainingOp::verify() { |
| // The following properties are already enforced by the ODS: |
| // 1. 'operand' and 'output' are ranked tensors. |
| // 2. 'scale', 'offset', 'batch_mean', 'batch_var' are 1D tensors. |
| // 3. Types of 'operand' and 'output' matches. |
| // 4. Same element-types for 'operand', 'batch_mean', & 'batch_var'. |
| // 5. Same shapes for 'scale', 'offset', 'batch_mean', & 'batch_var'. |
| |
| auto operandType = operand().getType().cast<RankedTensorType>(); |
| if (static_cast<int64_t>(feature_index()) >= operandType.getRank()) |
| return emitOpError() << "expects feature_index to be smaller " |
| "than the rank of operand type; got feature_index " |
| << feature_index() << ", and rank " |
| << operandType.getRank() << "."; |
| |
| if (static_cast<int64_t>(feature_index()) < 0) |
| return emitOpError() << "expects feature_index to be a " |
| << "non-negative number, got " |
| << static_cast<int64_t>(feature_index()) << "."; |
| |
| // Note:A valid value of feature-index implies 'operand_type.getRank() >=1'. |
| |
| const int64_t featureCount = operandType.getShape()[feature_index()]; |
| const int64_t scaleShape = |
| scale().getType().cast<RankedTensorType>().getShape()[0]; |
| // Check number of elements in input 'scale' equals feature_count. |
| // Together with (5) implies that 'scale', 'offset', 'batch_mean', & |
| // 'batch_var' all have the same shape. |
| if (scaleShape != featureCount) |
| return emitOpError() << "expects the size of scale factor to be " |
| "same as the feature count," |
| " but the size of scale factor is " |
| << scaleShape << " and the feature count is " |
| << featureCount << "."; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BatchNormInferenceOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult BatchNormInferenceOp::verify() { |
| // The following properties are already enforced by the ODS: |
| // 1. 'operand' and 'result' are ranked tensors. |
| // 2. 'scale', 'offset', 'mean', 'variance' are 1D tensors. |
| // 3. Types of 'operand' and 'result' matches. |
| // 4. Same shapes for 'scale', 'offset', 'mean', & 'variance'. |
| |
| auto operandType = operand().getType().cast<RankedTensorType>(); |
| if (static_cast<int64_t>(feature_index()) >= operandType.getRank()) |
| return emitOpError() << "expects feature_index to be smaller " |
| "than the rank of operand type; got feature_index " |
| << feature_index() << ", and rank " |
| << operandType.getRank() << "."; |
| |
| if (static_cast<int64_t>(feature_index()) < 0) |
| return emitOpError() << "expects feature_index to be a " |
| << "non-negative number, got " |
| << static_cast<int64_t>(feature_index()) << "."; |
| |
| // Note:A valid value of feature-index implies 'operand_type.getRank() >=1'. |
| |
| const int64_t featureCount = operandType.getShape()[feature_index()]; |
| const int64_t scaleSize = |
| scale().getType().cast<RankedTensorType>().getShape()[0]; |
| // Check number of elements in input 'scale' equals feature_count. |
| // Together with (4) implies that 'scale', 'offset', 'mean', & |
| // 'variance' all have the same shape. |
| if (scaleSize != featureCount) |
| return emitOpError() << "expects the size of scale factor to be " |
| "same as the feature count," |
| " but the size of scale factor is " |
| << scaleSize << " and the feature count is " |
| << featureCount << "."; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BitcastConvertOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult BitcastConvertOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| auto operandType = operands[0].getType().dyn_cast<RankedTensorType>(); |
| auto resultType = getType().dyn_cast<RankedTensorType>(); |
| |
| // Only ranked tensors are supported. |
| if (!operandType || !resultType) return failure(); |
| |
| // Shape-changing bitcast convert is not implemented. |
| // TODO(kramerb): This could be done by adjusting the last dimension. |
| DataLayout dataLayout = DataLayout::closest(*this); |
| unsigned operandElementSize = |
| dataLayout.getTypeSizeInBits(operandType.getElementType()); |
| unsigned resultElementSize = |
| dataLayout.getTypeSizeInBits(resultType.getElementType()); |
| if (operandElementSize != resultElementSize) return failure(); |
| |
| return ::mlir::mhlo::deriveShapeFromOperand( |
| &builder, getOperation(), operands.front(), &reifiedReturnShapes); |
| } |
| |
| /* |
| * We intend to verify the following properties |
| * P1. We cannot convert between complex and real types (cf xla) |
| * P3. The dimensions of the operand and the target |
| * shape must match, except that the shape with the smaller element bitwidth has |
| * an appropriately-sized additional innermost dimension, e.g. |
| * ... x f32 => [bitcast_convert] => ... x 4 x i8 |
| * ... x 4 x i8 => [bitcast_convert] => ... x f32 |
| */ |
| LogicalResult BitcastConvertOp::verify() { |
| auto operandTensorType = operand().getType().cast<TensorType>(); |
| auto targetTensorType = getResult().getType().cast<TensorType>(); |
| |
| // P1. |
| auto targetElt = targetTensorType.getElementType(); |
| auto operandElt = operandTensorType.getElementType(); |
| if (targetElt.isa<ComplexType>() != operandElt.isa<ComplexType>()) { |
| return emitOpError() |
| << "cannot convert between real and complex types, but got: " |
| << operandTensorType << " and " << targetTensorType; |
| } |
| |
| auto targetEltBitwidth = potentiallyComplexBitwidth(targetElt); |
| auto operandEltBitwidth = potentiallyComplexBitwidth(operandElt); |
| |
| // P2. |
| auto operandType = operandTensorType.dyn_cast<RankedTensorType>(); |
| auto targetType = targetTensorType.dyn_cast<RankedTensorType>(); |
| if (!operandType || !targetType) return success(); |
| |
| auto targetShape = targetType.getShape(); |
| auto operandShape = operandType.getShape(); |
| ArrayRef<int64_t> smallerEltShape, biggerEltShape; |
| Type smallerElt, biggerElt; |
| if (operandEltBitwidth < targetEltBitwidth) { |
| smallerEltShape = operandShape; |
| smallerElt = operandElt; |
| biggerEltShape = targetShape; |
| biggerElt = targetElt; |
| } else { |
| smallerEltShape = targetShape; |
| smallerElt = targetElt; |
| biggerEltShape = operandShape; |
| biggerElt = operandElt; |
| } |
| |
| ArrayRef<int64_t> smallerEltPrefix; |
| auto smallerEltBitwidth = std::min(targetEltBitwidth, operandEltBitwidth); |
| auto biggerEltBitwidth = std::max(targetEltBitwidth, operandEltBitwidth); |
| if (operandEltBitwidth != targetEltBitwidth) { |
| if (smallerEltShape.empty()) { |
| return emitOpError() << "does not allow the smaller element type to be " |
| "part of a 0d tensor, but got: " |
| << operandType << " and " << targetType << "."; |
| } |
| smallerEltPrefix = smallerEltShape.drop_back(); |
| if (!isDynamicDimSize(smallerEltShape.back()) && |
| smallerEltShape.back() * smallerEltBitwidth != biggerEltBitwidth) { |
| return emitOpError() << "requires compatible bitwidths. " |
| << "Got: " << operandType << " and " << targetType |
| << ", but " << smallerEltBitwidth << " * " |
| << smallerEltShape.back() |
| << " != " << biggerEltBitwidth << "."; |
| } |
| } else { |
| smallerEltPrefix = smallerEltShape; |
| } |
| |
| for (auto it : llvm::zip(smallerEltPrefix, biggerEltShape)) { |
| auto targetDim = std::get<0>(it); |
| auto operandDim = std::get<1>(it); |
| if (!isDynamicDimSize(targetDim) && !isDynamicDimSize(operandDim)) { |
| if (targetDim != operandDim) { |
| return emitOpError() << "operand and result shapes must match except " |
| "for the innermost dimension of the shape with " |
| "the smaller element type. Got: " |
| << operandType << " and " << targetType << "."; |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BroadcastOp |
| //===----------------------------------------------------------------------===// |
| |
| // TODO(b/129012527) These should be expressed as type constraints. |
| LogicalResult BroadcastOp::verify() { |
| auto sizes = broadcast_sizes(); |
| auto sizesType = sizes.getType(); |
| auto sizesRank = sizesType.getRank(); |
| if (sizesRank != 1) { |
| return emitOpError(llvm::formatv( |
| "broadcast_sizes has rank {0} instead of rank 1", sizesRank)); |
| } |
| |
| return success(); |
| } |
| |
| OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> attrs) { |
| auto type = getType().cast<RankedTensorType>(); |
| auto sizesType = broadcast_sizes().getType(); |
| if (sizesType.getNumElements() == 0) { |
| return getOperand(); |
| } |
| |
| // Constant fold when an operand is a splat tensor attribute. |
| if (!attrs[0] || !type.hasStaticShape()) return {}; |
| auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>(); |
| if (!splatOperandAttr) return {}; |
| |
| // Handle complex type |
| if (type.getElementType().isa<ComplexType>()) { |
| ComplexType complex = type.getElementType().cast<ComplexType>(); |
| if (complex.getElementType().isa<FloatType>()) { |
| return DenseElementsAttr::get( |
| type, {splatOperandAttr.getSplatValue<std::complex<APFloat>>()}); |
| } |
| if (complex.getElementType().isa<IntegerType>()) { |
| return DenseElementsAttr::get( |
| type, {splatOperandAttr.getSplatValue<std::complex<APInt>>()}); |
| } |
| return {}; |
| } |
| |
| return SplatElementsAttr::get( |
| type, splatOperandAttr.getSplatValue<mlir::Attribute>()); |
| } |
| |
| LogicalResult BroadcastOp::inferReturnTypeComponents( |
| MLIRContext*, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| BroadcastOp::Adaptor adaptor(operands, attributes, regions); |
| Value operand = adaptor.operand(); |
| auto operandType = operand.getType().dyn_cast<RankedTensorType>(); |
| if (!operandType) return failure(); |
| |
| Type elementTy = operandType.getElementType(); |
| auto dimensionAttr = adaptor.broadcast_sizes(); |
| for (int64_t size : dimensionAttr.getValues<int64_t>()) { |
| if (size < 0) |
| return emitOptionalError(location, |
| "Broadcast with negative dimension size ", size); |
| } |
| SmallVector<int64_t> shapeValues(dimensionAttr.getValues<int64_t>()); |
| llvm::append_range(shapeValues, operandType.getShape()); |
| |
| inferredReturnShapes.emplace_back(shapeValues, elementTy); |
| return success(); |
| } |
| |
| LogicalResult BroadcastOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| BroadcastOp::Adaptor adaptor(operands); |
| Value operand = adaptor.operand(); |
| |
| auto operandType = operand.getType().dyn_cast<RankedTensorType>(); |
| // Unranked tensors are not supported. |
| if (!operandType) return failure(); |
| |
| Location loc = getLoc(); |
| SmallVector<Value, 4> shapeValues; |
| |
| // Collect the broadcast sizes. |
| for (const auto& size : broadcast_sizes()) { |
| shapeValues.push_back( |
| builder.create<arith::ConstantIndexOp>(loc, size.getZExtValue())); |
| } |
| |
| // Collect the operand sizes. |
| for (auto index : llvm::seq<int64_t>(0, operandType.getRank())) { |
| shapeValues.push_back( |
| builder.createOrFold<tensor::DimOp>(loc, operand, index)); |
| } |
| |
| reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>( |
| loc, |
| RankedTensorType::get({static_cast<int64_t>(shapeValues.size())}, |
| builder.getIndexType()), |
| shapeValues)); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BroadcastInDimOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult BroadcastInDimOp::verify() { |
| auto operandType = operand().getType().dyn_cast<RankedTensorType>(); |
| if (!operandType) { |
| // The following verification checks all depend on knowing the rank of |
| // the operand. Bail out now if we don't know the rank of the operand. |
| return success(); |
| } |
| |
| auto operandRank = operandType.getRank(); |
| if (!broadcast_dimensions()) { |
| if (operandRank == 0) { |
| return success(); |
| } |
| return emitOpError( |
| llvm::formatv("broadcast_dimensions is absent, but required because " |
| "operand has non-zero rank ({0})", |
| operandRank)); |
| } |
| |
| auto dimensions = broadcast_dimensions(); |
| auto dimensionsType = broadcast_dimensions().getType(); |
| auto dimensionsRank = dimensionsType.getRank(); |
| if (dimensionsRank != 1) { |
| return emitOpError(llvm::formatv( |
| "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank)); |
| } |
| |
| auto dimensionsSize = dimensionsType.getNumElements(); |
| if (dimensionsSize != operandRank) { |
| return emitOpError(llvm::formatv( |
| "broadcast_dimensions size ({0}) does not match operand rank ({1})", |
| dimensionsSize, operandRank)); |
| } |
| |
| auto resultType = getResult().getType().cast<RankedTensorType>(); |
| auto resultRank = resultType.getRank(); |
| if (resultRank < operandRank) { |
| return emitOpError( |
| llvm::formatv("result rank ({0}) is less than operand rank ({1})", |
| resultRank, operandRank)); |
| } |
| |
| for (int i = 0; i != dimensionsSize; ++i) { |
| auto dimIndex = dimensions.getValues<int64_t>()[i]; |
| if (dimIndex >= resultRank) { |
| return emitOpError( |
| llvm::formatv("broadcast_dimensions contains invalid value {0} for " |
| "result with rank {1}", |
| dimIndex, resultRank)); |
| } |
| |
| if (!operandType.isDynamicDim(i)) { |
| auto dimSize = operandType.getDimSize(i); |
| auto resultDimSize = resultType.getDimSize(dimIndex); |
| if (dimSize != 1 && dimSize != resultDimSize) { |
| return emitOpError( |
| llvm::formatv("size of operand dimension {0} ({1}) is not equal to " |
| "1 or size of result dimension {2} ({3})", |
| i, dimSize, dimIndex, resultDimSize)); |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) { |
| auto type = getType().cast<RankedTensorType>(); |
| if (type == getOperand().getType()) { |
| auto broadcastValues = broadcast_dimensions().getValues<int64_t>(); |
| if (!std::equal(broadcastValues.begin(), broadcastValues.end(), |
| llvm::seq<int64_t>(0, type.getRank()).begin())) { |
| return {}; |
| } |
| return getOperand(); |
| } |
| |
| // Constant fold when an operand is a splat tensor attribute. |
| if (!attrs[0] || !type.hasStaticShape()) return {}; |
| auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>(); |
| if (!splatOperandAttr) return {}; |
| |
| // Handle complex type |
| if (type.getElementType().isa<ComplexType>()) { |
| ComplexType complex = type.getElementType().cast<ComplexType>(); |
| if (complex.getElementType().isa<FloatType>()) { |
| return DenseElementsAttr::get( |
| type, {splatOperandAttr.getSplatValue<std::complex<APFloat>>()}); |
| } |
| if (complex.getElementType().isa<IntegerType>()) { |
| return DenseElementsAttr::get( |
| type, {splatOperandAttr.getSplatValue<std::complex<APInt>>()}); |
| } |
| return {}; |
| } |
| |
| return SplatElementsAttr::get( |
| type, splatOperandAttr.getSplatValue<mlir::Attribute>()); |
| } |
| |
| // Simplify BroadcastInDim has the following behaviors: replace BroadcastInDim |
| // with Reshape or Transpose if they are equivalent or replace |
| // BroadcastInDim(BroadcastInDim(X)) with BroadcastInDim(X) |
| class BroadcastInDimSimplifier : public OpRewritePattern<BroadcastInDimOp> { |
| public: |
| using OpRewritePattern<BroadcastInDimOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(BroadcastInDimOp op, |
| PatternRewriter& rewriter) const override { |
| auto operandType = op.operand().getType().dyn_cast<RankedTensorType>(); |
| auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>(); |
| if (!operandType || !resultType) { |
| return failure(); |
| } |
| auto bsDimIndices = op.broadcast_dimensions().getValues<int64_t>(); |
| if (operandType.hasStaticShape() && resultType.hasStaticShape()) { |
| bool sameTotalElements = |
| operandType.getNumElements() == resultType.getNumElements(); |
| // BroadcastInDim equivalent to reshape |
| if (llvm::is_sorted(bsDimIndices) && sameTotalElements) { |
| rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand()); |
| return success(); |
| } |
| // BroadcastInDim equivalent to transpose |
| if (operandType.getRank() == resultType.getRank() && sameTotalElements) { |
| rewriter.replaceOpWithNewOp<TransposeOp>(op, op.getType(), op.operand(), |
| op.broadcast_dimensions()); |
| return success(); |
| } |
| } |
| // eliminate redundant BroadcastInDim |
| if (auto broadcastInDimOp = llvm::dyn_cast_or_null<BroadcastInDimOp>( |
| op.operand().getDefiningOp())) { |
| auto newIndices = |
| broadcastInDimOp.broadcast_dimensions() |
| .mapValues(op.broadcast_dimensions().getElementType(), |
| [&bsDimIndices](const APInt& dim) -> APInt { |
| return APInt(dim.getBitWidth(), |
| bsDimIndices[dim.getSExtValue()], true); |
| }) |
| .cast<DenseIntElementsAttr>(); |
| rewriter.replaceOpWithNewOp<BroadcastInDimOp>( |
| op, op.getType(), broadcastInDimOp.operand(), newIndices); |
| return success(); |
| } |
| return failure(); |
| } |
| }; |
| |
| void BroadcastInDimOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<BroadcastInDimSimplifier>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DynamicBroadcastInDimOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult DynamicBroadcastInDimOp::verify() { |
| auto operandType = operand().getType().dyn_cast<RankedTensorType>(); |
| auto resultType = getResult().getType().dyn_cast<RankedTensorType>(); |
| |
| // If either the operand or result are unranked, there is very little |
| // to verify statically. |
| if (!operandType || !resultType) { |
| return success(); |
| } |
| |
| auto outputDimensionsType = |
| output_dimensions().getType().cast<RankedTensorType>(); |
| auto outputDimensionsSize = outputDimensionsType.getDimSize(0); |
| auto operandRank = operandType.getRank(); |
| auto resultRank = resultType.getRank(); |
| |
| // Verify broadcast_dimensions. |
| auto bcastDimensions = broadcast_dimensions(); |
| auto bcastDimensionsType = broadcast_dimensions().getType(); |
| auto bcastDimensionsRank = bcastDimensionsType.getRank(); |
| // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1. |
| if (bcastDimensionsRank != 1) { |
| return emitOpError( |
| llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1", |
| bcastDimensionsRank)); |
| } |
| |
| auto bcastDimensionsSize = bcastDimensionsType.getNumElements(); |
| if (bcastDimensionsSize != operandRank) { |
| return emitOpError(llvm::formatv( |
| "broadcast_dimensions size ({0}) does not match operand rank ({1})", |
| bcastDimensionsSize, operandRank)); |
| } |
| |
| if (resultRank < operandRank) { |
| return emitOpError( |
| llvm::formatv("result rank ({0}) is less than operand rank ({1})", |
| resultRank, operandRank)); |
| } |
| |
| for (int i = 0; i != bcastDimensionsSize; ++i) { |
| auto dimIndex = bcastDimensions.getValues<int64_t>()[i]; |
| if (dimIndex >= resultRank) { |
| return emitOpError( |
| llvm::formatv("broadcast_dimensions contains invalid value {0} for " |
| "result with rank {1}", |
| dimIndex, resultRank)); |
| } |
| |
| auto dimSize = operandType.getDimSize(i); |
| auto resultDimSize = resultType.getDimSize(dimIndex); |
| // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we |
| // add a manual check for this. |
| if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) { |
| return emitOpError( |
| llvm::formatv("size of operand dimension {0} ({1}) is not compatible " |
| "with size of result dimension {2} ({3})", |
| i, dimSize, dimIndex, resultDimSize)); |
| } |
| } |
| |
| if (outputDimensionsSize != resultRank) { |
| return emitOpError( |
| llvm::formatv("result rank ({0}) is not equal to number of output " |
| "dimensions ({1})", |
| resultRank, outputDimensionsSize)); |
| } |
| |
| // Verify that the known expanding and non-expanding dimensions are a subset |
| // of the operand's dimensions. |
| int64_t numKnownExpansionBehavior = 0; |
| DenseSet<int64_t> knownExpansionBehavior; |
| auto collectExpansionBehaviorDims = |
| [&](const Optional<DenseIntElementsAttr>& attr) { |
| if (!attr) return; |
| for (const APInt& it : *attr) { |
| numKnownExpansionBehavior++; |
| knownExpansionBehavior.insert(it.getLimitedValue()); |
| } |
| }; |
| collectExpansionBehaviorDims(known_expanding_dimensions()); |
| collectExpansionBehaviorDims(known_nonexpanding_dimensions()); |
| if (knownExpansionBehavior.size() != numKnownExpansionBehavior) { |
| return emitOpError( |
| "duplicate expansion hint for at least one operand dimension"); |
| } |
| for (int64_t i : knownExpansionBehavior) { |
| if (i < 0 || i >= operandRank) { |
| return emitOpError( |
| llvm::formatv("hint for expanding dimension {0} does not refer to a " |
| "valid operand dimension", |
| i)); |
| } |
| } |
| |
| return success(); |
| } |
| |
| namespace { |
| // If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary |
| // BroadcastInDimOp. |
| class DynamicBroadcastInDimOpNotActuallyDynamic |
| : public OpRewritePattern<DynamicBroadcastInDimOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, |
| PatternRewriter& rewriter) const override { |
| auto type = op.getType().dyn_cast<RankedTensorType>(); |
| auto operandType = op.operand().getType().dyn_cast<RankedTensorType>(); |
| auto* outputDimOp = op.output_dimensions().getDefiningOp(); |
| if (!type || !operandType || !operandType.hasStaticShape()) { |
| return rewriter.notifyMatchFailure(op, "requires operand static shape"); |
| } |
| // output has static shape, replace with broadcast_in_dim |
| if (type.hasStaticShape()) { |
| rewriter.replaceOpWithNewOp<BroadcastInDimOp>(op, type, op.operand(), |
| op.broadcast_dimensions()); |
| return success(); |
| } |
| // output_dimensions are constant, set output shape with output_dimensions, |
| // then replace with broadcast_in_dim |
| if (outputDimOp && outputDimOp->hasTrait<mlir::OpTrait::ConstantLike>()) { |
| DenseIntElementsAttr shapeAttr; |
| if (matchPattern(outputDimOp, m_Constant(&shapeAttr))) { |
| SmallVector<int64_t> outputShape; |
| for (APInt shape : shapeAttr.getValues<APInt>()) { |
| outputShape.push_back(shape.getZExtValue()); |
| } |
| refineOpWithNewOp<BroadcastInDimOp>( |
| rewriter, op, |
| RankedTensorType::get(outputShape, type.getElementType()), |
| op.operand(), op.broadcast_dimensions()); |
| return success(); |
| } |
| } |
| return rewriter.notifyMatchFailure( |
| op, "requires output static shape or constant broadcast dimensions"); |
| } |
| }; |
| |
| class ChainedDynamicBroadcastInDimCanonicalization |
| : public OpRewritePattern<DynamicBroadcastInDimOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast, |
| PatternRewriter& rewriter) const override { |
| auto precedingBcast = |
| bcast.operand().getDefiningOp<DynamicBroadcastInDimOp>(); |
| if (!precedingBcast) return failure(); |
| |
| // Compose broadcast dimensions. |
| DenseIntElementsAttr precedingBcastDims = |
| precedingBcast.broadcast_dimensions(); |
| DenseIntElementsAttr bcastDims = bcast.broadcast_dimensions(); |
| SmallVector<APInt, 4> composition; |
| for (APInt precedingDim : precedingBcastDims) { |
| composition.push_back( |
| bcastDims.getValues<APInt>()[precedingDim.getZExtValue()]); |
| } |
| auto composedBcastDims = |
| DenseIntElementsAttr::get(precedingBcastDims.getType(), composition); |
| |
| rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>( |
| bcast, bcast.getType(), precedingBcast.operand(), |
| bcast.output_dimensions(), composedBcastDims); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void DynamicBroadcastInDimOp::getCanonicalizationPatterns( |
| RewritePatternSet& results, MLIRContext* context) { |
| results.add<ChainedDynamicBroadcastInDimCanonicalization, |
| DynamicBroadcastInDimOpNotActuallyDynamic, |
| DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2, |
| DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>( |
| context); |
| } |
| |
| LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| DynamicBroadcastInDimOp::Adaptor adaptor(operands); |
| reifiedReturnShapes.push_back( |
| castToIndexTensor(builder, getLoc(), adaptor.output_dimensions())); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ClampOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ClampOp::verify() { |
| auto operandType = operand().getType().cast<RankedTensorType>(); |
| auto operandShape = operandType.getShape(); |
| auto minType = min().getType().cast<RankedTensorType>(); |
| |
| auto minShape = minType.getShape(); |
| if (failed(verifyCompatibleShape(minType, operandType)) && |
| minType.getRank() != 0) { |
| return emitOpError(llvm::formatv( |
| "min shape [{0}] is not scalar and is not compatible to operand shape " |
| "[{1}]", |
| llvm::make_range(minShape.begin(), minShape.end()), |
| llvm::make_range(operandShape.begin(), operandShape.end()))); |
| } |
| |
| auto maxType = max().getType().cast<RankedTensorType>(); |
| auto maxShape = maxType.getShape(); |
| if (failed(verifyCompatibleShape(maxType, operandType)) && |
| maxType.getRank() != 0) { |
| return emitOpError(llvm::formatv( |
| "max shape [{0}] is not scalar and is not compatible to operand shape " |
| "[{1}]", |
| llvm::make_range(maxShape.begin(), maxShape.end()), |
| llvm::make_range(operandShape.begin(), operandShape.end()))); |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult ClampOp::inferReturnTypeComponents( |
| MLIRContext*, Optional<Location> /*location*/, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| ClampOp::Adaptor adaptor(operands, attributes, regions); |
| RankedTensorType operandType = |
| adaptor.operand().getType().cast<RankedTensorType>(); |
| inferredReturnShapes.emplace_back(operandType.getShape(), |
| operandType.getElementType()); |
| return success(); |
| } |
| |
| LogicalResult ClampOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| // For `mhlo.clamp`, the first operand may be a scalar. |
| return deriveShapeFromOperand(&builder, getOperation(), operands[1], |
| &reifiedReturnShapes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ComplexOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ComplexOp::inferReturnTypes( |
| MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr, |
| RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) { |
| TensorType operandType = operands[0].getType().cast<TensorType>(); |
| ComplexType elementTy = ComplexType::get(operandType.getElementType()); |
| inferredReturnTypes.push_back(getSameShapeTensorType(operandType, elementTy)); |
| return success(); |
| } |
| |
| OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) { |
| auto realOp = getOperand(0).getDefiningOp<mhlo::RealOp>(); |
| auto imagOp = getOperand(1).getDefiningOp<mhlo::ImagOp>(); |
| if (realOp && imagOp && realOp.getOperand() == imagOp.getOperand()) { |
| return realOp.getOperand(); |
| } |
| |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ImagOp |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| Type createRealType(TensorType type) { |
| auto elementTy = type.getElementType(); |
| if (auto complexTy = elementTy.dyn_cast<ComplexType>()) { |
| elementTy = complexTy.getElementType(); |
| } |
| return getSameShapeTensorType(type, elementTy); |
| } |
| } // namespace |
| |
| LogicalResult ImagOp::inferReturnTypes( |
| MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr, |
| RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) { |
| inferredReturnTypes.push_back( |
| createRealType(operands[0].getType().cast<TensorType>())); |
| return success(); |
| } |
| |
| OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) { |
| if (auto complexOp = getOperand().getDefiningOp<mhlo::ComplexOp>()) { |
| return complexOp.getOperand(1); |
| } |
| |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // IsFiniteOp |
| //===----------------------------------------------------------------------===// |
| |
| TensorType getSameShapeTensorType(TensorType tensorType, Type elementType) { |
| if (auto rankedTensorTy = tensorType.dyn_cast<RankedTensorType>()) { |
| return RankedTensorType::get(rankedTensorTy.getShape(), elementType, |
| rankedTensorTy.getEncoding()); |
| } |
| if (auto unrankedTensorTy = tensorType.dyn_cast<UnrankedTensorType>()) { |
| return UnrankedTensorType::get(elementType); |
| } |
| llvm_unreachable("unhandled type"); |
| } |
| |
| LogicalResult IsFiniteOp::inferReturnTypes( |
| MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr, |
| RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) { |
| auto argTy = operands.front().getType().cast<TensorType>(); |
| Builder b(ctx); |
| inferredReturnTypes.push_back(getSameShapeTensorType(argTy, b.getI1Type())); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RealOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult RealOp::inferReturnTypes( |
| MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr, |
| RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) { |
| inferredReturnTypes.push_back( |
| createRealType(operands[0].getType().cast<TensorType>())); |
| return success(); |
| } |
| |
| OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) { |
| if (auto complexOp = getOperand().getDefiningOp<mhlo::ComplexOp>()) { |
| return complexOp.getOperand(0); |
| } |
| |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ConcatenateOp |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| class SingleOperandConcatenateToCast : public OpRewritePattern<ConcatenateOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(ConcatenateOp op, |
| PatternRewriter& rewriter) const override { |
| if (op.val().size() != 1) return failure(); |
| |
| rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), |
| op.val().front()); |
| return success(); |
| } |
| }; |
| |
| class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(ConcatenateOp op, |
| PatternRewriter& rewriter) const override { |
| auto axis = op.dimension(); |
| llvm::SmallVector<Value, 6> newOperands; |
| for (auto operand : op.getOperands()) { |
| auto ty = operand.getType().cast<ShapedType>(); |
| if (!ty.hasRank() || ty.getDimSize(axis) != 0) { |
| newOperands.push_back(operand); |
| } |
| } |
| |
| if (!newOperands.empty() && newOperands.size() < op.getNumOperands()) { |
| rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(), |
| newOperands, op.dimension()); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| }; |
| |
| class ConcatenateForwarding : public OpRewritePattern<ConcatenateOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(ConcatenateOp op, |
| PatternRewriter& rewriter) const override { |
| auto getFlattenedOperands = [&](const Value& val) -> ValueRange { |
| auto definingOp = dyn_cast_or_null<ConcatenateOp>(val.getDefiningOp()); |
| // To avoid inflate the memory footprint, only flatten the ConcatenateOp |
| // when it has only one use. |
| if (definingOp && definingOp->hasOneUse() && |
| definingOp.dimension() == op.dimension()) |
| return definingOp.val(); |
| return val; |
| }; |
| |
| bool needToFlatten = false; |
| int operandCount = 0; |
| llvm::for_each(op.val(), [&](Value val) { |
| auto result = getFlattenedOperands(val); |
| if (result.size() != 1 || result[0] != val) needToFlatten = true; |
| operandCount += result.size(); |
| }); |
| |
| if (!needToFlatten) return failure(); |
| |
| llvm::SmallVector<Value, 6> newOperands; |
| newOperands.reserve(operandCount); |
| |
| for (auto operand : op.val()) { |
| auto flattenedOperands = getFlattenedOperands(operand); |
| newOperands.append(flattenedOperands.begin(), flattenedOperands.end()); |
| } |
| |
| rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(), |
| newOperands, op.dimension()); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| LogicalResult ConcatenateOp::inferReturnTypes( |
| MLIRContext*, Optional<Location> location, ValueRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<Type>& inferredReturnTypes) { |
| if (operands.empty()) { |
| return failure(); |
| } |
| |
| auto dimensionAttr = attributes.get("dimension").cast<IntegerAttr>(); |
| auto dimension = dimensionAttr.getInt(); |
| |
| auto firstType = (*operands.begin()).getType().cast<ShapedType>(); |
| auto outElement = firstType.getElementType(); |
| |
| // Find the first ranked input to determine the output rank. |
| for (auto type : operands.getTypes()) { |
| auto shapedType = type.cast<ShapedType>(); |
| if (shapedType.hasRank()) { |
| firstType = shapedType; |
| break; |
| } |
| } |
| |
| // If all inputs are unranked, the result must be unranked. |
| if (!firstType.hasRank()) { |
| inferredReturnTypes.push_back(UnrankedTensorType::get(outElement)); |
| return success(); |
| } |
| |
| auto outShape = llvm::to_vector<6>(firstType.getShape()); |
| |
| // Determine what the non-concatenate dimensions should be. |
| for (auto type : operands.getTypes()) { |
| auto shapedTy = type.cast<ShapedType>(); |
| if (!shapedTy.hasRank()) { |
| continue; |
| } |
| |
| for (const auto& it : llvm::enumerate(shapedTy.getShape())) { |
| // If a dimension is not dynamic, the output shape should match. |
| if (ShapedType::isDynamic(outShape[it.index()])) { |
| outShape[it.index()] = it.value(); |
| } |
| } |
| } |
| |
| outShape[dimension] = 0; |
| |
| for (auto operand : operands.getTypes()) { |
| auto type = operand.cast<ShapedType>(); |
| if (!type.hasRank()) { |
| inferredReturnTypes.push_back(UnrankedTensorType::get(outElement)); |
| return success(); |
| } |
| |
| // If the dimension is dynamic we know the output dimension is dynamic. |
| auto dim = type.getShape()[dimension]; |
| if (ShapedType::isDynamic(dim)) { |
| outShape[dimension] = ShapedType::kDynamicSize; |
| break; |
| } |
| |
| outShape[dimension] += dim; |
| } |
| |
| inferredReturnTypes.push_back(RankedTensorType::get(outShape, outElement)); |
| |
| return success(); |
| } |
| |
| void ConcatenateOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<ConcatenateOperandRemoval, ConcatenateForwarding, |
| SingleOperandConcatenateToCast>(context); |
| } |
| |
| template <typename T> |
| static Attribute foldConcatenateHelper(ConcatenateOp* op, |
| ArrayRef<Attribute> operands) { |
| auto axis = op->dimension(); |
| auto type = op->getType().cast<ShapedType>(); |
| auto shape = type.getShape(); |
| |
| size_t topSize = 1; |
| for (int i = 0, e = axis; i < e; i++) { |
| topSize = topSize * shape[i]; |
| } |
| |
| // Prevent folding if the result is too large. |
| if (type.getNumElements() > kFoldOpEltLimit) return {}; |
| |
| SmallVector<T, 6> values; |
| for (size_t i = 0; i < topSize; i++) { |
| for (auto operand : operands) { |
| DenseElementsAttr attr = operand.cast<DenseElementsAttr>(); |
| size_t bottomSize = attr.getNumElements() / topSize; |
| auto iter = attr.getValues<T>().begin() + i * bottomSize; |
| values.append(iter, iter + bottomSize); |
| } |
| } |
| |
| return DenseElementsAttr::get(type, values); |
| } |
| |
| static Attribute foldConcatenate(ConcatenateOp* op, |
| ArrayRef<Attribute> operands) { |
| for (auto operand : operands) { |
| if (!operand) return {}; |
| } |
| |
| auto type = op->getResult().getType().cast<ShapedType>(); |
| auto etype = type.getElementType(); |
| if (etype.isa<IntegerType>()) { |
| return foldConcatenateHelper<APInt>(op, operands); |
| } |
| |
| if (etype.isa<FloatType>()) { |
| return foldConcatenateHelper<APFloat>(op, operands); |
| } |
| |
| return {}; |
| } |
| |
| OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) { |
| if (getNumOperands() == 1) return getOperand(0); |
| |
| ShapedType type = getResult().getType().cast<ShapedType>(); |
| if (!type.hasStaticShape()) return {}; |
| |
| auto axis = dimension(); |
| if (auto attr = foldConcatenate(this, operands)) { |
| return attr; |
| } |
| |
| for (auto operand : getOperands()) { |
| auto ty = operand.getType().cast<ShapedType>(); |
| if (ty.getDimSize(axis) != 0) { |
| return {}; |
| } |
| } |
| |
| return DenseElementsAttr::get(type, ArrayRef<Attribute>()); |
| } |
| |
| LogicalResult ConcatenateOp::verify() { |
| RankedTensorType firstRankedType; |
| int firstRankedIndex; |
| int numOperands = getNumOperands(); |
| int64_t concatDimension = static_cast<int64_t>(dimension()); |
| if (concatDimension < 0) { |
| return emitOpError( |
| llvm::formatv("dimension {0} is negative", concatDimension)); |
| } |
| for (int i = 0; i < numOperands; i++) { |
| auto secondType = getOperand(i).getType().dyn_cast<ShapedType>(); |
| if (!secondType.hasRank()) { |
| continue; |
| } |
| |
| if (!firstRankedType) { |
| firstRankedType = secondType.cast<RankedTensorType>(); |
| firstRankedIndex = i; |
| if (firstRankedType.getRank() == 0) |
| return emitOpError( |
| llvm::formatv("rank-0 values cannot be concatenated")); |
| if (concatDimension >= firstRankedType.getRank()) { |
| return emitOpError( |
| llvm::formatv("dimension {0} is out-of-bounds for input rank {1}", |
| concatDimension, firstRankedType.getRank())); |
| } |
| continue; |
| } |
| |
| if (firstRankedType.getRank() != secondType.getRank()) { |
| return emitOpError(llvm::formatv( |
| "operands ({0}) and ({1}) do not match rank", firstRankedIndex, i)); |
| } |
| |
| auto firstShape = firstRankedType.getShape(); |
| auto secondShape = secondType.getShape(); |
| for (int d = 0; d < firstRankedType.getRank(); ++d) { |
| if (!ShapedType::isDynamic(firstShape[d]) && |
| !ShapedType::isDynamic(secondShape[d]) && |
| firstShape[d] != secondShape[d] && d != concatDimension) { |
| return emitOpError(llvm::formatv( |
| "shapes of operand ({0}) and ({1}) do not match at non-concat " |
| "index: ({2}) != ({3}) at non-concat index {4}", |
| firstRankedIndex, i, |
| llvm::make_range(firstShape.begin(), firstShape.end()), |
| llvm::make_range(secondShape.begin(), secondShape.end()), d)); |
| } |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult ConcatenateOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| ConcatenateOp::Adaptor adaptor(operands); |
| auto inputs = adaptor.val(); |
| |
| auto operandType = inputs[0].getType().dyn_cast<RankedTensorType>(); |
| // Not support unranked type a.t.m. |
| if (!operandType) return failure(); |
| |
| Location loc = this->getLoc(); |
| Type shapeScalarType = builder.getIndexType(); |
| auto toShapeScalarType = [&](Value v) { |
| return maybeCastTo(builder, loc, v, shapeScalarType); |
| }; |
| |
| SmallVector<SmallVector<Value, 4>, 4> allShapeValues; |
| for (size_t inputId = 0; inputId < inputs.size(); ++inputId) { |
| Value operand = inputs[inputId]; |
| auto operandType = operand.getType().dyn_cast<RankedTensorType>(); |
| if (!operandType) return failure(); |
| |
| SmallVector<Value, 4> shapeVals; |
| for (const auto& element : llvm::enumerate(operandType.getShape())) { |
| Value valueDim = toShapeScalarType( |
| builder.create<tensor::DimOp>(loc, operand, element.index())); |
| shapeVals.push_back(valueDim); |
| } |
| allShapeValues.emplace_back(std::move(shapeVals)); |
| } |
| |
| int axis = this->dimension(); |
| auto& shapeValues = allShapeValues[0]; |
| for (size_t vecId = 1; vecId < allShapeValues.size(); ++vecId) { |
| auto& otherShapeValues = allShapeValues[vecId]; |
| if (otherShapeValues.size() != shapeValues.size()) { |
| this->emitOpError() |
| << "Concatenate expects all operands must be of the same rank"; |
| return failure(); |
| } |
| shapeValues[axis] = builder.create<arith::AddIOp>(loc, shapeValues[axis], |
| otherShapeValues[axis]); |
| } |
| |
| Value outputShape = builder.create<tensor::FromElementsOp>( |
| loc, |
| RankedTensorType::get({static_cast<int64_t>(shapeValues.size())}, |
| shapeScalarType), |
| shapeValues); |
| reifiedReturnShapes.push_back(outputShape); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DynamicReshapeOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult DynamicReshapeOp::verify() { |
| auto resultType = result().getType().dyn_cast<RankedTensorType>(); |
| auto outputShapeType = output_shape().getType().dyn_cast<RankedTensorType>(); |
| if (resultType && outputShapeType && outputShapeType.hasStaticShape() && |
| outputShapeType.getDimSize(0) != resultType.getRank()) { |
| return emitError() << "output should have a rank equal to the number of " |
| "elements in output_shape"; |
| } |
| return success(); |
| } |
| |
| LogicalResult DynamicReshapeOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| DynamicReshapeOp::Adaptor adaptor(operands); |
| reifiedReturnShapes.push_back( |
| castToIndexTensor(builder, getLoc(), adaptor.output_shape())); |
| return success(); |
| } |
| |
| namespace { |
| class DynamicReshapeOpNotActuallyDynamic |
| : public OpRewritePattern<DynamicReshapeOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(DynamicReshapeOp op, |
| PatternRewriter& rewriter) const override { |
| auto type = op.result().getType().dyn_cast<RankedTensorType>(); |
| if (!type || !type.hasStaticShape()) { |
| return rewriter.notifyMatchFailure(op, "requires static shape tensor"); |
| } |
| rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand()); |
| return success(); |
| } |
| }; |
| |
| // Canonicalizes |
| // %0 = some_op(%tensor) |
| // %1 = "mhlo.dynamic_reshape"(%0, %shape) |
| // (tensor<?xT>, tensor<1xindex>) -> tensor<?xT> |
| // ... uses of %1. |
| // |
| // into |
| // |
| // ... uses of %0. |
| // This canonicalization is only correct if the input is correct! |
| // TODO(b/178779691): Use a more sophisticated canonicalization that preserves |
| // errors in input, and still allows us to get rid of redundant reshapes. |
| class RemoveRedundantRank1DynamicReshape |
| : public OpRewritePattern<DynamicReshapeOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(DynamicReshapeOp op, |
| PatternRewriter& rewriter) const override { |
| auto type = op.result().getType().dyn_cast<RankedTensorType>(); |
| if (!type || type.getRank() != 1 || type.hasStaticShape()) { |
| return rewriter.notifyMatchFailure( |
| op, "requires rank 1 shape tensor with dynamic dimension"); |
| } |
| auto operandType = op.operand().getType().dyn_cast<RankedTensorType>(); |
| if (!operandType || operandType.getRank() != 1 || |
| operandType.hasStaticShape()) { |
| return rewriter.notifyMatchFailure( |
| op, "requires rank 1 shape tensor with dynamic dimension"); |
| } |
| rewriter.replaceOp(op, {op.operand()}); |
| return success(); |
| } |
| }; |
| |
| // Canonicalizes |
| // %0 = "mhlo.dynamic_reshape"(%tensor, %shape) |
| // %1 = same_operands_and_result_shape_op(%tensor) |
| // %2 = "mhlo.dynamic_reshape"(%1, %shape) |
| // ... uses of %2. |
| // |
| // into |
| // |
| // %0 = "mhlo.dynamic_reshape"(%tensor, %shape) |
| // %1 = same_operands_and_result_shape_op(%tensor) |
| // ... uses of %1. |
| class DynamicReshapeOpSameShapeOpResult |
| : public OpRewritePattern<DynamicReshapeOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(DynamicReshapeOp op, |
| PatternRewriter& rewriter) const override { |
| Operation* defOp = op.operand().getDefiningOp(); |
| if (!defOp || |
| !defOp->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) { |
| return failure(); |
| } |
| Operation* inputDefOp = defOp->getOperand(0).getDefiningOp(); |
| if (!inputDefOp) { |
| return failure(); |
| } |
| auto reshape = dyn_cast<DynamicReshapeOp>(*inputDefOp); |
| if (reshape && reshape.output_shape() == op.output_shape()) { |
| rewriter.replaceOp(op, {defOp->getResult(0)}); |
| return success(); |
| } |
| return failure(); |
| } |
| }; |
| } // namespace |
| |
| void DynamicReshapeOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| // clang-format off |
| results.add< |
| DynamicReshapeOpNotActuallyDynamic, |
| DynamicReshapeOpSameShapeOpResult, |
| RemoveRedundantDynamicBroadcast, |
| RemoveRedundantDynamicReshape, |
| RemoveRedundantRank1DynamicReshape, |
| ShapeOfDynamicReshape |
| >(context); |
| // clang-format on |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DynamicSliceOp |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| // Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops. |
| // This canonicalization is applied the case when the `begin` input values are |
| // compile time constants and thus can be made into a tensor. |
| struct DynamicSliceToSlice : public OpRewritePattern<DynamicSliceOp> { |
| using OpRewritePattern<DynamicSliceOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(DynamicSliceOp dynamicSlice, |
| PatternRewriter& rewriter) const override { |
| Value input = dynamicSlice.operand(); |
| auto inputTensor = input.getType().dyn_cast<RankedTensorType>(); |
| if (!inputTensor || !inputTensor.hasStaticShape()) return failure(); |
| |
| auto sliceSizes = dynamicSlice.slice_sizes().getValues<int64_t>(); |
| SmallVector<int64_t, 4> tempStartIndices; |
| for (const auto& indexAndSliceStart : |
| llvm::enumerate(dynamicSlice.start_indices())) { |
| APInt val; |
| Value start = indexAndSliceStart.value(); |
| int64_t index = indexAndSliceStart.index(); |
| if (!matchPattern(start, m_ConstantInt(&val))) { |
| return failure(); |
| } |
| // Clamp the indices within bounds to faithfully mirror dynamic slice |
| // semantics. |
| int64_t clampedStart = |
| clamp(val.getSExtValue(), static_cast<int64_t>(0), |
| inputTensor.getDimSize(index) - sliceSizes[index]); |
| tempStartIndices.push_back(clampedStart); |
| } |
| |
| // At this point we've determined that the start indices are all constants; |
| // pack them into a single tensor. |
| auto loc = dynamicSlice.getLoc(); |
| int64_t inputRank = inputTensor.getRank(); |
| auto sliceStartIndices = rewriter.getI64TensorAttr(tempStartIndices); |
| DenseIntElementsAttr sliceLimits = buildSliceLimits( |
| sliceStartIndices, dynamicSlice.slice_sizes(), &rewriter); |
| DenseIntElementsAttr sliceStrides = |
| rewriter.getI64TensorAttr(SmallVector<int64_t, 4>(inputRank, 1)); |
| auto result = rewriter.create<SliceOp>(loc, input, sliceStartIndices, |
| sliceLimits, sliceStrides); |
| rewriter.replaceOp(dynamicSlice, {result}); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void DynamicSliceOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<DynamicSliceToSlice>(context); |
| } |
| |
| // Verifies that the number of slice sizes and the number of start indices match |
| LogicalResult DynamicSliceOp::verify() { |
| int numSliceSizes = slice_sizes().getNumElements(); |
| int numStartIndices = start_indices().size(); |
| if (numStartIndices != numSliceSizes) { |
| return emitOpError() << "has mismatched number of slice sizes (" |
| << numSliceSizes << ") and number of start indices (" |
| << numStartIndices << ")"; |
| } |
| auto operandType = operand().getType().dyn_cast<RankedTensorType>(); |
| if (!operandType) return failure(); |
| |
| if (operandType.getRank() != numStartIndices) { |
| return emitOpError() << "has mismatched number of start indices (" |
| << numStartIndices << ") and the rank of operand (" |
| << operandType.getRank() << ")"; |
| } |
| |
| for (int i = 0; i < numSliceSizes; ++i) { |
| int64_t sliceSize = slice_sizes().getValues<int64_t>()[i]; |
| if (sliceSize < 0) { |
| return emitOpError() << "has negative size index to dynamic slice: " |
| << sliceSize; |
| } |
| if (!operandType.isDynamicDim(i)) { |
| int64_t dimSize = operandType.getDimSize(i); |
| if (sliceSize > dimSize) { |
| return emitOpError() << "has slice size " << sliceSize |
| << " greater than dimension size " << dimSize |
| << " in dimension " << i << " of operand"; |
| } |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult DynamicSliceOp::inferReturnTypeComponents( |
| MLIRContext*, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| DynamicSliceOp::Adaptor adaptor(operands, attributes, regions); |
| Value operand = adaptor.operand(); |
| auto operandType = operand.getType().dyn_cast<RankedTensorType>(); |
| if (!operandType) return failure(); |
| |
| auto sliceSizes = adaptor.slice_sizes(); |
| Type elementTy = operandType.getElementType(); |
| inferredReturnShapes.emplace_back(sliceSizes.getValues<int64_t>(), elementTy); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RealDynamicSliceOp |
| //===----------------------------------------------------------------------===// |
| // Verifies that operand rank matches start_indices/limit_indices/strides size |
| LogicalResult RealDynamicSliceOp::verify() { |
| auto inputType = operand().getType().dyn_cast<RankedTensorType>(); |
| // If operand is unranked, there is very little to verify statically. |
| if (!inputType) return success(); |
| int inputRank = inputType.getRank(); |
| |
| auto startType = start_indices().getType().cast<RankedTensorType>(); |
| auto limitType = limit_indices().getType().cast<RankedTensorType>(); |
| auto stridesType = strides().getType().cast<RankedTensorType>(); |
| |
| if (inputRank != startType.getNumElements()) { |
| return emitOpError() << "has mismatched number of operand rank (" |
| << inputRank << ") and start_indices size (" |
| << startType.getNumElements() << ")"; |
| } |
| |
| if (inputRank != limitType.getNumElements()) { |
| return emitOpError() << "has mismatched number of operand rank (" |
| << inputRank << ") and limit_indices size (" |
| << limitType.getNumElements() << ")"; |
| } |
| |
| if (inputRank != stridesType.getNumElements()) { |
| return emitOpError() << "has mismatched number of operand rank (" |
| << inputRank << ") and strides size (" |
| << stridesType.getNumElements() << ")"; |
| } |
| |
| return success(); |
| } |
| |
| namespace { |
| // Canonicalizes RealDynamicSlice ops that can be replaced instead with Slice |
| // ops. This canonicalization is applied the case when the `begin` input values |
| // are compile time constants and thus can be made into a tensor. |
| struct RealDynamicSliceIsStatic : public OpRewritePattern<RealDynamicSliceOp> { |
| using OpRewritePattern<RealDynamicSliceOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(RealDynamicSliceOp realDynamicSlice, |
| PatternRewriter& rewriter) const override { |
| Location loc = realDynamicSlice.getLoc(); |
| Value input = realDynamicSlice.operand(); |
| Value output = realDynamicSlice.result(); |
| auto inputTy = input.getType().dyn_cast<RankedTensorType>(); |
| auto outputTy = output.getType().dyn_cast<RankedTensorType>(); |
| |
| if (!inputTy || !outputTy || !inputTy.hasStaticShape() || |
| !outputTy.hasStaticShape()) { |
| return failure(); |
| } |
| |
| int64_t inputRank = inputTy.getRank(); |
| |
| auto startVal = realDynamicSlice.start_indices(); |
| auto limitVal = realDynamicSlice.limit_indices(); |
| auto strideVal = realDynamicSlice.strides(); |
| auto startOp = startVal.getDefiningOp<mlir::arith::ConstantOp>(); |
| auto limitOp = limitVal.getDefiningOp<mlir::arith::ConstantOp>(); |
| auto strideOp = strideVal.getDefiningOp<mlir::arith::ConstantOp>(); |
| if (!startOp || !limitOp || !strideOp) return failure(); |
| |
| auto startAttr = |
| startOp.getValue().dyn_cast_or_null<DenseIntElementsAttr>(); |
| auto limitAttr = |
| limitOp.getValue().dyn_cast_or_null<DenseIntElementsAttr>(); |
| auto strideAttr = |
| strideOp.getValue().dyn_cast_or_null<DenseIntElementsAttr>(); |
| if (!startAttr || !limitAttr || !strideAttr) return failure(); |
| |
| SmallVector<int64_t, 4> tempStartIndices; |
| SmallVector<int64_t, 4> tempLimitIndices; |
| SmallVector<int64_t, 4> tempStride; |
| for (int64_t dimIdx = 0; dimIdx < inputRank; dimIdx++) { |
| int64_t start = startAttr.getValues<IntegerAttr>()[dimIdx].getInt(); |
| tempStartIndices.push_back(start); |
| int64_t limit = limitAttr.getValues<IntegerAttr>()[dimIdx].getInt(); |
| tempLimitIndices.push_back(limit); |
| int64_t end = strideAttr.getValues<IntegerAttr>()[dimIdx].getInt(); |
| tempStride.push_back(end); |
| } |
| |
| DenseIntElementsAttr sliceStartIndices = |
| rewriter.getI64TensorAttr(tempStartIndices); |
| DenseIntElementsAttr sliceLimitIndices = |
| rewriter.getI64TensorAttr(tempLimitIndices); |
| DenseIntElementsAttr sliceStrides = rewriter.getI64TensorAttr(tempStride); |
| auto result = rewriter.create<SliceOp>(loc, input, sliceStartIndices, |
| sliceLimitIndices, sliceStrides); |
| rewriter.replaceOp(realDynamicSlice, {result}); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void RealDynamicSliceOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<RealDynamicSliceIsStatic, RealDSliceToSlice>(context); |
| } |
| |
| LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| RealDynamicSliceOp::Adaptor adaptor(operands); |
| Value operand = adaptor.operand(); |
| Value startIndices = adaptor.start_indices(); |
| Value limitIndices = adaptor.limit_indices(); |
| Value strides = adaptor.strides(); |
| |
| auto operandType = operand.getType().dyn_cast<RankedTensorType>(); |
| // Not support unranked type a.t.m. |
| if (!operandType) return failure(); |
| |
| Location loc = this->getLoc(); |
| SmallVector<Value, 4> shapeValues; |
| shapeValues.reserve(operandType.getRank()); |
| Type shapeScalarType = |
| startIndices.getType().cast<ShapedType>().getElementType(); |
| Value one = builder.create<arith::ConstantIndexOp>(loc, 1); |
| one = maybeCastTo(builder, loc, one, shapeScalarType); |
| for (const auto& element : llvm::enumerate(operandType.getShape())) { |
| Value offset = builder.create<arith::ConstantIndexOp>(loc, element.index()); |
| Value valueStart = |
| builder.create<tensor::ExtractOp>(loc, startIndices, offset); |
| Value valueLimit = |
| builder.create<tensor::ExtractOp>(loc, limitIndices, offset); |
| Value valueStride = builder.create<tensor::ExtractOp>(loc, strides, offset); |
| // size = (limit - start + stride - 1) / stride |
| shapeValues.push_back(builder.create<arith::DivSIOp>( |
| loc, |
| builder.create<arith::SubIOp>( |
| loc, |
| builder.create<arith::AddIOp>( |
| loc, valueStride, |
| builder.create<arith::SubIOp>(loc, valueLimit, valueStart)), |
| one), |
| valueStride)); |
| } |
| |
| reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>( |
| loc, |
| RankedTensorType::get({static_cast<int64_t>(shapeValues.size())}, |
| shapeScalarType), |
| shapeValues)); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // InfeedOp |
| //===----------------------------------------------------------------------===// |
| |
| // Checks that the result type is of the form `zero_or_more_type(s), |
| // mhlo::token` |
| LogicalResult InfeedOp::verify() { |
| auto resultTypes = getResultTypes(); |
| if (resultTypes.empty()) |
| return emitOpError() |
| << "result is expected to be at least of size 1, but got " |
| << resultTypes.size(); |
| |
| if (!resultTypes[resultTypes.size() - 1].isa<TokenType>()) |
| return emitOpError() << "last element of result types is expected to " |
| "be of token type, but got " |
| << resultTypes[resultTypes.size() - 1]; |
| |
| // Verify layout attribute |
| constexpr char kLayoutAttr[] = "layout"; |
| if (!getOperation()->hasAttr(kLayoutAttr)) return success(); |
| |
| mlir::ArrayAttr layout = |
| getOperation()->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr); |
| if (!layout) |
| return emitOpError() << "layout-attribute expected to be of array-type."; |
| |
| if (layout.size() != resultTypes.size() - 1) { |
| return emitOpError() << "layout-attribute size must be " |
| << resultTypes.size() - 1 |
| << " (which is the number of " |
| "op-results - 1 (for token result)), but got " |
| << layout.size(); |
| } |
| |
| for (auto childLayout : layout) { |
| mlir::ArrayAttr childLayoutArr = childLayout.dyn_cast<mlir::ArrayAttr>(); |
| if (!childLayoutArr) { |
| return emitOpError() << "layout-attribute expected to have " |
| "elements of type array, but got " |
| << childLayout; |
| } |
| |
| for (auto i : childLayoutArr) { |
| mlir::IntegerAttr attr = i.dyn_cast<mlir::IntegerAttr>(); |
| if (!attr) { |
| return emitOpError() << "layout-attribute's leaf elements are " |
| "expected to be of type integer, but got " |
| << i; |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MapOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult MapOp::verify() { |
| // Checks if the number of `operands` match the arity of the map `computation` |
| // region. |
| auto& computationBlock = computation().front(); |
| auto computationArgs = computationBlock.getArguments(); |
| if (operands().size() != computationArgs.size()) |
| return emitOpError() << "expects number of operands to match the arity " |
| "of map computation, but got: " |
| << operands().size() << " and " |
| << computationArgs.size(); |
| |
| // The parameters of computation should all be scalars and match the element |
| // type of operands. |
| for (const auto& indexedArg : llvm::enumerate(computationArgs)) { |
| auto argType = indexedArg.value().getType().dyn_cast<TensorType>(); |
| if (!argType || argType.getRank() != 0) |
| return emitOpError() |
| << "computation arguments must be 0-rank tensor, but got: arg #" |
| << indexedArg.index() << " of type " |
| << indexedArg.value().getType(); |
| auto operandElemTy = operands()[indexedArg.index()] |
| .getType() |
| .cast<TensorType>() |
| .getElementType(); |
| if (argType.getElementType() != operandElemTy) { |
| return emitOpError() |
| << "element type of operands and computation arguments must " |
| "match, but got: " |
| << operandElemTy << " and " << argType.getElementType(); |
| } |
| } |
| |
| // Mapped computation must return single output |
| auto computationOutputs = computationBlock.getTerminator()->getOperands(); |
| if (computationOutputs.size() != 1) |
| return emitOpError() << "computation must return single output, but got: " |
| << computationOutputs.size(); |
| |
| // The output of computation must be scalar and have the same element type |
| // as op result. |
| auto computationOutputType = |
| computationOutputs[0].getType().dyn_cast<TensorType>(); |
| if (!computationOutputType || computationOutputType.getRank() != 0) |
| return emitOpError() << "computation must return 0-rank tensor, but got: " |
| << computationOutputs[0].getType(); |
| |
| auto resultType = getType().cast<TensorType>(); |
| if (computationOutputType.getElementType() != resultType.getElementType()) |
| return emitOpError() << "element type of result and computation output " |
| "must match, but got: " |
| << resultType.getElementType() << " and " |
| << computationOutputType.getElementType(); |
| |
| // Checks that the requested map dimension numbers are monotonically |
| // increasing. |
| DenseIntElementsAttr dimensions = this->dimensions(); |
| for (const auto& indexedValue : |
| llvm::enumerate(dimensions.getValues<int64_t>())) { |
| if (indexedValue.value() != static_cast<int64_t>(indexedValue.index())) |
| return emitOpError() << "requires monotonically increasing dimension " |
| "numbers, but got: " |
| << dimensions; |
| } |
| |
| // Checks that number of dimensions of operands matches the size of |
| // `dimensions` since we currently only support mapping across all |
| // dimensions: i.e., scalar map functions. |
| auto operandType = operands()[0].getType().cast<TensorType>(); |
| if (operandType.hasRank()) { |
| if (dimensions.size() != |
| static_cast<int64_t>(operandType.getShape().size())) |
| return emitOpError() |
| << "applied to a subset of dimensions currently not supported: " |
| "operand dimensions = " |
| << operandType.getShape().size() |
| << ", requested map dimensions size = " << dimensions.size(); |
| } |
| |
| return success(); |
| } |
| |
| OpFoldResult MapOp::fold(ArrayRef<Attribute> operands) { |
| mlir::Block& bb = computation().front(); |
| mlir::Operation& frontOp = bb.front(); |
| |
| auto retOp = mlir::dyn_cast<ReturnOp>(frontOp); |
| if (!retOp) return nullptr; |
| if (retOp.results().size() != 1) return nullptr; |
| |
| for (mlir::BlockArgument barg : bb.getArguments()) { |
| if (barg == retOp.results()[0]) return getOperands()[barg.getArgNumber()]; |
| } |
| return nullptr; |
| } |
| |
| LogicalResult MapOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| return deriveShapeFromOperand(&builder, getOperation(), operands.front(), |
| &reifiedReturnShapes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RecvOp |
| //===----------------------------------------------------------------------===// |
| |
| // Checks that the result type is of the form `zero_or_more_type(s), |
| // mhlo::token` |
| LogicalResult RecvOp::verify() { |
| auto resultTypes = getResultTypes(); |
| if (resultTypes.empty()) |
| return emitOpError() |
| << "result is expected to be at least of size 1, but got " |
| << resultTypes.size(); |
| if (!resultTypes[resultTypes.size() - 1].isa<TokenType>()) |
| return emitOpError() << "last element of result types is expected to " |
| "be of token type, but got " |
| << resultTypes[resultTypes.size() - 1]; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CopyOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); } |
| |
| //===----------------------------------------------------------------------===// |
| // ReduceWindowOp |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| // Infer the return-type of ReduceWindowOp. |
| SmallVector<TensorType> inferReduceWindowOpReturnType( |
| ArrayRef<TensorType> inputTypes, ArrayRef<TensorType> initTypes, |
| const ArrayRef<WindowDimension> window) { |
| SmallVector<TensorType> outputTypes; |
| for (size_t i = 0; i < inputTypes.size(); ++i) { |
| if (!inputTypes[i].hasRank()) { |
| outputTypes.push_back( |
| UnrankedTensorType::get(initTypes[i].getElementType())); |
| continue; |
| } |
| |
| outputTypes.push_back(RankedTensorType::get( |
| inferWindowOutputShape(inputTypes[i].getShape(), window), |
| initTypes[i].getElementType())); |
| } |
| |
| return outputTypes; |
| } |
| } // namespace |
| |
| // We intend to verify the following properties |
| // P1. The sizes of 'inputs' and 'init_values' must be at least 1. |
| // P2. All `inputs` need to have compatible shapes. |
| // P3. size-of(window_dimension) == rank-of(input), |
| // where input is an element of 'inputs'. |
| // P4. Verify and collect the window atributes. |
| // P5. Verify the inner block defining the reducer function. |
| // P6. Verify the return type. |
| LogicalResult ReduceWindowOp::verify() { |
| // P1. |
| // Note that the ODS ensures that there are even number of operands; Check if |
| // that number is not zero. |
| if (getOperands().empty()) |
| return emitOpError() << "expects the size of operands to be >= 2."; |
| |
| // Collect the input and init-value operands. Note that the operand-type is |
| // enforced as "TensorType" by ODS. |
| int64_t numInputs = getNumOperands() / 2; |
| auto operandTensorTypes = llvm::to_vector<4>(llvm::map_range( |
| getOperandTypes(), |
| [](Type t) -> TensorType { return t.cast<TensorType>(); })); |
| ArrayRef<TensorType> inputTypes(operandTensorTypes.begin(), |
| operandTensorTypes.begin() + numInputs); |
| ArrayRef<TensorType> initTypes(operandTensorTypes.begin() + numInputs, |
| operandTensorTypes.end()); |
| |
| // P2. |
| if (failed(verifyCompatibleShapes(operands().getTypes()))) |
| return emitOpError() << "requires same shape for all inputs"; |
| |
| // P3. |
| SmallVector<int64_t> windowDims = |
| convertDenseIntAttr(this->window_dimensions()); |
| for (const auto inputType : inputTypes) { |
| if (!inputType.hasRank()) continue; |
| if (inputType.getRank() != static_cast<int64_t>(windowDims.size())) |
| return emitOpError() |
| << "expects window-dimensions size == input rank, but got " |
| "window-dimensions size: " |
| << windowDims.size() << " and input: " << inputType |
| << " with rank = " << inputType.getRank() << "."; |
| } |
| |
| // P4. |
| auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc()); |
| if (failed(paddingOrErr)) return failure(); |
| SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr; |
| |
| auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions( |
| windowDims, convertDenseIntAttr(window_strides()), padding, |
| /*lhs_dilation=*/convertDenseIntAttr(base_dilations()), |
| /*rhs_dilation=*/convertDenseIntAttr(this->window_dilations()), getLoc()); |
| if (failed(windowOrErr)) return failure(); |
| |
| // P5. |
| bool allInputsUnranked = |
| llvm::all_of(inputTypes, [](TensorType t) { return !t.hasRank(); }); |
| |
| Block& block = body().front(); |
| SmallVector<TensorType> accumulatorSubshapes; |
| if (failed(verifyReducerShape(this->getLoc(), block, inputTypes, initTypes, |
| numInputs, windowDims, allInputsUnranked, |
| accumulatorSubshapes))) |
| return failure(); |
| |
| // P6. |
| if (numInputs != getNumResults()) |
| return emitOpError() << "expects " << numInputs |
| << " result values, but got " << getNumResults() |
| << "."; |
| |
| // The result-type is enforced as "TensorType" by ODS. |
| auto resultTensorTypes = llvm::to_vector<4>(llvm::map_range( |
| getResultTypes(), |
| [](Type t) -> TensorType { return t.cast<TensorType>(); })); |
| |
| // Check if the element-type of results match with the ones derived from |
| // the reducer-block. Already ensured that |accumulator_subshapes| == |
| // num_inputs == num_of_results. |
| for (int64_t shapeIdx = 0; |
| shapeIdx < static_cast<int64_t>(accumulatorSubshapes.size()); |
| shapeIdx++) { |
| if (accumulatorSubshapes[shapeIdx].getElementType() != |
| resultTensorTypes[shapeIdx].getElementType()) { |
| return emitError() |
| << "expects the element-type of reduce-op's return-value at index " |
| << shapeIdx |
| << " to match the element-type of reducer-block's " |
| "corresponding return-value, but got " |
| << resultTensorTypes[shapeIdx].getElementType() << " and " |
| << accumulatorSubshapes[shapeIdx].getElementType() << " resp."; |
| } |
| } |
| |
| // Check if the shape of results match with the ones derived from |
| // the input-types and wndow-attributes. |
| auto inferredReturnTypes = inferReduceWindowOpReturnType( |
| inputTypes, accumulatorSubshapes, *windowOrErr); |
| |
| for (size_t i = 0; i < getNumResults(); i++) { |
| if (failed(verifyCompatibleShape(resultTensorTypes[i], |
| inferredReturnTypes[i]))) { |
| return emitOpError() |
| << "expects result at index " << i |
| << " to have compatible shape with the corresponding " |
| "inferred type, but got " |
| << resultTensorTypes[i] << " and " << inferredReturnTypes[i] |
| << " resp."; |
| } |
| } |
| |
| return success(); |
| } |
| |
| // Get the operation used for reduction applied to `result_index`th result. Its |
| // expected to be a binary operation that consumes `result_index`th and |
| // `result_index + operands().size`th arguments of the body. |
| Operation* ReduceWindowOp::getReductionOp(int resultIndex) { |
| auto returnOp = cast<ReturnOp>(body().front().getTerminator()); |
| Operation* computeOp = returnOp.results()[resultIndex].getDefiningOp(); |
| if (computeOp->getNumOperands() != 2) return nullptr; |
| auto arg0 = computeOp->getOperand(0).dyn_cast<BlockArgument>(); |
| auto arg1 = computeOp->getOperand(1).dyn_cast<BlockArgument>(); |
| if (!arg0 || !arg1) return nullptr; |
| int64_t arg0Num = arg0.getArgNumber(); |
| int64_t arg1Num = arg1.getArgNumber(); |
| int64_t otherArgIndex = resultIndex + operands().size(); |
| if (arg0Num == resultIndex && arg1Num == otherArgIndex) return computeOp; |
| if (arg0Num == otherArgIndex && arg1Num == resultIndex && |
| computeOp->hasTrait<mlir::OpTrait::IsCommutative>()) |
| return computeOp; |
| return nullptr; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ReducePrecisionOp |
| //===----------------------------------------------------------------------===// |
| |
| // The following property is already enforced by the ODS: |
| // P0. operand element type is float |
| // P1. mantissa_bits >= 0 |
| // We intend to verify the following properties |
| // P2. exponent_bits >= 1 |
| LogicalResult ReducePrecisionOp::verify() { |
| if (exponent_bits() < 1) { |
| return emitOpError() << "exponent_bits must be at least 1."; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ReverseOp |
| //===----------------------------------------------------------------------===// |
| |
| template <typename T> |
| static Attribute foldReverseHelper(DenseElementsAttr& attr, ShapedType& type, |
| DenseIntElementsAttr& dims) { |
| int64_t numElements = attr.getNumElements(); |
| // No-op if the tensor has 0 elements. |
| // No-op if the result of folding is too large. |
| if (numElements == 0 || numElements > kFoldOpEltLimit) return {}; |
| |
| SmallVector<T> result(attr.getValues<T>().begin(), attr.getValues<T>().end()); |
| |
| size_t rank = type.getRank(); |
| SmallVector<int64_t> stride(rank + 1, numElements); |
| for (size_t i = 0; i < rank; i++) { |
| if (type.getDimSize(i) == 0) return {}; |
| stride[i + 1] = stride[i] / type.getDimSize(i); |
| } |
| |
| for (auto dim : dims.getValues<int64_t>()) { |
| // For example, given: |
| // * tensor: tensor<2x3x2xi32> |
| // [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9,10], [11, 12]]] |
| // * dim: [1] |
| // |
| // We're going to reverse the tensor with respect to dim as follows: |
| // 1) Split the tensor into blocks, i.e. smaller tensors whose type is |
| // derived from the tensor by dropping the first `dim` dimensions, i.e. |
| // tensor<3x2xi32> for the running example. |
| // 2) Split each block into windows, i.e. even smaller tensors whose type |
| // is derived from the block by dropping the first dimension of the |
| // block, i.e. tensor<2xi32> for the running example. |
| // 3) Within each block, swap windows but don't change the order of |
| // elements within the windows: 0th window goes to N-1st spot, 1st window |
| // goes to N-2nd spot etc. |
| // |
| // For the running example, the result will be: |
| // [[[5, 6], [3, 4], [1, 2]], [[11, 12], [9, 10], [7, 8]]]. |
| // |
| // Note how elements within windows haven't changed their order with respect |
| // to each other and how blocks haven't changed their order with respect to |
| // each other. |
| int64_t numWindows = type.getDimSize(dim); |
| int64_t windowSize = stride[dim] / numWindows; |
| |
| for (int64_t index = 0; index < numElements; index++) { |
| int64_t blockNumber = index / stride[dim]; |
| int64_t windowNumber = (index % stride[dim]) / windowSize; |
| int64_t reversedWindowNumber = numWindows - windowNumber - 1; |
| if (windowNumber >= reversedWindowNumber) continue; |
| int64_t reversedIndex = blockNumber * stride[dim] + |
| reversedWindowNumber * windowSize + |
| index % windowSize; |
| std::swap(result[index], result[reversedIndex]); |
| } |
| } |
| return DenseElementsAttr::get(type, result); |
| } |
| |
| OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) { |
| Value input = operand(); |
| |
| // No dimensions to reverse. |
| DenseIntElementsAttr dims = dimensions(); |
| if (dims.getNumElements() == 0) return input; |
| |
| // If size of all dimensions to reverse equals 1, then the reverse is a no-op. |
| // Eg. Reverse dimensions {0,1} of a 1x1x2 tensor |
| auto shapedType = input.getType().cast<ShapedType>(); |
| if (llvm::all_of(dims.getValues<int64_t>(), [&](int64_t dim) { |
| return shapedType.getDimSize(dim) == 1; |
| })) |
| return input; |
| |
| // If the operand is a static shaped tensor of constants, return reversed |
| // tensor |
| DenseElementsAttr inputAttr = |
| operands.begin()->dyn_cast_or_null<DenseElementsAttr>(); |
| if (inputAttr && shapedType.hasStaticShape()) { |
| auto etype = shapedType.getElementType(); |
| if (etype.isa<IntegerType>()) |
| return foldReverseHelper<APInt>(inputAttr, shapedType, dims); |
| if (etype.isa<FloatType>()) |
| return foldReverseHelper<APFloat>(inputAttr, shapedType, dims); |
| } |
| |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ReduceOp |
| //===----------------------------------------------------------------------===// |
| |
| // Returns the result type after reducing operand of the given type across the |
| // specified dimensions. |
| static TensorType getReduceResultType(Type operandTy, |
| DenseIntElementsAttr dimensions, |
| Builder* builder) { |
| Type elementTy = getElementTypeOrSelf(operandTy); |
| |
| auto rankedTy = operandTy.dyn_cast<RankedTensorType>(); |
| if (!rankedTy) return UnrankedTensorType::get(elementTy); |
| |
| int64_t rank = rankedTy.getRank(); |
| llvm::SmallVector<bool, 4> dimsMask(rank, false); |
| for (int64_t dim : dimensions.getValues<int64_t>()) dimsMask[dim] = true; |
| |
| SmallVector<int64_t, 4> shape; |
| for (int64_t i = 0; i < rank; ++i) { |
| if (!dimsMask[i]) shape.push_back(rankedTy.getDimSize(i)); |
| } |
| |
| return RankedTensorType::get(shape, elementTy); |
| } |
| |
| void ReduceOp::build(OpBuilder& builder, OperationState& state, |
| ValueRange operands, ValueRange initValues, |
| DenseIntElementsAttr dimensions) { |
| SmallVector<Type, 1> resultTy; |
| resultTy.reserve(operands.size()); |
| |
| for (Value operand : operands) { |
| resultTy.push_back( |
| getReduceResultType(operand.getType(), dimensions, &builder)); |
| } |
| build(builder, state, resultTy, operands, initValues, dimensions); |
| } |
| |
| LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult>& results) { |
| // No dimensions to reduce. |
| if (dimensions().getNumElements() == 0) { |
| for (Value operand : this->operands()) { |
| results.push_back(operand); |
| } |
| return success(); |
| } |
| |
| // If all returned values in the ReduceOp region exists outside |
| // the region replace the ReduceOp with those values. |
| mlir::Block& bb = this->body().front(); |
| SmallVector<Value> replacedResults; |
| if (auto retOp = mlir::dyn_cast<ReturnOp>(bb.back())) { |
| for (Value result : retOp.results()) { |
| if (result.getParentRegion() == retOp->getParentRegion()) |
| return failure(); |
| replacedResults.push_back(result); |
| } |
| |
| results.insert(results.end(), replacedResults.begin(), |
| replacedResults.end()); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| |
| bool hasSameOperandAndResultTypes(Operation& op) { |
| Type expected; |
| if (op.getNumResults() != 0) expected = op.getResult(0).getType(); |
| if (op.getNumOperands() != 0) expected = op.getOperand(0).getType(); |
| if (!expected) return false; |
| |
| auto typeMatch = [&](Type actual) { return actual == expected; }; |
| return llvm::all_of(op.getOperandTypes(), typeMatch) && |
| llvm::all_of(op.getResultTypes(), typeMatch); |
| } |
| |
| // Checks the following eligibility criteria for compact printing of |
| // mhlo.reduce: |
| // E1. The reduce-op wraps a single inner-op in the associated region. |
| // E2. The single operation is a commutative binary-op from mhlo dialect, zero |
| // region, producing single result such that the operands and result all |
| // have the same type. |
| // E3. The reduce-op consist of at least one input-operand; The operand-types of |
| // inner-op should be derived trivially from the element-type of reduce-op's |
| // first input-operand. |
| // E4. The arguments of the region's only basic block are forwarded perfectly |
| // to inner-op's operands. |
| // E5. The reduce-op, inner-op, blocks arguments, and the return-op all have the |
| // same location. |
| // E6. The single operation result is perfectly forwarded to the reduce op |
| // return. |
| static bool isEligibleForCompactPrint(ReduceOp op) { |
| // Check E1. |
| auto& block = op.body().front(); |
| if (!hasSingleElement(block.without_terminator())) return false; |
| |
| Operation& innerOp = *block.begin(); |
| |
| // Check E2. |
| if (innerOp.getDialect() != op->getDialect()) return false; |
| |
| if (innerOp.getNumOperands() != 2 || |
| !innerOp.hasTrait<mlir::OpTrait::OneResult>() || |
| !hasSameOperandAndResultTypes(innerOp) || |
| !innerOp.hasTrait<mlir::OpTrait::IsCommutative>() || |
| !innerOp.hasTrait<mlir::OpTrait::ZeroRegions>()) |
| return false; |
| |
| // Check E3. |
| if (op.operands().empty()) return false; |
| |
| auto elemType = |
| op.operands()[0].getType().cast<TensorType>().getElementType(); |
| auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType); |
| if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false; |
| |
| // Check E4. |
| if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false; |
| |
| // Check E5. |
| auto retOp = dyn_cast<ReturnOp>(block.getTerminator()); |
| if (!retOp) return false; |
| |
| auto blockArgLoc = block.getArgument(0).getLoc(); |
| if (blockArgLoc != block.getArgument(1).getLoc()) return false; |
| |
| if (innerOp.getLoc() != op.getLoc() || retOp.getLoc() != op.getLoc() || |
| blockArgLoc != op.getLoc()) |
| return false; |
| |
| // Check E6. |
| return llvm::equal(innerOp.getResults(), retOp.getOperands()); |
| } |
| |
| void ReduceOp::print(OpAsmPrinter& p) { |
| { |
| // Print the pairs of operands under the form: |
| // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) |
| StringRef comma = ""; |
| int numOperandPairs = getNumOperands() / 2; |
| for (int opId : llvm::seq<int>(0, numOperandPairs)) { |
| p << comma << "(" << getOperand(opId) |
| << " init: " << getOperand(opId + numOperandPairs) << ")"; |
| comma = ", "; |
| } |
| } |
| |
| // If the reduce-op is eligible for compact printing, we emit the one-liner: |
| // mhlo.reduce applies <inner-op> across dimensions = [...] : <func-type> |
| // Note: We are not printing the function type of reduction operation. We |
| // have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3) |
| // to derive the type from that of reduce-op. |
| if (isEligibleForCompactPrint(*this)) { |
| Operation& innerOp = body().front().front(); |
| p << " applies "; |
| printEscapedString(innerOp.getName().getStringRef(), p.getStream()); |
| |
| p << " across dimensions = ["; |
| llvm::interleaveComma(dimensions().getValues<int64_t>(), p); |
| p << "]"; |
| p << " : "; |
| p.printFunctionalType(*this); |
| } else { |
| p << " across dimensions = ["; |
| llvm::interleaveComma(dimensions().getValues<int64_t>(), p); |
| p << "]"; |
| p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"}); |
| p << " : "; |
| p.printFunctionalType(*this); |
| p.printNewline(); |
| p << " reducer"; |
| { |
| // Print the pairs of block operands under the form: |
| // (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc): |
| Block& reducer = body().front(); |
| int numOperandPairs = getNumOperands() / 2; |
| for (int opId : llvm::seq<int>(0, numOperandPairs)) { |
| p << "("; |
| p.printRegionArgument(reducer.getArgument(opId)); |
| p << ", "; |
| p.printRegionArgument(reducer.getArgument(opId + numOperandPairs)); |
| p << ") "; |
| } |
| } |
| p << ' '; |
| p.printRegion(body(), /*printEntryBlockArgs=*/false); |
| } |
| } |
| |
| ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) { |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| Location currLocation = parser.getEncodedSourceLoc(loc); |
| |
| // Parse the operands of reduce-op, this is a list of pair under the form: |
| // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) |
| // Each input to reduce is paired with its init value, even though in memory |
| // they are stored with the input first and the init values after. |
| SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; |
| SmallVector<OpAsmParser::UnresolvedOperand, 2> initOperands; |
| do { |
| (void)parser.parseOptionalComma(); |
| if (parser.parseOptionalLParen()) break; |
| OpAsmParser::UnresolvedOperand operand, initOperand; |
| if (parser.parseOperand(operand) || parser.parseKeyword("init") || |
| parser.parseColon() || parser.parseOperand(initOperand) || |
| parser.parseRParen()) |
| return failure(); |
| operands.push_back(operand); |
| initOperands.push_back(initOperand); |
| } while (true); |
| operands.append(initOperands); |
| |
| // Check if we are parsing the compact version of reduce-op: |
| // mhlo.reduce applies <inner-op> across dimensions = [...] : <func-type> |
| // else parse the "region-based" variant. |
| if (failed(parser.parseOptionalKeyword("applies"))) { |
| // Parse the inner-op dimensions, reduce-op's function-type and |
| // optional location. |
| SmallVector<int64_t> dimensions; |
| auto parseDim = [&]() -> ParseResult { |
| if (parser.parseInteger(dimensions.emplace_back())) return failure(); |
| return success(); |
| }; |
| |
| FunctionType reduceOpFntype; |
| if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || |
| parser.parseEqual() || |
| parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, |
| parseDim) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColon() || parser.parseType(reduceOpFntype) || |
| parser.parseKeyword("reducer")) |
| return failure(); |
| OpBuilder builder(parser.getBuilder().getContext()); |
| result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions)); |
| |
| // Parse the "reducer" region now. |
| SmallVector<OpAsmParser::UnresolvedOperand, 2> reducerOperands; |
| SmallVector<OpAsmParser::UnresolvedOperand, 2> reducerInitOperands; |
| SmallVector<Type, 2> reducerTypes; |
| SmallVector<Type, 2> reducerInitTypes; |
| SmallVector<Optional<Location>, 2> reducerLocs; |
| SmallVector<Optional<Location>, 2> reducerInitLocs; |
| auto parseBlockOperand = |
| [&](SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands, |
| SmallVectorImpl<Type>& types, |
| SmallVectorImpl<Optional<Location>>& locs) -> ParseResult { |
| OpAsmParser::UnresolvedOperand operand; |
| Type type; |
| Optional<Location> loc; |
| if (parser.parseOperand(operand, /*allowResultNumber=*/false) || |
| parser.parseColon() || parser.parseType(type) || |
| parser.parseOptionalLocationSpecifier(loc)) |
| return failure(); |
| operands.push_back(operand); |
| types.push_back(type); |
| locs.push_back(loc); |
| return success(); |
| }; |
| do { |
| if (failed(parser.parseOptionalLParen())) break; |
| if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) || |
| parser.parseComma() || |
| parseBlockOperand(reducerInitOperands, reducerInitTypes, |
| reducerInitLocs) || |
| parser.parseRParen()) |
| return failure(); |
| } while (true); |
| reducerOperands.append(reducerInitOperands); |
| reducerTypes.append(reducerInitTypes); |
| reducerLocs.append(reducerInitLocs); |
| result.addTypes(reduceOpFntype.getResults()); |
| SmallVector<OpAsmParser::Argument> reducerArgs; |
| createArgs(reducerOperands, reducerTypes, reducerArgs); |
| |
| // Derive the SSA-values for reduce-op's operands and parse the region, and |
| // the optional trailing location. |
| Optional<Location> trailingLoc; |
| if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc, |
| result.operands) || |
| parser.parseRegion(*result.addRegion(), reducerArgs)) |
| return failure(); |
| // Set the individual block arguments. |
| for (auto argAndLoc : |
| llvm::zip(result.regions.front()->front().getArguments(), reducerLocs)) |
| if (std::get<1>(argAndLoc)) |
| std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).getValue()); |
| result.location = trailingLoc.getValueOr(currLocation); |
| return success(); |
| } |
| |
| // Parse the inner-op name and check if the contract on inner-op |
| // mentioned in "isEligibleForCompactPrint::E2" for pretty-priting is met. |
| FailureOr<OperationName> innerOpNameInfo = parser.parseCustomOperationName(); |
| if (failed(innerOpNameInfo)) return failure(); |
| |
| StringRef innerOpName = innerOpNameInfo->getStringRef(); |
| Dialect* innerOpDialect = innerOpNameInfo->getDialect(); |
| if (!innerOpDialect || !innerOpDialect->getNamespace().equals("mhlo") || |
| !innerOpNameInfo->hasTrait<mlir::OpTrait::NOperands<2>::Impl>() || |
| !innerOpNameInfo->hasTrait<mlir::OpTrait::OneResult>() || |
| !innerOpNameInfo->hasTrait<mlir::OpTrait::IsCommutative>() || |
| !innerOpNameInfo->hasTrait<mlir::OpTrait::ZeroRegions>()) { |
| parser.emitError(loc, |
| "expected the inner-op to be a commutative binary-op from " |
| "mhlo dialect, zero region, producing single result"); |
| return failure(); |
| } |
| |
| // Parse the inner-op dimensions, reduce-op's function-type and |
| // optional location. |
| SmallVector<int64_t> dimensions; |
| auto parseDim = [&]() -> ParseResult { |
| if (parser.parseInteger(dimensions.emplace_back())) return failure(); |
| return success(); |
| }; |
| |
| Optional<Location> explicitLoc; |
| FunctionType reduceOpFntype; |
| if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || |
| parser.parseEqual() || |
| parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) || |
| parser.parseColon() || parser.parseType(reduceOpFntype) || |
| parser.parseOptionalLocationSpecifier(explicitLoc)) |
| return failure(); |
| |
| if (!reduceOpFntype || reduceOpFntype.getInputs().empty()) { |
| if (!reduceOpFntype) return parser.emitError(loc, "expected function type"); |
| return parser.emitError(loc, |
| "input types missing in reduce-op function type"); |
| } |
| |
| // If location of reduce-op is explicitly provided, then use it; Else use |
| // the parser's current location. |
| Location reduceOpLoc = explicitLoc.getValueOr(currLocation); |
| |
| // Derive the SSA-values for reduce-op's operands. |
| if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc, |
| result.operands)) |
| return failure(); |
| |
| // Derive the type of inner-op from that of reduce-op's input operand. |
| auto innerOpType = RankedTensorType::get( |
| /*shape=*/{}, getElementTypeOrSelf(reduceOpFntype.getInput(0))); |
| |
| // Add a region for reduce-op. |
| Region& region = *result.addRegion(); |
| |
| // Create a basic-block inside reduce-op's region. |
| Block& block = region.emplaceBlock(); |
| auto lhs = block.addArgument(innerOpType, reduceOpLoc); |
| auto rhs = block.addArgument(innerOpType, reduceOpLoc); |
| |
| // Create and insert an "inner-op" operation in the block. |
| OpBuilder builder(parser.getBuilder().getContext()); |
| builder.setInsertionPointToStart(&block); |
| |
| OperationState innerOpState(reduceOpLoc, innerOpName); |
| innerOpState.operands.push_back(lhs); |
| innerOpState.operands.push_back(rhs); |
| innerOpState.addTypes(innerOpType); |
| |
| Operation* innerOp = builder.create(innerOpState); |
| |
| // Insert a return statement in the block returning the inner-op's result. |
| builder.create<ReturnOp>(innerOp->getLoc(), innerOp->getResults()); |
| |
| // Populate the reduce-op operation-state with result-type, location, and |
| // dimension attribute. |
| result.addTypes(reduceOpFntype.getResults()); |
| result.location = innerOp->getLoc(); |
| result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions)); |
| |
| return success(); |
| } |
| |
| LogicalResult ReduceOp::verify() { |
| // Check that there are even number of operands and >= 2. |
| if (getNumOperands() % 2 != 0 || getOperands().empty()) |
| return emitOpError() << "expects the size of operands to be even and >= 2"; |
| |
| // Collect the input and init-value operands. Note that the operand-type is |
| // enforced as "TensorType" by ODS. |
| int64_t numInputs = getNumOperands() / 2; |
| auto operandTensorTypes = llvm::to_vector<4>(llvm::map_range( |
| getOperandTypes(), |
| [](Type t) -> TensorType { return t.cast<TensorType>(); })); |
| ArrayRef<TensorType> inputArgTypes(operandTensorTypes.begin(), |
| operandTensorTypes.begin() + numInputs); |
| ArrayRef<TensorType> initValueTypes(operandTensorTypes.begin() + numInputs, |
| operandTensorTypes.end()); |
| |
| // Check for unranked tensors in input operands. |
| int64_t rankedInputIdx = -1; |
| for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) { |
| if (inputArgTypes[inputIdx].hasRank()) { |
| rankedInputIdx = inputIdx; |
| break; |
| } |
| } |
| |
| bool allInputsUnranked = (rankedInputIdx == -1); |
| |
| // Check that all input operands have compatible shapes. The element types may |
| // be different. |
| if (!allInputsUnranked) { |
| for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) { |
| if (failed(mlir::verifyCompatibleShape(inputArgTypes[rankedInputIdx], |
| inputArgTypes[inputIdx]))) { |
| return emitOpError() |
| << "expects all inputs to have compatible shapes. Shape at" |
| << " input-index " << inputIdx |
| << " is not compatible with shape at input-index " |
| << rankedInputIdx; |
| } |
| } |
| } |
| |
| // Check that |
| // 1. the dimensions of reduce-op are in-bounds for the given shape. |
| // 2. the dimension-attribute have no duplicate entries. |
| DenseSet<int64_t> dimensionsToReduceSet; |
| for (int64_t dimension : dimensions().getValues<int64_t>()) { |
| if ((!allInputsUnranked && |
| dimension >= inputArgTypes[rankedInputIdx].getRank()) || |
| dimension < 0) { |
| return emitError() << "Out-of-bounds dimension " << dimension |
| << " for input-tensor rank: " |
| << inputArgTypes[rankedInputIdx].getRank(); |
| } |
| |
| if (!dimensionsToReduceSet.insert(dimension).second) { |
| return emitError() << "Duplicate reduction dimension: " << dimension; |
| } |
| } |
| |
| // Verify the inner block defining the reducer function. |
| SmallVector<int64_t> newDimensions; |
| if (!allInputsUnranked) { |
| for (int inputIdx = 0; inputIdx < inputArgTypes[rankedInputIdx].getRank(); |
| ++inputIdx) { |
| if (!dimensionsToReduceSet.count(inputIdx)) { |
| newDimensions.push_back( |
| inputArgTypes[rankedInputIdx].getDimSize(inputIdx)); |
| } |
| } |
| } |
| |
| Block& block = body().front(); |
| SmallVector<TensorType> accumulatorSubShapes; |
| if (failed(verifyReducerShape(this->getLoc(), block, inputArgTypes, |
| initValueTypes, numInputs, newDimensions, |
| allInputsUnranked, accumulatorSubShapes))) |
| return failure(); |
| |
| // Check if the reduce-op's result-type matches with the one derived from |
| // the reducer-block and dimensions attribute. |
| if (getResults().size() != accumulatorSubShapes.size()) |
| return emitError() << "Unexpected number of reduce-op's returned values: " |
| << getResults().size() << " vs " |
| << accumulatorSubShapes.size() << " (expected)"; |
| |
| for (int64_t shapeIdx = 0; |
| shapeIdx < static_cast<int64_t>(accumulatorSubShapes.size()); |
| shapeIdx++) { |
| // The result-type is enforced as "TensorType" by ODS. |
| auto opResultType = getResult(shapeIdx).getType().cast<TensorType>(); |
| |
| // Check element-type. |
| if (accumulatorSubShapes[shapeIdx].getElementType() != |
| opResultType.getElementType()) { |
| return emitError() |
| << "Unexpected element-type for reduce-op's return value at index " |
| << shapeIdx << ": " << opResultType.getElementType() << " vs " |
| << accumulatorSubShapes[shapeIdx].getElementType() |
| << " (expected)"; |
| } |
| |
| // Check shape. |
| if (!allInputsUnranked && opResultType.hasRank() && |
| failed(verifyCompatibleShape(newDimensions, opResultType.getShape()))) { |
| Type expectedResultType = RankedTensorType::get( |
| newDimensions, accumulatorSubShapes[shapeIdx].getElementType()); |
| return emitError() |
| << "Unexpected type for reduce-op's return value at index " |
| << shapeIdx << ": " << opResultType << " vs " << expectedResultType |
| << " (expected)"; |
| } |
| } |
| |
| return success(); |
| } |
| |
| // Enable constant folding to occur within the region of the ReduceOp |
| // by replacing block argument uses with constants if: |
| // 1. All the ReduceOp operands are splat constants. |
| // 2. The ReduceOp region consists of a single logical AND or logical OR. |
| // The pattern leverages the idempotent property of the AND and OR operators |
| // to determine the value of a reduction on splat constants. Other boolean |
| // operators do not have this property, and need separate patterns to resolve |
| // reductions of their splat constants. |
| struct LowerBoolSplatConstantsIntoRegion : public OpRewritePattern<ReduceOp> { |
| using OpRewritePattern<ReduceOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ReduceOp op, |
| PatternRewriter& rewriter) const override { |
| mlir::Block& bb = op.body().front(); |
| |
| // Ensure only a compute op and return op exist and the |
| // compute op is an AND or OR op. |
| if (bb.getOperations().size() != 2) return failure(); |
| if (!mlir::isa<mhlo::AndOp, mhlo::OrOp>(bb.front())) return failure(); |
| |
| // Ensure all operands are splat constants. |
| SmallVector<DenseElementsAttr, 4> bargCstAttrs; |
| for (auto inpAndBarg : llvm::zip(op.getOperands(), bb.getArguments())) { |
| Value inp = std::get<0>(inpAndBarg); |
| BlockArgument barg = std::get<1>(inpAndBarg); |
| ConstantOp cst = inp.getDefiningOp<ConstantOp>(); |
| if (!cst) return failure(); |
| |
| auto cstAttr = cst.value().dyn_cast_or_null<DenseElementsAttr>(); |
| if (!cstAttr.isSplat()) { |
| return rewriter.notifyMatchFailure(op, "Must be splat constant."); |
| } |
| |
| auto bargShapedType = barg.getType().dyn_cast<ShapedType>(); |
| if (!bargShapedType) return failure(); |
| |
| auto bargCstAttr = DenseElementsAttr::get( |
| bargShapedType, cstAttr.getSplatValue<mlir::Attribute>()); |
| bargCstAttrs.push_back(bargCstAttr); |
| } |
| |
| // Create new splat constants to replace block arguments. |
| for (BlockArgument barg : bb.getArguments()) { |
| int argIdx = barg.getArgNumber(); |
| mhlo::ConstantOp newCst = rewriter.create<mhlo::ConstantOp>( |
| bb.front().getLoc(), barg.getType(), bargCstAttrs[argIdx]); |
| barg.replaceAllUsesWith(newCst); |
| } |
| return success(); |
| } |
| }; |
| |
| void ReduceOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<LowerBoolSplatConstantsIntoRegion>(context); |
| } |
| |
| LogicalResult ReduceOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| ReduceOp::Adaptor adaptor(operands); |
| auto inputs = adaptor.operands(); |
| |
| auto operandType = inputs[0].getType().dyn_cast<RankedTensorType>(); |
| // Not support unranked type a.t.m. |
| if (!operandType) return failure(); |
| |
| Location loc = this->getLoc(); |
| SmallVector<Value, 4> shapeValues; |
| SmallVector<int64_t, 4> dimensions(this->dimensions().getValues<int64_t>()); |
| shapeValues.reserve(operandType.getRank()); |
| Type shapeScalarType = builder.getIndexType(); |
| auto toShapeScalarType = [&](Value v) { |
| return maybeCastTo(builder, loc, v, shapeScalarType); |
| }; |
| |
| for (const auto& element : llvm::enumerate(operandType.getShape())) { |
| int64_t idx = element.index(); |
| auto* it = std::find(dimensions.begin(), dimensions.end(), idx); |
| if (it != dimensions.end()) { |
| continue; |
| } |
| Value valueDim = toShapeScalarType( |
| builder.create<tensor::DimOp>(loc, inputs[0], element.index())); |
| shapeValues.push_back(valueDim); |
| } |
| |
| Value outputShape = builder.create<tensor::FromElementsOp>( |
| loc, |
| RankedTensorType::get({static_cast<int64_t>(shapeValues.size())}, |
| shapeScalarType), |
| shapeValues); |
| for (size_t i = 0; i < inputs.size(); ++i) { |
| reifiedReturnShapes.push_back(outputShape); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RngBitGeneratorOp |
| //===----------------------------------------------------------------------===// |
| |
| // Verify that input state has the same shape as output shape |
| LogicalResult RngBitGeneratorOp::verify() { |
| auto initialShape = initial_state().getType().dyn_cast<RankedTensorType>(); |
| auto outputShape = output_state().getType().dyn_cast<RankedTensorType>(); |
| if (initialShape.getShape() != outputShape.getShape()) |
| return emitOpError() |
| << "output state shape must match initial state shape. Got: " |
| << initialShape << " and " << outputShape; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RngOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult RngOp::verify() { |
| auto dist = rng_distribution(); |
| if (dist == RngDistribution::UNIFORM) { |
| return success(); |
| } |
| auto muTy = a().getType().cast<TensorType>().getElementType(); |
| auto sigmaTy = b().getType().cast<TensorType>().getElementType(); |
| if (muTy.isa<FloatType>() && sigmaTy.isa<FloatType>()) { |
| return success(); |
| } |
| return emitOpError() << "mu and sigma must be floats"; |
| } |
| |
| LogicalResult RngOp::inferReturnTypeComponents( |
| MLIRContext* context, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| return rngInferReturnTypeComponents(context, location, operands, attributes, |
| regions, inferredReturnShapes); |
| } |
| |
| LogicalResult RngOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| RngOp::Adaptor adaptor(operands); |
| reifiedReturnShapes.push_back( |
| castToIndexTensor(builder, getLoc(), adaptor.shape())); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XlaRngGetAndUpdateStateOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult XlaRngGetAndUpdateStateOp::verify() { |
| auto resultTy = getType().cast<RankedTensorType>(); |
| if (!resultTy) return emitOpError() << "Output is not ranked."; |
| if (!resultTy.hasStaticShape()) |
| return emitOpError() << "Output is not statically shaped."; |
| auto rank = resultTy.getRank(); |
| if (rank != 1) |
| return emitOpError() << "Output is of rank " << rank << " instead of 1"; |
| auto extent = resultTy.getDimSize(0); |
| if (extent != 2) |
| return emitOpError() << "Output size is " << extent << " instead of 2"; |
| |
| return success(); |
| } |
| |
| LogicalResult XlaRngGetAndUpdateStateOp::inferReturnTypes( |
| MLIRContext* ctx, Optional<Location>, ValueRange, DictionaryAttr, |
| RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) { |
| inferredReturnTypes.push_back(mlir::RankedTensorType::get( |
| {2}, mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned))); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SelectOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult SelectOp::verify() { |
| // The operands 'on_true' and 'on_false' should have compatible types, i.e., |
| // (a) have the same element type, and |
| // (b) have compatible shapes (i.e. the same shape and/or at least one |
| // dynamic shape) |
| if (!compatibleShapeAndElementType(on_true().getType(), on_false().getType())) |
| return emitOpError() |
| << "requires compatible types for non-predicate operands"; |
| |
| // The predicate, if not-scalar, should have the same shape as the remaining |
| // operands. |
| auto predTy = pred().getType().dyn_cast<RankedTensorType>(); |
| bool predMayBeScalar = !predTy || predTy.getRank() == 0; |
| if (predMayBeScalar) return success(); |
| |
| if (failed(verifyCompatibleShape(pred().getType(), on_true().getType()))) |
| return emitOpError() << "requires the same shape for all operands"; |
| |
| return success(); |
| } |
| |
| OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) { |
| if (on_true() == on_false()) { |
| return on_true(); |
| } |
| |
| auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); |
| if (!predicate) { |
| return {}; |
| } |
| |
| auto predicateTy = predicate.getType().cast<ShapedType>(); |
| if (!predicateTy.getElementType().isInteger(1)) { |
| return {}; |
| } |
| |
| if (predicate.isSplat()) { |
| return predicate.getSplatValue<APInt>().getBoolValue() ? on_true() |
| : on_false(); |
| } |
| |
| return {}; |
| } |
| |
| // simplify select(not(%pred), true_value, false_value) => select(%pred, |
| // false_value, true_value) |
| static LogicalResult selectCanonicalization(SelectOp selectOp, |
| PatternRewriter& rewriter) { |
| auto notOp = selectOp.pred().getDefiningOp<NotOp>(); |
| if (!notOp) { |
| return failure(); |
| } |
| std::array<Value, 3> newOperands = {notOp.operand(), selectOp.on_false(), |
| selectOp.on_true()}; |
| rewriter.updateRootInPlace( |
| selectOp, [&]() { selectOp.getOperation()->setOperands(newOperands); }); |
| return success(); |
| } |
| |
| void SelectOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* /*context*/) { |
| results.add(&selectCanonicalization); |
| } |
| |
| // Makes it such that a SelectOp that is a non-root operation in a DRR infers |
| // the return type based on operand type. |
| LogicalResult SelectOp::inferReturnTypeComponents( |
| MLIRContext*, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| SelectOp::Adaptor op(operands, attributes); |
| auto trueType = op.on_true().getType().cast<TensorType>(); |
| auto falseType = op.on_false().getType().cast<TensorType>(); |
| |
| // The output shape should be the most general of the operand shapes at each |
| // dimension. |
| ShapedTypeComponents& outputType = inferredReturnShapes.emplace_back(); |
| if (trueType == falseType || !trueType.hasRank()) { |
| outputType = ShapedTypeComponents(trueType.cast<ShapedType>()); |
| } else if (!falseType.hasRank()) { |
| outputType = ShapedTypeComponents(falseType.cast<ShapedType>()); |
| } else { |
| assert(trueType.getRank() == falseType.getRank()); |
| llvm::SmallVector<int64_t, 4> dims; |
| dims.reserve(trueType.getRank()); |
| for (auto dim : llvm::zip(trueType.getShape(), falseType.getShape())) { |
| dims.push_back(std::get<0>(dim) == std::get<1>(dim) |
| ? std::get<0>(dim) |
| : ShapedType::kDynamicSize); |
| } |
| outputType = ShapedTypeComponents(dims, trueType.getElementType()); |
| } |
| return success(); |
| } |
| |
| LogicalResult SelectOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| // For `hlo.select`, the first operand may be a scalar. |
| return deriveShapeFromOperand(&builder, getOperation(), operands[1], |
| &reifiedReturnShapes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SetDimensionSizeOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult SetDimensionSizeOp::verify() { |
| if (auto size = this->size().getType().dyn_cast<RankedTensorType>()) { |
| if (size.getRank() != 0) |
| return emitOpError() << "size operand should be of rank-0"; |
| } |
| |
| return verifyDimAttr(*this); |
| } |
| |
| OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) { |
| DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>(); |
| if (input) return input; |
| |
| DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>(); |
| if (!size || !size.isSplat()) return {}; |
| |
| auto ty = getType().dyn_cast<RankedTensorType>(); |
| if (!ty) return {}; |
| |
| int64_t dimSize = ty.getDimSize(dimension()); |
| if (dimSize == size.getSplatValue<IntegerAttr>().getInt()) return operand(); |
| return {}; |
| } |
| |
| // TODO(b/238903565): Switch to inferReturnTypeComponents after adding support |
| // for the encoding upstream. |
| LogicalResult SetDimensionSizeOp::inferReturnTypes( |
| MLIRContext* context, Optional<Location> location, ValueRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<Type>& inferredReturnTypes) { |
| Location loc = location.getValueOr(UnknownLoc::get(context)); |
| |
| SetDimensionSizeOp::Adaptor adaptor(operands, attributes, regions); |
| if (failed(adaptor.verify(loc))) return failure(); |
| |
| auto inputType = adaptor.operand().getType().dyn_cast<RankedTensorType>(); |
| if (!inputType) { |
| inferredReturnTypes.push_back(adaptor.operand().getType()); |
| return success(); |
| } |
| |
| int64_t dim = adaptor.dimension(); |
| int64_t rank = inputType.getRank(); |
| if (dim < 0 || dim >= rank) { |
| return mlir::emitError(loc) << "expects dimension to be in range [0, " |
| << rank << "); got: [" << dim << "]."; |
| } |
| |
| auto shape = llvm::to_vector<4>(inputType.getShape()); |
| llvm::SmallVector<int64_t, 4> bounds(rank, ShapedType::kDynamicSize); |
| if (auto encoding = |
| inputType.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>()) |
| bounds = llvm::to_vector<4>(encoding.getBounds()); |
| |
| // TODO(hinsu): Handle the case when the size operand is a constant. |
| if (shape[dim] != ShapedType::kDynamicSize) bounds[dim] = shape[dim]; |
| shape[dim] = ShapedType::kDynamicSize; |
| |
| auto extensions = TypeExtensionsAttr::get(context, bounds); |
| auto resultType = |
| RankedTensorType::get(shape, inputType.getElementType(), extensions); |
| inferredReturnTypes.push_back(resultType); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // PadOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult PadOp::inferReturnTypeComponents( |
| MLIRContext*, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| PadOp::Adaptor adaptor(operands, attributes, regions); |
| auto inputType = adaptor.operand().getType().cast<RankedTensorType>(); |
| auto padType = adaptor.padding_value().getType().cast<RankedTensorType>(); |
| |
| if (padType.getRank() != 0) { |
| return emitOptionalError( |
| location, llvm::formatv("padding value type should be a rank-0 " |
| "tensor, is rank {0}", |
| padType.getRank())); |
| } |
| |
| const auto& paddingLow = adaptor.edge_padding_low(); |
| if (paddingLow.getType().getNumElements() != inputType.getRank()) { |
| return emitOptionalError( |
| location, |
| llvm::formatv( |
| "edge_padding_low length ({0}) must match operand rank ({1})", |
| paddingLow.getType().getNumElements(), inputType.getRank())); |
| } |
| |
| const auto& paddingHigh = adaptor.edge_padding_high(); |
| if (paddingHigh.getType().getNumElements() != inputType.getRank()) { |
| return emitOptionalError( |
| location, |
| llvm::formatv( |
| "edge_padding_high length ({0}) must match operand rank ({1})", |
| paddingHigh.getType().getNumElements(), inputType.getRank())); |
| } |
| |
| const auto& paddingInterior = adaptor.interior_padding(); |
| if (paddingInterior.getType().getNumElements() != inputType.getRank()) { |
| return emitOptionalError( |
| location, |
| llvm::formatv( |
| "interior_padding length ({0}) must match operand rank ({1})", |
| paddingInterior.getType().getNumElements(), inputType.getRank())); |
| } |
| |
| auto inputShape = inputType.getShape(); |
| SmallVector<int64_t> resultShape; |
| for (int i = 0, e = inputShape.size(); i < e; i++) { |
| if (isDynamicDimSize(inputShape[i])) { |
| resultShape.push_back(ShapedType::kDynamicSize); |
| continue; |
| } |
| |
| int64_t paddingLowVal = paddingLow.getValues<APInt>()[i].getSExtValue(); |
| int64_t paddingHighVal = paddingHigh.getValues<APInt>()[i].getSExtValue(); |
| int64_t paddingInteriorVal = |
| paddingInterior.getValues<APInt>()[i].getSExtValue(); |
| if (paddingInteriorVal < 0) { |
| return emitOptionalError( |
| location, llvm::formatv("Interior padding cannot be negative: {0}", |
| paddingInteriorVal)); |
| } |
| int64_t expectedOutput = |
| inputShape[i] + paddingLowVal + paddingHighVal + |
| std::max<int64_t>(inputShape[i] - 1, 0LL) * paddingInteriorVal; |
| if (expectedOutput < 0) { |
| return emitOptionalError( |
| location, |
| llvm::formatv("Padding result in negative size for dimension {0}", |
| i)); |
| } |
| resultShape.push_back(expectedOutput); |
| } |
| inferredReturnShapes.emplace_back(resultShape, inputType.getElementType()); |
| |
| return success(); |
| } |
| |
| template <typename T> |
| OpFoldResult padOpFoldHelper(DenseElementsAttr input, DenseElementsAttr padding, |
| RankedTensorType returnType, |
| DenseIntElementsAttr edgePaddingLow, |
| DenseIntElementsAttr /*edgePaddingHigh*/, |
| DenseIntElementsAttr interiorPadding) { |
| // Prevent folding if the result is too large. |
| if (returnType.getNumElements() > kFoldOpEltLimit) return {}; |
| |
| // Fill the full result tensor with the padding value. |
| llvm::SmallVector<T, 4> result(returnType.getNumElements(), |
| padding.getValues<T>()[0]); |
| |
| auto nextIndex = [](llvm::SmallVector<uint64_t, 8>& index, |
| llvm::ArrayRef<int64_t> shape) { |
| for (int64_t i = index.size() - 1; i >= 0; --i) { |
| ++index[i]; |
| if (static_cast<int64_t>(index[i]) < shape[i]) return; |
| index[i] = 0; |
| } |
| }; |
| |
| // Iterate over all elements of the input tensor and copy it to the correct |
| // location in the output tensor. |
| llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0); |
| uint64_t numElements = input.getNumElements(); |
| for (uint64_t operandIdx = 0; operandIdx < numElements; operandIdx++) { |
| uint64_t resultIdx = 0; |
| uint64_t idxMultiplyer = 1; |
| for (int64_t i = index.size() - 1; i >= 0; --i) { |
| resultIdx += (edgePaddingLow.getValues<int64_t>()[i] + |
| index[i] * (interiorPadding.getValues<int64_t>()[i] + 1)) * |
| idxMultiplyer; |
| idxMultiplyer *= returnType.getDimSize(i); |
| } |
| result[resultIdx] = input.getValues<T>()[index]; |
| nextIndex(index, input.getType().getShape()); |
| } |
| return DenseElementsAttr::get(returnType, result); |
| } |
| |
| OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) { |
| // If all padding is zero then it is an identity pad. |
| auto isZero = [](const APInt& i) { return i == 0; }; |
| if (llvm::all_of(edge_padding_low().getValues<APInt>(), isZero) && |
| llvm::all_of(edge_padding_high().getValues<APInt>(), isZero) && |
| llvm::all_of(interior_padding().getValues<APInt>(), isZero)) |
| return operand(); |
| |
| // If any padding is negative then it isn't supported by the folder (yet). |
| auto isNegative = [](const APInt& i) { return i.slt(0); }; |
| if (llvm::any_of(edge_padding_low().getValues<APInt>(), isNegative) || |
| llvm::any_of(edge_padding_high().getValues<APInt>(), isNegative) || |
| llvm::any_of(interior_padding().getValues<APInt>(), isNegative)) |
| return {}; |
| |
| DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>(); |
| DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>(); |
| RankedTensorType returnType = getType().dyn_cast_or_null<RankedTensorType>(); |
| if (!input || !input.getType().hasRank() || !padding || !returnType || |
| !returnType.hasStaticShape()) |
| return {}; |
| |
| if (returnType.getElementType().isa<IntegerType>()) |
| return padOpFoldHelper<APInt>(input, padding, returnType, |
| edge_padding_low(), edge_padding_high(), |
| interior_padding()); |
| if (returnType.getElementType().isa<FloatType>()) |
| return padOpFoldHelper<APFloat>(input, padding, returnType, |
| edge_padding_low(), edge_padding_high(), |
| interior_padding()); |
| if (ComplexType complex = |
| returnType.getElementType().dyn_cast_or_null<ComplexType>()) { |
| // TODO(atondwal): Allow int types in HLO_complex |
| if (complex.getElementType().isa<FloatType>()) |
| return padOpFoldHelper<std::complex<APFloat>>( |
| input, padding, returnType, edge_padding_low(), edge_padding_high(), |
| interior_padding()); |
| } |
| return {}; |
| } |
| |
| LogicalResult PadOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| PadOp::Adaptor adaptor(operands, this->getOperation()->getAttrDictionary()); |
| auto loc = this->getLoc(); |
| Value operand = adaptor.operand(); |
| auto operandTy = operand.getType().cast<RankedTensorType>(); |
| |
| llvm::SmallVector<int32_t> padHigh; |
| llvm::SmallVector<int32_t> padLow; |
| llvm::SmallVector<int32_t> padInterior; |
| |
| auto padHighAttr = adaptor.edge_padding_high(); |
| auto padLowAttr = adaptor.edge_padding_low(); |
| auto padInteriorAttr = adaptor.interior_padding(); |
| |
| padHigh.reserve(padHighAttr.getNumElements()); |
| padLow.reserve(padLowAttr.getNumElements()); |
| padInterior.reserve(padInteriorAttr.getNumElements()); |
| |
| for (const APInt& val : padHighAttr.getValues<APInt>()) |
| padHigh.push_back(val.getSExtValue()); |
| |
| for (const APInt& val : padLowAttr.getValues<APInt>()) |
| padLow.push_back(val.getSExtValue()); |
| |
| for (const APInt& val : padInteriorAttr.getValues<APInt>()) |
| padInterior.push_back(val.getSExtValue()); |
| |
| Value one = builder.create<arith::ConstantIndexOp>(loc, 1).getResult(); |
| Value zero = builder.create<arith::ConstantIndexOp>(loc, 0).getResult(); |
| |
| llvm::SmallVector<Value> dimensions; |
| dimensions.reserve(operandTy.getRank()); |
| for (int i = 0, s = operandTy.getRank(); i < s; ++i) { |
| Value padEdge = |
| builder.create<arith::ConstantIndexOp>(loc, padHigh[i] + padLow[i]); |
| |
| // First we grab the initial interior size. |
| Value dim = builder.create<tensor::DimOp>(loc, operand, i).getResult(); |
| |
| // Compute the interior of the tensor and determine padding size. |
| if (padInterior[i] > 0) { |
| Value padInter = |
| builder.create<arith::ConstantIndexOp>(loc, padInterior[i]) |
| .getResult(); |
| Value interior = builder.create<arith::SubIOp>(loc, dim, one).getResult(); |
| interior = builder.create<arith::MaxSIOp>(loc, interior, zero); |
| interior = builder.create<arith::MulIOp>(loc, interior, padInter); |
| dim = builder.create<arith::AddIOp>(loc, dim, interior).getResult(); |
| } |
| |
| // Then we add the padding on the edge of the tensor. |
| dim = builder.create<arith::AddIOp>(loc, dim, padEdge).getResult(); |
| dimensions.push_back(dim); |
| } |
| |
| Value dimensionTensor = |
| builder.create<tensor::FromElementsOp>(loc, dimensions).getResult(); |
| reifiedReturnShapes.push_back(dimensionTensor); |
| return success(); |
| } |
| |
| // If the input tensor has a dimension of length-0, the input tensor is |
| // irrelevant. Instead we can broadcast the pad value to the output size rather |
| // than pad the input tensor. |
| struct PadEmptyTensor : public OpRewritePattern<PadOp> { |
| using OpRewritePattern<PadOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(PadOp op, |
| PatternRewriter& rewriter) const override { |
| auto operand = op.operand(); |
| auto padVal = op.padding_value(); |
| |
| auto operandTy = operand.getType().cast<RankedTensorType>(); |
| auto resultTy = op.getType().cast<RankedTensorType>(); |
| |
| if (llvm::all_of(operandTy.getShape(), [](int64_t d) { return d != 0; })) { |
| return failure(); |
| } |
| |
| if (resultTy.hasStaticShape()) { |
| auto dimsType = RankedTensorType::get({0}, rewriter.getIntegerType(64)); |
| auto dims = |
| DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{}); |
| rewriter.replaceOpWithNewOp<mhlo::BroadcastInDimOp>(op, resultTy, padVal, |
| dims); |
| return success(); |
| } |
| |
| llvm::SmallVector<Value> reifiedShapes; |
| if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(), |
| reifiedShapes))) |
| return failure(); |
| |
| auto dimsType = RankedTensorType::get({0}, rewriter.getIntegerType(64)); |
| auto broadcastDims = |
| DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{}); |
| rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>( |
| op, op.getType(), padVal, reifiedShapes.front(), broadcastDims); |
| |
| return failure(); |
| } |
| }; |
| |
| void PadOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<PadEmptyTensor>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DynamicPadOp |
| //===----------------------------------------------------------------------===// |
| |
| // If the input tensor has a dimension of length-0, the input tensor is |
| // irrelevant. Instead we can broadcast the pad value to the output size rather |
| // than pad the input tensor. |
| struct DynamicPadEmptyTensor : public OpRewritePattern<DynamicPadOp> { |
| using OpRewritePattern<DynamicPadOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(DynamicPadOp op, |
| PatternRewriter& rewriter) const override { |
| // auto loc = op.getLoc(); |
| auto operand = op.operand(); |
| auto padVal = op.padding_value(); |
| |
| auto operandTy = operand.getType().cast<RankedTensorType>(); |
| |
| if (llvm::all_of(operandTy.getShape(), [](int64_t d) { return d != 0; })) { |
| return failure(); |
| } |
| |
| llvm::SmallVector<Value> reifiedShapes; |
| if (failed(op.reifyReturnTypeShapes(rewriter, op->getOperands(), |
| reifiedShapes))) |
| return failure(); |
| |
| auto dimsType = RankedTensorType::get({0}, rewriter.getIntegerType(64)); |
| auto broadcastDims = |
| DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{}); |
| rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>( |
| op, op.getType(), padVal, reifiedShapes.front(), broadcastDims); |
| |
| return failure(); |
| } |
| }; |
| |
| void DynamicPadOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<DPadToPad, DynamicPadEmptyTensor>(context); |
| } |
| |
| LogicalResult DynamicPadOp::verify() { |
| auto inputType = operand().getType().dyn_cast<RankedTensorType>(); |
| // If operand is unranked, there is very little to verify statically. |
| if (!inputType) return success(); |
| int inputRank = inputType.getRank(); |
| |
| auto padType = padding_value().getType().cast<RankedTensorType>(); |
| if (padType.getRank() != 0) { |
| return emitOpError() << "padding value type should be a rank-0"; |
| } |
| |
| auto paddingLowType = edge_padding_low().getType().cast<RankedTensorType>(); |
| if (paddingLowType.getNumElements() != inputRank) { |
| return emitOpError() << "edge_padding_low length(" |
| << paddingLowType.getNumElements() |
| << ") must match operand rank(" << inputRank << ")."; |
| } |
| |
| auto paddingHighType = edge_padding_high().getType().cast<RankedTensorType>(); |
| if (paddingHighType.getNumElements() != inputRank) { |
| return emitOpError() << "edge_padding_high length(" |
| << paddingHighType.getNumElements() |
| << ") must match operand rank(" << inputRank << ")."; |
| } |
| |
| auto interiorPaddingType = |
| interior_padding().getType().cast<RankedTensorType>(); |
| if (interiorPaddingType.getNumElements() != inputRank) { |
| return emitOpError() << "edge_padding_interior length(" |
| << interiorPaddingType.getNumElements() |
| << ") must match operand rank(" << inputRank << ")."; |
| } |
| |
| auto outputType = getResult().getType().dyn_cast<RankedTensorType>(); |
| // If result is unranked, there is very little to verify statically. |
| if (!outputType) return success(); |
| int outputRank = outputType.getRank(); |
| if (inputRank != outputRank) { |
| return emitOpError() << "operand rank(" << inputRank |
| << ") must match result(" << outputRank << ")."; |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult DynamicPadOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| DynamicPadOp::Adaptor adaptor(operands); |
| Value operand = adaptor.operand(); |
| Value edgePaddingLow = adaptor.edge_padding_low(); |
| Value edgePaddingHigh = adaptor.edge_padding_high(); |
| Value interiorPadding = adaptor.interior_padding(); |
| |
| auto operandType = operand.getType().dyn_cast<RankedTensorType>(); |
| // Not support unranked pad a.t.m. |
| if (!operandType) return failure(); |
| |
| auto loc = this->getLoc(); |
| SmallVector<Value, 4> shapeValues; |
| shapeValues.reserve(operandType.getRank()); |
| Type shapeScalarType = |
| edgePaddingLow.getType().cast<ShapedType>().getElementType(); |
| |
| auto toShapeScalarType = [&](Value v) { |
| return maybeCastTo(builder, loc, v, shapeScalarType); |
| }; |
| |
| Value zero = |
| toShapeScalarType(builder.create<arith::ConstantIndexOp>(loc, 0)); |
| Value one = toShapeScalarType(builder.create<arith::ConstantIndexOp>(loc, 1)); |
| |
| for (int idx : llvm::seq<int>(0, operandType.getShape().size())) { |
| Value valueDim = |
| toShapeScalarType(builder.create<tensor::DimOp>(loc, operand, idx)); |
| Value offset = builder.create<arith::ConstantIndexOp>(loc, idx); |
| Value valueLow = |
| builder.create<tensor::ExtractOp>(loc, edgePaddingLow, offset); |
| Value valueHigh = |
| builder.create<tensor::ExtractOp>(loc, edgePaddingHigh, offset); |
| Value valueInterior = |
| builder.create<tensor::ExtractOp>(loc, interiorPadding, offset); |
| // output_size = input_size + padding_low + padding_high + interior * |
| // max(input_size - 1, 0) |
| Value valueDimLessThanOne = builder.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::slt, valueDim, one); |
| Value interiorSize = builder.create<arith::MulIOp>( |
| loc, valueInterior, |
| builder.create<mlir::arith::SelectOp>( |
| loc, valueDimLessThanOne, zero, |
| builder.create<arith::SubIOp>(loc, valueDim, one))); |
| shapeValues.push_back(builder.create<arith::AddIOp>( |
| loc, |
| builder.create<arith::AddIOp>( |
| loc, builder.create<arith::AddIOp>(loc, interiorSize, valueDim), |
| valueLow), |
| valueHigh)); |
| } |
| |
| reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>( |
| loc, |
| RankedTensorType::get({static_cast<int64_t>(shapeValues.size())}, |
| shapeScalarType), |
| shapeValues)); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ReshapeOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ReshapeOp::verify() { |
| // If the operand type is dynamically shaped there is nothing to verify. |
| auto operandTy = operand().getType().dyn_cast<RankedTensorType>(); |
| if (!operandTy || !operandTy.hasStaticShape()) return success(); |
| |
| // If the operand type is statically shaped (not required) the number of |
| // elements must match that of the result type. |
| auto resultTy = getType().cast<RankedTensorType>(); |
| assert(resultTy && resultTy.hasStaticShape() && |
| "result type must be statically shaped"); |
| int64_t numResultElements = resultTy.getNumElements(); |
| int64_t numOperandElements = operandTy.getNumElements(); |
| if (numResultElements != numOperandElements) |
| return emitOpError() << "number of output elements (" << numResultElements |
| << ") doesn't match expected number of elements (" |
| << numOperandElements << ")"; |
| |
| return success(); |
| } |
| |
| OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) { |
| if (getOperand().getType() == getType()) { |
| return getOperand(); |
| } |
| |
| if (auto prevOp = getOperand().getDefiningOp<ReshapeOp>()) { |
| setOperand(prevOp.getOperand()); |
| return getResult(); |
| } |
| |
| if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) { |
| return reshape(elements, getResult().getType().cast<ShapedType>()); |
| } |
| |
| return {}; |
| } |
| |
| void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<IdentityBroadcastReshape, IdentityBroadcastInDimReshape, |
| EliminateRedundantReshape, EliminateIdentityReshape>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ReplicaId Op |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ReplicaIdOp::inferReturnTypes( |
| MLIRContext* context, Optional<Location>, ValueRange operands, |
| DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) { |
| inferredReturnTypes.push_back(RankedTensorType::get( |
| /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned))); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AddDependency Op |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult AddDependencyOp::inferReturnTypes( |
| MLIRContext* context, Optional<Location>, ValueRange operands, |
| DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) { |
| inferredReturnTypes.push_back(operands.getTypes()[0]); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // If Op |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult verifyConditionalBranch(Operation* op, Region& region, |
| llvm::Twine branchName) { |
| if (region.getNumArguments() != 0) |
| return op->emitOpError() |
| << branchName << " must have 0 arguments, but found " |
| << region.getNumArguments(); |
| |
| TypeRange branchReturnTypes = |
| region.front().getTerminator()->getOperandTypes(); |
| if (branchReturnTypes != op->getResultTypes()) |
| return op->emitOpError() |
| << branchName << " returned types (" << branchReturnTypes |
| << ") do not match op result types (" << op->getResultTypes() << ")"; |
| |
| return success(); |
| } |
| |
| LogicalResult IfOp::verify() { |
| if (failed(verifyConditionalBranch(*this, true_branch(), |
| /*branchName=*/"true_branch"))) { |
| return failure(); |
| } |
| |
| if (failed(verifyConditionalBranch(*this, false_branch(), |
| /*branchName=*/"false_branch"))) { |
| return failure(); |
| } |
| return success(); |
| } |
| |
| static LogicalResult inlineIfConstantCondition(IfOp ifOp, |
| PatternRewriter& rewriter) { |
| DenseIntElementsAttr predAttr; |
| if (!matchPattern(ifOp.pred(), m_Constant(&predAttr))) return failure(); |
| |
| if (predAttr.getSplatValue<BoolAttr>().getValue()) { |
| replaceOpWithRegion(rewriter, ifOp, ifOp.true_branch()); |
| } else { |
| replaceOpWithRegion(rewriter, ifOp, ifOp.false_branch()); |
| } |
| return success(); |
| } |
| |
| void IfOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add(&inlineIfConstantCondition); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Case Op |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult CaseOp::verify() { |
| auto numBranches = branches().size(); |
| |
| for (unsigned i = 0; i < numBranches; ++i) |
| if (failed(verifyConditionalBranch(*this, branches()[i], |
| /*branchName=*/"branch " + Twine(i)))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| static LogicalResult inlineCaseConstantCondition(CaseOp caseOp, |
| PatternRewriter& rewriter) { |
| DenseIntElementsAttr indexAttr; |
| if (!matchPattern(caseOp.index(), m_Constant(&indexAttr))) { |
| return failure(); |
| } |
| int64_t index = |
| indexAttr.getSplatValue<IntegerAttr>().getValue().getSExtValue(); |
| // For an OOB index, the last branch is executed as the default branch: |
| // https://www.tensorflow.org/xla/operation_semantics#conditional |
| if (index < 0 || index >= caseOp.getNumRegions()) |
| index = caseOp.getNumRegions() - 1; |
| |
| Region& region = caseOp.getRegion(index); |
| if (!llvm::hasSingleElement(region)) return failure(); |
| replaceOpWithRegion(rewriter, caseOp, region); |
| return success(); |
| } |
| |
| void CaseOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add(&inlineCaseConstantCondition); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SqrtOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) { |
| auto val = operands[0].dyn_cast_or_null<DenseElementsAttr>(); |
| if (!val) return {}; |
| |
| auto type = getElementTypeOrSelf(getType()); |
| if (!type.isF32() && !type.isF64()) return {}; |
| |
| auto shapedType = getType().cast<ShapedType>(); |
| if (!shapedType.hasStaticShape()) return {}; |
| |
| // Prevent folding if the result is too large. |
| if (val.getNumElements() > kFoldOpEltLimit) return {}; |
| |
| int bitWidth = type.getIntOrFloatBitWidth(); |
| llvm::SmallVector<APFloat, 4> values; |
| values.reserve(val.getNumElements()); |
| for (auto it : val.getValues<APFloat>()) { |
| double value = bitWidth == 32 ? it.convertToFloat() : it.convertToDouble(); |
| if (value < 0) return {}; |
| value = std::sqrt(value); |
| if (bitWidth == 32) |
| values.emplace_back(static_cast<float>(value)); |
| else |
| values.emplace_back(value); |
| } |
| return DenseFPElementsAttr::get(shapedType, values); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // UnaryOps |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult parseUnaryOp(OpAsmParser& parser, OperationState& result) { |
| SmallVector<OpAsmParser::UnresolvedOperand> operands; |
| Type type; |
| // If the operand is in-between parentheses, use generic form. |
| SMLoc loc = parser.getCurrentLocation(); |
| if (!parser.parseOptionalLParen()) { |
| if (parser.parseOperandList(operands) || parser.parseRParen() || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColon() || parser.parseType(type)) |
| return failure(); |
| auto fnType = type.dyn_cast<FunctionType>(); |
| if (!fnType) { |
| parser.emitError(loc, "expected function type"); |
| return failure(); |
| } |
| if (parser.resolveOperands(operands, fnType.getInputs(), loc, |
| result.operands)) |
| return failure(); |
| result.addTypes(fnType.getResults()); |
| return success(); |
| } |
| // Otherwise, use shorthand syntax. |
| return failure(parser.parseOperandList(operands) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(type) || |
| parser.resolveOperands(operands, type, result.operands) || |
| parser.addTypeToList(type, result.types)); |
| } |
| |
| void printUnaryOp(Operation* op, OpAsmPrinter& p) { |
| assert(op->getNumResults() == 1 && "op should have one result"); |
| assert(op->getNumOperands() == 1 && "op should have one input"); |
| // If not all types are the same, use generic form. |
| auto resultType = op->getResult(0).getType(); |
| if (resultType != op->getOperandTypes()[0]) { |
| p.printGenericOp(op, /*printOpName=*/false); |
| return; |
| } |
| // Otherwise, use the shorthand syntax. |
| p << ' '; |
| p.printOperands(op->getOperands()); |
| p.printOptionalAttrDict(op->getAttrs()); |
| p << " : " << resultType; |
| } |
| |
| template <typename Op, typename ElementType = Type, typename ValType, |
| typename Convert> |
| static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) { |
| if (!attrs[0]) return {}; |
| |
| DenseElementsAttr val = attrs[0].dyn_cast<DenseElementsAttr>(); |
| if (!val) return {}; |
| |
| ShapedType type = op->getType().template cast<ShapedType>(); |
| if (!type.hasStaticShape()) { |
| return {}; |
| } |
| |
| Type etype = type.getElementType(); |
| |
| // Evaluate for integer values. |
| if (!etype.isa<ElementType>()) { |
| return {}; |
| } |
| |
| // Prevent folding if the result is too large. |
| if (val.getNumElements() > kFoldOpEltLimit) return {}; |
| |
| SmallVector<ValType, 6> values; |
| values.reserve(val.getNumElements()); |
| for (const auto v : val.getValues<ValType>()) { |
| values.push_back(Convert()(v)); |
| } |
| |
| return DenseElementsAttr::get(type, values); |
| } |
| |
| struct Round { |
| APFloat operator()(const APFloat& f) { |
| APFloat r = f; |
| r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway); |
| return r; |
| } |
| }; |
| |
| struct RoundNearestEven { |
| APFloat operator()(const APFloat& f) { |
| APFloat r = f; |
| r.roundToIntegral(llvm::RoundingMode::NearestTiesToEven); |
| return r; |
| } |
| }; |
| |
| struct LogicalNot { |
| APInt operator()(const APInt& i) { |
| return APInt(i.getBitWidth(), static_cast<uint64_t>(!i)); |
| } |
| }; |
| |
| template <typename FloatOrInt> |
| struct Sign { |
| APFloat compute(const APFloat& f) { |
| if (f.isZero() || f.isNaN()) return f; |
| double value = f.isNegative() ? -1.0 : 1.0; |
| APFloat val(value); |
| bool unused; |
| val.convert(f.getSemantics(), APFloat::rmNearestTiesToEven, &unused); |
| return val; |
| } |
| |
| APInt compute(const APInt& i) { |
| APInt r = i; |
| if (r == 0) return r; |
| if (r.isNegative()) { |
| return APInt(r.getBitWidth(), -1, /*isSigned=*/true); |
| } |
| return APInt(r.getBitWidth(), 1, /*isSigned=*/true); |
| } |
| |
| FloatOrInt operator()(const FloatOrInt& fi) { return compute(fi); } |
| }; |
| |
| #define UNARY_FOLDER(Op, Func) \ |
| OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \ |
| if (getElementTypeOrSelf(getType()).isa<FloatType>()) \ |
| return UnaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \ |
| if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \ |
| return UnaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs); \ |
| return {}; \ |
| } |
| |
| #define UNARY_FOLDER_INT(Op, Func) \ |
| OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \ |
| if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \ |
| return UnaryFolder<Op, IntegerType, APInt, Func>(this, attrs); \ |
| return {}; \ |
| } |
| |
| #define UNARY_FOLDER_FLOAT(Op, Func) \ |
| OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \ |
| if (getElementTypeOrSelf(getType()).isa<FloatType>()) \ |
| return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \ |
| return {}; \ |
| } |
| |
| UNARY_FOLDER(NegOp, std::negate); |
| UNARY_FOLDER(SignOp, Sign); |
| UNARY_FOLDER_INT(NotOp, LogicalNot); |
| UNARY_FOLDER_FLOAT(RoundNearestEvenOp, RoundNearestEven); |
| UNARY_FOLDER_FLOAT(RoundOp, Round); |
| |
| #undef UNARY_FOLDER |
| #undef UNARY_FOLDER_INT |
| #undef UNARY_FOLDER_FLOAT |
| |
| //===----------------------------------------------------------------------===// |
| // BinaryOps |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult parseBinaryOp(OpAsmParser& parser, OperationState& result) { |
| SmallVector<OpAsmParser::UnresolvedOperand> operands; |
| Type type; |
| // If the operand list is in-between parentheses, use generic form. |
| SMLoc loc = parser.getCurrentLocation(); |
| if (!parser.parseOptionalLParen()) { |
| if (parser.parseOperandList(operands) || parser.parseRParen() || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColon() || parser.parseType(type)) |
| return failure(); |
| auto fnType = type.dyn_cast<FunctionType>(); |
| if (!fnType) { |
| parser.emitError(loc, "expected function type"); |
| return failure(); |
| } |
| if (parser.resolveOperands(operands, fnType.getInputs(), loc, |
| result.operands)) |
| return failure(); |
| result.addTypes(fnType.getResults()); |
| return success(); |
| } |
| // Otherwise, use shorthand syntax. |
| return failure(parser.parseOperandList(operands) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(type) || |
| parser.resolveOperands(operands, type, result.operands) || |
| parser.addTypeToList(type, result.types)); |
| } |
| |
| void printBinaryOp(Operation* op, OpAsmPrinter& p) { |
| assert(op->getNumResults() == 1 && "op should have one result"); |
| // If not all types are the same, use generic form. |
| auto resultType = op->getResult(0).getType(); |
| if (llvm::any_of(op->getOperandTypes(), |
| [&](Type type) { return type != resultType; })) { |
| p.printGenericOp(op, /*printOpName=*/false); |
| return; |
| } |
| // Otherwise, use the shorthand syntax. |
| p << ' '; |
| p.printOperands(op->getOperands()); |
| p.printOptionalAttrDict(op->getAttrs()); |
| p << " : " << resultType; |
| } |
| |
| static const APFloat& addSign(const APFloat& v, Type) { return v; } |
| static APSInt addSign(const APInt& v, Type t) { |
| // Add signedness information to the value, treating signless as signed. |
| return APSInt(v, t.isUnsignedInteger()); |
| } |
| |
| template <typename Op, typename ElementType = Type, typename ValType, |
| typename Convert> |
| static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) { |
| if (!attrs[0] || !attrs[1]) return {}; |
| |
| DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>(); |
| DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>(); |
| if (!lhs || !rhs) return {}; |
| |
| ShapedType type = op->getType().template cast<ShapedType>(); |
| if (!type.hasStaticShape()) { |
| return {}; |
| } |
| |
| Type etype = type.getElementType(); |
| |
| // Evaluate for integer values. |
| if (!etype.isa<ElementType>()) { |
| return {}; |
| } |
| |
| // Special case for folding splats no matter how large. |
| // Only covers the case of both attrs being splats; operation-specific cases |
| // like adding a zero or multiplying by one are handled elsewhere. |
| SplatElementsAttr splatLhs = lhs.dyn_cast<SplatElementsAttr>(); |
| SplatElementsAttr splatRhs = rhs.dyn_cast<SplatElementsAttr>(); |
| if (splatLhs && splatRhs) { |
| auto signedLhs = addSign(splatLhs.getSplatValue<ValType>(), etype); |
| auto signedRhs = addSign(splatRhs.getSplatValue<ValType>(), etype); |
| FailureOr<decltype(signedLhs)> result(Convert()(signedLhs, signedRhs)); |
| return succeeded(result) ? SplatElementsAttr::get(type, *result) |
| : Attribute(); |
| } |
| |
| // Prevent folding if the result is too large. |
| if (lhs.getNumElements() > kFoldOpEltLimit) return {}; |
| |
| SmallVector<ValType, 6> values; |
| values.reserve(lhs.getNumElements()); |
| for (const auto zip : |
| llvm::zip(lhs.getValues<ValType>(), rhs.getValues<ValType>())) { |
| auto signedLhs = addSign(std::get<0>(zip), etype); |
| auto signedRhs = addSign(std::get<1>(zip), etype); |
| FailureOr<decltype(signedLhs)> result(Convert()(signedLhs, signedRhs)); |
| if (failed(result)) { |
| return {}; |
| } |
| values.push_back(std::move(*result)); |
| } |
| |
| return DenseElementsAttr::get(type, values); |
| } |
| |
| template <typename T> |
| struct Divide : std::divides<T> {}; |
| |
| template <> |
| struct Divide<APSInt> { |
| FailureOr<APSInt> operator()(const APSInt& a, const APSInt& b) const { |
| if (b.isZero()) return failure(); |
| return a / b; |
| } |
| }; |
| |
| template <typename T> |
| struct Remainder : std::modulus<T> {}; |
| |
| template <> |
| struct Remainder<APSInt> { |
| FailureOr<APSInt> operator()(const APSInt& a, const APSInt& b) const { |
| if (b.isZero()) return failure(); |
| return a % b; |
| } |
| }; |
| |
| template <> |
| struct Remainder<APFloat> { |
| APFloat operator()(const APFloat& a, const APFloat& b) const { |
| APFloat result(a); |
| result.remainder(b); |
| return result; |
| } |
| }; |
| |
| template <typename T> |
| struct Max { |
| T operator()(const T& a, const T& b) const { return std::max<T>(a, b); } |
| }; |
| |
| template <typename T> |
| struct Min { |
| T operator()(const T& a, const T& b) const { return std::min<T>(a, b); } |
| }; |
| |
| template <typename T> |
| struct And { |
| T operator()(const T& a, const T& b) const { return a & b; } |
| }; |
| |
| template <typename T> |
| struct Or { |
| T operator()(const T& a, const T& b) const { return a | b; } |
| }; |
| |
| template <typename T> |
| struct Xor { |
| T operator()(const T& a, const T& b) const { return a ^ b; } |
| }; |
| |
| #define BINARY_FOLDER_INTERNAL(Op, Func) \ |
| if (getElementTypeOrSelf(getType()).isa<FloatType>()) \ |
| return BinaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \ |
| if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \ |
| return BinaryFolder<Op, IntegerType, APInt, Func<APSInt>>(this, attrs); \ |
| return {}; |
| |
| #define BINARY_FOLDER(Op, Func) \ |
| OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \ |
| BINARY_FOLDER_INTERNAL(Op, Func) \ |
| } |
| |
| // Addition, subtraction and multiplication use the std:: versions of the ops. |
| // Due to the other ops behaving differently in signed vs unsigned integers, |
| // APInts need a special implementation. Currently, it replicates signed int |
| // op behavior. |
| BINARY_FOLDER(SubtractOp, std::minus); |
| BINARY_FOLDER(DivOp, Divide); |
| BINARY_FOLDER(RemOp, Remainder); |
| BINARY_FOLDER(MaxOp, Max); |
| BINARY_FOLDER(MinOp, Min); |
| |
| bool isSplatZero(SplatElementsAttr attr) { |
| if (!attr) return false; |
| if (attr.getElementType().isa<FloatType>()) { |
| return attr.getSplatValue<APFloat>().isZero(); |
| } |
| if (attr.getElementType().isa<IntegerType>()) { |
| return attr.getSplatValue<APInt>().isZero(); |
| } |
| return false; |
| } |
| |
| OpFoldResult AddOp::fold(ArrayRef<Attribute> attrs) { |
| // Handle special case where one operand is 0: x + 0 => x |
| if (attrs[0] || attrs[1]) { |
| SplatElementsAttr splatLhs = attrs[0].dyn_cast_or_null<SplatElementsAttr>(); |
| SplatElementsAttr splatRhs = attrs[1].dyn_cast_or_null<SplatElementsAttr>(); |
| if (isSplatZero(splatLhs)) return splatRhs ? (OpFoldResult)splatRhs : rhs(); |
| if (isSplatZero(splatRhs)) return splatLhs ? (OpFoldResult)splatLhs : lhs(); |
| } |
| if (attrs[0] && attrs[1]) { |
| BINARY_FOLDER_INTERNAL(AddOp, std::plus) |
| } |
| return {}; |
| } |
| |
| bool isSplatOne(SplatElementsAttr attr) { |
| if (!attr) return false; |
| if (attr.getElementType().isa<FloatType>()) { |
| return attr.getSplatValue<APFloat>().convertToDouble() == 1.0; |
| } |
| if (attr.getElementType().isa<IntegerType>()) { |
| return attr.getSplatValue<APInt>().getSExtValue() == 1; |
| } |
| return false; |
| } |
| |
| OpFoldResult MulOp::fold(ArrayRef<Attribute> attrs) { |
| // Handle special case where one operand is 1: x * 1 => x |
| if (attrs[0] || attrs[1]) { |
| SplatElementsAttr splatLhs = attrs[0].dyn_cast_or_null<SplatElementsAttr>(); |
| SplatElementsAttr splatRhs = attrs[1].dyn_cast_or_null<SplatElementsAttr>(); |
| if (isSplatOne(splatLhs)) return splatRhs ? (OpFoldResult)splatRhs : rhs(); |
| if (isSplatOne(splatRhs)) return splatLhs ? (OpFoldResult)splatLhs : lhs(); |
| } |
| if (attrs[0] && attrs[1]) { |
| BINARY_FOLDER_INTERNAL(MulOp, std::multiplies); |
| } |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Logical Ops |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) { |
| if (lhs() == rhs()) return lhs(); |
| |
| auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>(); |
| auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>(); |
| |
| if (lhsVal && lhsVal.isSplat()) { |
| if (lhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) { |
| return rhs(); |
| } |
| |
| if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) { |
| return lhsVal; |
| } |
| } |
| |
| if (rhsVal && rhsVal.isSplat()) { |
| if (rhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) { |
| return lhs(); |
| } |
| |
| if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) { |
| return rhsVal; |
| } |
| } |
| |
| if (!rhsVal || !lhsVal) return {}; |
| return BinaryFolder<AndOp, IntegerType, APInt, And<APSInt>>(this, operands); |
| } |
| |
| OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) { |
| if (lhs() == rhs()) return lhs(); |
| |
| auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>(); |
| auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>(); |
| |
| if (lhsVal && lhsVal.isSplat()) { |
| if (lhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) { |
| return lhsVal; |
| } |
| |
| if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) { |
| return rhs(); |
| } |
| } |
| |
| if (rhsVal && rhsVal.isSplat()) { |
| if (rhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) { |
| return rhsVal; |
| } |
| |
| if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) { |
| return lhs(); |
| } |
| } |
| |
| if (!rhsVal || !lhsVal) return {}; |
| return BinaryFolder<OrOp, IntegerType, APInt, Or<APSInt>>(this, operands); |
| } |
| |
| OpFoldResult XorOp::fold(ArrayRef<Attribute> operands) { |
| // Fold x^x to 0. Attributes only support static shapes. |
| auto rType = getType().cast<ShapedType>(); |
| if (lhs() == rhs() && rType.hasStaticShape()) { |
| Builder builder(getContext()); |
| return builder.getZeroAttr(rType); |
| } |
| |
| auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>(); |
| auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>(); |
| |
| if (lhsVal && lhsVal.isSplat()) { |
| if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) { |
| return rhs(); |
| } |
| } |
| |
| if (rhsVal && rhsVal.isSplat()) { |
| if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) { |
| return lhs(); |
| } |
| } |
| |
| if (!rhsVal || !lhsVal) return {}; |
| return BinaryFolder<XorOp, IntegerType, APInt, Xor<APSInt>>(this, operands); |
| } |
| |
| #undef BINARY_FOLDER_INTERNAL |
| #undef BINARY_FOLDER |
| |
| //===----------------------------------------------------------------------===// |
| // SliceOp |
| //===----------------------------------------------------------------------===// |
| |
| // Returns output dimension size for slice result for the given arguments. |
| // Returns -1 if arguments are illegal. |
| static int64_t inferSliceDim(int64_t inputDim, int64_t start, int64_t end, |
| int64_t stride) { |
| if (inputDim == -1 || start < 0 || start > end || end > inputDim || |
| stride == 0) |
| return -1; |
| |
| return llvm::divideCeil(end - start, stride); |
| } |
| |
| // The following properties are already enforced by the ODS: |
| // type(start_indices) == type(limit_indices) == type(strides). |
| // Verify the following properties: |
| // P1. Verify rank(start_indices) == 1. |
| // P2. Verify size(start_indices) == rank(operand). |
| // P3~5. Verify 0 <= start_indices[i] <= limit_indices[i] <= shape(operand)[i]. |
| // P6. Verify stride[i] > 0. |
| LogicalResult SliceOp::inferReturnTypes( |
| MLIRContext* context, Optional<Location> location, ValueRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<Type>& inferredReturnTypes) { |
| SliceOpAdaptor slice(operands, attributes); |
| Type ty = slice.operand().getType(); |
| RankedTensorType rankedTy = ty.dyn_cast<RankedTensorType>(); |
| if (!rankedTy) { |
| // The operand type is unranked, so the best we can infer for the result |
| // type is an unranked tensor with the same element type as the operand |
| // type. |
| inferredReturnTypes.assign({ty}); |
| return success(); |
| } |
| |
| ShapedType attrTy = slice.start_indices().getType(); |
| // P1. |
| // Note: ODS has type(start_indices) == type(limit_indices) == type(strides) |
| // So this implies rank(limit_indices) == rank(strides) == 1 also. |
| if (attrTy.getRank() != 1) { |
| return emitOptionalError(location, "start_indices has rank ", |
| attrTy.getRank(), " instead of required rank 1"); |
| } |
| |
| // P2. |
| int64_t rank = rankedTy.getRank(); |
| if (attrTy.getNumElements() != rank) { |
| return emitOptionalError( |
| location, "the number of elements in start_indices (", |
| attrTy.getNumElements(), ") does not match the rank of the operand (", |
| rank, ")"); |
| } |
| |
| SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>()); |
| SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>()); |
| SmallVector<int64_t, 4> strideVals(slice.strides().getValues<int64_t>()); |
| |
| SmallVector<int64_t, 4> shape; |
| shape.reserve(rank); |
| for (int64_t i = 0, e = rank; i != e; i++) { |
| if (isDynamicDimSize(rankedTy.getDimSize(i))) { |
| shape.push_back(ShapedType::kDynamicSize); |
| continue; |
| } |
| // P3. |
| if (start[i] < 0) |
| return emitOptionalError(location, "negative start index ", start[i], |
| " in dimension ", i); |
| // P4. |
| if (limit[i] > rankedTy.getDimSize(i)) |
| return emitOptionalError(location, "limit index ", limit[i], |
| " is larger than dimension size ", |
| rankedTy.getDimSize(i), " in dimension ", i); |
| // P5. |
| if (start[i] > limit[i]) |
| return emitOptionalError(location, "start index ", start[i], |
| " is larger than limit index ", limit[i], |
| " in dimension ", i); |
| // P6. |
| if (strideVals[i] <= 0) |
| return emitOptionalError(location, "stride must be positive but got ", |
| strideVals[i], " in dimension ", i); |
| |
| shape.push_back(inferSliceDim(rankedTy.getDimSize(i), start[i], limit[i], |
| strideVals[i])); |
| } |
| inferredReturnTypes.assign( |
| {RankedTensorType::get(shape, rankedTy.getElementType())}); |
| return success(); |
| } |
| |
| template <typename I, typename E> |
| static void sliceElements(I values, ArrayRef<int64_t> sizes, |
| ArrayRef<int64_t> starts, ArrayRef<int64_t> limits, |
| ArrayRef<int64_t> strides, |
| llvm::SmallVectorImpl<E>* outValues) { |
| assert(starts.size() == limits.size()); |
| assert(starts.size() == strides.size()); |
| if (starts.empty()) return; |
| |
| int64_t start = starts.front(); |
| int64_t limit = limits.front(); |
| int64_t stride = strides.front(); |
| if (starts.size() == 1) { |
| for (int i = start; i < limit; i += stride) { |
| outValues->push_back(*(values + i)); |
| } |
| return; |
| } |
| |
| for (; start < limit; start += stride) { |
| auto begin = values + start * sizes.front(); |
| sliceElements<I, E>(begin, sizes.drop_front(), starts.drop_front(), |
| limits.drop_front(), strides.drop_front(), outValues); |
| } |
| } |
| |
| template <typename I, typename E> |
| static Attribute foldSlice(SliceOp* op, I values) { |
| auto start = llvm::to_vector<6>(op->start_indices().getValues<int64_t>()); |
| auto limit = llvm::to_vector<6>(op->limit_indices().getValues<int64_t>()); |
| auto stride = llvm::to_vector<6>(op->strides().getValues<int64_t>()); |
| |
| // TODO(b/235903849): This should be op->getType().case<ShapedType>(). |
| auto resultType = op->operand().getType().cast<ShapedType>(); |
| if (!resultType.hasStaticShape()) return {}; |
| |
| auto shape = resultType.getShape(); |
| int64_t count = resultType.getNumElements(); |
| if (count == 0) { |
| return DenseElementsAttr::get<E>( |
| op->getResult().getType().cast<ShapedType>(), |
| /*list=*/{}); |
| } |
| |
| // Compute the striding for each dimension. |
| llvm::SmallVector<int64_t, 6> sizes; |
| sizes.reserve(shape.size()); |
| for (auto v : shape) { |
| count = count / v; |
| sizes.push_back(count); |
| } |
| |
| // Prevent folding if the result is too large. |
| if (resultType.getNumElements() > kFoldOpEltLimit) return {}; |
| |
| llvm::SmallVector<E, 6> outValues; |
| outValues.reserve(resultType.getNumElements()); |
| sliceElements<I, E>(values, sizes, start, limit, stride, &outValues); |
| |
| return DenseElementsAttr::get(op->getResult().getType().cast<ShapedType>(), |
| outValues); |
| } |
| |
| OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) { |
| // Check if the SliceOp is a NoOp operation. |
| auto operandType = getOperand().getType().cast<ShapedType>(); |
| auto resultType = getResult().getType().cast<ShapedType>(); |
| |
| if (operandType.hasStaticShape() && resultType.hasStaticShape() && |
| (operandType.getShape() == resultType.getShape())) { |
| return getOperand(); |
| } |
| |
| if (operands.empty() || !operands.front()) return {}; |
| |
| // Evaluate for statically valued inputs. |
| DenseElementsAttr elements = operands.front().dyn_cast<DenseElementsAttr>(); |
| if (!elements) return {}; |
| |
| auto etype = elements.getType().getElementType(); |
| if (etype.isa<IntegerType>()) { |
| return foldSlice<DenseElementsAttr::IntElementIterator, APInt>( |
| this, elements.value_begin<APInt>()); |
| } |
| if (etype.isa<FloatType>()) { |
| return foldSlice<DenseElementsAttr::FloatElementIterator, APFloat>( |
| this, elements.value_begin<APFloat>()); |
| } |
| |
| return {}; |
| } |
| |
| namespace { |
| // In cases where a concat is fed into a slice, it is possible the concat |
| // can be simplified or bypassed. This checks which inputs to the concat are |
| // used by the slice, either reducing the number of concatenated values or |
| // entirely removes the concat. |
| struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> { |
| using OpRewritePattern<SliceOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(SliceOp slice, |
| PatternRewriter& rewriter) const override { |
| auto resultTy = slice.getType().cast<ShapedType>(); |
| if (!resultTy.hasStaticShape()) { |
| return failure(); |
| } |
| |
| auto sliceInput = slice.operand(); |
| auto sliceInputTy = sliceInput.getType().cast<ShapedType>(); |
| auto concat = sliceInput.getDefiningOp<ConcatenateOp>(); |
| if (!concat) { |
| return failure(); |
| } |
| |
| auto dimension = concat.dimension(); |
| |
| auto start = slice.start_indices().getValues<APInt>(); |
| auto limit = slice.limit_indices().getValues<APInt>(); |
| |
| auto sliceStart = (*(start.begin() + dimension)).getSExtValue(); |
| auto sliceLimit = (*(limit.begin() + dimension)).getSExtValue(); |
| |
| // We need to determine what inputs from the concat affect the slice, and |
| // how the bounds of the slice need to be updated for the minimally required |
| // inputs. |
| int64_t runningSize = 0; |
| int64_t frontOffset = sliceInputTy.getShape()[dimension]; |
| |
| auto subsetStart = concat.operand_end(); |
| auto subsetEnd = concat.operand_end(); |
| for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) { |
| auto input = *it; |
| ShapedType inputTy = input.getType().cast<ShapedType>(); |
| if (inputTy.isDynamicDim(dimension)) { |
| return failure(); |
| } |
| auto dimSize = inputTy.getShape()[dimension]; |
| |
| // If this position is in the slice its the start of the subset and we |
| // need to update the start and limit values. |
| if (runningSize + dimSize > sliceStart && |
| subsetStart == concat.operand_end()) { |
| subsetStart = it; |
| frontOffset = runningSize; |
| } |
| |
| // Determine the last required offset. |
| if (runningSize < sliceLimit) { |
| subsetEnd = it + 1; |
| } |
| |
| runningSize += dimSize; |
| } |
| |
| auto subsetSize = subsetEnd - subsetStart; |
| // We need all inputs so no optimization. |
| if (subsetSize == concat.getNumOperands()) { |
| return failure(); |
| } |
| |
| // If there's nothing to slice that means the output is an empty tensor and |
| // there is dead code. We do nothing here and rely on other passes to clean |
| // this up. |
| if (subsetSize == 0) { |
| return failure(); |
| } |
| |
| if (subsetSize > 1 && !concat.getResult().hasOneUse()) { |
| return failure(); |
| } |
| |
| auto concatRange = OperandRange(subsetStart, subsetEnd); |
| auto newConcat = rewriter.create<ConcatenateOp>( |
| concat.getLoc(), concatRange, concat.dimension()); |
| |
| llvm::SmallVector<APInt, 6> newStart(start); |
| llvm::SmallVector<APInt, 6> newLimit(limit); |
| newStart[dimension] -= frontOffset; |
| newLimit[dimension] -= frontOffset; |
| |
| auto attrType = slice.start_indices().getType().cast<ShapedType>(); |
| auto create = rewriter.create<SliceOp>( |
| slice.getLoc(), newConcat, |
| DenseIntElementsAttr::get(attrType, newStart), |
| DenseIntElementsAttr::get(attrType, newLimit), slice.strides()); |
| rewriter.replaceOp(slice, create.getResult()); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void SliceOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<SimplifyConcatSlice>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SortOp |
| //===----------------------------------------------------------------------===// |
| |
| void SortOp::build(OpBuilder& builder, OperationState& state, |
| ValueRange operands, int64_t dimension, bool isStable) { |
| state.addOperands(operands); |
| state.addAttribute("dimension", builder.getI64IntegerAttr(dimension)); |
| state.addAttribute("is_stable", builder.getBoolAttr(isStable)); |
| |
| for (Value operand : operands) state.addTypes(operand.getType()); |
| |
| state.addRegion(); |
| } |
| |
| LogicalResult SortOp::verify() { |
| Operation::operand_range operands = this->operands(); |
| if (operands.empty()) return emitOpError("requires at least one input"); |
| |
| // TODO(antiagainst): verify partionally dynamic shapes |
| if (llvm::all_of(operands, [](Value operand) { |
| return operand.getType().cast<ShapedType>().hasRank(); |
| })) { |
| ArrayRef<int64_t> inputShape = |
| (*operands.begin()).getType().cast<ShapedType>().getShape(); |
| |
| if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) { |
| return operand.getType().cast<ShapedType>().getShape() != inputShape; |
| })) |
| return emitOpError("requires all inputs to have the same dimensions"); |
| |
| int64_t rank = inputShape.size(); |
| int64_t cmpDim = dimension(); |
| if (cmpDim < -rank || cmpDim >= rank) |
| return emitOpError("dimension attribute value must be in range [-") |
| << rank << ", " << rank << "), but found " << cmpDim; |
| } |
| |
| Block& block = comparator().front(); |
| size_t numOperands = getOperation()->getNumOperands(); |
| if (block.getNumArguments() != 2 * numOperands) |
| return emitOpError("comparator block should have ") |
| << 2 * numOperands << " arguments"; |
| |
| for (const auto& indexedOperand : llvm::enumerate(operands)) { |
| int index = indexedOperand.index(); |
| Type elementType = |
| indexedOperand.value().getType().cast<ShapedType>().getElementType(); |
| Type tensorType = RankedTensorType::get({}, elementType); |
| for (int i : {2 * index, 2 * index + 1}) { |
| Type argType = block.getArgument(i).getType(); |
| if (argType != tensorType) |
| return emitOpError("comparator block argument #") |
| << i << " should be of type " << tensorType << " but got " |
| << argType; |
| } |
| } |
| |
| // Mapped computation must return single output. |
| auto comparatorResult = block.getTerminator()->getOperands(); |
| if (comparatorResult.size() != 1) |
| return emitOpError() << "comparator must return single output, but got: " |
| << comparatorResult.size(); |
| |
| // The output of computation must be 0-ranked tensor with element-type i1. |
| auto comparatorResultType = |
| comparatorResult[0].getType().dyn_cast<RankedTensorType>(); |
| if (!comparatorResultType || comparatorResultType.getRank() != 0 || |
| !comparatorResultType.getElementType().isInteger(1)) |
| return emitOpError() << "comparator must return tensor<i1>, but got: " |
| << comparatorResult[0].getType(); |
| |
| // check number of return-values and their element-types. |
| auto resultTypes = getResultTypes(); |
| if (resultTypes.size() != numOperands) |
| return emitOpError() << "expects the number of results to be same as " |
| "number of operands. Got number of results = " |
| << resultTypes.size() |
| << " and number of operands = " << numOperands; |
| |
| for (auto it : llvm::zip(operands, getResultTypes())) |
| if (std::get<0>(it).getType().cast<TensorType>().getElementType() != |
| std::get<1>(it).cast<TensorType>().getElementType()) |
| return emitOpError() |
| << "expects the operands and results to have pairwize equal " |
| "element-types, but got " |
| << std::get<0>(it).getType().cast<TensorType>().getElementType() |
| << " vs " << std::get<1>(it).cast<TensorType>().getElementType(); |
| |
| return success(); |
| } |
| |
| /// Drops the operands if the results are not used and they are not used in |
| /// op.comparator(). |
| static LogicalResult sortDropEmptyUseArgs(SortOp op, |
| PatternRewriter& rewriter) { |
| DenseSet<unsigned> erasedArgs; |
| unsigned numOperands = op.getNumOperands(); |
| for (unsigned i = 0; i < numOperands; ++i) { |
| if (!op.getResult(i).use_empty()) continue; |
| Block& block = op.comparator().front(); |
| if (!block.getArgument(i * 2).use_empty()) continue; |
| if (!block.getArgument(i * 2 + 1).use_empty()) continue; |
| erasedArgs.insert(i); |
| } |
| if (erasedArgs.empty()) return failure(); |
| |
| SmallVector<Value> newOperands; |
| SmallVector<unsigned> erasedBlockArgs; |
| for (const auto& en : llvm::enumerate(op.operands())) { |
| if (erasedArgs.contains(en.index())) { |
| erasedBlockArgs.push_back(en.index() * 2); |
| erasedBlockArgs.push_back(en.index() * 2 + 1); |
| } else { |
| newOperands.push_back(en.value()); |
| } |
| } |
| |
| auto newOp = rewriter.create<SortOp>(op.getLoc(), newOperands, op.dimension(), |
| op.is_stable()); |
| Region& region = newOp.comparator(); |
| rewriter.inlineRegionBefore(op.comparator(), region, region.end()); |
| region.front().eraseArguments(erasedBlockArgs); |
| |
| SmallVector<Value> results; |
| for (unsigned i = 0, j = 0; i < numOperands; ++i) { |
| if (erasedArgs.contains(i)) { |
| results.push_back({}); |
| } else { |
| results.push_back(newOp.getResult(j++)); |
| } |
| } |
| rewriter.replaceOp(op, results); |
| |
| return success(); |
| } |
| |
| /// Set the sorting dimension to the last dimension if it's not set and the rank |
| /// is known. |
| static LogicalResult sortOpInferDefaultDimension(SortOp op, |
| PatternRewriter& rewriter) { |
| auto ty = op.getResultTypes()[0].dyn_cast<ShapedType>(); |
| if (!ty) { |
| return failure(); |
| } |
| if (static_cast<int64_t>(op.dimension()) != -1) { |
| return failure(); |
| } |
| |
| IntegerAttr dim = rewriter.getI64IntegerAttr(ty.getRank() - 1); |
| auto newOp = rewriter.create<SortOp>(op.getLoc(), op.getResultTypes(), |
| op.operands(), dim, op.is_stableAttr()); |
| Region& region = newOp.comparator(); |
| rewriter.inlineRegionBefore(op.comparator(), region, region.end()); |
| rewriter.replaceOp(op, newOp.getResults()); |
| |
| return success(); |
| } |
| |
| void SortOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* /*context*/) { |
| results.add(sortDropEmptyUseArgs); |
| results.add(sortOpInferDefaultDimension); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TransposeOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) { |
| if (auto elements = operands.front().dyn_cast_or_null<SplatElementsAttr>()) { |
| return reshape(elements, getResult().getType().cast<ShapedType>()); |
| } |
| for (const auto& it : llvm::enumerate(permutation().getValues<APInt>())) { |
| if (it.index() != it.value()) { |
| return {}; |
| } |
| } |
| return getOperand(); |
| } |
| |
| // transpose(transpose(X)) => transpose(X) |
| static LogicalResult eliminateRedundantTranspse(TransposeOp op, |
| PatternRewriter& rewriter) { |
| auto tranposeOperand = op.operand().getDefiningOp<TransposeOp>(); |
| if (!tranposeOperand) { |
| return failure(); |
| } |
| auto operandPermutation = tranposeOperand.permutation().getValues<APInt>(); |
| auto newPermutation = |
| op.permutation() |
| .mapValues(op.permutation().getElementType(), |
| [&operandPermutation](const APInt& index) -> APInt { |
| return operandPermutation[index.getSExtValue()]; |
| }) |
| .cast<DenseIntElementsAttr>(); |
| rewriter.replaceOpWithNewOp<TransposeOp>( |
| op, op.getResult().getType(), tranposeOperand.operand(), newPermutation); |
| return success(); |
| } |
| |
| // transpose(broadcast_in_dim(X)) => broadcast_in_dim(X) |
| static LogicalResult eliminateBroadcastInDimTranspose( |
| TransposeOp op, PatternRewriter& rewriter) { |
| auto broadcastInDimOp = op.operand().getDefiningOp<BroadcastInDimOp>(); |
| if (!broadcastInDimOp) { |
| return failure(); |
| } |
| DenseIntElementsAttr broadcastDimensions = |
| broadcastInDimOp.broadcast_dimensions(); |
| DenseIntElementsAttr permutation = op.permutation(); |
| SmallVector<int64_t> newBroadcastDimensions; |
| for (auto dimension : broadcastDimensions.getValues<int64_t>()) { |
| int64_t index = 0; |
| for (auto p : permutation.getValues<int64_t>()) { |
| if (p == dimension) { |
| newBroadcastDimensions.push_back(index); |
| break; |
| } |
| index++; |
| } |
| } |
| rewriter.replaceOpWithNewOp<BroadcastInDimOp>( |
| op, op->getResultTypes(), broadcastInDimOp.operand(), |
| rewriter.getI64TensorAttr(newBroadcastDimensions)); |
| return success(); |
| } |
| |
| // simplify Transpose: replace Transpose with Reshape if they are equivalent |
| static LogicalResult simplifyTranspose(TransposeOp op, |
| PatternRewriter& rewriter) { |
| auto operandType = op.operand().getType().dyn_cast<RankedTensorType>(); |
| auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>(); |
| if (!operandType || !resultType) { |
| return failure(); |
| } |
| // Not support dynamic shape a.t.m. BTW, when it's dynamic shape, |
| // maybe Transpose should be replaced by DynamicReshape. |
| if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) { |
| return failure(); |
| } |
| auto permutation = op.permutation().getValues<int64_t>(); |
| llvm::SmallVector<int64_t> sortedPermutation; |
| for (int64_t i = 0, e = resultType.getRank(); i < e; i++) { |
| if (resultType.getDimSize(i) != 1) { |
| sortedPermutation.push_back(permutation[i]); |
| } |
| } |
| if (llvm::is_sorted(sortedPermutation)) { |
| rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand()); |
| return success(); |
| } |
| return failure(); |
| } |
| |
| void TransposeOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* /*context*/) { |
| results.add(eliminateRedundantTranspse); |
| results.add(eliminateBroadcastInDimTranspose); |
| results.add(simplifyTranspose); |
| } |
| |
| LogicalResult TransposeOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| TransposeOp::Adaptor adaptor(operands); |
| Value operand = adaptor.operand(); |
| |
| auto operandType = operand.getType().dyn_cast<RankedTensorType>(); |
| // Not support unranked type a.t.m. |
| if (!operandType) return failure(); |
| |
| Location loc = this->getLoc(); |
| SmallVector<int64_t, 4> permutation(this->permutation().getValues<int64_t>()); |
| SmallVector<Value, 4> shapeValues(permutation.size()); |
| |
| Type shapeScalarType = builder.getIndexType(); |
| auto toShapeScalarType = [&](Value v) { |
| return maybeCastTo(builder, loc, v, shapeScalarType); |
| }; |
| |
| for (const auto& element : llvm::enumerate(operandType.getShape())) { |
| int64_t idx = element.index(); |
| auto* it = std::find(permutation.begin(), permutation.end(), idx); |
| Value valueDim = toShapeScalarType( |
| builder.createOrFold<tensor::DimOp>(loc, operand, element.index())); |
| shapeValues[std::distance(permutation.begin(), it)] = valueDim; |
| } |
| |
| Value outputShape = builder.create<tensor::FromElementsOp>( |
| loc, |
| RankedTensorType::get({static_cast<int64_t>(shapeValues.size())}, |
| shapeScalarType), |
| shapeValues); |
| reifiedReturnShapes.push_back(outputShape); |
| |
| return success(); |
| } |
| |
| // Method for InferTypeOpInterface: infer the return type from the operand type |
| // and the permutation. |
| LogicalResult TransposeOp::inferReturnTypes( |
| MLIRContext* /*context*/, Optional<Location> loc, ValueRange operands, |
| DictionaryAttr attributes, RegionRange, |
| SmallVectorImpl<Type>& inferredReturnTypes) { |
| auto type = operands[0].getType(); |
| auto rankedTy = type.dyn_cast<RankedTensorType>(); |
| if (!rankedTy) { |
| auto shapedTy = type.dyn_cast<ShapedType>(); |
| inferredReturnTypes.emplace_back(shapedTy); |
| return success(); |
| } |
| auto permutation = attributes.getAs<DenseIntElementsAttr>("permutation"); |
| int64_t rank = rankedTy.getRank(); |
| if (permutation.getType().getRank() != 1) |
| return emitOptionalError(loc, "TransposeOp permutation has rank ", |
| permutation.getType().getRank(), |
| " instead of rank 1"); |
| |
| if (permutation.size() != rank) |
| return emitOptionalError(loc, "TransposeOp operand rank ", rank, |
| " does not match permutation size ", |
| permutation.size()); |
| |
| std::vector<int64_t> range(rank); |
| std::iota(range.begin(), range.end(), 0); |
| if (!std::is_permutation(range.begin(), range.end(), permutation.begin())) |
| return emitOptionalError(loc, |
| "attribute permutation must be a permutation" |
| " of [", |
| range, "] but got ", permutation); |
| |
| SmallVector<int64_t> resultShape; |
| ArrayRef<int64_t> inputShape = rankedTy.getShape(); |
| for (int64_t dim : permutation.getValues<int64_t>()) { |
| resultShape.push_back(inputShape[dim]); |
| } |
| inferredReturnTypes.emplace_back(RankedTensorType::get( |
| resultShape, rankedTy.getElementType(), rankedTy.getEncoding())); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TriangularSolveOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult TriangularSolveOp::verify() { |
| auto aType = a().getType().dyn_cast<RankedTensorType>(); |
| |
| // Skip verifier if a is unranked tensor. |
| if (!aType) return success(); |
| |
| // Check that a should have rank >= 2 |
| auto aRank = aType.getRank(); |
| if (aRank < 2) |
| return emitOpError() << "operand 'a' must have rank >= 2, but got " |
| << aType; |
| |
| // The two minor dimensions of a must have same size. |
| if (aType.getDimSize(aRank - 2) != aType.getDimSize(aRank - 1)) |
| return emitOpError() << "two minor dimensions of operand 'a' must have " |
| "equal size, but got " |
| << aType; |
| |
| auto bType = b().getType().dyn_cast<RankedTensorType>(); |
| // If b is unranked skip remaining checks. |
| if (!bType) return success(); |
| |
| // Check that a and b have same rank. |
| auto bRank = bType.getRank(); |
| if (aRank != bRank) |
| return emitOpError() << "operands must have equal rank, but got " << aType |
| << " and " << bType; |
| |
| // The shared dimension of a and b should match. |
| if (aType.getDimSize(aRank - 1) != |
| bType.getDimSize(bRank - (left_side() ? 2 : 1))) |
| return emitOpError() << "shared dimension of operands 'a' and 'b' does " |
| "not match, but got " |
| << aType << " and " << bType; |
| |
| // The leading batch dimensions of a and b must be equal. |
| auto aBatchDims = aType.getShape().drop_back(2); |
| auto bBatchDims = bType.getShape().drop_back(2); |
| if (aBatchDims != bBatchDims) |
| return emitOpError() |
| << "leading batch dimensions of the operands must be same, but got " |
| << aType << " and " << bType; |
| |
| // Result and argument b must have same shape. |
| auto resultType = getType().dyn_cast<RankedTensorType>(); |
| if (!resultType) return success(); |
| if (resultType != bType) |
| return emitOpError() |
| << "result and operand 'b' must have same shape, but got " |
| << resultType << " and " << bType; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // GetTupleElementOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult GetTupleElementOp::inferReturnTypes( |
| MLIRContext*, Optional<Location>, ValueRange operands, |
| DictionaryAttr attributes, RegionRange, |
| SmallVectorImpl<Type>& inferredReturnTypes) { |
| auto tupleType = operands[0].getType().dyn_cast<TupleType>(); |
| if (!tupleType) return failure(); |
| |
| auto indexAttr = attributes.get("index").cast<IntegerAttr>(); |
| auto index = indexAttr.getInt(); |
| if (index < 0 || index >= static_cast<int64_t>(tupleType.size())) |
| return failure(); |
| |
| inferredReturnTypes.push_back(tupleType.getType(index)); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TupleOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult TupleOp::inferReturnTypes( |
| MLIRContext* context, Optional<Location>, ValueRange operands, |
| DictionaryAttr attributes, RegionRange, |
| SmallVectorImpl<Type>& inferredReturnTypes) { |
| inferredReturnTypes.push_back(TupleType::get(context, TypeRange(operands))); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // UnaryEinsumOp |
| //===----------------------------------------------------------------------===// |
| |
| void UnaryEinsumOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add<UnaryEinsumToEinsum>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CompareOp |
| //===----------------------------------------------------------------------===// |
| |
| void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, |
| Value rhs, ComparisonDirection comparisonDirection, |
| ComparisonType compareType) { |
| build(builder, result, lhs, rhs, |
| ComparisonDirectionAttr::get(builder.getContext(), comparisonDirection), |
| ComparisonTypeAttr::get(builder.getContext(), compareType)); |
| } |
| |
| LogicalResult CompareOp::inferReturnTypeComponents( |
| mlir::MLIRContext* ctx, llvm::Optional<mlir::Location>, |
| ValueShapeRange operands, mlir::DictionaryAttr, mlir::RegionRange, |
| llvm::SmallVectorImpl<mlir::ShapedTypeComponents>& inferredReturnTypes) { |
| ShapedTypeComponents& components = |
| inferredReturnTypes.emplace_back(IntegerType::get(ctx, /*width=*/1)); |
| auto argTy = operands.front().getType().cast<TensorType>(); |
| if (argTy.hasRank()) { |
| components = |
| ShapedTypeComponents(argTy.getShape(), components.getElementType()); |
| } |
| return success(); |
| } |
| |
| LogicalResult CompareOp::reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| return deriveShapeFromOperand(&builder, getOperation(), operands.front(), |
| &reifiedReturnShapes); |
| } |
| |
| template <typename T> |
| struct Less : std::less<T> {}; |
| |
| template <> |
| struct Less<APInt> { |
| bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); } |
| }; |
| |
| template <typename T> |
| struct LessEqual : std::less_equal<T> {}; |
| |
| template <> |
| struct LessEqual<APInt> { |
| bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); } |
| }; |
| |
| template <typename T> |
| struct Greater : std::greater<T> {}; |
| |
| template <> |
| struct Greater<APInt> { |
| bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); } |
| }; |
| |
| template <typename T> |
| struct GreaterEqual : std::greater_equal<T> {}; |
| |
| template <> |
| struct GreaterEqual<APInt> { |
| bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); } |
| }; |
| |
| template <typename Op, typename ElementType, typename SrcType, typename Convert> |
| static Attribute CompareFolder(CompareOp op, ArrayRef<Attribute> attrs) { |
| if (!attrs[0] || !attrs[1]) return {}; |
| |
| DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>(); |
| DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>(); |
| if (!lhs || !rhs) return {}; |
| |
| ShapedType operandType = |
| op.getOperand(0).getType().template cast<ShapedType>(); |
| if (!operandType.hasStaticShape()) { |
| return {}; |
| } |
| |
| if (!operandType.getElementType().isa<ElementType>()) { |
| return {}; |
| } |
| |
| // Prevent folding if the result is too large. |
| if (lhs.getNumElements() > kFoldOpEltLimit) return {}; |
| |
| SmallVector<bool, 6> values; |
| values.reserve(lhs.getNumElements()); |
| for (const auto zip : |
| llvm::zip(lhs.getValues<SrcType>(), rhs.getValues<SrcType>())) { |
| values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip))); |
| } |
| |
| auto resultTy = op.getType().cast<ShapedType>(); |
| return DenseElementsAttr::get(resultTy, values); |
| } |
| |
| OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) { |
| auto resultTy = getType().cast<ShapedType>(); |
| if (!resultTy.hasStaticShape()) return {}; |
| |
| auto direction = comparison_direction(); |
| auto lhsTy = getElementTypeOrSelf(lhs()); |
| if (lhs() == rhs() && !lhsTy.isa<FloatType>() && |
| (!lhsTy.isa<ComplexType>() || |
| !lhsTy.cast<ComplexType>().getElementType().isa<FloatType>())) { |
| if (direction == ComparisonDirection::LE || |
| direction == ComparisonDirection::EQ || |
| direction == ComparisonDirection::GE) { |
| return DenseIntElementsAttr::get(resultTy, {true}); |
| } |
| return DenseIntElementsAttr::get(resultTy, {false}); |
| } |
| |
| auto opElType = lhs().getType().cast<ShapedType>().getElementType(); |
| // Fold tensor<*xi1> != false to just return tensor<*xi1> |
| if (direction == ComparisonDirection::NE && opElType.isInteger(1)) { |
| DenseIntElementsAttr cstAttr; |
| if (matchPattern(lhs(), m_Constant(&cstAttr))) { |
| if (cstAttr.isSplat() && !cstAttr.getSplatValue<bool>()) { |
| return rhs(); |
| } |
| } |
| |
| if (matchPattern(rhs(), m_Constant(&cstAttr))) { |
| if (cstAttr.isSplat() && !cstAttr.getSplatValue<bool>()) { |
| return lhs(); |
| } |
| } |
| } |
| |
| // Fold tensor<*xi1> == True to just return tensor<*xi1> |
| if (direction == ComparisonDirection::EQ && opElType.isInteger(1)) { |
| DenseIntElementsAttr cstAttr; |
| if (matchPattern(lhs(), m_Constant(&cstAttr))) { |
| if (cstAttr.isSplat() && cstAttr.getSplatValue<bool>()) { |
| return rhs(); |
| } |
| } |
| |
| if (matchPattern(rhs(), m_Constant(&cstAttr))) { |
| if (cstAttr.isSplat() && cstAttr.getSplatValue<bool>()) { |
| return lhs(); |
| } |
| } |
| } |
| |
| if (!operands[0] || !operands[1]) { |
| return {}; |
| } |
| |
| #define COMPARE_FOLDER(Op, comparison, Func) \ |
| if (direction == comparison) { \ |
| if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \ |
| *this, operands)) \ |
| return folded; \ |
| if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>( \ |
| *this, operands)) \ |
| return folded; \ |
| } |
| |
| COMPARE_FOLDER(CompareOp, ComparisonDirection::EQ, std::equal_to); |
| COMPARE_FOLDER(CompareOp, ComparisonDirection::NE, std::not_equal_to); |
| COMPARE_FOLDER(CompareOp, ComparisonDirection::LT, Less); |
| COMPARE_FOLDER(CompareOp, ComparisonDirection::LE, LessEqual); |
| COMPARE_FOLDER(CompareOp, ComparisonDirection::GT, Greater); |
| COMPARE_FOLDER(CompareOp, ComparisonDirection::GE, GreaterEqual); |
| #undef COMPARE_FOLDER |
| |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SelectAndScatterOp |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| // Infer the return-type of SelectAndScatterOp. |
| TensorType inferSelectAndScatterOpReturnType( |
| TensorType operandType, const ArrayRef<WindowDimension> window) { |
| if (!operandType.hasRank()) |
| return UnrankedTensorType::get(operandType.getElementType()); |
| |
| return RankedTensorType::get( |
| inferWindowOutputShape(operandType.getShape(), window), |
| operandType.getElementType()); |
| } |
| } // namespace |
| |
| // We intend to verify the following properties: |
| // P1. Check if the select function has a proper shape of (T,T) -> PRED, where |
| // T is a 0-D tensor with element-type same as 'operand' element-type. |
| // P2. Verify scatter-computation type. |
| // P3. size-of(window_dimension) == rank-of(input), |
| // where input is an element of 'inputs'. |
| // P4. Verify and collect the window attributes. |
| // P5. Verify the return type matches the operand-type. |
| // P6. Check if the result type of window operation matches the source type. |
| LogicalResult SelectAndScatterOp::verify() { |
| auto operandType = operand().getType().cast<TensorType>(); |
| auto initValueType = init_value().getType().cast<TensorType>(); |
| auto sourceType = source().getType().cast<TensorType>(); |
| auto resultType = getResult().getType().cast<TensorType>(); |
| |
| // P1. |
| Block& selectBlock = select().front(); |
| |
| if (selectBlock.getArguments().size() != 2) |
| return emitOpError() |
| << "expects the select-region to take 2 parameters, but takes " |
| << selectBlock.getArguments().size(); |
| |
| Type expectedSelectArgType = |
| RankedTensorType::get({}, operandType.getElementType()); |
| for (const auto& selectArgIt : llvm::enumerate(selectBlock.getArguments())) |
| if (!compatibleShapeAndElementType(expectedSelectArgType, |
| selectArgIt.value().getType(), |
| /*ignoreFpPrecision=*/true)) |
| return emitOpError() |
| << "expects the type of select-region's parameter at index " |
| << selectArgIt.index() << " to be " << expectedSelectArgType |
| << ", but got " << selectArgIt.value().getType(); |
| |
| auto selectResult = selectBlock.getTerminator()->getOperands(); |
| if (selectResult.size() != 1) |
| return emitOpError() |
| << "expects select-region to return single value, but got: " |
| << selectResult.size(); |
| |
| auto selectResultType = selectResult[0].getType().dyn_cast<TensorType>(); |
| if (!selectResultType || !selectResultType.getElementType().isInteger(1) || |
| (selectResultType.hasRank() && |
| selectResultType.cast<RankedTensorType>().getRank() != 0)) |
| return emitOpError() << "expects the return-type of select-region to be " |
| "tensor<i1>, but got: " |
| << selectResult[0].getType(); |
| |
| // P2. |
| Block& scatterBlock = scatter().front(); |
| SmallVector<TensorType> accumulatorSubshapes; |
| if (failed(verifyReducerShape( |
| this->getLoc(), scatterBlock, |
| {RankedTensorType::get({}, sourceType.getElementType())}, |
| {initValueType}, |
| /*numInputs=*/1, /*allowedDimensions=*/{}, |
| /*allInputsUnranked=*/false, accumulatorSubshapes))) |
| return failure(); |
| |
| // P3. |
| SmallVector<int64_t> windowDims = |
| convertDenseIntAttr(this->window_dimensions()); |
| if (operandType.hasRank()) { |
| if (operandType.getRank() != static_cast<int64_t>(windowDims.size())) |
| return emitOpError() |
| << "expects window-dimensions size == operand rank, but got " |
| "window-dimensions size: " |
| << windowDims.size() << " and operand-type: " << operandType |
| << " with rank = " << operandType.getRank() << "."; |
| } |
| |
| // P4. |
| auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc()); |
| if (failed(paddingOrErr)) return failure(); |
| SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr; |
| |
| auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions( |
| windowDims, convertDenseIntAttr(window_strides()), padding, |
| /*lhs_dilation=*/{}, /*rhs_dilation=*/{}, getLoc()); |
| if (failed(windowOrErr)) return failure(); |
| |
| // P5. |
| if (!compatibleShapeAndElementType(operandType, resultType)) |
| return emitOpError() |
| << "expects the return-type to match the operand-type, but got " |
| << resultType << " and " << operandType << " resp."; |
| |
| // P6. |
| auto windowResultType = |
| inferSelectAndScatterOpReturnType(operandType, *windowOrErr); |
| |
| if (!compatibleShapeAndElementType(windowResultType, sourceType, |
| /*ignoreFpPrecision=*/true)) |
| return emitOpError() << "expects source-type to be " << windowResultType |
| << ", but got" << sourceType; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ScatterOp |
| //===----------------------------------------------------------------------===// |
| |
| /* |
| * We intend to verify the following properties: |
| * P1. The 'update_window_dims' must be valid indices of 'updates' tensor. |
| * P2. The 'inserted_window_dims' must be valid indices of 'operand' tensor. |
| * P3. Check if the rank-of('operand') == size-of('update_window_dims') + |
| * size-of('inserted_window_dims') |
| * P4. size-of('scatter_dims_to_operand_dims') = |
| * 'scatter_indices'['index_vector_dim'] & |
| * 'scatter_dims_to_operand_dims' must be valid indices of 'operand' tensor. |
| */ |
| LogicalResult validateScatterDimensionNumbers( |
| ShapedType operandType, ArrayRef<int64_t> scatterIndicesShape, |
| ShapedType updateType, bool operandTypeRanked, |
| bool scatterIndicesTypeRanked, bool updatesTypeRanked, |
| ScatterDimensionNumbersAttr dimNumbers, Location loc) { |
| // P1. |
| auto updateWindowDims = to_vector(dimNumbers.getUpdateWindowDims()); |
| if (!llvm::is_sorted(updateWindowDims)) |
| return mlir::emitError(loc) |
| << "Expects update_window_dims to be sorted; got: [" |
| << updateWindowDims << "]."; |
| |
| if (hasDuplicates(updateWindowDims)) |
| return mlir::emitError(loc) |
| << "Expects update_window_dims to not repeat; got: [" |
| << updateWindowDims << "]."; |
| |
| if (updatesTypeRanked) { |
| for (int64_t windowDim : updateWindowDims) { |
| if (windowDim < 0 || windowDim >= updateType.getRank()) { |
| return mlir::emitError(loc) |
| << "Expects each element of update_window_dims to be in range " |
| "[0, " |
| "rank-of('updates') i.e. [0, " |
| << updateType.getRank() << "). got: " << windowDim << "."; |
| } |
| } |
| } |
| |
| // P2. |
| auto insertedWindowDims = to_vector(dimNumbers.getInsertedWindowDims()); |
| if (!llvm::is_sorted(insertedWindowDims)) |
| return mlir::emitError(loc) |
| << "Expects inserted_window_dims to be sorted; got: [" |
| << insertedWindowDims << "]."; |
| |
| if (hasDuplicates(insertedWindowDims)) |
| return mlir::emitError(loc) |
| << "Expects inserted_window_dims to not repeat; got: [" |
| << insertedWindowDims << "]."; |
| |
| if (operandTypeRanked) { |
| for (int64_t insertedDim : insertedWindowDims) { |
| if (insertedDim < 0 || insertedDim >= operandType.getRank()) { |
| return mlir::emitError(loc) |
| << "Expects each element of inserted_window_dims to be in range " |
| "[0, rank-of('operand') i.e. [0, " |
| << operandType.getRank() << "). got: " << insertedDim << "."; |
| } |
| } |
| } |
| |
| // P3. |
| if (operandTypeRanked) { |
| auto windowSize = updateWindowDims.size() + insertedWindowDims.size(); |
| if (operandType.getRank() != static_cast<int64_t>(windowSize)) |
| return mlir::emitError(loc) |
| << "Expects rank-of operand to match " |
| "size-of('update_window_dims') + " |
| "size-of('inserted_window_dims') i.e. " |
| << windowSize << " but got " << operandType.getRank() << "."; |
| } |
| |
| // P4. |
| auto scatterDimsToOperandDims = |
| to_vector(dimNumbers.getScatterDimsToOperandDims()); |
| auto indexVectorDim = dimNumbers.getIndexVectorDim(); |
| if (scatterIndicesTypeRanked) { |
| if (!isDynamicDimSize(scatterIndicesShape[indexVectorDim]) && |
| static_cast<int64_t>(scatterDimsToOperandDims.size()) != |
| scatterIndicesShape[dimNumbers.getIndexVectorDim()]) |
| return mlir::emitError(loc) |
| << "Scatter op has " << scatterDimsToOperandDims.size() |
| << " elements in scatter_dims_to_operand_dims and the bound of " |
| "dimension index_vector_dim=" |
| << dimNumbers.getIndexVectorDim() << " of scatter_indices is " |
| << scatterIndicesShape[dimNumbers.getIndexVectorDim()] |
| << ". These two numbers must be equal."; |
| } |
| |
| if (operandTypeRanked) { |
| for (int64_t i = 0; |
| i < static_cast<int64_t>(scatterDimsToOperandDims.size()); ++i) { |
| int64_t scatterDimToOperandDim = scatterDimsToOperandDims[i]; |
| if (scatterDimToOperandDim < 0 || |
| scatterDimToOperandDim >= operandType.getRank()) |
| return mlir::emitError(loc) |
| << "Invalid scatter_dims_to_operand_dims mapping; domain is [0, " |
| << operandType.getRank() << "), got: " << i << "->" |
| << scatterDimToOperandDim << "."; |
| } |
| } |
| |
| if (hasDuplicates(scatterDimsToOperandDims)) |
| return mlir::emitError(loc) |
| << "Expects scatter_dims_to_operand_dims to not repeat; got: [" |
| << scatterDimsToOperandDims << "]."; |
| |
| return success(); |
| } |
| /* |
| * We intend to verify the following properties: |
| * P0. scatter_indices argument must be an integral tensor. Enforced by ODS. |
| * P1. Scatter index leaf dimension must be within [0, rank(scatter_indices)" |
| * " + 1). |
| * P2. Verify reducer shape. |
| * P3. rank-of('updates[i]') == size-of('update_window_dims') + |
| * rank-of('scatter_indices') - 1, where 'scatter_indices' is expanded by a |
| * trailing 1 dimension if 'index_vector_dim' == rank-of('scatter_indices') |
| * for all values of `i`. |
| * P4. Validate the scatter-dimensions-numbers. |
| * P5. Valide the bounds of each of the 'updates' w.r.t the operands. |
| * P6. Validate the bounds of each of the 'updates' w.r.t the |
| * 'scatter_indices'. |
| * P7. Check return types. |
| */ |
| LogicalResult ScatterOp::verify() { |
| // Get the first operand and update, since variadic Scatter is not yet |
| // implemented |
| auto numOperands = operands().size(); |
| auto scatterIndicesType = scatter_indices().getType().dyn_cast<TensorType>(); |
| |
| SmallVector<TensorType, 1> operandTypes = |
| llvm::to_vector(llvm::map_range(operands().getTypes(), [](Type type) { |
| return type.cast<TensorType>(); |
| })); |
| SmallVector<TensorType, 1> updatesTypes = llvm::to_vector(llvm::map_range( |
| updates().getTypes(), [](Type type) { return type.cast<TensorType>(); })); |
| bool allOperandTypesRanked = |
| llvm::all_of(operands().getTypes(), |
| [](Type type) { return type.isa<RankedTensorType>(); }); |
| bool scatterIndicesTypeRanked = scatterIndicesType.isa<RankedTensorType>(); |
| |
| // P1. |
| int64_t indexVectorDim = scatter_dimension_numbers().getIndexVectorDim(); |
| if (scatterIndicesTypeRanked) { |
| if (indexVectorDim > scatterIndicesType.getRank() || indexVectorDim < 0) |
| return emitOpError() |
| << "expects scatter index leaf dimension to be within [0, " |
| "rank(scatter_indices) + 1." |
| " rank(scatter_indices) is " |
| << scatterIndicesType.getRank() |
| << " and scatter index leaf dimension is " << indexVectorDim |
| << "."; |
| } |
| |
| // P2. |
| Block& block = update_computation().front(); |
| SmallVector<TensorType> accumulatorSubshapes; |
| SmallVector<TensorType> inputTypes, initValueTypes; |
| for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) { |
| inputTypes.push_back(operandTypes[i]); |
| initValueTypes.push_back( |
| RankedTensorType::get({}, updatesTypes[i].getElementType())); |
| } |
| if (failed(verifyReducerShape( |
| this->getLoc(), block, inputTypes, initValueTypes, numOperands, |
| /*allowedDimensions=*/{}, |
| /*allInputsUnranked=*/!allOperandTypesRanked, accumulatorSubshapes))) |
| return failure(); |
| |
| // P3. |
| auto updateWindowDims = scatter_dimension_numbers().getUpdateWindowDims(); |
| SmallVector<int64_t> expandedScatterIndicesShape; |
| if (scatterIndicesTypeRanked) { |
| expandedScatterIndicesShape = |
| llvm::to_vector(scatterIndicesType.getShape()); |
| if (static_cast<int64_t>(expandedScatterIndicesShape.size()) == |
| indexVectorDim) |
| expandedScatterIndicesShape.push_back(1); |
| } |
| |
| for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) { |
| if (scatterIndicesTypeRanked && updatesTypes[i].isa<RankedTensorType>()) { |
| int64_t expectedUpdatesRank = |
| expandedScatterIndicesShape.size() - 1 + updateWindowDims.size(); |
| if (updatesTypes[i].getRank() != expectedUpdatesRank) |
| return emitOpError() |
| << "expects updates tensor must be of rank " |
| << expectedUpdatesRank |
| << " ( == rank-of('scatter_indices') - 1 + " |
| "size-of('update_window_dims'), where 'scatter_indices' is " |
| "expanded by a trailing 1 dimension if 'index_vector_dim' == " |
| "rank-of('scatter_indices')), but got " |
| << updatesTypes[i].getRank() << "."; |
| } |
| } |
| |
| // P4. |
| for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) { |
| if (failed(validateScatterDimensionNumbers( |
| operandTypes[i], expandedScatterIndicesShape, updatesTypes[i], |
| operandTypes[i].isa<RankedTensorType>(), scatterIndicesTypeRanked, |
| updatesTypes[i].isa<RankedTensorType>(), |
| scatter_dimension_numbers(), getLoc()))) |
| return failure(); |
| } |
| |
| // P5. |
| for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) { |
| if (updatesTypes[i].isa<RankedTensorType>()) { |
| auto updatesShape = updatesTypes[i].getShape(); |
| if (operandTypes[i].isa<RankedTensorType>()) { |
| auto operandShape = operandTypes[i].getShape(); |
| auto insertedWindowDims = |
| scatter_dimension_numbers().getInsertedWindowDims(); |
| |
| int64_t insertedDimsSeen = 0; |
| SmallVector<int64_t> maxUpdateSliceSizes; |
| const auto dimensionsSize = operandTypes[i].getRank(); |
| maxUpdateSliceSizes.reserve(dimensionsSize); |
| for (int i = 0; i < dimensionsSize; ++i) { |
| if (insertedDimsSeen < |
| static_cast<int64_t>(insertedWindowDims.size()) && |
| insertedWindowDims[insertedDimsSeen] == i) { |
| ++insertedDimsSeen; |
| } else { |
| maxUpdateSliceSizes.push_back(operandShape[i]); |
| } |
| } |
| |
| for (int64_t i = 0; i < static_cast<int64_t>(updateWindowDims.size()); |
| ++i) { |
| auto updateWindowDim = updateWindowDims[i]; |
| |
| if (isDynamicDimSize(updatesShape[updateWindowDim]) || |
| isDynamicDimSize(maxUpdateSliceSizes[i])) |
| continue; |
| |
| if (updatesShape[updateWindowDim] > maxUpdateSliceSizes[i]) { |
| return emitOpError() |
| << "expects bounds of the window dimensions of " |
| "updates to not exceed the " |
| "bounds of the corresponding dimensions of " |
| "operand. For dimension " |
| << updateWindowDim << ", updates bound is " |
| << updatesShape[updateWindowDim] << ", operand bound is " |
| << maxUpdateSliceSizes[i] << "."; |
| } |
| } |
| } |
| |
| // P6. |
| if (scatterIndicesTypeRanked) { |
| int64_t scatterDimsSeen = 0; |
| for (int64_t i = 0; i < static_cast<int64_t>(updatesShape.size()); |
| ++i) { |
| bool isUpdateWindowDim = std::binary_search( |
| updateWindowDims.begin(), updateWindowDims.end(), i); |
| |
| if (isUpdateWindowDim) continue; |
| if (scatterDimsSeen == indexVectorDim) ++scatterDimsSeen; |
| |
| if (!isDynamicDimSize(updatesShape[i]) && |
| !isDynamicDimSize(expandedScatterIndicesShape[scatterDimsSeen]) && |
| (updatesShape[i] != |
| expandedScatterIndicesShape[scatterDimsSeen])) { |
| return emitOpError() |
| << "expects bounds of the scatter dimensions of " |
| "updates to be same as the " |
| "bounds of the corresponding dimensions of " |
| "scatter indices. For " |
| "scatter dimension " |
| << i << ", updates bound is " << updatesShape[i] |
| << " , scatter_indices " |
| "bound is " |
| << expandedScatterIndicesShape[scatterDimsSeen] << "."; |
| } |
| ++scatterDimsSeen; |
| } |
| } |
| } |
| } |
| |
| // P7. |
| for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) { |
| if (!compatibleShapeAndElementType(operandTypes[i], getResult(i).getType())) |
| return emitOpError() |
| << "expects the return type to be same as the operand type: " |
| << operandTypes[i] << ", but got " << getResult(i).getType() |
| << "."; |
| } |
| |
| return success(); |
| } |
| |
| llvm::SmallVector<Attribute, 4> evaluateMhloRegion(Region& region, |
| ArrayRef<Attribute> inputs) { |
| if (region.getNumArguments() != inputs.size()) return {}; |
| |
| llvm::DenseMap<Value, Attribute> values; |
| values.reserve(region.getNumArguments()); |
| for (auto it : llvm::zip(region.getArguments(), inputs)) { |
| values.try_emplace(std::get<0>(it), std::get<1>(it)); |
| } |
| |
| for (auto& op : region.getOps()) { |
| llvm::SmallVector<Attribute, 4> inputs; |
| for (auto& operand : op.getOpOperands()) { |
| inputs.push_back(values.lookup(operand.get())); |
| } |
| if (isa<ReturnOp>(op)) return inputs; |
| |
| llvm::SmallVector<OpFoldResult, 4> results; |
| if (failed(op.fold(inputs, results))) return {}; |
| for (auto it : llvm::zip(op.getResults(), results)) { |
| if (!std::get<1>(it).is<Attribute>()) return {}; |
| values.insert({std::get<0>(it), std::get<1>(it).get<Attribute>()}); |
| } |
| } |
| return {}; |
| } |
| |
| LogicalResult ScatterOp::fold( |
| ArrayRef<Attribute> args, |
| llvm::SmallVectorImpl<OpFoldResult>& foldResults) { |
| // Variadic Scatter not yet implemented |
| if (operands().size() != 1 || updates().size() != 1) return failure(); |
| auto index = args[1].dyn_cast_or_null<DenseIntElementsAttr>(); |
| if (!index) return failure(); |
| |
| auto baseType = operands().getTypes()[0].dyn_cast<RankedTensorType>(); |
| auto updateType = updates().getTypes()[0].dyn_cast<RankedTensorType>(); |
| auto indexType = index.getType().cast<RankedTensorType>(); |
| if (!baseType || !indexType || !updateType) return failure(); |
| |
| // TODO(b/228310289): Work around canonicalization crash for complex types. |
| // Remove after upstream MLIR has been fixed. |
| if (baseType.getElementType().isa<ComplexType>()) return failure(); |
| |
| // Catch a trivial full replacement of base with update, this does not require |
| // these to be constant: just that we know the type. |
| if (updateType == baseType && updateType.hasStaticShape() && |
| baseType.hasStaticShape() && index.isSplat() && |
| index.getSplatValue<uint32_t>() == 0 && |
| llvm::hasSingleElement(update_computation().front())) { |
| foldResults.push_back(updates()[0]); |
| return success(); |
| } |
| auto base = args[0].dyn_cast_or_null<DenseElementsAttr>(); |
| auto update = args[2].dyn_cast_or_null<DenseElementsAttr>(); |
| if (!base || !update) return failure(); |
| |
| // Add the virtual trailing dimension of size 1 if indexVectorDim equals to |
| // indexType.rank. |
| const int64_t indexVectorDim = |
| scatter_dimension_numbers().getIndexVectorDim(); |
| if (indexVectorDim == indexType.getRank()) { |
| auto indexShape = indexType.getShape().vec(); |
| indexShape.push_back(1); |
| indexType = RankedTensorType::get(indexShape, indexType.getElementType()); |
| index = reshape(index, indexType).cast<DenseIntElementsAttr>(); |
| } |
| |
| // Increment the multi-dimensional index vector based on the limits for each |
| // dimension specified by shape and returns false if the index rolled around |
| // with true otherwise. |
| auto nextIndex = [](llvm::SmallVector<uint64_t, 8>& index, |
| llvm::ArrayRef<int64_t> shape) { |
| for (int64_t i = index.size() - 1; i >= 0; --i) { |
| ++index[i]; |
| if (index[i] < static_cast<unsigned long>(shape[i])) return true; |
| index[i] = 0; |
| } |
| return false; |
| }; |
| |
| // Prevent folding if the result is too large. |
| if (base.getNumElements() > kFoldOpEltLimit) return failure(); |
| |
| // Iterate over all elements of the update tensor, then find the corresponding |
| // value in the indices tensor to determine which location we have to update |
| // in the base/result tensor. |
| llvm::SmallVector<Attribute, 8> results(base.getValues<Attribute>()); |
| llvm::SmallVector<uint64_t, 8> updateIndex(updateType.getRank(), 0); |
| llvm::SmallVector<uint64_t, 8> indexIndex; |
| indexIndex.reserve(indexType.getRank()); |
| llvm::SmallVector<int64_t, 8> baseIndex; |
| baseIndex.reserve(baseType.getRank()); |
| do { |
| // Compute the index for the slice of the indices tensor for this update |
| // value. |
| indexIndex.clear(); |
| if (indexVectorDim == 0) indexIndex.push_back(0); |
| for (int64_t i = 0; i < static_cast<int64_t>(updateIndex.size()); ++i) { |
| if (llvm::count(scatter_dimension_numbers().getUpdateWindowDims(), i) == |
| 0) |
| indexIndex.push_back(updateIndex[i]); |
| if (static_cast<int64_t>(indexIndex.size()) == indexVectorDim) |
| indexIndex.push_back(0); |
| } |
| |
| // Compute the index for the given update value in the base tensor. |
| baseIndex.assign(baseType.getRank(), 0); |
| uint64_t indexCount = indexType.getShape()[indexVectorDim]; |
| for (uint64_t i = 0; i < indexCount; ++i) { |
| uint64_t operandDim = |
| scatter_dimension_numbers().getScatterDimsToOperandDims()[i]; |
| indexIndex[indexVectorDim] = i; |
| baseIndex[operandDim] += |
| index.getValues<APInt>()[indexIndex].getSExtValue(); |
| } |
| uint64_t updateWindowDimIndex = 0; |
| auto insertedWindowDims = |
| scatter_dimension_numbers().getInsertedWindowDims(); |
| auto updateWindowDims = scatter_dimension_numbers().getUpdateWindowDims(); |
| for (uint64_t i = 0; i < baseIndex.size(); ++i) { |
| if (llvm::count(insertedWindowDims, i)) continue; |
| baseIndex[i] += updateIndex[updateWindowDims[updateWindowDimIndex]]; |
| updateWindowDimIndex++; |
| } |
| |
| // Compute the linear index for the index into the base tensor. |
| int64_t linearBaseIndex = 0; |
| int64_t linearBaseIndexMultiplyer = 1; |
| for (int64_t i = baseIndex.size() - 1; i >= 0; --i) { |
| // Out of bound index have backend specific behaviour so avoid folding it. |
| if (baseIndex[i] < 0 || baseIndex[i] >= baseType.getShape()[i]) |
| return failure(); |
| linearBaseIndex += baseIndex[i] * linearBaseIndexMultiplyer; |
| linearBaseIndexMultiplyer *= baseType.getShape()[i]; |
| } |
| |
| // Evaluate update computation and update the value with the newly computed |
| // attribute in the base tensor. |
| auto lhs = DenseElementsAttr::get( |
| RankedTensorType::get({}, baseType.getElementType()), |
| results[linearBaseIndex]); |
| auto rhs = DenseElementsAttr::get( |
| RankedTensorType::get({}, baseType.getElementType()), |
| update.getValues<Attribute>()[updateIndex]); |
| auto newValue = evaluateMhloRegion(update_computation(), {lhs, rhs}); |
| if (newValue.size() != 1 || !newValue[0]) return failure(); |
| results[linearBaseIndex] = |
| newValue[0].cast<DenseElementsAttr>().getValues<Attribute>()[0]; |
| } while (nextIndex(updateIndex, updateType.getShape())); |
| |
| foldResults.push_back(DenseElementsAttr::get(baseType, results)); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // WhileOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult WhileOp::verify() { |
| if (getNumOperands() != cond().front().getNumArguments()) |
| return emitOpError() << "mismatch in operand count (" << getNumOperands() |
| << ") vs the condition block argument count (" |
| << cond().front().getNumArguments() << ")"; |
| if (getNumOperands() != body().front().getNumArguments()) |
| return emitOpError() << "mismatch in operand count (" << getNumOperands() |
| << ") vs the body block argument count (" |
| << body().front().getNumArguments() << ")"; |
| for (const auto& enumeratedOperands : llvm::enumerate( |
| llvm::zip(getOperandTypes(), cond().front().getArgumentTypes(), |
| body().front().getArgumentTypes()))) { |
| int argCount = enumeratedOperands.index(); |
| const auto& operands = enumeratedOperands.value(); |
| Type operandType = std::get<0>(operands); |
| Type condType = std::get<1>(operands); |
| Type bodyType = std::get<2>(operands); |
| if (operandType != condType) |
| return emitOpError() << "type mismatch between operand #" << argCount |
| << " and the matching condition block argument: " |
| << operandType << " vs " << condType; |
| if (operandType != bodyType) |
| return emitOpError() << "type mismatch between operand #" << argCount |
| << " and the matching body block argument: " |
| << operandType << " vs " << bodyType; |
| } |
| // Check the return type for the condition block. |
| { |
| auto condReturnOp = cast<ReturnOp>(cond().front().back()); |
| if (condReturnOp->getNumOperands() != 1) |
| return condReturnOp.emitOpError() |
| << "expects a single operand for while condition body return, got " |
| << condReturnOp->getNumOperands(); |
| auto operandType = |
| condReturnOp->getOperand(0).getType().dyn_cast<RankedTensorType>(); |
| if (!operandType || operandType.getRank() != 0 || |
| !operandType.getElementType().isInteger(1)) |
| return condReturnOp.emitOpError() |
| << "expects a zero-ranked tensor of i1, got " |
| << condReturnOp->getOperand(0).getType(); |
| } |
| // Check the return type for the body block. |
| { |
| auto bodyReturnOp = cast<ReturnOp>(body().front().back()); |
| if (bodyReturnOp->getNumOperands() != getNumOperands()) |
| return bodyReturnOp.emitOpError() |
| << "expects body to return a many value as the operands (" |
| << getNumOperands() << "), got " << bodyReturnOp->getNumOperands(); |
| for (const auto& enumeratedOperandTypes : llvm::enumerate( |
| llvm::zip(bodyReturnOp->getOperandTypes(), getOperandTypes()))) { |
| Type operandType = std::get<0>(enumeratedOperandTypes.value()); |
| Type returnType = std::get<1>(enumeratedOperandTypes.value()); |
| if (operandType != returnType) |
| return bodyReturnOp.emitOpError() |
| << "type mismatch between operand #" |
| << enumeratedOperandTypes.index() |
| << " and the enclosing WhileOp returned value: " << operandType |
| << " vs " << returnType; |
| } |
| } |
| return success(); |
| } |
| |
| /// Print a `while` op. |
| /// |
| /// op ::= `mhlo.while` `(` assignment-list `)` `:` types attribute-dict |
| /// `cond` region |
| /// `do` region |
| /// assignment-list ::= assignment | assignment `,` assignment-list |
| /// assignment ::= ssa-value `=` ssa-value |
| void WhileOp::print(OpAsmPrinter& p) { |
| p << '('; |
| llvm::interleaveComma(llvm::zip(getBody()->getArguments(), getOperands()), p, |
| [&](auto zip) { |
| p.printOperand(std::get<0>(zip)); |
| p << " = "; |
| p.printOperand(std::get<1>(zip)); |
| }); |
| p << ")"; |
| if (getNumOperands()) { |
| p << " : "; |
| llvm::interleaveComma(getOperandTypes(), p); |
| } |
| p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); |
| p.printNewline(); |
| p << " cond "; |
| p.printRegion(getRegion(0), /*printEntryBlockArgs=*/false); |
| p << " do "; |
| p.printRegion(getRegion(1), /*printEntryBlockArgs=*/false); |
| } |
| |
| ParseResult WhileOp::parse(OpAsmParser& parser, OperationState& result) { |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| // Parse the operands of the while: these are of the form: |
| // %iter_arg = %init_val |
| // where %iter_arg is the name of the block argument in the cond/body blocks |
| // and %init_val is the actual operand. |
| SmallVector<OpAsmParser::UnresolvedOperand> operands; |
| SmallVector<OpAsmParser::UnresolvedOperand> iterArgs; |
| if (parser.parseLParen()) return failure(); |
| do { |
| if (succeeded(parser.parseOptionalRParen())) break; |
| OpAsmParser::UnresolvedOperand operand, iterArg; |
| if (parser.parseOperand(iterArg) || parser.parseEqual() || |
| parser.parseOperand(operand)) |
| return failure(); |
| iterArgs.push_back(iterArg); |
| operands.push_back(operand); |
| if (succeeded(parser.parseOptionalRParen())) break; |
| if (failed(parser.parseComma())) return failure(); |
| } while (true); |
| if (!operands.empty()) { |
| if (parser.parseColon() || parser.parseTypeList(result.types)) |
| return failure(); |
| } |
| |
| SmallVector<OpAsmParser::Argument> args; |
| createArgs(iterArgs, result.types, args); |
| if (parser.resolveOperands(operands, result.types, loc, result.operands) || |
| parser.parseOptionalAttrDictWithKeyword(result.attributes) || |
| parser.parseKeyword("cond") || |
| parser.parseRegion(*result.addRegion(), args) || |
| parser.parseKeyword("do") || |
| parser.parseRegion(*result.addRegion(), args)) |
| return failure(); |
| return success(); |
| } |
| |
| LogicalResult WhileOp::fold(ArrayRef<Attribute> /*operands*/, |
| SmallVectorImpl<OpFoldResult>& results) { |
| DenseIntElementsAttr condValue; |
| auto condReturnOp = cast<ReturnOp>(cond().front().back()); |
| if (!matchPattern(condReturnOp.getOperand(0), m_Constant(&condValue))) |
| return failure(); |
| if (condValue.getSplatValue<BoolAttr>().getValue()) |
| return failure(); // TODO(mhlo): this is an infinite loop, should we fold? |
| |
| results.append(getOperands().begin(), getOperands().end()); |
| return success(); |
| } |
| |
| static LogicalResult whileCanonicalization(WhileOp whileOp, |
| PatternRewriter& rewriter) { |
| // Turn loop invariant values into implicit capture. |
| // Check if there is at least one value is forwarded from one iteration to the |
| // next, or one of the yielded value is an implicit capture already. Otherwise |
| // there is nothing to do here. |
| Block* cond = whileOp.getBody(0); |
| Block* body = whileOp.getBody(1); |
| auto bodyReturnOp = cast<ReturnOp>(body->getTerminator()); |
| if (!llvm::any_of(llvm::zip(whileOp->getOperands(), body->getArguments(), |
| bodyReturnOp->getOperands()), |
| [&](auto zip) { |
| return (std::get<0>(zip) == std::get<2>(zip) || |
| std::get<1>(zip) == std::get<2>(zip)); |
| })) |
| return rewriter.notifyMatchFailure(whileOp, "no loop invariant found"); |
| |
| SmallVector<Value> newOperands, resultsToReplace; |
| SmallVector<unsigned> invariantArgIdxs; |
| for (const auto& enumeratedOperands : llvm::enumerate(llvm::zip( |
| whileOp.getOperands(), cond->getArguments(), body->getArguments(), |
| bodyReturnOp->getOperands(), whileOp->getResults()))) { |
| const auto& operands = enumeratedOperands.value(); |
| Value whileOperand = std::get<0>(operands); |
| BlockArgument condBlockArg = std::get<1>(operands); |
| BlockArgument bodyBlockArg = std::get<2>(operands); |
| Value bodyReturnOperand = std::get<3>(operands); |
| Value whileResult = std::get<4>(operands); |
| |
| bool forwarded = (whileOperand == bodyReturnOperand || |
| bodyBlockArg == bodyReturnOperand); |
| if (forwarded) { |
| invariantArgIdxs.push_back(enumeratedOperands.index()); |
| condBlockArg.replaceAllUsesWith(whileOperand); |
| bodyBlockArg.replaceAllUsesWith(whileOperand); |
| whileResult.replaceAllUsesWith(whileOperand); |
| continue; |
| } |
| newOperands.push_back(whileOperand); |
| resultsToReplace.push_back(whileResult); |
| } |
| cond->eraseArguments(invariantArgIdxs); |
| body->eraseArguments(invariantArgIdxs); |
| for (int idx : llvm::reverse(invariantArgIdxs)) |
| bodyReturnOp->eraseOperand(idx); |
| |
| WhileOp newWhileOp = rewriter.create<WhileOp>( |
| whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands); |
| newWhileOp.getBodyRegion(0).takeBody(whileOp.getBodyRegion(0)); |
| newWhileOp.getBodyRegion(1).takeBody(whileOp.getBodyRegion(1)); |
| for (auto results : llvm::zip(resultsToReplace, newWhileOp->getResults())) |
| std::get<0>(results).replaceAllUsesWith(std::get<1>(results)); |
| rewriter.eraseOp(whileOp); |
| return success(); |
| } |
| |
| void WhileOp::getCanonicalizationPatterns(RewritePatternSet& results, |
| MLIRContext* context) { |
| results.add(&whileCanonicalization); |
| } |
| |
| LogicalResult UniformDequantizeOp::inferReturnTypeComponents( |
| MLIRContext*, Optional<Location> /*location*/, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { |
| UniformDequantizeOp::Adaptor adaptor(operands, attributes, regions); |
| auto operandType = (*operands.begin()).getType().cast<ShapedType>(); |
| // Trait HLO_QuantizedIntTensor in ODS guarantees QuantizedType; |
| auto quantType = operandType.getElementType().cast<quant::QuantizedType>(); |
| auto shape = operandType.dyn_cast<ShapedType>().getShape(); |
| inferredReturnShapes.emplace_back(shape, quantType.getExpressedType()); |
| return success(); |
| } |
| |
| using mlir::hlo::parseWindowAttributes; |
| using mlir::hlo::printWindowAttributes; |
| |
| } // namespace mhlo |
| } // namespace mlir |
| |
| #define GET_OP_CLASSES |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" |
| |
| namespace mlir { |
| namespace mhlo { |
| |
| //===----------------------------------------------------------------------===// |
| // mhlo Dialect Interfaces |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| struct HLOInlinerInterface : public DialectInlinerInterface { |
| using DialectInlinerInterface::DialectInlinerInterface; |
| |
| // Allow all call operations to be inlined. |
| bool isLegalToInline(Operation* call, Operation* callable, |
| bool wouldBeCloned) const final { |
| return true; |
| } |
| // We don't have any special restrictions on what can be inlined into |
| // destination regions (e.g. while/conditional bodies). Always allow it. |
| bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned, |
| BlockAndValueMapping& valueMapping) const final { |
| return true; |
| } |
| // Operations in mhlo dialect are always legal to inline since they are |
| // pure. |
| bool isLegalToInline(Operation*, Region*, bool, |
| BlockAndValueMapping&) const final { |
| return true; |
| } |
| }; |
| } // end anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // mhlo Dialect Constructor |
| //===----------------------------------------------------------------------===// |
| |
| MhloDialect::MhloDialect(MLIRContext* context) |
| : Dialect(getDialectNamespace(), context, TypeID::get<MhloDialect>()) { |
| addOperations< |
| #define GET_OP_LIST |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" |
| >(); |
| addInterfaces<HLOInlinerInterface>(); |
| addTypes<TokenType>(); |
| addAttributes< |
| #define GET_ATTRDEF_LIST |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.cc.inc" |
| >(); |
| context->loadDialect<tensor::TensorDialect>(); |
| } |
| |
| Type MhloDialect::parseType(DialectAsmParser& parser) const { |
| StringRef dataType; |
| if (parser.parseKeyword(&dataType)) return Type(); |
| |
| if (dataType == "token") return TokenType::get(getContext()); |
| parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << dataType; |
| return nullptr; |
| } |
| |
| void MhloDialect::printType(Type type, DialectAsmPrinter& os) const { |
| if (type.isa<TokenType>()) { |
| os << "token"; |
| return; |
| } |
| os << "<unknown mhlo type>"; |
| } |
| |
| // Entry point for Attribute parsing, TableGen generated code will handle the |
| // dispatch to the individual classes. |
| Attribute MhloDialect::parseAttribute(DialectAsmParser& parser, |
| Type type) const { |
| StringRef attrTag; |
| Attribute attr; |
| auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr); |
| if (parseResult.hasValue()) return attr; |
| parser.emitError(parser.getNameLoc(), "unknown mhlo attribute"); |
| return Attribute(); |
| } |
| |
| // Entry point for Attribute printing, TableGen generated code will handle the |
| // dispatch to the individual classes. |
| void MhloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const { |
| LogicalResult result = generatedAttributePrinter(attr, os); |
| (void)result; |
| assert(succeeded(result)); |
| } |
| |
| /// Helpers for attributes parsing. |
| |
| static ParseResult parseDims(AsmParser& parser, SmallVector<int64_t>& dims) { |
| dims.clear(); |
| if (parser.parseLSquare()) return failure(); |
| while (failed(parser.parseOptionalRSquare())) { |
| dims.emplace_back(); |
| if (parser.parseInteger(dims.back())) return failure(); |
| (void)parser.parseOptionalComma(); |
| } |
| return success(); |
| } |
| |
| static ParseResult parseDimsWithMinimumElements(AsmParser& parser, |
| SmallVector<int64_t>& dims, |
| int minElements) { |
| if (failed(parseDims(parser, dims))) return failure(); |
| if (static_cast<int64_t>(dims.size()) < minElements) |
| return parser.emitError(parser.getCurrentLocation()) |
| << "expected at least " << minElements << " element(s), found " |
| << dims.size(); |
| return success(); |
| } |
| |
| /// Parse a custom attribute that resembles a struct of the form |
| /// < |
| /// foo = something_parsed_by_custom_parser, |
| /// bar = something_parsed_by_different_custom_parser, |
| /// baz something_parsed_by_another_custom_parser |
| /// > |
| /// The optional argument `parse_equal` array can be used to denote if |
| /// '=' follows the keyword (see baz in the example above) for a field. If |
| /// not provided, all fields must be followed by a '='. |
| static ParseResult parseStruct( |
| AsmParser& parser, ArrayRef<StringRef> keywords, |
| ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs, |
| ArrayRef<bool> parseEqual = {}) { |
| assert(keywords.size() == parseFuncs.size()); |
| assert(parseEqual.empty() || parseEqual.size() == keywords.size()); |
| SmallVector<bool> seen(keywords.size(), false); |
| while (failed(parser.parseOptionalGreater())) { |
| bool foundOne = false; |
| for (const auto& it : llvm::enumerate(keywords)) { |
| size_t index = it.index(); |
| StringRef keyword = it.value(); |
| if (succeeded(parser.parseOptionalKeyword(keyword))) { |
| if (seen[index]) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "duplicated `" << keyword << "` entry"; |
| } |
| if (parseEqual.empty() || parseEqual[index]) { |
| if (failed(parser.parseEqual())) return failure(); |
| } |
| if (failed(parseFuncs[index]())) return failure(); |
| if (failed(parser.parseOptionalComma())) return parser.parseGreater(); |
| seen[index] = true; |
| foundOne = true; |
| } |
| } |
| if (!foundOne) { |
| auto parseError = parser.emitError(parser.getCurrentLocation()) |
| << "expected one of: "; |
| llvm::interleaveComma(keywords, parseError, [&](StringRef kw) { |
| parseError << '`' << kw << '`'; |
| }); |
| return parseError; |
| } |
| } |
| return success(); |
| } |
| |
| // Helpers to print an optional array or integer field, to simplify writing |
| // attribute printers. |
| template <typename T> |
| static void printField(AsmPrinter& printer, StringRef name, T field, |
| StringRef& separator) { |
| if (field != 0) { |
| printer << separator << name << " = " << field; |
| separator = ", "; |
| } |
| } |
| template <typename T> |
| static void printField(AsmPrinter& printer, StringRef name, ArrayRef<T> field, |
| StringRef& separator) { |
| if (!field.empty()) { |
| printer << separator << name << " = ["; |
| llvm::interleaveComma(field, printer); |
| printer << "]"; |
| separator = ", "; |
| } |
| } |
| template <typename... Ts> |
| static void printStruct(AsmPrinter& printer, StringRef name, |
| Ts... printFields) { |
| printer << "<"; |
| StringRef separator = ""; |
| // Fold expression to print each entry in the parameter pack. |
| // TODO(mhlo-team): this can be simplified when TF moves to C++17. |
| using unused = int[]; |
| (void)unused{0, (printField(printer, std::get<0>(printFields), |
| std::get<1>(printFields), separator), |
| 0)...}; |
| printer << ">"; |
| } |
| |
| // Custom printer and parser for ScatterDimensionNumbersAttr. |
| void ScatterDimensionNumbersAttr::print(AsmPrinter& printer) const { |
| printStruct(printer, "scatter", |
| std::make_pair("update_window_dims", getUpdateWindowDims()), |
| std::make_pair("inserted_window_dims", getInsertedWindowDims()), |
| std::make_pair("scatter_dims_to_operand_dims", |
| getScatterDimsToOperandDims()), |
| std::make_pair("index_vector_dim", getIndexVectorDim())); |
| } |
| Attribute ScatterDimensionNumbersAttr::parse(AsmParser& parser, Type type) { |
| if (failed(parser.parseLess())) return {}; |
| SmallVector<int64_t> updateWindowDims; |
| SmallVector<int64_t> insertedWindowDims; |
| SmallVector<int64_t> scatterDimsToOperandDims; |
| int64_t indexVectorDim = 0; |
| |
| if (failed(parseStruct( |
| parser, |
| {"update_window_dims", "inserted_window_dims", |
| "scatter_dims_to_operand_dims", "index_vector_dim"}, |
| {[&]() { return parseDims(parser, updateWindowDims); }, |
| [&]() { return parseDims(parser, insertedWindowDims); }, |
| [&]() { return parseDims(parser, scatterDimsToOperandDims); }, |
| [&]() { return parser.parseInteger(indexVectorDim); }}))) { |
| parser.emitError(parser.getCurrentLocation()) |
| << "failed parsing scatter dimension numbers attribute"; |
| return {}; |
| } |
| |
| return ScatterDimensionNumbersAttr::get( |
| parser.getContext(), updateWindowDims, insertedWindowDims, |
| scatterDimsToOperandDims, indexVectorDim); |
| } |
| |
| // Custom printer and parser for GatherDimensionNumbersAttr. |
| void GatherDimensionNumbersAttr::print(AsmPrinter& printer) const { |
| printStruct(printer, "gather", std::make_pair("offset_dims", getOffsetDims()), |
| std::make_pair("collapsed_slice_dims", getCollapsedSliceDims()), |
| std::make_pair("start_index_map", getStartIndexMap()), |
| std::make_pair("index_vector_dim", getIndexVectorDim())); |
| } |
| |
| Attribute GatherDimensionNumbersAttr::parse(AsmParser& parser, Type type) { |
| if (failed(parser.parseLess())) return {}; |
| |
| SmallVector<int64_t> offsetDims; |
| SmallVector<int64_t> collapsedSliceDims; |
| SmallVector<int64_t> startIndexMap; |
| int64_t indexVectorDim = 0; |
| |
| if (failed(parseStruct( |
| parser, |
| {"offset_dims", "collapsed_slice_dims", "start_index_map", |
| "index_vector_dim"}, |
| {[&]() { return parseDims(parser, offsetDims); }, |
| [&]() { return parseDims(parser, collapsedSliceDims); }, |
| [&]() { return parseDims(parser, startIndexMap); }, |
| [&]() { return parser.parseInteger(indexVectorDim); }}))) { |
| parser.emitError(parser.getCurrentLocation()) |
| << "failed parsing gather dimension numbers attribute"; |
| return {}; |
| } |
| |
| return GatherDimensionNumbersAttr::get(parser.getContext(), offsetDims, |
| collapsedSliceDims, startIndexMap, |
| indexVectorDim); |
| } |
| |
| // Custom printer and parser for DotDimensionNumbersAttr. |
| void DotDimensionNumbersAttr::print(AsmPrinter& printer) const { |
| printStruct( |
| printer, "dot", |
| std::make_pair("lhs_batching_dimensions", getLhsBatchingDimensions()), |
| std::make_pair("rhs_batching_dimensions", getRhsBatchingDimensions()), |
| std::make_pair("lhs_contracting_dimensions", |
| getLhsContractingDimensions()), |
| std::make_pair("rhs_contracting_dimensions", |
| getRhsContractingDimensions())); |
| } |
| |
| Attribute DotDimensionNumbersAttr::parse(AsmParser& parser, Type type) { |
| if (failed(parser.parseLess())) return {}; |
| |
| SmallVector<int64_t> lhsBatchingDimensions; |
| SmallVector<int64_t> rhsBatchingDimensions; |
| SmallVector<int64_t> lhsContractingDimensions; |
| SmallVector<int64_t> rhsContractingDimensions; |
| |
| if (failed(parseStruct( |
| parser, |
| {"lhs_batching_dimensions", "rhs_batching_dimensions", |
| "lhs_contracting_dimensions", "rhs_contracting_dimensions"}, |
| {[&]() { return parseDims(parser, lhsBatchingDimensions); }, |
| [&]() { return parseDims(parser, rhsBatchingDimensions); }, |
| [&]() { return parseDims(parser, lhsContractingDimensions); }, |
| [&]() { return parseDims(parser, rhsContractingDimensions); }}))) { |
| parser.emitError(parser.getCurrentLocation()) |
| << "failed parsing dot dimension numbers attribute"; |
| return {}; |
| } |
| return DotDimensionNumbersAttr::get( |
| parser.getContext(), lhsBatchingDimensions, rhsBatchingDimensions, |
| lhsContractingDimensions, rhsContractingDimensions); |
| } |
| |
| namespace { |
| enum NonSpatialDim : int64_t { |
| IOBatch = -1, // Input or output batch dimension |
| IOFeature = -2, // Input or output feature dimension |
| KIFeature = -3, // Kernel input feature dimension |
| KOFeature = -4, // Kernel output feature dimensions. |
| }; |
| |
| struct DenseMapInfoNonSpatialDim { |
| static inline NonSpatialDim getEmptyKey() { |
| return NonSpatialDim(DenseMapInfo<int64_t>::getEmptyKey()); |
| } |
| |
| static inline NonSpatialDim getTombstoneKey() { |
| return NonSpatialDim(DenseMapInfo<int64_t>::getTombstoneKey()); |
| } |
| |
| static unsigned getHashValue(const NonSpatialDim& key) { |
| return DenseMapInfo<int64_t>::getHashValue(key); |
| } |
| |
| static bool isEqual(const NonSpatialDim& lhs, const NonSpatialDim& rhs) { |
| return lhs == rhs; |
| } |
| }; |
| |
| char nonSpatialDimToString(NonSpatialDim dim) { |
| switch (dim) { |
| case IOBatch: |
| return 'b'; |
| case IOFeature: |
| return 'f'; |
| case KIFeature: |
| return 'i'; |
| case KOFeature: |
| return 'o'; |
| } |
| llvm_unreachable("Unknown NonSpatialDim"); |
| } |
| } // namespace |
| |
| // Custom printer and parser for convolution attribute. |
| void printConvolutionDimensions(AsmPrinter& p, ConvDimensionNumbersAttr dnums) { |
| // TODO(b/202040055): we should check the attribute invariant and print the |
| // "raw" form if they are violated, otherwise we'll crash here. |
| constexpr int64_t kUnknownDim = std::numeric_limits<int64_t>::min(); |
| auto printDim = |
| [&](ArrayRef<int64_t> spatialDims, |
| ArrayRef<std::pair<int64_t, NonSpatialDim>> nonSpatialDims) { |
| int64_t numDims = 0; |
| if (!spatialDims.empty()) { |
| numDims = |
| *std::max_element(spatialDims.begin(), spatialDims.end()) + 1; |
| } |
| for (const auto& dim : nonSpatialDims) { |
| numDims = std::max(numDims, dim.first + 1); |
| } |
| |
| llvm::SmallVector<int64_t> dims(numDims, kUnknownDim); |
| // Fill each element of dims with a (< 0) NonSpatialDim enum or a (>=0) |
| // spatial dimension index. |
| for (const std::pair<int64_t, NonSpatialDim>& nonSpatialDim : |
| nonSpatialDims) { |
| dims[nonSpatialDim.first] = nonSpatialDim.second; |
| } |
| for (const auto& spatialDim : llvm::enumerate(spatialDims)) { |
| dims[spatialDim.value()] = static_cast<int64_t>(spatialDim.index()); |
| } |
| |
| // Each dimension numbers will be printed as a comma separated list |
| // surrounded by square brackets, e.g., [b, 0, 1, 2, f] |
| p << '['; |
| llvm::interleaveComma(dims, p, [&](int64_t dim) { |
| if (dim == kUnknownDim) { |
| p << "?"; |
| } else if (dim >= 0) { |
| p << dim; |
| } else { |
| p << nonSpatialDimToString(static_cast<NonSpatialDim>(dim)); |
| } |
| }); |
| p << ']'; |
| }; |
| |
| printDim(dnums.getInputSpatialDimensions(), |
| {{dnums.getInputBatchDimension(), IOBatch}, |
| {dnums.getInputFeatureDimension(), IOFeature}}); |
| p << "x"; |
| printDim(dnums.getKernelSpatialDimensions(), |
| {{dnums.getKernelInputFeatureDimension(), KIFeature}, |
| {dnums.getKernelOutputFeatureDimension(), KOFeature}}); |
| p << "->"; |
| printDim(dnums.getOutputSpatialDimensions(), |
| {{dnums.getOutputBatchDimension(), IOBatch}, |
| {dnums.getOutputFeatureDimension(), IOFeature}}); |
| } |
| |
| void printConvolutionDimensions(AsmPrinter& p, Operation*, |
| ConvDimensionNumbersAttr dnums) { |
| printConvolutionDimensions(p, dnums); |
| } |
| |
| // Custom printer and parser for ConvDimensionNumbersAttr. |
| void ConvDimensionNumbersAttr::print(AsmPrinter& printer) const { |
| printer << "<"; |
| printConvolutionDimensions(printer, *this); |
| printer << ">"; |
| } |
| |
| // If the attribute is written with `#mhlo.conv raw<`, we parse it as a struct |
| // instead of the compressed format. This enables writing tests covering |
| // impossible/invalid internal representation for the attribute. |
| static ParseResult parseConvolutionDimensionsRaw( |
| AsmParser& parser, ConvDimensionNumbersAttr& dnums) { |
| int64_t inputBatchDimension = 0; |
| int64_t inputFeatureDimension = 0; |
| SmallVector<int64_t> inputSpatialDimensions; |
| int64_t kernelInputFeatureDimension = 0; |
| int64_t kernelOutputFeatureDimension = 0; |
| SmallVector<int64_t> kernelSpatialDimensions; |
| int64_t outBatchDimension = 0; |
| int64_t outputFeatureDimension = 0; |
| SmallVector<int64_t> outputSpatialDimensions; |
| if (failed(parseStruct( |
| parser, |
| {"input_batch_dimension", "input_feature_dimension", |
| "input_spatial_dimensions", "kernel_input_feature_dimension", |
| "kernel_output_feature_dimension", "kernel_spatial_dimensions", |
| "output_batch_dimension", "output_feature_dimension", |
| "output_spatial_dimensions"}, |
| { |
| [&]() { return parser.parseInteger(inputBatchDimension); }, |
| [&]() { return parser.parseInteger(inputFeatureDimension); }, |
| [&]() { return parseDims(parser, inputSpatialDimensions); }, |
| [&]() { |
| return parser.parseInteger(kernelInputFeatureDimension); |
| }, |
| [&]() { |
| return parser.parseInteger(kernelOutputFeatureDimension); |
| }, |
| [&]() { return parseDims(parser, kernelSpatialDimensions); }, |
| [&]() { return parser.parseInteger(outBatchDimension); }, |
| [&]() { return parser.parseInteger(outputFeatureDimension); }, |
| [&]() { return parseDims(parser, outputSpatialDimensions); }, |
| }))) { |
| parser.emitError(parser.getCurrentLocation()) |
| << "failed parsing dot dimension numbers attribute"; |
| return failure(); |
| } |
| dnums = ConvDimensionNumbersAttr::get( |
| parser.getBuilder().getContext(), inputBatchDimension, |
| inputFeatureDimension, inputSpatialDimensions, |
| kernelInputFeatureDimension, kernelOutputFeatureDimension, |
| kernelSpatialDimensions, outBatchDimension, outputFeatureDimension, |
| outputSpatialDimensions); |
| return success(); |
| } |
| |
| ParseResult parseConvolutionDimensions(AsmParser& parser, |
| ConvDimensionNumbersAttr& dnums) { |
| // Parsing a single set of dim numbers gives the spatial dimensions as a |
| // single ArrayRef<int64_t> and a list of non-spatial dimensions as |
| // IntegerAttrs (indexed by the NonSpatialDim enum). |
| using parse_dim_result_t = |
| std::pair<llvm::SmallVector<int64_t>, |
| llvm::SmallDenseMap<NonSpatialDim, int64_t, 4, |
| DenseMapInfoNonSpatialDim>>; |
| |
| // Note that the allowed_non_spatial_dims is a set (as opposed to unordered |
| // set) because its used to print a list of allowed non spatial dims in the |
| // error messages, so making it a set keeps the error messages deterministic. |
| auto parseDims = |
| [&](std::set<NonSpatialDim, std::greater<>> allowedNonSpatialDims, |
| parse_dim_result_t& parsedDims) -> ParseResult { |
| auto& spatialDims = std::get<0>(parsedDims); |
| auto& nonSpatialDims = std::get<1>(parsedDims); |
| spatialDims.clear(); |
| nonSpatialDims.clear(); |
| |
| // Parse the starting [ |
| if (parser.parseLSquare()) { |
| return failure(); |
| } |
| |
| llvm::SmallDenseMap<int64_t, int64_t> spatialDimsMap; |
| constexpr int64_t kInvalidDimension = -1; |
| // Keep track of the maximum spatial dimension parsed as we expect to see |
| // all the dimensions from 0 to maximum dimension parsed. |
| int64_t maxParsedSpatialDim = kInvalidDimension; |
| |
| int64_t index = 0; |
| do { |
| int64_t spatialDim; |
| auto dimLocation = parser.getCurrentLocation(); |
| OptionalParseResult parseResult = parser.parseOptionalInteger(spatialDim); |
| if (parseResult.hasValue()) { |
| if (parseResult.getValue().failed()) { |
| return failure(); |
| } |
| // We were successful in parsing an integer. Check if it is a valid |
| // dimension (non-negative and no duplicate) and add its index to the |
| // spatial dims map. |
| if (spatialDim < 0) |
| return parser.emitError(dimLocation) |
| << "Unexpected dimension " << spatialDim; |
| if (!spatialDimsMap |
| .insert(std::pair<int64_t, int64_t>(spatialDim, index)) |
| .second) |
| return parser.emitError(dimLocation) |
| << "Duplicate entries for spatial dimension " << spatialDim; |
| maxParsedSpatialDim = std::max(spatialDim, maxParsedSpatialDim); |
| } else if (!parser.parseOptionalQuestion()) { |
| // Do nothing other than increment `index` at the bottom of the loop; |
| // '?' means "unknown dimension", and it's not represented in the |
| // return value of this function. |
| } else { |
| // We did not parse an integer or question mark. We expect a keyword |
| // token. |
| StringRef keyword; |
| if (parser.parseKeyword(&keyword)) { |
| return failure(); |
| } |
| if (keyword.size() != 1 || allowedNonSpatialDims.empty()) { |
| return parser.emitError(dimLocation, "Unexpected keyword ") |
| << keyword; |
| } |
| // Check if the keyword matches one of the allowed non-spatial dims. |
| // If so, add it to the non_spatial dims and remove it from the |
| // allowed set so that it won't be allowed again. |
| bool isAllowed = false; |
| for (NonSpatialDim allowed : allowedNonSpatialDims) { |
| if (keyword[0] == nonSpatialDimToString(allowed)) { |
| nonSpatialDims.insert({allowed, index}); |
| allowedNonSpatialDims.erase(allowed); |
| isAllowed = true; |
| break; |
| } |
| } |
| |
| if (!isAllowed) { |
| mlir::InFlightDiagnostic diag = |
| parser.emitError(dimLocation, "Unexpected dimension "); |
| diag << keyword << ", expecting "; |
| llvm::interleaveComma( |
| allowedNonSpatialDims, diag, |
| [&](NonSpatialDim dim) { diag << nonSpatialDimToString(dim); }); |
| return diag; |
| } |
| } |
| index++; |
| } while (parser.parseOptionalComma().succeeded()); |
| |
| // Make sure all expected non-spatial dimensions are parsed. |
| if (!allowedNonSpatialDims.empty()) { |
| mlir::InFlightDiagnostic diag = |
| parser.emitError(parser.getCurrentLocation(), "Expected dimensions "); |
| llvm::interleaveComma( |
| allowedNonSpatialDims, diag, |
| [&](NonSpatialDim dim) { diag << nonSpatialDimToString(dim); }); |
| diag << " not specified"; |
| return diag; |
| } |
| |
| // parse ending ] |
| if (parser.parseRSquare()) { |
| return failure(); |
| } |
| |
| // Number of expected spatial dimensions is one more than the maximum parsed |
| // spatial dimension. For example, if we parse [0, 3, 2, b, i, 1], then the |
| // maximum parsed spatial dimension is 3 and the number of expected spatial |
| // dimensions is 4. |
| int64_t numSpatialDimensions = maxParsedSpatialDim + 1; |
| spatialDims.resize(numSpatialDimensions); |
| // Store spatial dimensions in a vector which maps spatial dim (vector |
| // index) -> index in the tensor dimensions. For example, for parsed |
| // dimension numbers [0, 3, 2, b, i, 1] the spatial dimension vector would |
| // be [0, 5, 2, 1]. |
| // |
| // Get all the unspecified spatial dimensions to throw a more descriptive |
| // error later. |
| llvm::SmallVector<int64_t> unspecifiedSpatialDims; |
| constexpr int kPrintUnspecifiedDimsMax = 10; |
| for (int dim = 0; dim < numSpatialDimensions; ++dim) { |
| auto it = spatialDimsMap.find(dim); |
| if (it == spatialDimsMap.end()) { |
| // Have an upper bound on the number of unspecified dimensions to print |
| // in the error message. |
| if (unspecifiedSpatialDims.size() < kPrintUnspecifiedDimsMax) |
| unspecifiedSpatialDims.push_back(dim); |
| continue; |
| } |
| spatialDims[dim] = it->second; |
| } |
| |
| // Verify that we got all spatial dimensions between 0 and maximum parsed |
| // spatial dimension. |
| if (!unspecifiedSpatialDims.empty()) { |
| mlir::InFlightDiagnostic diag = parser.emitError( |
| parser.getCurrentLocation(), "Expected spatial dimensions "); |
| llvm::interleaveComma(unspecifiedSpatialDims, diag); |
| diag << " not specified"; |
| return diag; |
| } |
| |
| return success(); |
| }; |
| |
| parse_dim_result_t parsedDims; |
| if (parseDims({IOBatch, IOFeature}, parsedDims)) { |
| return failure(); |
| } |
| llvm::SmallVector<int64_t> inputSpatialDimensions = parsedDims.first; |
| int64_t inputBatchDimension = parsedDims.second[IOBatch]; |
| int64_t inputFeatureDimension = parsedDims.second[IOFeature]; |
| if (parser.parseKeyword("x")) return failure(); |
| if (parseDims({KIFeature, KOFeature}, parsedDims)) { |
| return failure(); |
| } |
| llvm::SmallVector<int64_t> kernelSpatialDimensions = parsedDims.first; |
| int64_t kernelInputFeatureDimension = parsedDims.second[KIFeature]; |
| int64_t kernelOutputFeatureDimension = parsedDims.second[KOFeature]; |
| if (parser.parseArrow()) { |
| return failure(); |
| } |
| if (parseDims({IOBatch, IOFeature}, parsedDims)) { |
| return failure(); |
| } |
| llvm::SmallVector<int64_t> outputSpatialDimensions = parsedDims.first; |
| const int64_t outBatchDimension = parsedDims.second[IOBatch]; |
| const int64_t outputFeatureDimension = parsedDims.second[IOFeature]; |
| dnums = ConvDimensionNumbersAttr::get( |
| parser.getBuilder().getContext(), inputBatchDimension, |
| inputFeatureDimension, inputSpatialDimensions, |
| kernelInputFeatureDimension, kernelOutputFeatureDimension, |
| kernelSpatialDimensions, outBatchDimension, outputFeatureDimension, |
| outputSpatialDimensions); |
| |
| return success(); |
| } |
| |
| Attribute ConvDimensionNumbersAttr::parse(AsmParser& parser, Type type) { |
| if (failed(parser.parseLess())) return {}; |
| ConvDimensionNumbersAttr dnums; |
| if (succeeded(parser.parseOptionalKeyword("raw"))) { |
| if (failed(parseConvolutionDimensionsRaw(parser, dnums))) return {}; |
| return dnums; |
| } |
| if (failed(parseConvolutionDimensions(parser, dnums))) return {}; |
| if (failed(parser.parseGreater())) return {}; |
| return dnums; |
| } |
| |
| // Custom printer and parser for ArgResultAliasAttr. |
| constexpr char kMustAlias[] = "must_alias"; |
| constexpr char kResult[] = "result_index"; |
| constexpr char kArgTupleIndices[] = "tuple_indices"; |
| |
| void ArgResultAliasAttr::print(AsmPrinter& printer) const { |
| printer << "<"; |
| |
| // The attribute can have empty tuple indices. Only print argument tuple |
| // indices if they are non-empty. |
| if (!getArgTupleIndices().empty()) |
| printer << kArgTupleIndices << " = [" << getArgTupleIndices() << "], "; |
| |
| // Print the result index followed by any result tuple indices if present. |
| printer << kResult << " = ["; |
| printer << getResultIndex(); |
| if (!getResultTupleIndices().empty()) { |
| printer << ", " << getResultTupleIndices(); |
| } |
| printer << "]"; |
| |
| // Print the "must_alias" keyword if this is a must alias, otherwise skip. |
| if (getIsMustAlias()) printer << ", " << kMustAlias; |
| |
| printer << ">"; |
| } |
| |
| Attribute ArgResultAliasAttr::parse(AsmParser& parser, Type type) { |
| if (failed(parser.parseLess())) return {}; |
| llvm::SmallVector<int64_t> argTupleIndices; |
| // The first element of result indices holds the aliased result index and the |
| // remaining elements are the result tuple indices. |
| llvm::SmallVector<int64_t> resultIndices; |
| bool isMustAlias = false; |
| |
| // This conveys to parseStruct that keyword "must_alias" (3rd field) is not |
| // followed by a "=", but other fields are. |
| llvm::SmallVector<bool, 3> parseEqual = {true, true, false}; |
| |
| if (failed(parseStruct(parser, {kArgTupleIndices, kResult, kMustAlias}, |
| {[&]() { return parseDims(parser, argTupleIndices); }, |
| [&]() { |
| // Since the first element is the index of result, |
| // at least one element is expected. |
| return parseDimsWithMinimumElements( |
| parser, resultIndices, /*minElements=*/1); |
| }, |
| [&]() { |
| // always succeeds if the keyword "must_alias" was |
| // parsed |
| isMustAlias = true; |
| return success(); |
| }}, |
| parseEqual))) { |
| parser.emitError(parser.getCurrentLocation()) |
| << "failed parsing argument-result alias attribute"; |
| return {}; |
| } |
| |
| int64_t resultIndex = resultIndices[0]; |
| auto resultTupleIndices = |
| ArrayRef<int64_t>{resultIndices.begin() + 1, resultIndices.end()}; |
| |
| return ArgResultAliasAttr::get(parser.getContext(), argTupleIndices, |
| resultIndex, resultTupleIndices, isMustAlias); |
| } |
| |
| // Returns the element type pointed to by `indices` in type `t`. If the indices |
| // are invalid, returns nullptr. |
| static Type getTypeFromTupleIndices(Type type, ArrayRef<int64_t> indices) { |
| Type current = type; |
| for (auto index : indices) { |
| TupleType tupleType = current.dyn_cast<TupleType>(); |
| if (!tupleType || index >= static_cast<int64_t>(tupleType.size())) |
| return {}; |
| current = tupleType.getType(index); |
| } |
| return current; |
| } |
| |
| static LogicalResult verifyArgResultAliasAttr(StringAttr attrName, |
| ArgResultAliasAttr aliasAttr, |
| unsigned argIndex, |
| Operation* op) { |
| // The attribute can only be applied to function-like operations. |
| if (!isa<mlir::FunctionOpInterface>(op)) |
| return op->emitOpError() << "attribute " << attrName |
| << " can only be used on function-like operations"; |
| |
| // Verify there are no negative indices. |
| auto tupleIndices = llvm::concat<const int64_t>( |
| aliasAttr.getArgTupleIndices(), aliasAttr.getResultTupleIndices()); |
| if (llvm::any_of(tupleIndices, [](const int64_t val) { return val < 0; }) || |
| aliasAttr.getResultIndex() < 0) |
| return op->emitOpError() |
| << "attribute " << attrName |
| << " expects all argument and result indices to be >= 0"; |
| |
| // Verify that the result index is not out of range. Since the attribute is a |
| // function argument attribute, the argument index is always correct when this |
| // verifier is called. |
| FunctionOpInterface funcOp = cast<FunctionOpInterface>(op); |
| ArrayRef<Type> argTypes = funcOp.getArgumentTypes(); |
| ArrayRef<Type> resultTypes = funcOp.getResultTypes(); |
| if (aliasAttr.getResultIndex() >= static_cast<int64_t>(resultTypes.size())) |
| return op->emitOpError() |
| << "attribute " << attrName |
| << " result index is out of range, must be <" << resultTypes.size(); |
| |
| // Verify that argument and result types pointed to by the indices are valid |
| // and compatible. |
| Type argType = getTypeFromTupleIndices(argTypes[argIndex], |
| aliasAttr.getArgTupleIndices()); |
| if (!argType) |
| return op->emitOpError() |
| << "attribute " << attrName << " argument tuple indices are invalid"; |
| Type resultType = |
| getTypeFromTupleIndices(resultTypes[aliasAttr.getResultIndex()], |
| aliasAttr.getResultTupleIndices()); |
| if (!resultType) |
| return op->emitOpError() |
| << "attribute " << attrName << " result tuple indices are invalid"; |
| |
| if (failed(mlir::verifyCompatibleShape(argType, resultType)) || |
| getElementTypeOrSelf(argType) != getElementTypeOrSelf(resultType)) |
| return op->emitOpError() << "attribute " << attrName |
| << " aliases do not have compatible types, " |
| << argType << " vs. " << resultType; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Type utilities |
| //===----------------------------------------------------------------------===// |
| |
| Type getExpressedTypeOrSelf(Type type) { |
| auto quantType = type.dyn_cast<quant::QuantizedType>(); |
| return quantType ? quantType.getExpressedType() : type; |
| } |
| |
| bool isCompatibleForMhloTypeInference(Type tp1, Type tp2) { |
| // Dynamism: We don't require shapes to be the same, we only require them |
| // to be compatible, which means that: |
| // 1) At least one of the shapes is unranked. |
| // 2) Or both shapes have the same rank and their dimensions are compatible, |
| // i.e. for each pair of corresponding dimensions: |
| // 2.1) At least one of the dimensions is dynamic, |
| // 2.2) Or both dimensions are equal. |
| // These relaxed rules simplify the implementation of type inference, allowing |
| // ops with partially inferred types to pass verification. |
| // No additional code is needed to check bounded cases. |
| // Individual ops may introduce additional constraints. |
| auto stp1 = tp1.dyn_cast<ShapedType>(); |
| auto stp2 = tp2.dyn_cast<ShapedType>(); |
| if (stp1 && stp2) { |
| return succeeded(verifyCompatibleShape(stp1, stp2)) && |
| isCompatibleForMhloTypeInference(stp1.getElementType(), |
| stp2.getElementType()); |
| } |
| |
| // Quantization: In the most general case, we allow any combination of |
| // quantized/non-quantized across any combination of operands/results, |
| // and some differences in quantization parameters across operands/results. |
| // Individual ops may introduce additional constraints. |
| auto qtp1 = tp1.dyn_cast<quant::QuantizedType>(); |
| auto qtp2 = tp2.dyn_cast<quant::QuantizedType>(); |
| if (qtp1 && qtp2) { |
| if (qtp1.getStorageType() != qtp2.getStorageType() || |
| qtp1.getStorageTypeMin() != qtp2.getStorageTypeMin() || |
| qtp1.getStorageTypeMax() != qtp2.getStorageTypeMax()) |
| return false; |
| } |
| auto etp1 = getExpressedTypeOrSelf(tp1); |
| auto etp2 = getExpressedTypeOrSelf(tp2); |
| |
| // Sparsity: In the most general case, we allow any combination of |
| // sparsity/denseness across any combination of operands/results, as well as |
| // differences in sparsity encodings for operands and results. |
| // Individual ops may introduce additional constraints. |
| // No additional code is needed to check this because of how sparsity is |
| // currently implemented. |
| |
| // Default case: Unless dynamism, quantization and/or sparsity are involved, |
| // the types are required to be exactly equal. |
| return etp1 == etp2; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Builder utilities |
| //===----------------------------------------------------------------------===// |
| |
| // Builds the region `body` for mhlo.sort's comparator: for each type in |
| // `element_types`, create two block arguments, one for lhs and one for rhs, and |
| // generates mhlo.compare op to compare them with the given `direction`. |
| // |
| // Note that this right now only does comparision on the first pair of block |
| // arguments. |
| static void buildSortComparisonBody(llvm::ArrayRef<Type> elementTypes, |
| ComparisonDirection direction, |
| llvm::Optional<StringRef> compareType, |
| Region* body, OpBuilder* builder) { |
| OpBuilder::InsertionGuard insertionPointGurad(*builder); |
| |
| Location loc = body->getLoc(); |
| Block* block = builder->createBlock(body); |
| // Add two arguments for each element type. |
| for (Type elementType : elementTypes) { |
| TensorType tensorType = RankedTensorType::get({}, elementType); |
| block->addArguments({tensorType, tensorType}, |
| SmallVector<Location, 2>(2, loc)); |
| } |
| |
| ComparisonType typeAttr; |
| if (compareType) |
| typeAttr = symbolizeComparisonType(*compareType).getValue(); |
| else |
| typeAttr = ComparisonType::NOTYPE; |
| Value compare = builder->create<mhlo::CompareOp>( |
| loc, block->getArgument(0), block->getArgument(1), direction, typeAttr); |
| |
| builder->create<mhlo::ReturnOp>(loc, compare); |
| } |
| |
| SortOp createSortOp(PatternRewriter* rewriter, const Location& loc, |
| const llvm::ArrayRef<Value>& operands, |
| const llvm::ArrayRef<Type>& elementTypes, int64_t dimension, |
| bool isStable, ComparisonDirection direction) { |
| assert(!operands.empty() && "No operands to sort"); |
| // Create the sort op. |
| auto sortOp = |
| rewriter->create<mhlo::SortOp>(loc, operands, dimension, isStable); |
| |
| // Use TOTALORDER comparison type instead of the default comparison if the |
| // element type is of type float. |
| llvm::Optional<StringRef> compareType = llvm::None; |
| for (auto const& elementType : elementTypes) |
| if (elementType.isa<FloatType>()) { |
| compareType.emplace("TOTALORDER"); |
| break; |
| } |
| buildSortComparisonBody(elementTypes, direction, compareType, |
| &sortOp.comparator(), rewriter); |
| return sortOp; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Shape inference |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult deriveShapeFromOperand( |
| OpBuilder* builder, Operation* op, Value operand, |
| SmallVectorImpl<Value>* reifiedReturnShapes) { |
| auto shapedTy = operand.getType().dyn_cast<ShapedType>(); |
| if (!shapedTy) { |
| op->emitOpError() << "operand is not a shaped type"; |
| return failure(); |
| } |
| reifiedReturnShapes->assign( |
| {builder->create<shape::ShapeOfOp>(op->getLoc(), operand)}); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MHLO Dialect Hooks |
| //===----------------------------------------------------------------------===// |
| |
| Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value, |
| Type type, Location loc) { |
| auto elementsAttr = value.dyn_cast<ElementsAttr>(); |
| // HLO dialect constants only support ElementsAttr unlike standard dialect |
| // constant which supports all attributes. |
| if (!elementsAttr) return nullptr; |
| // HLO dialect constants require the type of value and result to match. |
| if (type != elementsAttr.getType()) return nullptr; |
| |
| return builder.create<mhlo::ConstantOp>(loc, type, elementsAttr); |
| } |
| |
| LogicalResult MhloDialect::verifyRegionArgAttribute(Operation* op, |
| unsigned /*regionIndex*/, |
| unsigned argIndex, |
| NamedAttribute attr) { |
| if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) { |
| if (failed( |
| verifyArgResultAliasAttr(attr.getName(), aliasAttr, argIndex, op))) |
| return failure(); |
| } |
| return success(); |
| } |
| |
| LogicalResult MhloDialect::verifyOperationAttribute(Operation* op, |
| NamedAttribute attr) { |
| if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) { |
| if (!isa<mlir::FunctionOpInterface>(op)) |
| return op->emitOpError() |
| << "attribute " << attr.getName() |
| << " can only be used on function-like operations"; |
| } |
| return success(); |
| } |
| |
| } // namespace mhlo |
| } // namespace mlir |