| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. |
| |
| #include <algorithm> |
| #include <numeric> |
| #include <string> |
| #include <utility> |
| |
| #include "llvm/ADT/BitVector.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/SmallSet.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" |
| #include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h" |
| #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" |
| #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h" |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
| #include "mlir/Dialect/Arithmetic/Utils/Utils.h" |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/Complex/IR/Complex.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/Shape/IR/Shape.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Tensor/Utils/Utils.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Location.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir { |
| namespace { |
| |
| template <typename OpTy> |
| SmallVector<NamedAttribute> pruneAttributeList(OpTy op) { |
| auto opAttributes = op.getAttributeNames(); |
| llvm::StringSet<> elidedAttrs; |
| elidedAttrs.insert(opAttributes.begin(), opAttributes.end()); |
| SmallVector<NamedAttribute> preservedAttrs; |
| for (auto attr : op->getAttrs()) { |
| if (elidedAttrs.count(attr.getName())) continue; |
| preservedAttrs.push_back(attr); |
| } |
| return preservedAttrs; |
| } |
| |
| /// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes |
| /// are "parallel" except the last `nReduction` elements, where are "reduction" |
| /// attributes. |
| SmallVector<StringRef, 3> getParallelAndReductionIterators( |
| unsigned nLoops, unsigned nReduction) { |
| SmallVector<StringRef, 3> res(nLoops - nReduction, |
| getParallelIteratorTypeName()); |
| res.append(nReduction, getReductionIteratorTypeName()); |
| return res; |
| } |
| |
| SmallVector<StringRef, 3> getNParallelLoopsAttrs(unsigned nParallelLoops) { |
| return getParallelAndReductionIterators(nParallelLoops, 0); |
| } |
| |
| Value getResultValue(Operation* op) { return op->getResult(0); } |
| |
| ShapedType getHloOpResultType(Operation* op) { |
| return getResultValue(op).getType().cast<ShapedType>(); |
| } |
| |
| bool verifyHloOpBufferOrTensorSemantics(Operation* op) { |
| auto verifyType = [&](Value val) -> bool { |
| return val.getType().isa<RankedTensorType>(); |
| }; |
| if (!llvm::all_of(op->getOperands(), verifyType)) return false; |
| return llvm::all_of(op->getResults(), verifyType); |
| } |
| |
| Value getInitTensor(OpBuilder& b, Location loc, ShapedType type, |
| ArrayRef<Value> dynSizes) { |
| return b.create<linalg::InitTensorOp>(loc, dynSizes, type.getShape(), |
| type.getElementType()); |
| } |
| |
| Value getInitSparseTensor(OpBuilder& b, Location loc, ShapedType type, |
| ArrayRef<Value> dynSizes) { |
| return b.create<bufferization::AllocTensorOp>(loc, type, dynSizes, |
| /*copy=*/Value(), |
| /*memory_space=*/IntegerAttr()); |
| } |
| |
| Value getInitTensorFor(OpBuilder& b, Location loc, ShapedType resultType, |
| Operation* op, ValueRange operands) { |
| bool isSparse = sparse_tensor::getSparseTensorEncoding(resultType) != nullptr; |
| // Collect the sizes for a ranked tensor to be passed as parameter to a |
| // new tensor initialization operation. This operation only needs the |
| // dynamic sizes. |
| SmallVector<Value> sizes; |
| if (resultType.hasRank() && !resultType.hasStaticShape()) { |
| // Ask the op for its output shape. |
| auto shapeSource = cast<InferShapedTypeOpInterface>(op); |
| SmallVector<Value, 1> reifiedShapes; |
| (void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes); |
| assert(reifiedShapes.size() == 1 && "Expected one reified result"); |
| // Construct sizes for the required dimensions. |
| for (auto& en : llvm::enumerate(resultType.getShape())) { |
| if (en.value() != ShapedType::kDynamicSize) continue; |
| sizes.push_back(b.create<tensor::ExtractOp>( |
| loc, reifiedShapes[0], |
| ValueRange{b.create<arith::ConstantIndexOp>(loc, en.index())})); |
| } |
| } |
| return isSparse ? getInitSparseTensor(b, loc, resultType, sizes) |
| : getInitTensor(b, loc, resultType, sizes); |
| } |
| |
| Value fillTensorWithZeros(OpBuilder& builder, Location loc, Value tensor) { |
| auto type = tensor.getType().cast<ShapedType>(); |
| Value zero; |
| // Complex numbers are a special case. |
| if (auto complexType = type.getElementType().dyn_cast<ComplexType>()) { |
| auto zeroElement = builder.getZeroAttr(complexType.getElementType()); |
| auto zeroAttr = builder.getArrayAttr({zeroElement, zeroElement}); |
| zero = builder.create<complex::ConstantOp>(loc, complexType, zeroAttr); |
| } else { |
| auto zeroAttr = builder.getZeroAttr(type.getElementType()); |
| zero = builder.create<arith::ConstantOp>(loc, zeroAttr); |
| } |
| return builder.create<linalg::FillOp>(loc, zero, tensor).result(); |
| } |
| |
| static inline bool hasIntegralShapeType(Operation* op) { |
| auto stp = op->getOperand(0).getType().dyn_cast<ShapedType>(); |
| return stp && stp.getElementType().isIntOrIndex(); |
| } |
| |
| /// Sparsifies a (block of) operation(s) that cannot be handled directly |
| /// by the sparse compiler but has well-known semi-ring semantics. |
| /// |
| /// This yields something of the following form: |
| /// |
| /// %result = sparse_tensor.unary %values[0] |
| /// present={ |
| /// ^bb1(%val): |
| /// ... codegen proceeds here using %val .... |
| /// sparse_tensor.yield |
| /// } |
| /// absent={} |
| /// linalg.yield %result |
| Value preSparsify(Operation* op, llvm::SmallVector<Value, 2>& values, Type rtp, |
| OpBuilder* b) { |
| // Apply for semi-ring operations that lower to elaborate code |
| // (any sign-op, any elt-wise conversion, or an integral abs-op). |
| if (isa<mhlo::SignOp>(op) || isa<mhlo::ConvertOp>(op) || |
| (isa<mhlo::AbsOp>(op) && hasIntegralShapeType(op))) { |
| if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) && |
| !sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType())) |
| return Value(); |
| Location loc = op->getLoc(); |
| auto semiring = b->create<sparse_tensor::UnaryOp>(loc, rtp, values[0]); |
| Type itp = values[0].getType(); |
| Block* present = b->createBlock(&semiring.presentRegion(), {}, itp, loc); |
| b->setInsertionPointToStart(&semiring.presentRegion().front()); |
| values[0] = present->getArgument(0); |
| return semiring; |
| } |
| return Value(); |
| } |
| |
| /// Finalizes sparse semi-ring construction. |
| Value postSparsify(Operation* op, Value semiring, Value result, OpBuilder* b) { |
| if (semiring) { |
| b->create<sparse_tensor::YieldOp>(op->getLoc(), result); |
| b->setInsertionPointAfter(semiring.getDefiningOp()); |
| return semiring; |
| } |
| return result; |
| } |
| |
| SmallVector<int64_t, 4> extract1DVector(DenseIntElementsAttr elements) { |
| SmallVector<int64_t, 4> ret; |
| for (const APInt& element : elements) { |
| ret.push_back(element.getLimitedValue()); |
| } |
| return ret; |
| } |
| |
| /// Returns a permutation AffineMap that puts all reduction dimensions to the |
| /// last. The order of parallel loops and reduction loops are all sorted. E.g., |
| /// if `rank` is 4 and `reductionDims` is {1, 3}, then |
| /// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of |
| /// the AffineMap is returned. |
| AffineMap getTransposeMapForReduction(MLIRContext* context, int rank, |
| ArrayRef<int64_t> reductionDims) { |
| llvm::SmallSetVector<int, 4> s; |
| for (auto dim : reductionDims) s.insert(dim); |
| |
| SmallVector<unsigned, 4> permutation; |
| for (int i = 0; i < rank; ++i) |
| if (!s.count(i)) permutation.push_back(i); |
| for (auto dim : reductionDims) permutation.push_back(dim); |
| |
| auto map = AffineMap::getPermutationMap(permutation, context); |
| return inversePermutation(map); |
| } |
| |
| /// Returns true if the given `attr` is a splat of the given `value`. |
| bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) { |
| return attr.isSplat() && attr.getSplatValue<uint64_t>() == value; |
| } |
| |
| /// Returns true if the given `dimensionNumbers` from a mhlo.convolution op |
| /// follows a canonical form: |
| /// |
| /// * Input dimensions have order: (batch_count, spatial_dims, |
| /// input_channel_count). |
| /// * Filter dimensions have order: (spatial_dims, input_channel_count, |
| /// output_channel_count). |
| /// * Output dimensions have order: (batch_count, spatial_dims, |
| /// output_channel_count). |
| static bool hasCanonicalDimensionNumbers( |
| mhlo::ConvDimensionNumbersAttr dimensionNumbers) { |
| const int inputSpatialRank = |
| llvm::size(dimensionNumbers.getInputSpatialDimensions()); |
| // The dimensions for input should follow the order of |
| // batch_count, spatial_dims..., input_feature_count. |
| if (dimensionNumbers.getInputBatchDimension() != 0 || |
| dimensionNumbers.getInputFeatureDimension() != (inputSpatialRank + 1)) { |
| return false; |
| } |
| |
| const int kernelSpatialRank = |
| llvm::size(dimensionNumbers.getKernelSpatialDimensions()); |
| // The dimensions for filter should follow the order of |
| // spatial_dims..., input_feature_count, num_output_feature_count. |
| if (dimensionNumbers.getKernelInputFeatureDimension() != kernelSpatialRank || |
| dimensionNumbers.getKernelOutputFeatureDimension() != |
| (kernelSpatialRank + 1)) { |
| return false; |
| } |
| |
| const int outputSpatialRank = |
| llvm::size(dimensionNumbers.getOutputSpatialDimensions()); |
| // The dimensions for output should follow the order of |
| // batch_count, spatial_dims.., output_feature_count. |
| if (dimensionNumbers.getOutputBatchDimension() != 0 || |
| dimensionNumbers.getOutputFeatureDimension() != (outputSpatialRank + 1)) { |
| return false; |
| } |
| |
| if (inputSpatialRank != outputSpatialRank || |
| inputSpatialRank != kernelSpatialRank) { |
| return false; |
| } |
| |
| const auto* inputSpatialDim = |
| dimensionNumbers.getInputSpatialDimensions().begin(); |
| const auto* kernelSpatialDim = |
| dimensionNumbers.getKernelSpatialDimensions().begin(); |
| const auto* outputSpatialDim = |
| dimensionNumbers.getOutputSpatialDimensions().begin(); |
| // Check spatial dims are ordered correctly. |
| for (int i = 0; i < inputSpatialRank; ++i) { |
| const int dim = i + 1; |
| if ((*inputSpatialDim++) != dim || (*outputSpatialDim++) != dim || |
| (*kernelSpatialDim++) != i) { |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // mhlo.RngOp conversion patterns. |
| //===----------------------------------------------------------------------===// |
| |
| // Pass to lower from rng to stateless pseudo RNG with LCG |
| // algorithm |
| struct RngUniformConversion : public OpConversionPattern<mhlo::RngOp> { |
| using OpConversionPattern<mhlo::RngOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::RngOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| // We only handle uniform distributions |
| if (op.rng_distribution() != ::mlir::mhlo::RngDistribution::UNIFORM) { |
| return failure(); |
| } |
| // TODO(raikonenfnu): Handle other element types as well. |
| auto minTy = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>(); |
| auto maxTy = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>(); |
| if (!minTy.getElementType().dyn_cast<FloatType>() || |
| !maxTy.getElementType().dyn_cast<FloatType>()) { |
| return rewriter.notifyMatchFailure( |
| op, "expected min/max for rng op to be FloatType"); |
| } |
| auto targetTy = this->typeConverter->convertType(op.getResult().getType()) |
| .cast<ShapedType>(); |
| if (!targetTy) { |
| return rewriter.notifyMatchFailure( |
| op, "expected target shape of rng op to be ShapedType"); |
| } |
| auto loc = op.getLoc(); |
| Value initTensor = |
| getInitTensorFor(rewriter, loc, targetTy, op, adaptor.getOperands()); |
| // Creates index map using target matrix's rank. |
| auto targetRank = targetTy.getRank(); |
| SmallVector<AffineMap, 3> indexingMaps( |
| 2, AffineMap::get(targetRank, /*symbolCount=*/0, |
| SmallVector<AffineExpr>({}), rewriter.getContext())); |
| indexingMaps.push_back(rewriter.getMultiDimIdentityMap(targetRank)); |
| const int kInitialSeed = 0; |
| // Generic region with LCG Algorithm that make use of element index from: |
| // https://reviews.llvm.org/D101364 |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, /*resultTensors=*/targetTy, |
| /*inputs=*/ |
| ValueRange{adaptor.getOperands()[0], adaptor.getOperands()[1]}, |
| /*outputs=*/initTensor, indexingMaps, |
| getParallelAndReductionIterators(/*nLoops=*/targetRank, |
| /*nReduction=*/0), |
| [&](OpBuilder& b, Location loc, ValueRange args) { |
| llvm::SmallVector<Value> updateVec = {b.create<arith::ConstantOp>( |
| loc, b.getI32IntegerAttr(kInitialSeed))}; |
| Value multiplier = |
| b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1103515245)); |
| Value incrementStep = |
| b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(12345)); |
| // For output matrix with rank N: |
| // temp1 = (cast(I32, index(D.0)) + seed) * mult + incr |
| // ... |
| // tempN = (cast(I32, index(D.(N))) + tempN_1) * mult + incr |
| for (int i = 0; i < targetRank; i++) { |
| Value update = updateVec.back(); |
| Value ind = b.create<linalg::IndexOp>(loc, i); |
| Value castInd = |
| b.create<arith::IndexCastOp>(loc, b.getI32Type(), ind); |
| Value addRes = b.create<arith::AddIOp>(loc, castInd, update); |
| Value multRes = b.create<arith::MulIOp>(loc, addRes, multiplier); |
| Value incRes = b.create<arith::AddIOp>(loc, multRes, incrementStep); |
| updateVec.push_back(incRes); |
| } |
| // Scaling = (max - min) * const(F64, 2.3283064E-10) |
| // which is derived from rand(min,max) = rand()/(RAND_MAX/(max-min)). |
| Value epsilon = b.create<arith::ConstantOp>( |
| loc, b.getFloatAttr(args[0].getType(), 2.3283064E-10)); |
| Value range = b.create<arith::SubFOp>(loc, args[1], args[0]); |
| Value scale = b.create<arith::MulFOp>(loc, range, epsilon); |
| // Res = cast(T, cast(F64, tempN) * scaling + min) |
| Value updateCast = b.create<arith::UIToFPOp>( |
| loc, targetTy.getElementType(), updateVec.back()); |
| Value scaleUpdate = b.create<arith::MulFOp>(loc, updateCast, scale); |
| Value res = b.create<arith::AddFOp>(loc, scaleUpdate, args[0]); |
| b.create<linalg::YieldOp>(loc, res); |
| }, |
| pruneAttributeList(op)); |
| rewriter.replaceOp(op, linalgOp.getResults()); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // mhlo.Einsum conversion patterns. |
| //===----------------------------------------------------------------------===// |
| |
| // Looks through a set of dimension that has been marked as reduction axes, |
| // if it is found within the set, then we set it as "reduction", otherwise |
| // we can label it as "parallel". |
| SmallVector<StringRef, 3> getEinsumLoopsAttrs( |
| const llvm::SmallSetVector<StringRef, 4>& inputInd, |
| const llvm::SmallSetVector<StringRef, 4>& reductionDims) { |
| SmallVector<StringRef, 3> res; |
| for (StringRef dim : inputInd) { |
| if (!reductionDims.contains(dim)) { |
| res.push_back(getParallelIteratorTypeName()); |
| } else { |
| res.push_back(getReductionIteratorTypeName()); |
| } |
| } |
| return res; |
| } |
| |
| SmallVector<Value, 2> extractDynamicEinsumSizes( |
| OpBuilder& b, Location loc, Value lhs, Value rhs, |
| const SmallVector<std::string>& lhsLoopVec, |
| const SmallVector<std::string>& rhsLoopVec, |
| const SmallVector<std::string>& outputLoopVec) { |
| SmallVector<Value, 2> dynSizes; |
| for (const std::string& dimInd : outputLoopVec) { |
| Value dimSize; |
| const auto* dimIndIt = |
| std::find(lhsLoopVec.begin(), lhsLoopVec.end(), dimInd); |
| if (dimIndIt != lhsLoopVec.end()) { |
| // Query from lhs vars. |
| auto dimIndPos = dimIndIt - lhsLoopVec.begin(); |
| auto lhsShape = lhs.getType().dyn_cast<RankedTensorType>().getShape(); |
| if (lhsShape[dimIndPos] != ShapedType::kDynamicSize) continue; |
| dimSize = b.create<tensor::DimOp>(loc, lhs, dimIndPos); |
| } else { |
| // query from rhs vars. |
| dimIndIt = std::find(rhsLoopVec.begin(), rhsLoopVec.end(), dimInd); |
| auto dimIndPos = dimIndIt - rhsLoopVec.begin(); |
| auto rhsShape = rhs.getType().dyn_cast<RankedTensorType>().getShape(); |
| if (rhsShape[dimIndPos] != ShapedType::kDynamicSize) continue; |
| dimSize = b.create<tensor::DimOp>(loc, rhs, dimIndPos); |
| } |
| dynSizes.push_back(dimSize); |
| } |
| return dynSizes; |
| } |
| |
| // Adds indices/axes that are missing from output set. |
| llvm::SmallSetVector<StringRef, 4> findSummationAxes( |
| const llvm::SmallSetVector<StringRef, 4>& inputSet, |
| const llvm::SmallSetVector<StringRef, 4>& outputSet) { |
| llvm::SmallSetVector<StringRef, 4> summationAxes; |
| for (StringRef ind : inputSet) { |
| if (!outputSet.contains(ind)) summationAxes.insert(ind); |
| } |
| return summationAxes; |
| } |
| |
| // Given a 1:1 map from std::string -> affine dimension expression |
| // we can get the affine expression of dimensions that an |
| // operand will access based on the input_str of einsum_config. |
| // For example: |
| // let string_dim_umap = {'a' : d0, 'b' : d1, 'c' : d2} |
| // for einsum_config "abc,cb->acb" |
| // first_input_operand will get umap[{"a","b","c"}] -> (d0, d1, d2). |
| // second_input_operand will get umap[{"c","b"}] -> (d2, d1). |
| // output_operand will get umap[{"a","c","b"}] -> (d0, d2, d1). |
| SmallVector<AffineExpr> getExprFromConfig( |
| const SmallVector<std::string>& loopDims, |
| const DenseMap<StringRef, AffineExpr>& strAffineDimUmap) { |
| SmallVector<AffineExpr> exprs; |
| for (const auto& dim : loopDims) { |
| exprs.push_back(strAffineDimUmap.lookup(dim)); |
| } |
| return exprs; |
| } |
| |
| // Convert mhlo.einsum op into linalg.generic. |
| // Algorithm in general 3 steps: |
| |
| // Step1) Dissect entire einsum_config to different operands |
| // e.g f("abc,cd->abd") = {lhs:["abc"], rhs:["cd"], out:["abd"]}. |
| |
| // Step2) Split up the string into vector of the elements |
| // e.g {lhs:["abc"], rhs:["cd"], out:["abd"]} = {lhs:["a","b","c"], |
| // rhs:["c","d"], out:["a","b","d"]}. |
| |
| // Step3) Convert the vector into data access |
| // patern represented by affineMaps with affineDimensions e.g |
| // {lhs:["a","b","c"], rhs:["c","d"], out:["a","b","d"]} = {lhs:[d0,d1,d2], |
| // rhs:[d2,d3], out:[d0,d1,d3]}. |
| class EinsumToLinalgConverter : public OpConversionPattern<mhlo::EinsumOp> { |
| public: |
| using OpConversionPattern<mhlo::EinsumOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::EinsumOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| auto getRank = [](Value v) { |
| return v.getType().cast<ShapedType>().getRank(); |
| }; |
| auto einsumConfig = op.einsum_config(); |
| |
| // With the assumption of binary input operand and single output |
| // get the inputs and output operands' indices. |
| // einsum_config = "lhs_loop,rhs_loop->out_loop" |
| std::size_t posArrow = einsumConfig.find(kArrow); |
| std::size_t posComma = einsumConfig.find(kComma); |
| |
| StringRef lhsLoop = einsumConfig.substr(0, posComma); |
| StringRef rhsLoop = einsumConfig.substr( |
| posComma + kComma.size(), posArrow - (posComma + kComma.size())); |
| StringRef outLoop = einsumConfig.substr(posArrow + kArrow.size()); |
| |
| // Check for Invalid Configs. |
| // 1.Check that there is only maximum 2 inputs |
| // 2.Check that there is only maximum 1 output |
| // 3.Check that there is 1 kArrow |
| if (rhsLoop.find(kComma) != std::string::npos || |
| outLoop.find(kComma) != std::string::npos || |
| outLoop.find(kArrow) != std::string::npos) { |
| return rewriter.notifyMatchFailure(op, "Invalid einsum config!"); |
| } |
| |
| // Find result type, if on tensors. |
| auto resultTy = this->typeConverter->convertType(getHloOpResultType(op)) |
| .dyn_cast<RankedTensorType>(); |
| |
| // Check result type compatibility. |
| if (!resultTy || !(resultTy.getElementType().isSignlessIntOrFloat())) { |
| return rewriter.notifyMatchFailure(op, "Invalid result type"); |
| } |
| |
| // Convert the representation to vector<string>. |
| SmallVector<std::string> lhsEin = |
| getEinsumConfigAsVector(lhsLoop, getRank(adaptor.lhs())); |
| SmallVector<std::string> rhsEin = |
| getEinsumConfigAsVector(rhsLoop, getRank(adaptor.rhs())); |
| SmallVector<std::string> outEin = |
| getEinsumConfigAsVector(outLoop, resultTy.getRank()); |
| |
| if (!checkBatchHasEqualRank(lhsEin.size(), lhsLoop, rhsEin.size(), rhsLoop, |
| outEin.size(), outLoop)) { |
| return rewriter.notifyMatchFailure( |
| op, "Invalid elipsis('...') within einsum config!"); |
| } |
| |
| // Find all unique indices in the input and output. |
| llvm::SmallSetVector<StringRef, 4> inputInd; |
| llvm::SmallSetVector<StringRef, 4> outputInd; |
| |
| inputInd.insert(lhsEin.begin(), lhsEin.end()); |
| inputInd.insert(rhsEin.begin(), rhsEin.end()); |
| outputInd.insert(outEin.begin(), outEin.end()); |
| |
| llvm::SmallSetVector<StringRef, 4> reductionAxe = |
| findSummationAxes(inputInd, outputInd); |
| |
| // Find input/output values and types. |
| auto loc = op.getLoc(); |
| |
| // Prepare init tensor for linalg.generic op. |
| auto dynSizes = extractDynamicEinsumSizes( |
| rewriter, loc, adaptor.lhs(), adaptor.rhs(), lhsEin, rhsEin, outEin); |
| Value output = getInitTensor(rewriter, loc, resultTy, dynSizes); |
| if (!reductionAxe.empty()) { |
| output = fillTensorWithZeros(rewriter, loc, output); |
| } |
| |
| // Create indexing maps. |
| // Create a 1:1 map from f:strDimension -> affineDimension. |
| int64_t nloops = inputInd.size(); |
| DenseMap<StringRef, AffineExpr> strAffineDimUmap; |
| for (auto& it : llvm::enumerate(inputInd)) { |
| strAffineDimUmap[it.value()] = rewriter.getAffineDimExpr(it.index()); |
| } |
| |
| // From einsum_config of each operand in vector<string>, generate |
| // the equivalent vector<AffineExpr>. |
| SmallVector<AffineMap, 4> maps; |
| for (const SmallVector<std::string>& loopOperand : |
| {lhsEin, rhsEin, outEin}) { |
| auto exprs = getExprFromConfig(loopOperand, strAffineDimUmap); |
| maps.push_back(AffineMap::get(nloops, 0, exprs, rewriter.getContext())); |
| } |
| |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, resultTy ? resultTy : TypeRange{}, adaptor.getOperands(), output, |
| maps, getEinsumLoopsAttrs(inputInd, reductionAxe), |
| [reductionAxe](OpBuilder& b, Location nestedLoc, ValueRange args) { |
| Value resultVal = |
| b.create<mlir::arith::MulFOp>(nestedLoc, args[0], args[1]); |
| if (!reductionAxe.empty()) { |
| resultVal = |
| b.create<mlir::arith::AddFOp>(nestedLoc, args[2], resultVal); |
| } |
| b.create<linalg::YieldOp>(nestedLoc, resultVal); |
| }, |
| pruneAttributeList(op)); |
| rewriter.replaceOp(op, linalgOp.getResults()); |
| return success(); |
| } |
| |
| private: |
| static constexpr StringRef kArrow = "->"; |
| static constexpr StringRef kComma = ","; |
| static constexpr StringRef kEllipsis = "..."; |
| |
| static bool checkBatchHasEqualRank(size_t lhsRank, StringRef lhsLoop, |
| size_t rhsRank, StringRef rhsLoop, |
| size_t outRank, StringRef outLoop); |
| static SmallVector<std::string> getEinsumConfigAsVector(StringRef loop, |
| size_t operandRank); |
| }; |
| |
| // Definition of util const member variables. |
| constexpr StringRef EinsumToLinalgConverter::kArrow; |
| constexpr StringRef EinsumToLinalgConverter::kComma; |
| constexpr StringRef EinsumToLinalgConverter::kEllipsis; |
| |
| // Convert the representation from string/vector<char> to vector<string>. |
| // i.e ("abc") -> {"a", "b", "c"}. For cases with ellipsis with batch rank 3: |
| // get loop_dim = f("ab...cde") = {"a","b","0","1","2","c","d","e"} |
| SmallVector<std::string> EinsumToLinalgConverter::getEinsumConfigAsVector( |
| StringRef loop, size_t operandRank) { |
| SmallVector<std::string> loopDim; |
| size_t preElip = loop.find(kEllipsis); |
| bool hasElip = preElip != std::string::npos; |
| if (!hasElip) preElip = loop.size(); |
| // Add the dimension until the end or up to ellipsis if it exist. |
| for (int preElipInd = 0; preElipInd < preElip; preElipInd++) { |
| loopDim.push_back(loop.substr(preElipInd, 1).str()); |
| } |
| if (!hasElip) return loopDim; |
| // Case where Ellipsis presence: |
| size_t nonBatchRank = loop.size() - kEllipsis.size(); |
| size_t batchRank = operandRank - nonBatchRank; |
| // Add the batch dimension ("0",...,"N") where N is rank of batch into the |
| // loop. |
| for (int batchInd = 0; batchInd < batchRank; batchInd++) { |
| loopDim.push_back(std::to_string(batchInd)); |
| } |
| // Add the dimension after ellipsis into the loop. |
| int postElip = preElip + kEllipsis.size(); |
| for (int postElipInd = postElip; postElipInd < loop.size(); postElipInd++) { |
| loopDim.push_back(loop.substr(postElipInd, 1).str()); |
| } |
| return loopDim; |
| } |
| |
| // Returns true if all operand's batch has same rank. |
| bool EinsumToLinalgConverter::checkBatchHasEqualRank( |
| size_t lhsRank, StringRef lhsLoop, size_t rhsRank, StringRef rhsLoop, |
| size_t outRank, StringRef outLoop) { |
| SmallVector<int, 3> batchRankVec; |
| if (lhsRank != lhsLoop.size()) { |
| size_t lhsBatchRank = lhsRank - (lhsLoop.size() - kEllipsis.size()); |
| batchRankVec.push_back(lhsBatchRank); |
| } |
| if (rhsRank != rhsLoop.size()) { |
| size_t rhsBatchRank = rhsRank - (rhsLoop.size() - kEllipsis.size()); |
| batchRankVec.push_back(rhsBatchRank); |
| } |
| if (outRank != outLoop.size()) { |
| size_t outBatchRank = outRank - (outLoop.size() - kEllipsis.size()); |
| batchRankVec.push_back(outBatchRank); |
| } |
| bool batchHasEqualRank = true; |
| |
| // Condition is valid if only 1 operand or less have batches. |
| if (batchRankVec.size() < 2) return batchHasEqualRank; |
| if (!std::equal(batchRankVec.begin() + 1, batchRankVec.end(), |
| batchRankVec.begin()) && |
| batchRankVec.size() > 1) |
| batchHasEqualRank = false; |
| return batchHasEqualRank; |
| } |
| |
| template <typename OpTy> |
| class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> { |
| public: |
| using OpConversionPattern<OpTy>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| OpTy op, typename OpTy::Adaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| // Find maximum rank / number of loops. |
| auto getRank = [](Value v) { |
| return v.getType().cast<ShapedType>().getRank(); |
| }; |
| auto isScalar = [&](Value v) { return getRank(v) == 0; }; |
| auto it = llvm::find_if_not(adaptor.getOperands(), isScalar); |
| Value maxRankArg = |
| it != adaptor.getOperands().end() ? *it : adaptor.getOperands().front(); |
| int64_t nloops = getRank(maxRankArg); |
| |
| // Apply only if all operands are scalar or have the same rank. Some ops, |
| // like `mhlo.select`, support implicit broadcasting of scalars. |
| if (!llvm::all_of(adaptor.getOperands(), [&](Value v) { |
| int64_t r = getRank(v); |
| return r == 0 || r == nloops; |
| })) { |
| return rewriter.notifyMatchFailure( |
| op, "Operands must be os same rank or scalar."); |
| } |
| |
| // Find result type, if on tensors. |
| Optional<ShapedType> resultTy; |
| resultTy = this->typeConverter->convertType(op->getResultTypes().front()) |
| .template dyn_cast<ShapedType>(); |
| |
| // Check result type compatibility. |
| if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != nloops || |
| !(resultTy->getElementType().isSignlessIntOrFloat() || |
| resultTy->getElementType().isa<ComplexType>())) { |
| return rewriter.notifyMatchFailure( |
| op, "mismatched operand/result types or iterator count"); |
| } |
| |
| // Find input/output values and types. |
| auto loc = op.getLoc(); |
| ValueRange inputs = adaptor.getOperands(); |
| Value output = |
| getInitTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands()); |
| |
| // Create indexing maps. |
| AffineMap scalarMap = AffineMap::get(nloops, 0, rewriter.getContext()); |
| AffineMap idMap = rewriter.getMultiDimIdentityMap(nloops); |
| SmallVector<AffineMap, 4> maps; |
| for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap); |
| maps.push_back(idMap); |
| |
| // Build `linalg.generic` op. |
| bool failed = false; |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, resultTy ? *resultTy : TypeRange{}, inputs, output, maps, |
| getNParallelLoopsAttrs(nloops), |
| [&](OpBuilder& nestedBuilder, Location /*nested_loc*/, |
| ValueRange args) { |
| Type innerResultTy = getElementTypeOrSelf(output); |
| auto argvec = llvm::to_vector<2>(args.take_front(inputs.size())); |
| auto semiring = preSparsify(op, argvec, innerResultTy, &rewriter); |
| Value innerResult = mhlo::MhloOpToStdScalarOp::mapOp( |
| op, innerResultTy, argvec, &rewriter); |
| if (innerResult == nullptr) { |
| failed = true; |
| } else { |
| innerResult = postSparsify(op, semiring, innerResult, &rewriter); |
| nestedBuilder.create<linalg::YieldOp>(loc, innerResult); |
| } |
| }, |
| pruneAttributeList(op)); |
| if (failed) return failure(); |
| |
| rewriter.replaceOp(op, linalgOp->getResults()); |
| return success(); |
| } |
| }; |
| |
| template <typename MhloOp> |
| class ScalarPointwiseToStandardConverter : public OpConversionPattern<MhloOp> { |
| public: |
| using OpConversionPattern<MhloOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| MhloOp mhloOp, ConversionPatternRewriter& rewriter) const final { |
| auto loc = mhloOp.getLoc(); |
| auto argType = |
| mhloOp.getOperand(0).getType().template dyn_cast<ShapedType>(); |
| if (!argType || !argType.getElementType().isSignlessIntOrFloat() || |
| (argType.getRank() != 0)) { |
| return failure(); |
| } |
| |
| // Create two loads from the input. |
| auto lhs = rewriter.create<memref::LoadOp>(loc, mhloOp.lhs()); |
| auto rhs = rewriter.create<memref::LoadOp>(loc, mhloOp.rhs()); |
| Value opResult = mhlo::MhloOpToStdScalarOp::mapOp( |
| mhloOp, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs}, |
| &rewriter); |
| rewriter.create<memref::StoreOp>(loc, opResult, mhloOp.out()); |
| rewriter.eraseOp(mhloOp); |
| return success(); |
| } |
| }; |
| |
| /// Base class for lowering HLO operations that have one operand and one result, |
| /// and are semantically equivalent to a copy of the input to the output (like |
| /// transpose, some reshape, etc.). The derived classes need to provide a method |
| /// `getIndexingMaps` that returns AffineMaps for the index maps of the input |
| /// and the output. |
| template <typename Derived, typename OpTy> |
| class DataMovementOpConverter : public OpConversionPattern<OpTy> { |
| public: |
| using OpConversionPattern<OpTy>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| OpTy op, typename OpTy::Adaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| if (!verifyHloOpBufferOrTensorSemantics(op)) return failure(); |
| auto resultType = getHloOpResultType(op); |
| resultType = this->typeConverter->convertType(resultType) |
| .template cast<ShapedType>(); |
| |
| SmallVector<AffineMap, 2> indexingMaps = |
| Derived::getIndexingMaps(op, &rewriter); |
| if (indexingMaps.empty()) return failure(); |
| |
| auto nloops = resultType.getRank(); |
| auto loc = op.getLoc(); |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, |
| /*resultTensorTypes=*/resultType, |
| /*inputs=*/adaptor.getOperands().front(), |
| /*outputBuffers=*/ |
| |
| ValueRange{getInitTensorFor(rewriter, loc, resultType, op, |
| adaptor.getOperands())}, |
| indexingMaps, getNParallelLoopsAttrs(nloops), |
| [&](OpBuilder& nestedBuilder, Location /*nested_loc*/, |
| ValueRange args) { |
| nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); |
| }, |
| pruneAttributeList(op)); |
| rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); |
| return success(); |
| } |
| }; |
| |
| /// Pattern to convert BroadcastOp to Linalg ops. |
| template <typename OpTy> |
| class BroadcastConverter |
| : public DataMovementOpConverter<BroadcastConverter<OpTy>, OpTy> { |
| public: |
| using DataMovementOpConverter<BroadcastConverter, |
| OpTy>::DataMovementOpConverter; |
| |
| static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp, |
| Builder* b) { |
| ShapedType inputType = |
| broadcastOp.operand().getType().template cast<ShapedType>(); |
| unsigned inputRank = inputType.getRank(); |
| unsigned nloops = getHloOpResultType(broadcastOp).getRank(); |
| |
| // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to |
| // the input's dimensions. |
| unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes()); |
| SmallVector<AffineExpr, 4> inputDimExprs; |
| inputDimExprs.reserve(inputRank); |
| for (unsigned i = 0; i < inputRank; ++i) { |
| inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i)); |
| } |
| |
| AffineMap inputMap; |
| MLIRContext* context = b->getContext(); |
| if (inputDimExprs.empty()) { |
| // The input is a scalar, i.e. this is a scalar broadcast op. |
| inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context); |
| } else { |
| inputMap = |
| AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); |
| } |
| return {inputMap, b->getMultiDimIdentityMap(nloops)}; |
| } |
| }; |
| |
| class HloBroadcastInDimConverter |
| : public DataMovementOpConverter<HloBroadcastInDimConverter, |
| mhlo::BroadcastInDimOp> { |
| public: |
| using DataMovementOpConverter< |
| HloBroadcastInDimConverter, |
| mhlo::BroadcastInDimOp>::DataMovementOpConverter; |
| |
| static SmallVector<AffineMap, 2> getIndexingMaps( |
| mhlo::BroadcastInDimOp broadcastOp, Builder* b) { |
| auto resultType = getHloOpResultType(broadcastOp); |
| auto operandType = |
| broadcastOp.operand().getType().template cast<ShapedType>(); |
| unsigned nloops = resultType.getRank(); |
| |
| // The input is a scalar, i.e. this is a scalar broadcast op. |
| if (operandType.getRank() == 0) { |
| return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), |
| b->getMultiDimIdentityMap(nloops)}; |
| } |
| |
| auto operandShape = operandType.getShape(); |
| SmallVector<AffineExpr, 4> dimExprs; |
| dimExprs.reserve(nloops); |
| |
| if (broadcastOp.broadcast_dimensions()) { |
| for (const auto& broadcastDim : |
| enumerate(broadcastOp.broadcast_dimensions().getValues<APInt>())) { |
| int size = broadcastDim.value().getSExtValue(); |
| bool expansionNeeded = operandShape[broadcastDim.index()] == 1 && |
| resultType.getShape()[size] != 1; |
| dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0) |
| : b->getAffineDimExpr(size)); |
| } |
| } |
| return { |
| AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), |
| b->getMultiDimIdentityMap(nloops)}; |
| } |
| }; |
| |
| // If the input has a static shape we know exactly when the broadcast must |
| // expand (the dimension is 1, which also trivially expands to 1) or will never |
| // expand (the dimension is not 1). We can also source the information from the |
| // optionally provided attrbibutes on statically known broadcasting behavior. |
| // This means we can lower the broadcast just as we would lower a fully static |
| // broadcast and go directly to `linalg.generic`. |
| |
| // This also covers the important case of broadcasting a scalar. Ideally the |
| // pattern (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be |
| // converted to a tensor dialect op similar to TF's `ConstantLikeOp`. |
| class HloDynamicBroadcastInDimConverter |
| : public OpConversionPattern<mhlo::DynamicBroadcastInDimOp> { |
| public: |
| using OpConversionPattern<mhlo::DynamicBroadcastInDimOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::DynamicBroadcastInDimOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| Value operand = adaptor.operand(); |
| auto operandType = operand.getType().dyn_cast<RankedTensorType>(); |
| if (!operandType) return failure(); |
| auto resultType = |
| typeConverter->convertType(op.getType()).dyn_cast<RankedTensorType>(); |
| if (!resultType) return failure(); |
| |
| // Determine dimension expressions based on whether the dimension is |
| // expanding (0) or non-expanding (identity), and fail if we cannot decide |
| // this. |
| SmallVector<AffineExpr> dimExprs(operandType.getRank(), nullptr); |
| |
| // Use static type info. |
| auto bcastDims = llvm::to_vector( |
| llvm::map_range(op.broadcast_dimensions(), [](const APInt& d) { |
| return static_cast<int64_t>(d.getLimitedValue()); |
| })); |
| for (const auto& it : llvm::enumerate(operandType.getShape())) { |
| if (ShapedType::isDynamic(it.value())) continue; |
| bool isExpanding = it.value() == 1; |
| dimExprs[it.index()] = |
| isExpanding ? rewriter.getAffineConstantExpr(0) |
| : rewriter.getAffineDimExpr(bcastDims[it.index()]); |
| } |
| |
| // Use annotated expansion behavior, if available. |
| if (op.known_expanding_dimensions()) { |
| for (const auto& it : |
| op.known_expanding_dimensions()->getValues<APInt>()) { |
| auto i = it.getLimitedValue(); |
| dimExprs[i] = rewriter.getAffineConstantExpr(0); |
| } |
| } |
| if (op.known_nonexpanding_dimensions()) { |
| for (const auto& it : |
| op.known_nonexpanding_dimensions()->getValues<APInt>()) { |
| auto i = it.getLimitedValue(); |
| dimExprs[i] = rewriter.getAffineDimExpr(bcastDims[i]); |
| } |
| } |
| |
| // Fail if unknown expansion behavior remains. |
| if (!llvm::all_of(dimExprs, [](AffineExpr expr) { return expr; })) |
| return failure(); |
| |
| // Materialize `linalg.generic` op. |
| Location loc = op.getLoc(); |
| int64_t nloops = resultType.getRank(); |
| Value init = |
| getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); |
| rewriter.replaceOpWithNewOp<linalg::GenericOp>( |
| op, TypeRange{init.getType()}, ValueRange{operand}, |
| /*outputBuffers=*/ValueRange{init}, |
| llvm::makeArrayRef( |
| {AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dimExprs, |
| rewriter.getContext()), |
| rewriter.getMultiDimIdentityMap(nloops)}), |
| getNParallelLoopsAttrs(nloops), |
| [&](OpBuilder& nestedBuilder, Location /*nested_loc*/, |
| ValueRange args) { |
| nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); |
| }, |
| pruneAttributeList(op)); |
| return success(); |
| } |
| }; |
| |
| template <typename OpTy> |
| class TransposeConverter |
| : public DataMovementOpConverter<TransposeConverter<OpTy>, OpTy> { |
| public: |
| using DataMovementOpConverter<TransposeConverter<OpTy>, |
| OpTy>::DataMovementOpConverter; |
| static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { |
| auto resultType = getHloOpResultType(op).template cast<ShapedType>(); |
| auto nloops = resultType.getRank(); |
| SmallVector<AffineExpr, 2> inputExprs; |
| inputExprs.resize(resultType.getRank()); |
| for (const auto& permutation : llvm::enumerate(op.permutation())) { |
| inputExprs[permutation.value().getZExtValue()] = |
| b->getAffineDimExpr(permutation.index()); |
| } |
| return { |
| AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), |
| b->getMultiDimIdentityMap(nloops)}; |
| } |
| }; |
| |
| // Lowers mhlo.RealDynamicSliceOp to tensor.extract_slice and other |
| // arith/tensor dialect ops. |
| class RealDynamicSliceConverter |
| : public OpConversionPattern<mhlo::RealDynamicSliceOp> { |
| public: |
| using OpConversionPattern<mhlo::RealDynamicSliceOp>::OpConversionPattern; |
| |
| // Computes size of a slice as |
| // size = ceil((limit - start)/stride) |
| static Value computeSize(Location loc, Value start, Value limit, Value stride, |
| ConversionPatternRewriter& b) { |
| Value delta = b.create<arith::SubIOp>(loc, limit, start); |
| Value ret = b.create<arith::CeilDivUIOp>(loc, delta, stride); |
| if (ret.getType().isIndex()) return ret; |
| return b.create<arith::IndexCastOp>(loc, b.getIndexType(), ret); |
| } |
| |
| LogicalResult matchAndRewrite( |
| mhlo::RealDynamicSliceOp realDynamicSliceOp, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| Location loc = realDynamicSliceOp.getLoc(); |
| auto argType = adaptor.operand().getType().dyn_cast<ShapedType>(); |
| if (!argType || !argType.hasRank()) { |
| return rewriter.notifyMatchFailure(realDynamicSliceOp, |
| "require known-rank args"); |
| } |
| |
| Type dimElementType = getElementTypeOrSelf(adaptor.start_indices()); |
| if (getElementTypeOrSelf(adaptor.limit_indices()) != dimElementType || |
| getElementTypeOrSelf(adaptor.strides()) != dimElementType) { |
| return rewriter.notifyMatchFailure( |
| realDynamicSliceOp, |
| "requires same element type for all dimension specification"); |
| } |
| Type arithType = |
| dimElementType.isIndex() ? rewriter.getI64Type() : dimElementType; |
| Type indexType = rewriter.getIndexType(); |
| |
| auto resultType = |
| this->typeConverter->convertType(realDynamicSliceOp.getType()) |
| .cast<RankedTensorType>(); |
| Value zero = |
| rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(arithType, 0)); |
| SmallVector<OpFoldResult, 4> offsets, sizes, strides; |
| SmallVector<Type, 3> clampType(3, arithType); |
| for (auto i : llvm::seq<unsigned>(0, argType.getRank())) { |
| Value dim = rewriter.create<arith::ConstantIndexOp>(loc, i); |
| Value start = |
| rewriter.create<tensor::ExtractOp>(loc, adaptor.start_indices(), dim); |
| Value limit = |
| rewriter.create<tensor::ExtractOp>(loc, adaptor.limit_indices(), dim); |
| Value stride = |
| rewriter.create<tensor::ExtractOp>(loc, adaptor.strides(), dim); |
| |
| // Compute i-th dimension size of the result : size[i]. |
| // If the i-th dimension of the result type is known, we go ahead with it |
| // else we compute it using limit, start and stride values. |
| int64_t resultDimSize = resultType.getDimSize(i); |
| Value size = |
| ShapedType::isDynamic(resultDimSize) |
| ? computeSize(loc, start, limit, stride, rewriter) |
| : rewriter.create<arith::ConstantIndexOp>(loc, resultDimSize); |
| |
| // Fetch i-th dimension size of the operand and calculate upper bound as |
| // ub = operand_dim[i] - size[i] |
| Value operandDimSize = |
| rewriter.createOrFold<tensor::DimOp>(loc, adaptor.operand(), dim); |
| Value upperBound = |
| rewriter.createOrFold<arith::SubIOp>(loc, operandDimSize, size); |
| |
| // We clamp the start_index to keep it bounded as |
| // 0 <= start_index[i] <= ub |
| // Clamp does not support index type, so cast to integer type. |
| start = rewriter.createOrFold<arith::IndexCastOp>(loc, arithType, start); |
| upperBound = |
| rewriter.createOrFold<arith::IndexCastOp>(loc, arithType, upperBound); |
| start = mhlo::MhloOpToStdScalarOp::mapOpOfType<mhlo::ClampOp>( |
| loc, arithType, clampType, ValueRange{zero, start, upperBound}, |
| &rewriter); |
| |
| offsets.push_back( |
| rewriter.createOrFold<arith::IndexCastOp>(loc, indexType, start)); |
| if (ShapedType::isDynamic(resultDimSize)) |
| sizes.push_back(size); |
| else |
| sizes.push_back(IntegerAttr::get(indexType, resultDimSize)); |
| strides.push_back( |
| rewriter.createOrFold<arith::IndexCastOp>(loc, indexType, stride)); |
| } |
| |
| rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( |
| realDynamicSliceOp, resultType, adaptor.operand(), offsets, sizes, |
| strides); |
| return success(); |
| } |
| }; |
| |
| // Converts reshape ops that can be proven to be either a collapse of dimensions |
| // or expansion of dimensions of the operand. |
| class ReshapeOpConverter : public OpConversionPattern<mhlo::ReshapeOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::ReshapeOp reshapeOp, mhlo::ReshapeOp::Adaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| if (!verifyHloOpBufferOrTensorSemantics(reshapeOp)) return failure(); |
| auto operand = adaptor.operand(); |
| auto operandType = operand.getType().cast<ShapedType>(); |
| auto elemType = operandType.getElementType(); |
| auto resultType = reshapeOp.getType().cast<ShapedType>(); |
| |
| if (!resultType.hasStaticShape()) return failure(); |
| |
| resultType = typeConverter->convertType(resultType).cast<ShapedType>(); |
| |
| // Special case where the result is a scalar. |
| if (resultType.getRank() == 0 && !operandType.hasStaticShape()) { |
| // This means all dimensions of the operand need to be 1. We add a cast to |
| // cast the dynamic dimensions to 1. |
| auto staticType = RankedTensorType::get( |
| llvm::SmallVector<int64_t>(operandType.getRank(), 1), elemType); |
| operand = rewriter.create<tensor::CastOp>(reshapeOp.getLoc(), staticType, |
| operand); |
| rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( |
| reshapeOp, resultType, operand, ArrayRef<ReassociationIndices>{}); |
| return success(); |
| } |
| |
| // Compute the reassociation maps for the linalg operation. This will |
| // succeed if the reshape can be done with a single expand_shape or |
| // collapse_shape. |
| if (Optional<SmallVector<ReassociationIndices>> reassociationMap = |
| getReassociationIndicesForReshape(operandType, resultType)) { |
| if (resultType.getRank() < operandType.getRank()) { |
| // We have found a working reassociation map. If the operand is dynamic, |
| // we first need to cast all unknown dimensions in the input that get |
| // collapsed to a static-sized dimension in the output, to 1. |
| SmallVector<int64_t> shape(operandType.getShape().begin(), |
| operandType.getShape().end()); |
| for (const auto& map : llvm::enumerate(*reassociationMap)) { |
| // If the result dim is dynamic, we do not mind dynamic entries in the |
| // source. |
| if (resultType.isDynamicDim(map.index())) continue; |
| for (auto targetDim : map.value()) { |
| if (shape[targetDim] == ShapedType::kDynamicSize) |
| shape[targetDim] = 1; |
| } |
| } |
| // Insert a cast if types are not the same (ignoring sparse encoding). |
| auto enc = sparse_tensor::getSparseTensorEncoding(operandType); |
| auto newOperandType = RankedTensorType::get(shape, elemType, enc); |
| if (newOperandType != operandType) { |
| operand = rewriter.create<tensor::CastOp>(reshapeOp.getLoc(), |
| newOperandType, operand); |
| } |
| // Generate collapse operation. |
| rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( |
| reshapeOp, resultType, operand, *reassociationMap); |
| } else { |
| // Generate expand operation. |
| rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( |
| reshapeOp, resultType, operand, *reassociationMap); |
| } |
| return success(); |
| } |
| |
| Value collapsedOp = operand; |
| Location loc = reshapeOp.getLoc(); |
| auto getIdentityExprs = [&rewriter](int64_t n) { |
| SmallVector<AffineExpr, 4> exprs; |
| for (int i = 0; i < n; ++i) exprs.push_back(rewriter.getAffineDimExpr(i)); |
| return exprs; |
| }; |
| // Otherwise, we need to first reduce all source dimensions into one and |
| // then expand to the destination dimensions. If there is only a single |
| // source dimension, the reduce step can be skipped. TensorCollapseShape |
| // expects a different rank of operand and result. |
| if (operandType.getRank() != 1) { |
| SmallVector<ReassociationExprs, 4> collapsingMap = { |
| // Use operand_type here because we need to collapse all operands |
| // dimensions. |
| getIdentityExprs(operandType.getRank())}; |
| |
| collapsedOp = |
| rewriter.create<tensor::CollapseShapeOp>(loc, operand, collapsingMap); |
| } |
| // Cast to a known static type if the input has dynamic dimensions. |
| int64_t totalElems = resultType.getNumElements(); |
| auto collapsedType = RankedTensorType::get({totalElems}, elemType); |
| collapsedOp = |
| rewriter.create<tensor::CastOp>(loc, collapsedType, collapsedOp); |
| if (resultType.getRank() == 1) { |
| rewriter.replaceOp(reshapeOp, collapsedOp); |
| } else { |
| SmallVector<ReassociationExprs, 4> expandingMap = { |
| // Use resultType here because we need to expand to all result |
| // dimensions. |
| getIdentityExprs(resultType.getRank())}; |
| rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( |
| reshapeOp, resultType, collapsedOp, expandingMap); |
| } |
| return success(); |
| } |
| }; |
| |
| template <typename OpTy> |
| class IotaConverter : public OpConversionPattern<OpTy> { |
| public: |
| using OpConversionPattern<OpTy>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| OpTy iotaOp, typename OpTy::Adaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| ShapedType resultShapedType = getHloOpResultType(iotaOp); |
| if (!resultShapedType) return failure(); |
| resultShapedType = this->typeConverter->convertType(resultShapedType) |
| .template dyn_cast<ShapedType>(); |
| |
| Type resultElementType = resultShapedType.getElementType(); |
| |
| // Construct the indexing maps needed for linalg.generic ops. |
| unsigned nloops = resultShapedType.getRank(); |
| |
| Location loc = iotaOp.getLoc(); |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, |
| /*resultTensorTypes=*/ |
| ArrayRef<Type>{resultShapedType}, |
| /*inputs=*/ValueRange{}, |
| /*outputBuffers=*/ |
| |
| ValueRange{getInitTensorFor(rewriter, loc, resultShapedType, iotaOp, |
| adaptor.getOperands())}, |
| llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), |
| getNParallelLoopsAttrs(nloops), |
| [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange /*args*/) { |
| Value indexOp = nestedBuilder.create<linalg::IndexOp>( |
| nestedLoc, iotaOp.iota_dimension()); |
| Type unwrappedResultElementType = resultElementType; |
| if (auto complexType = |
| unwrappedResultElementType.dyn_cast<ComplexType>()) |
| unwrappedResultElementType = complexType.getElementType(); |
| Value castOp = nestedBuilder.create<arith::IndexCastOp>( |
| nestedLoc, |
| nestedBuilder.getIntegerType( |
| unwrappedResultElementType.getIntOrFloatBitWidth()), |
| indexOp); |
| castOp = mhlo::MhloOpToStdScalarOp::mapOpOfType<mhlo::ConvertOp>( |
| nestedLoc, resultElementType, castOp.getType(), castOp, |
| &nestedBuilder); |
| nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp); |
| }, |
| pruneAttributeList(iotaOp)); |
| rewriter.replaceOp(iotaOp, linalgOp.result_tensors()); |
| return success(); |
| } |
| }; |
| |
| /// Converts mhlo.concatenate operation to a linalg.generic op. |
| struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> { |
| using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::ConcatenateOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const override { |
| // Shortcut the one-operand case, simplifies code below. |
| if (adaptor.getOperands().size() == 1) { |
| rewriter.replaceOp(op, adaptor.getOperands()[0]); |
| return success(); |
| } |
| |
| auto resultType = this->typeConverter->convertType(op.getResult().getType()) |
| .dyn_cast<RankedTensorType>(); |
| if (!resultType) return failure(); |
| |
| uint64_t dim = op.dimension(); |
| Location loc = op.getLoc(); |
| Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| |
| // Allocate the output tensor with init_tensor. |
| Value result = |
| getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); |
| |
| // Generate a generic op to gather the elements of the concatenate. This is |
| // awkward standalone but allows fusion with other generic ops. |
| int64_t nloops = resultType.getRank(); |
| rewriter.replaceOpWithNewOp<linalg::GenericOp>( |
| op, |
| /*resultTensorTypes=*/resultType, |
| /*inputs=*/ValueRange{}, /*outputBuffers=*/result, |
| llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), |
| getNParallelLoopsAttrs(nloops), |
| [&](OpBuilder& nestedBuilder, Location loc, ValueRange) { |
| OpBuilder b = nestedBuilder; |
| Value concatDimSize = zero; |
| Value result; |
| |
| SmallVector<Value, 4> extractIndices; |
| extractIndices.reserve(nloops); |
| for (int64_t i = 0; i < nloops; i++) { |
| extractIndices.push_back(b.create<linalg::IndexOp>(loc, i)); |
| } |
| |
| Value indexOp = b.create<linalg::IndexOp>(loc, dim); |
| for (auto& it : llvm::enumerate(adaptor.getOperands())) { |
| Value arg = it.value(); |
| Value newConcatDimSize; |
| scf::IfOp ifOp; |
| if (it.index() != (adaptor.getOperands().size() - 1)) { |
| // Calculate how far along we have iterated along the concatenate |
| // dimension. That way we can tell which input to select. |
| newConcatDimSize = b.create<arith::AddIOp>( |
| loc, concatDimSize, b.create<tensor::DimOp>(loc, arg, dim)); |
| Value cmp = b.create<arith::CmpIOp>(loc, rewriter.getI1Type(), |
| arith::CmpIPredicate::ult, |
| indexOp, newConcatDimSize); |
| ifOp = b.create<scf::IfOp>(loc, resultType.getElementType(), cmp, |
| true); |
| if (result) { |
| b.create<scf::YieldOp>(loc, ifOp->getResults()[0]); |
| } else { |
| result = ifOp->getResults()[0]; |
| } |
| |
| b = ifOp.getThenBodyBuilder(b.getListener()); |
| } |
| |
| // Now adjust the index for the concatenated dimension to fit into |
| // the selected tensor and do an extract at that position. |
| extractIndices[dim] = |
| b.create<arith::SubIOp>(loc, indexOp, concatDimSize); |
| Value extract = |
| b.create<tensor::ExtractOp>(loc, arg, extractIndices); |
| b.create<scf::YieldOp>(loc, extract); |
| |
| if (ifOp) { |
| b = ifOp.getElseBodyBuilder(b.getListener()); |
| concatDimSize = newConcatDimSize; |
| } |
| } |
| nestedBuilder.create<linalg::YieldOp>(loc, result); |
| }, |
| pruneAttributeList(op)); |
| return success(); |
| } |
| }; |
| |
| class ConstConverterTensor : public OpConversionPattern<mhlo::ConstantOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::ConstantOp constOp, OpAdaptor /*adaptor*/, |
| ConversionPatternRewriter& rewriter) const final { |
| auto valueAttr = constOp.value().cast<DenseElementsAttr>(); |
| auto type = |
| typeConverter->convertType(constOp.getType()).cast<ShapedType>(); |
| if (type != constOp.getType()) { |
| // Signedness conversion. |
| valueAttr = valueAttr.mapValues(type.getElementType(), |
| [](const APInt& i) { return i; }); |
| } |
| rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, type, valueAttr); |
| return success(); |
| } |
| }; |
| |
| // TODO(b/156787842): Support the lowering for dynamic shapes. |
| class ReverseConverter |
| : public DataMovementOpConverter<ReverseConverter, mhlo::ReverseOp> { |
| public: |
| using DataMovementOpConverter<ReverseConverter, |
| mhlo::ReverseOp>::DataMovementOpConverter; |
| static SmallVector<AffineMap, 2> getIndexingMaps(mhlo::ReverseOp op, |
| Builder* b) { |
| auto resultType = getHloOpResultType(op).cast<ShapedType>(); |
| auto nloops = resultType.getRank(); |
| SmallVector<AffineExpr, 2> inputExprs; |
| inputExprs.reserve(nloops); |
| for (int i = 0; i < nloops; ++i) |
| inputExprs.push_back(b->getAffineDimExpr(i)); |
| for (auto dim : op.dimensions()) { |
| int i = dim.getZExtValue(); |
| if (resultType.isDynamicDim(i)) return {}; |
| int n = resultType.getShape()[i]; |
| inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; |
| } |
| return { |
| AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), |
| b->getMultiDimIdentityMap(nloops)}; |
| } |
| }; |
| |
| class SliceConverter : public OpConversionPattern<mhlo::SliceOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::SliceOp sliceOp, typename mhlo::SliceOp::Adaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| auto argType = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>(); |
| if (!argType || !argType.hasRank()) { |
| return rewriter.notifyMatchFailure(sliceOp, "expects known-rank args"); |
| } |
| |
| SmallVector<OpFoldResult, 3> offsets, sizes, strides; |
| for (int i = 0, e = argType.getRank(); i < e; ++i) { |
| auto start = sliceOp.start_indices().getValues<int64_t>()[i]; |
| auto limit = sliceOp.limit_indices().getValues<int64_t>()[i]; |
| auto stride = sliceOp.strides().getValues<int64_t>()[i]; |
| offsets.push_back(rewriter.getI64IntegerAttr(start)); |
| // Say that there are k elements in total, we have condition: |
| // start + (k - 1) * strides <= limit - 1 |
| // -> |
| // k <= (limit - 1 - start) / strides + 1 |
| sizes.push_back( |
| rewriter.getI64IntegerAttr((limit - 1 - start) / stride + 1)); |
| strides.push_back(rewriter.getI64IntegerAttr(stride)); |
| } |
| rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( |
| sliceOp, adaptor.getOperands()[0], offsets, sizes, strides); |
| return success(); |
| } |
| }; |
| |
| class DynamicSliceConverter : public OpConversionPattern<mhlo::DynamicSliceOp> { |
| public: |
| using OpConversionPattern<mhlo::DynamicSliceOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::DynamicSliceOp dynamicSliceOp, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| auto loc = dynamicSliceOp.getLoc(); |
| auto argType = adaptor.operand().getType().dyn_cast<ShapedType>(); |
| if (!argType || !argType.hasRank()) { |
| return rewriter.notifyMatchFailure(dynamicSliceOp, |
| "require known-rank args"); |
| } |
| |
| auto indexType = rewriter.getIndexType(); |
| SmallVector<OpFoldResult, 3> startIndices, sizes; |
| Value zero = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getZeroAttr(adaptor.start_indices()[0] |
| .getType() |
| .cast<RankedTensorType>() |
| .getElementType())); |
| for (auto& en : llvm::enumerate( |
| llvm::zip(adaptor.start_indices(), |
| dynamicSliceOp.slice_sizes().getValues<int64_t>()))) { |
| int64_t size = std::get<1>(en.value()); |
| sizes.push_back(rewriter.getI64IntegerAttr(size)); |
| |
| // By mhlo.DynamicSlice definition: |
| // `start_indices[i] = clamp(start_indices[i], |
| // 0, operand.dimension_size[i] - size_indices[i])` |
| Value startIndex = |
| rewriter.create<tensor::ExtractOp>(loc, std::get<0>(en.value())); |
| Value ub = rewriter.createOrFold<tensor::DimOp>(loc, adaptor.operand(), |
| en.index()); |
| // ClampOp lowering does not support index type, so cast it into integer |
| // type. |
| ub = rewriter.createOrFold<arith::IndexCastOp>(loc, startIndex.getType(), |
| ub); |
| ub = rewriter.createOrFold<arith::SubIOp>( |
| loc, ub, |
| rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getIntegerAttr(startIndex.getType(), size))); |
| startIndex = mhlo::MhloOpToStdScalarOp::mapOpOfType<mhlo::ClampOp>( |
| loc, startIndex.getType(), |
| ArrayRef<Type>{startIndex.getType(), startIndex.getType(), |
| startIndex.getType()}, |
| ArrayRef<Value>{zero, startIndex, ub}, &rewriter); |
| startIndices.push_back( |
| rewriter.create<arith::IndexCastOp>(loc, indexType, startIndex) |
| .getResult()); |
| } |
| |
| int64_t rank = argType.getRank(); |
| SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1)); |
| |
| auto resultType = this->typeConverter->convertType(dynamicSliceOp.getType()) |
| .cast<RankedTensorType>(); |
| |
| rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( |
| dynamicSliceOp, resultType, adaptor.operand(), startIndices, sizes, |
| strides); |
| return success(); |
| } |
| }; |
| |
| class DynamicUpdateSliceConverter |
| : public OpConversionPattern<mhlo::DynamicUpdateSliceOp> { |
| public: |
| using OpConversionPattern<mhlo::DynamicUpdateSliceOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| auto loc = op.getLoc(); |
| auto operandType = adaptor.operand().getType().dyn_cast<RankedTensorType>(); |
| if (!operandType || !operandType.hasStaticShape()) { |
| return rewriter.notifyMatchFailure( |
| op, "require static ranked type for operand"); |
| } |
| |
| auto updateType = adaptor.update().getType().dyn_cast<RankedTensorType>(); |
| if (!updateType || !updateType.hasStaticShape()) { |
| return rewriter.notifyMatchFailure( |
| op, "require static ranked type for operand"); |
| } |
| |
| // We do not have to clamp sizes because the semantic of `update` |
| // guarantees that it is always in the bounds. See |
| // https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice |
| SmallVector<OpFoldResult, 3> sizes; |
| for (auto size : updateType.getShape()) { |
| sizes.push_back(rewriter.getIndexAttr(size)); |
| } |
| |
| auto indexType = rewriter.getIndexType(); |
| SmallVector<OpFoldResult, 3> startIndices; |
| Type startIndexType = adaptor.start_indices()[0] |
| .getType() |
| .cast<RankedTensorType>() |
| .getElementType(); |
| Value zero = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getZeroAttr(startIndexType)); |
| for (auto& en : llvm::enumerate(adaptor.start_indices())) { |
| // By mhlo.DynamicUpdateSlice definition: |
| // `start_indices[i] = clamp(start_indices[i], |
| // 0, operand.dimension_size[i] - update.dimension_size[i])` |
| Value startIndex = rewriter.create<tensor::ExtractOp>(loc, en.value()); |
| Value ub = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getIntegerAttr(startIndexType, |
| operandType.getDimSize(en.index()) - |
| updateType.getDimSize(en.index()))); |
| startIndex = mhlo::MhloOpToStdScalarOp::mapOpOfType<mhlo::ClampOp>( |
| loc, startIndexType, |
| ArrayRef<Type>{startIndexType, startIndexType, startIndexType}, |
| ArrayRef<Value>{zero, startIndex, ub}, &rewriter); |
| startIndices.push_back( |
| rewriter.create<arith::IndexCastOp>(loc, indexType, startIndex) |
| .getResult()); |
| } |
| |
| int64_t rank = operandType.getRank(); |
| SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1)); |
| rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( |
| op, adaptor.update(), adaptor.operand(), startIndices, sizes, strides); |
| return success(); |
| } |
| }; |
| |
| enum class DotOperationType { |
| kVectorDot = 0, |
| kMatrixVector, |
| kVectorMatrix, |
| kMatrixMatrix, |
| kUnsupported |
| }; |
| |
| DotOperationType getDotOperationType(mhlo::DotOp dotOp) { |
| ArrayRef<int64_t> lhsShape = |
| dotOp.lhs().getType().cast<ShapedType>().getShape(); |
| ArrayRef<int64_t> rhsShape = |
| dotOp.rhs().getType().cast<ShapedType>().getShape(); |
| auto shapeMatches = [](int64_t a, int64_t b) { |
| return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize || |
| a == b; |
| }; |
| if (lhsShape.size() == 1 && rhsShape.size() == 1 && |
| shapeMatches(lhsShape[0], rhsShape[0])) { |
| return DotOperationType::kVectorDot; |
| } |
| if (lhsShape.size() == 2 && rhsShape.size() == 1 && |
| shapeMatches(lhsShape[1], rhsShape[0])) { |
| return DotOperationType::kMatrixVector; |
| } |
| if (lhsShape.size() == 1 && rhsShape.size() == 2 && |
| shapeMatches(lhsShape[0], rhsShape[0])) { |
| return DotOperationType::kVectorMatrix; |
| } |
| if (lhsShape.size() == 2 && rhsShape.size() == 2 && |
| shapeMatches(lhsShape[1], rhsShape[0])) { |
| return DotOperationType::kMatrixMatrix; |
| } |
| return DotOperationType::kUnsupported; |
| } |
| |
| SmallVector<Value, 2> getDotOpInitTensorDynSizes(OpBuilder& b, Location loc, |
| Value lhs, Value rhs, |
| DotOperationType type) { |
| SmallVector<Value, 2> dynShape; |
| switch (type) { |
| case DotOperationType::kMatrixMatrix: { |
| if (lhs.getType().cast<ShapedType>().isDynamicDim(0)) |
| dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0)); |
| if (rhs.getType().cast<ShapedType>().isDynamicDim(1)) |
| dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1)); |
| break; |
| } |
| case DotOperationType::kMatrixVector: { |
| if (lhs.getType().cast<ShapedType>().isDynamicDim(0)) |
| dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0)); |
| break; |
| } |
| case DotOperationType::kVectorMatrix: { |
| if (rhs.getType().cast<ShapedType>().isDynamicDim(1)) |
| dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1)); |
| break; |
| } |
| case DotOperationType::kVectorDot: |
| case DotOperationType::kUnsupported: |
| default: { |
| break; |
| } |
| } |
| return dynShape; |
| } |
| |
| template <DotOperationType op_type, typename LinalgOp> |
| class DotOpConversion : public OpConversionPattern<mhlo::DotOp> { |
| public: |
| using OpConversionPattern<mhlo::DotOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| mhlo::DotOp op, mhlo::DotOp::Adaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| if (!verifyHloOpBufferOrTensorSemantics(op)) { |
| return failure(); |
| } |
| if (getDotOperationType(op) != op_type) return failure(); |
| |
| Location loc = op.getLoc(); |
| // Convert unsigned to signed. This works because signed and unsigned |
| // integer matmul is the same operation in two's complement. |
| auto outputType = |
| typeConverter->convertType(op.getType()).cast<ShapedType>(); |
| SmallVector<Value, 2> dynShape = getDotOpInitTensorDynSizes( |
| rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type); |
| auto initTensor = getInitTensor(rewriter, loc, outputType, dynShape); |
| Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor); |
| rewriter.replaceOpWithNewOp<LinalgOp>( |
| op, TypeRange{outputType}, ValueRange{adaptor.lhs(), adaptor.rhs()}, |
| ValueRange{zeroTensor}, pruneAttributeList(op)); |
| return success(); |
| } |
| }; |
| |
| class DotGeneralBatchMatMulOpConversion |
| : public OpConversionPattern<mhlo::DotGeneralOp> { |
| public: |
| using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| mhlo::DotGeneralOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| if (!verifyHloOpBufferOrTensorSemantics(op)) { |
| return failure(); |
| } |
| if (op.getType().cast<RankedTensorType>().getRank() != 3) { |
| return rewriter.notifyMatchFailure(op, "expected a batch matmul"); |
| } |
| |
| mhlo::DotDimensionNumbersAttr dimNumbers = op.dot_dimension_numbers(); |
| auto lhsBatchingDims = dimNumbers.getLhsBatchingDimensions(); |
| auto rhsBatchingDims = dimNumbers.getRhsBatchingDimensions(); |
| auto lhsContractingDims = dimNumbers.getLhsContractingDimensions(); |
| auto rhsContractingDims = dimNumbers.getRhsContractingDimensions(); |
| if (lhsBatchingDims.size() != 1 || lhsBatchingDims[0] != 0) { |
| return rewriter.notifyMatchFailure( |
| op, "expected lhs batching dimensions exactly {0}"); |
| } |
| if (rhsBatchingDims.size() != 1 || rhsBatchingDims[0] != 0) { |
| return rewriter.notifyMatchFailure( |
| op, "expected rhs batching dimensions exactly {0}"); |
| } |
| if (lhsContractingDims.size() != 1 || lhsContractingDims[0] != 2) { |
| return rewriter.notifyMatchFailure( |
| op, "expected lhs contracting dimensions exactly {2}"); |
| } |
| if (rhsContractingDims.size() != 1 || rhsContractingDims[0] != 1) { |
| return rewriter.notifyMatchFailure( |
| op, "expected rhs contracting dimensions exactly {1}"); |
| } |
| |
| Location loc = op.getLoc(); |
| // Convert unsigned to signed. This works because signed and unsigned |
| // integer matmul is the same operation in two's complement. |
| auto outputType = |
| typeConverter->convertType(op.getType()).cast<ShapedType>(); |
| auto initTensor = |
| getInitTensorFor(rewriter, loc, outputType, op, adaptor.getOperands()); |
| Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor); |
| Operation* linalgOp = rewriter.create<linalg::BatchMatmulOp>( |
| loc, /*resultTensorTypes=*/TypeRange{outputType}, |
| /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()}, |
| /*outputBuffers=*/ValueRange{zeroTensor}, pruneAttributeList(op)); |
| |
| rewriter.replaceOp(op, linalgOp->getResults()); |
| return success(); |
| } |
| }; |
| |
| class MapOpConverter : public OpConversionPattern<mhlo::MapOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| mhlo::MapOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| if (!verifyHloOpBufferOrTensorSemantics(op)) return failure(); |
| |
| auto resultType = |
| typeConverter->convertType(op.getType()).cast<ShapedType>(); |
| assert(op.dimensions().size() == resultType.getRank() && |
| "Expected a pointwise map"); |
| |
| Location loc = op.getLoc(); |
| Value output = |
| getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); |
| SmallVector<AffineMap> indexingMaps( |
| op.getNumOperands() + 1, |
| rewriter.getMultiDimIdentityMap(resultType.getRank())); |
| |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, resultType, adaptor.getOperands(), output, indexingMaps, |
| getNParallelLoopsAttrs(resultType.getRank()), |
| /*bodyBuild=*/nullptr, pruneAttributeList(op)); |
| |
| // Convert the signature of the body. We scalarize the operands and add a |
| // scalar operand representing the output tensor. |
| Region& region = linalgOp.region(); |
| rewriter.inlineRegionBefore(op.computation(), region, region.end()); |
| TypeConverter::SignatureConversion signatureConverter(op.getNumOperands() + |
| 1); |
| |
| for (const auto& it : llvm::enumerate(op.getOperation()->getOperands())) { |
| signatureConverter.addInputs( |
| it.index(), |
| typeConverter->convertType( |
| it.value().getType().cast<ShapedType>().getElementType())); |
| } |
| signatureConverter.addInputs(resultType.getElementType()); |
| |
| rewriter.applySignatureConversion(®ion, signatureConverter); |
| rewriter.replaceOp(op, linalgOp.getResults()); |
| return success(); |
| } |
| }; |
| |
| bool isInBodyOfLinalgOps(Operation* op) { |
| auto* parentOp = op->getParentRegion()->getParentOp(); |
| return parentOp->getDialect() == |
| parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>(); |
| } |
| |
| template <typename OpTy> |
| struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> { |
| using OpConversionPattern<OpTy>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| OpTy op, typename OpTy::Adaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| if (!isInBodyOfLinalgOps(op)) { |
| return failure(); |
| } |
| if (!op.getResult().getType().template isa<TensorType>()) return failure(); |
| if (llvm::all_of(adaptor.getOperands(), [](Value arg) { |
| return arg.getType().template isa<TensorType>(); |
| })) { |
| return failure(); |
| } |
| // RemoveSignTypeConverter would give us a tensor. We also have to scalarize |
| // so do it manually. |
| Type resultType = getElementTypeOrSelf(op.getType()); |
| if (resultType.isUnsignedInteger()) { |
| resultType = IntegerType::get(resultType.getContext(), |
| resultType.getIntOrFloatBitWidth()); |
| } |
| // The scalar mapper has to know the original type. At this point the |
| // operands have been converted from `tensor<ui32>` to `i32` so recreate |
| // `ui32` from the original operands. |
| auto operandTypes = llvm::to_vector(llvm::map_range( |
| op->getOperandTypes(), [](Type t) { return getElementTypeOrSelf(t); })); |
| Value result = mhlo::MhloOpToStdScalarOp::mapOpWithArgTypes( |
| op, resultType, operandTypes, adaptor.getOperands(), &rewriter); |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| SmallVector<Value, 8> getReduceOpInitTensorDynSizes( |
| OpBuilder& b, Location loc, Value arg, ShapedType resultType, |
| ArrayRef<int64_t> reductionDims) { |
| llvm::SmallSetVector<int, 4> s; |
| for (auto dim : reductionDims) s.insert(dim); |
| |
| SmallVector<unsigned, 4> parallelDims; |
| SmallVector<Value, 8> dynShape; |
| int rank = arg.getType().cast<RankedTensorType>().getRank(); |
| for (int i = 0, j = 0; i < rank; ++i) { |
| if (s.count(i)) continue; |
| if (!resultType.isDynamicDim(j++)) continue; |
| dynShape.push_back(b.create<tensor::DimOp>(loc, arg, i)); |
| } |
| |
| return dynShape; |
| } |
| |
| class ReduceRegionReturnOpConversion |
| : public OpConversionPattern<mhlo::ReturnOp> { |
| public: |
| using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| mhlo::ReturnOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| if (!isInBodyOfLinalgOps(op)) { |
| return failure(); |
| } |
| SmallVector<Value, 4> operands(adaptor.getOperands()); |
| for (size_t i = 0; i < operands.size(); ++i) { |
| if (operands[i].getType().isa<ShapedType>()) { |
| auto loc = operands[i].getLoc(); |
| operands[i] = rewriter.create<tensor::ExtractOp>(loc, operands[i]); |
| } |
| } |
| rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, operands); |
| return success(); |
| } |
| }; |
| |
| class ReduceConversion : public OpConversionPattern<mhlo::ReduceOp> { |
| public: |
| using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| mhlo::ReduceOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| Location loc = op.getLoc(); |
| |
| int numOperands = static_cast<int>(adaptor.operands().size()); |
| |
| if (llvm::any_of(adaptor.operands(), [](Value v) { |
| return !v.getType().cast<ShapedType>().getRank(); |
| })) { |
| return rewriter.notifyMatchFailure(op, "expects known-rank args"); |
| } |
| auto srcRank = adaptor.operands()[0].getType().cast<ShapedType>().getRank(); |
| |
| SmallVector<int64_t, 4> reductionDims = extract1DVector(op.dimensions()); |
| |
| SmallVector<Type> resultTypes; |
| if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) |
| return failure(); |
| |
| SmallVector<Value> operands, outputs; |
| SmallVector<AffineMap, 3> indexingMaps; |
| for (auto values : |
| llvm::zip(adaptor.operands(), adaptor.init_values(), resultTypes)) { |
| // Check if init_value is constant. If so, inline the value into the |
| // region. |
| Value operand = std::get<0>(values); |
| Value initValue = std::get<1>(values); |
| Type resultType = std::get<2>(values); |
| initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue); |
| |
| operands.push_back(operand); |
| SmallVector<Value, 8> dynShape = getReduceOpInitTensorDynSizes( |
| rewriter, loc, operand, resultType, reductionDims); |
| auto initTensor = getInitTensor(rewriter, loc, resultType, dynShape); |
| Value filledTensor = |
| rewriter.create<linalg::FillOp>(loc, initValue, initTensor).result(); |
| outputs.push_back(filledTensor); |
| } |
| |
| // Prepare indexing maps for linalg generic op. The elements are for src |
| // and dst. Transpose `src` to make the reduction loops be the innermost, |
| // because it's easier to fully utilize processors. |
| indexingMaps.append( |
| numOperands, getTransposeMapForReduction(rewriter.getContext(), |
| (int)srcRank, reductionDims)); |
| |
| // The indexing map of `dst` should drop the reduction loops. Since the |
| // reduction loops now are all in the innermost, drops |
| // `reduction_dims.size()` dimensions. We don't need an inverse |
| // permutation here because they are the same. |
| SmallVector<AffineExpr, 4> exprs; |
| for (int i = 0, e = srcRank - reductionDims.size(); i < e; ++i) |
| exprs.push_back(rewriter.getAffineDimExpr(i)); |
| indexingMaps.append(numOperands, |
| AffineMap::get(srcRank, /*symbolCount=*/0, exprs, |
| rewriter.getContext())); |
| |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, /*resultTensorTypes=*/resultTypes, operands, |
| /*outputBuffers=*/ValueRange{outputs}, indexingMaps, |
| getParallelAndReductionIterators(srcRank, reductionDims.size()), |
| /*bodyBuild=*/nullptr, pruneAttributeList(op)); |
| |
| // Convert the signature of the body. The reduce op region apply function |
| // has a signature (lhs, rhs) -> output, all of the same tensor type t. |
| // This is converted to a function with the same signature but with |
| // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will |
| // be converted to "(f32, f32, f32)". |
| Region& region = linalgOp.region(); |
| rewriter.inlineRegionBefore(op.body(), region, region.end()); |
| TypeConverter::SignatureConversion signatureConverter(numOperands * 2); |
| |
| // map operand and init values's types |
| for (const auto& it : llvm::enumerate(op.getOperation()->getOperands())) { |
| signatureConverter.addInputs( |
| it.index(), |
| typeConverter->convertType( |
| it.value().getType().cast<ShapedType>().getElementType())); |
| } |
| |
| rewriter.applySignatureConversion(®ion, signatureConverter); |
| rewriter.replaceOp(op, linalgOp.getResults()); |
| return success(); |
| } |
| }; |
| |
| // Decomposes a pad with negative edge padding into a pad without negative edge |
| // padding and a tensor.extract_slice. |
| struct PadOpNegativePaddingConversion |
| : public OpConversionPattern<mhlo::PadOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::PadOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const override { |
| SmallVector<int64_t, 4> padLow; |
| SmallVector<int64_t, 4> padHigh; |
| SmallVector<OpFoldResult, 4> sliceStarts; |
| |
| bool hasNegativePadding = false; |
| for (int64_t low : op.edge_padding_low().getValues<int64_t>()) { |
| if (low >= 0) { |
| padLow.push_back(low); |
| sliceStarts.push_back(rewriter.getIndexAttr(0)); |
| } else { |
| padLow.push_back(0); |
| sliceStarts.push_back(rewriter.getIndexAttr(-low)); |
| hasNegativePadding = true; |
| } |
| } |
| |
| for (int64_t high : op.edge_padding_high().getValues<int64_t>()) { |
| if (high >= 0) { |
| padHigh.push_back(high); |
| } else { |
| padHigh.push_back(-high); |
| hasNegativePadding = true; |
| } |
| } |
| |
| // If there's no negative edge padding we're done. |
| if (!hasNegativePadding) return failure(); |
| |
| // Create a new pad op with the positive values. |
| Value pad = rewriter.create<mhlo::PadOp>( |
| op.getLoc(), adaptor.operand(), adaptor.padding_value(), |
| rewriter.getI64TensorAttr(padLow), rewriter.getI64TensorAttr(padHigh), |
| op.interior_padding()); |
| |
| // Then slice according to the negative edge padding. Static shapes only for |
| // now. |
| if (!op.getType().hasStaticShape()) return failure(); |
| SmallVector<OpFoldResult, 4> sizes(llvm::map_range( |
| op.getType().getShape(), |
| [&](int64_t dim) { return rewriter.getIndexAttr(dim); })); |
| SmallVector<OpFoldResult, 4> strides(sliceStarts.size(), |
| rewriter.getIndexAttr(1)); |
| rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(op, pad, sliceStarts, |
| sizes, strides); |
| return success(); |
| } |
| }; |
| |
| /// Converts mhlo.pad operation to tensor.pad or tensor.insert_slice. |
| struct PadOpConversion : public OpConversionPattern<mhlo::PadOp> { |
| using OpConversionPattern<mhlo::PadOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::PadOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const override { |
| auto loc = op.getLoc(); |
| auto resultType = typeConverter->convertType(op.getResult().getType()); |
| |
| // Negative edge padding is decomposed separately. |
| auto isNegative = [](const APInt& intVal) { return intVal.isNegative(); }; |
| if (llvm::any_of(op.edge_padding_low().getValues<APInt>(), isNegative) || |
| llvm::any_of(op.edge_padding_high().getValues<APInt>(), isNegative)) |
| return failure(); |
| |
| Value paddingVal = |
| rewriter.createOrFold<tensor::ExtractOp>(loc, adaptor.padding_value()); |
| |
| SmallVector<OpFoldResult, 4> low( |
| op.edge_padding_low().getValues<IntegerAttr>()); |
| |
| // If there is no interior padding lower to tensor.pad directly. |
| if (llvm::all_of(op.interior_padding().getValues<APInt>(), |
| [](const APInt& intVal) { return intVal.isZero(); })) { |
| SmallVector<OpFoldResult, 4> high( |
| op.edge_padding_high().getValues<IntegerAttr>()); |
| auto padTensorOp = tensor::createPadScalarOp( |
| resultType, adaptor.operand(), paddingVal, low, high, |
| /*nofold=*/false, loc, rewriter); |
| rewriter.replaceOp(op, padTensorOp.getResult()); |
| return success(); |
| } |
| |
| // We have interior padding, which can be lowered to tensor.insert_slice. |
| // Start by filling a result-sized tensor with the pad value. |
| auto initTensor = |
| getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); |
| auto fill = |
| rewriter.create<linalg::FillOp>(loc, paddingVal, initTensor).result(); |
| |
| // Get sizes of the original operand. |
| auto operandType = adaptor.operand().getType().cast<ShapedType>(); |
| auto sizes = llvm::to_vector<4>(llvm::map_range( |
| llvm::seq<int64_t>(0, operandType.getRank()), |
| [&](int64_t dim) -> OpFoldResult { |
| if (!operandType.isDynamicDim(dim)) |
| return rewriter.getIndexAttr(operandType.getDimSize(dim)); |
| return rewriter.create<tensor::DimOp>(loc, adaptor.operand(), dim) |
| .result(); |
| })); |
| // Map interior padding to strides. |
| auto strides = llvm::to_vector<4>( |
| llvm::map_range(op.interior_padding().getValues<IntegerAttr>(), |
| [&](IntegerAttr stride) -> OpFoldResult { |
| return rewriter.getIntegerAttr(stride.getType(), |
| stride.getValue() + 1); |
| })); |
| |
| rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( |
| op, adaptor.operand(), fill, low, sizes, strides); |
| return success(); |
| } |
| }; |
| |
| // Apply dilation and padding to the input of a convolution. |
| Value applyConvolutionPadding(Location loc, Value input, |
| DenseIntElementsAttr padding, |
| DenseIntElementsAttr lhsDilation, |
| OpBuilder& rewriter) { |
| if ((!padding || isSplatValue(padding, 0)) && |
| (!lhsDilation || isSplatValue(lhsDilation, 1))) |
| return input; |
| |
| auto inputType = input.getType().cast<ShapedType>(); |
| auto rank = inputType.getRank(); |
| |
| // Translate window padding into low/high padding. |
| SmallVector<int64_t, 8> padLow(rank, 0); |
| SmallVector<int64_t, 8> padHigh(rank, 0); |
| if (padding) { |
| // The padding attribute contains two values per dimension, but excludes the |
| // batch and feature dimensions. |
| assert(rank * 2 == padding.size() + 4 && |
| "There should be 2 padding values per dimension, i.e low and high."); |
| for (auto i : llvm::seq<int64_t>(0, padding.size() / 2)) { |
| padLow[i + 1] = padding.getValues<int64_t>()[i * 2]; |
| padHigh[i + 1] = padding.getValues<int64_t>()[i * 2 + 1]; |
| } |
| } |
| |
| // Translate input dilation into interior padding. |
| SmallVector<int64_t, 8> padInterior(rank, 0); |
| if (lhsDilation) { |
| assert(rank == lhsDilation.size() + 2); |
| for (auto i : llvm::seq<int64_t>(0, lhsDilation.size())) { |
| padInterior[i + 1] = lhsDilation.getValues<int64_t>()[i] - 1; |
| } |
| } |
| |
| auto indexType = rewriter.getIndexType(); |
| auto attrType = RankedTensorType::get({rank}, indexType); |
| Value zero = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getZeroAttr( |
| RankedTensorType::get({}, inputType.getElementType()))); |
| return rewriter.create<mhlo::PadOp>( |
| loc, input, zero, DenseIntElementsAttr::get(attrType, padLow), |
| DenseIntElementsAttr::get(attrType, padHigh), |
| DenseIntElementsAttr::get(attrType, padInterior)); |
| } |
| |
| /// Converts mhlo.conv operation to linalg named op. This only covers normal |
| /// convolution cases. The op must have canonical dimension numbers. Depthwise |
| /// convolution and pointwise convolution are not handled in the conversion. |
| struct NormalConvolutionOpConversion |
| : public OpConversionPattern<mhlo::ConvolutionOp> { |
| using OpConversionPattern<mhlo::ConvolutionOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::ConvolutionOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const override { |
| if (!hasCanonicalDimensionNumbers(op.dimension_numbers())) return failure(); |
| if (op.feature_group_count() != 1u) return failure(); |
| |
| Location loc = op.getLoc(); |
| Value input = adaptor.lhs(); |
| Value filter = adaptor.rhs(); |
| auto resultType = |
| typeConverter->convertType(op.getResult().getType()).cast<ShapedType>(); |
| int64_t rank = resultType.getRank(); |
| |
| // The output shape is N spatial_dims F. |
| SmallVector<Value, 8> dynSizes; |
| if (resultType.isDynamicDim(0)) { |
| dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); |
| } |
| for (int64_t i = 1, e = rank - 1; i < e; ++i) { |
| if (resultType.isDynamicDim(i)) { |
| return rewriter.notifyMatchFailure( |
| op, "expected output spatial dims to be static shapes"); |
| } |
| } |
| if (resultType.isDynamicDim(rank - 1)) { |
| dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, filter, rank - 1)); |
| } |
| Value initTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, dynSizes, resultType.getShape(), resultType.getElementType()); |
| Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor); |
| linalg::LinalgOp res; |
| Attribute strides = op.window_stridesAttr(); |
| Attribute dilations = op.rhs_dilationAttr(); |
| |
| // Apply padding and input dilation. |
| input = applyConvolutionPadding(loc, input, op.paddingAttr(), |
| op.lhs_dilationAttr(), rewriter); |
| |
| switch (rank) { |
| case 2: { |
| res = rewriter.create<linalg::MatmulOp>( |
| loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, |
| pruneAttributeList(op)); |
| break; |
| } |
| case 3: { |
| res = rewriter.create<linalg::Conv1DNwcWcfOp>( |
| loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, |
| strides, dilations, pruneAttributeList(op)); |
| break; |
| } |
| case 4: { |
| res = rewriter.create<linalg::Conv2DNhwcHwcfOp>( |
| loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, |
| strides, dilations, pruneAttributeList(op)); |
| break; |
| } |
| case 5: { |
| res = rewriter.create<linalg::Conv3DNdhwcDhwcfOp>( |
| loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, |
| strides, dilations, pruneAttributeList(op)); |
| break; |
| } |
| default: |
| return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op"); |
| } |
| rewriter.replaceOp(op, res.getOperation()->getResults()); |
| return success(); |
| } |
| }; |
| |
| /// Converts mhlo.convolution operation to |
| /// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or |
| /// depthwise_conv_2d_input_nhwc_filter_hwc op. |
| struct DepthwiseConvolutionOpConversion |
| : public OpConversionPattern<mhlo::ConvolutionOp> { |
| using OpConversionPattern<mhlo::ConvolutionOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::ConvolutionOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const override { |
| if (op.batch_group_count() != 1) return failure(); |
| // Fall into the normal convolution cases. |
| if (op.feature_group_count() == 1) return failure(); |
| |
| const mhlo::ConvDimensionNumbersAttr& dimensionNumbers = |
| op.dimension_numbers(); |
| const auto spatialRank = |
| llvm::size(dimensionNumbers.getInputSpatialDimensions()); |
| if (spatialRank == 0 || spatialRank > 3) { |
| return rewriter.notifyMatchFailure(op, "only support up to 3D for now"); |
| } |
| |
| // Make sure that this is depthwise convolution. |
| int64_t inputFeatureDim = dimensionNumbers.getInputFeatureDimension(); |
| int64_t inputFeatureCount = |
| op.lhs().getType().cast<ShapedType>().getDimSize(inputFeatureDim); |
| if (op.feature_group_count() != inputFeatureCount) { |
| return rewriter.notifyMatchFailure(op, "not depth-wise convolution"); |
| } |
| |
| // Make sure that this convolution has a canonical form. |
| if (!hasCanonicalDimensionNumbers(dimensionNumbers)) { |
| return rewriter.notifyMatchFailure(op, "does not have canonical form"); |
| } |
| |
| Attribute windowStrides; |
| if (op.window_strides()) { |
| windowStrides = op.window_strides().getValue(); |
| } else { |
| windowStrides = SplatElementsAttr::get( |
| VectorType::get({spatialRank}, rewriter.getI64Type()), |
| rewriter.getI64IntegerAttr(1)); |
| } |
| |
| Attribute rhsDilation; |
| if (op.rhs_dilation()) { |
| rhsDilation = op.rhs_dilation().getValue(); |
| } else { |
| rhsDilation = SplatElementsAttr::get( |
| VectorType::get({spatialRank}, rewriter.getI64Type()), |
| rewriter.getI64IntegerAttr(1)); |
| } |
| |
| Location loc = op.getLoc(); |
| Value input = adaptor.lhs(); |
| Value filter = adaptor.rhs(); |
| auto resultType = typeConverter->convertType(op.getResult().getType()) |
| .cast<RankedTensorType>(); |
| if (!resultType.hasStaticShape()) { |
| return rewriter.notifyMatchFailure(op, |
| "expected output has static shapes"); |
| } |
| |
| // Apply padding and input dilation. |
| input = applyConvolutionPadding(loc, input, op.paddingAttr(), |
| op.lhs_dilationAttr(), rewriter); |
| |
| auto filterDims = |
| llvm::to_vector<4>(op.rhs().getType().cast<ShapedType>().getShape()); |
| |
| auto getReassociationIndicesToCollapseLastTwoDims = [](Value v) { |
| SmallVector<ReassociationIndices> reassociations; |
| int64_t rank = v.getType().cast<ShapedType>().getRank(); |
| for (int64_t i = 0; i < rank - 1; ++i) reassociations.emplace_back(1, i); |
| reassociations.back().push_back(rank - 1); |
| return reassociations; |
| }; |
| |
| int64_t kernelInputFeatureDimension = |
| dimensionNumbers.getKernelInputFeatureDimension(); |
| int64_t kernelOutputFeatureDimension = |
| dimensionNumbers.getKernelOutputFeatureDimension(); |
| if (filterDims[kernelInputFeatureDimension] * |
| filterDims[kernelOutputFeatureDimension] != |
| op.feature_group_count()) { |
| // For cases where channel multiplier != 1 |
| |
| // Reshaping filter shape |
| // [filter_height, filter_width, 1, kernel-output-feature]. |
| // to |
| // [filter_height, filter_width, feature_group_count, |
| // kernel-output-feature/feature_group_count ] |
| SmallVector<int64_t> reshapedFilterDims; |
| reshapedFilterDims.assign(filterDims.begin(), filterDims.end()); |
| auto reshapedFilter = filter; |
| if (filterDims[kernelInputFeatureDimension] == 1) { |
| reshapedFilterDims[kernelInputFeatureDimension] = |
| op.feature_group_count(); |
| reshapedFilterDims[kernelOutputFeatureDimension] /= |
| op.feature_group_count(); |
| auto reshapedFilterType = RankedTensorType::get( |
| reshapedFilterDims, |
| op.rhs().getType().cast<RankedTensorType>().getElementType()); |
| |
| reshapedFilter = |
| rewriter.create<mhlo::ReshapeOp>(loc, reshapedFilterType, filter); |
| } |
| |
| auto outputDims = resultType.getShape(); |
| auto channelMultiplier = reshapedFilterDims.back(); |
| SmallVector<int64_t> reshapedOutputDims; |
| reshapedOutputDims.assign(outputDims.begin(), outputDims.end()); |
| reshapedOutputDims.push_back(channelMultiplier); |
| reshapedOutputDims[reshapedOutputDims.size() - 2] /= channelMultiplier; |
| |
| Value initTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, reshapedOutputDims, resultType.getElementType()); |
| Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor); |
| |
| auto reshapedOutputType = RankedTensorType::get( |
| reshapedOutputDims, resultType.getElementType()); |
| Value conv; |
| switch (spatialRank) { |
| case 1: |
| conv = |
| rewriter |
| .create<linalg::DepthwiseConv1DNwcWcmOp>( |
| loc, reshapedOutputType, |
| ValueRange{input, reshapedFilter}, ValueRange{zeroTensor}, |
| windowStrides, rhsDilation, pruneAttributeList(op)) |
| .getResult(0); |
| break; |
| case 2: |
| conv = |
| rewriter |
| .create<linalg::DepthwiseConv2DNhwcHwcmOp>( |
| loc, reshapedOutputType, |
| ValueRange{input, reshapedFilter}, ValueRange{zeroTensor}, |
| windowStrides, rhsDilation, pruneAttributeList(op)) |
| .getResult(0); |
| break; |
| case 3: |
| conv = |
| rewriter |
| .create<linalg::DepthwiseConv3DNdhwcDhwcmOp>( |
| loc, reshapedOutputType, |
| ValueRange{input, reshapedFilter}, ValueRange{zeroTensor}, |
| windowStrides, rhsDilation, pruneAttributeList(op)) |
| .getResult(0); |
| break; |
| } |
| |
| // Create a Linalg reshape op that converts the output from 5 dimensions |
| // into 4 dimensions (by collapsing the last two dimensions). This is |
| // needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns |
| // 5 dimensions for the output. |
| rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( |
| op, resultType, conv, |
| getReassociationIndicesToCollapseLastTwoDims(conv)); |
| } else { |
| // For cases where channel multiplier == 1 |
| Value initTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, resultType.getShape(), resultType.getElementType()); |
| Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor); |
| |
| // Create a Linalg reshape op that converts the filter from 4 dimensions |
| // into 3 dimensions (by droping the unit dimension). This is needed |
| // because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3 |
| // dimensions for the filter. |
| |
| filterDims[filterDims.size() - 2] = |
| static_cast<int64_t>(op.feature_group_count()); |
| filterDims.pop_back(); |
| |
| RankedTensorType filterShape = |
| RankedTensorType::get(filterDims, op.getType().getElementType()); |
| |
| Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>( |
| loc, filterShape, filter, |
| getReassociationIndicesToCollapseLastTwoDims(filter)); |
| |
| switch (spatialRank) { |
| case 1: |
| rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcOp>( |
| op, resultType, ValueRange{input, reshapedFilter}, |
| ValueRange{zeroTensor}, windowStrides, rhsDilation, |
| pruneAttributeList(op)); |
| break; |
| case 2: |
| rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>( |
| op, resultType, ValueRange{input, reshapedFilter}, |
| ValueRange{zeroTensor}, windowStrides, rhsDilation, |
| pruneAttributeList(op)); |
| break; |
| case 3: |
| rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>( |
| op, resultType, ValueRange{input, reshapedFilter}, |
| ValueRange{zeroTensor}, windowStrides, rhsDilation, |
| pruneAttributeList(op)); |
| break; |
| } |
| } |
| |
| return success(); |
| } |
| }; |
| |
| struct ReduceWindowOpOnTensorsGenericConversion |
| : public OpConversionPattern<mhlo::ReduceWindowOp> { |
| using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| mhlo::ReduceWindowOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const override { |
| MLIRContext* ctx = op->getContext(); |
| Location loc = op.getLoc(); |
| llvm::SmallVector<Value> initValues = adaptor.init_values(); |
| llvm::SmallVector<Type> resultTypes = llvm::to_vector(op.getResultTypes()); |
| auto numOperands = initValues.size(); |
| |
| llvm::SmallVector<int64_t> windowDimensions = |
| extract1DVector(op.window_dimensions()); |
| |
| llvm::SmallVector<int64_t> padding; |
| if (op.padding()) { |
| padding = extract1DVector(*op.padding()); |
| } |
| |
| llvm::SmallVector<int64_t> baseDilations; |
| if (op.base_dilations()) { |
| baseDilations = extract1DVector(*op.base_dilations()); |
| } |
| |
| llvm::SmallVector<int64_t> windowStrides(windowDimensions.size(), 1); |
| if (op.window_strides()) { |
| windowStrides = extract1DVector(*op.window_strides()); |
| } |
| |
| llvm::SmallVector<int64_t> windowDilations(windowDimensions.size(), 1); |
| if (op.window_dilations()) { |
| windowDilations = extract1DVector(*op.window_dilations()); |
| } |
| |
| auto rank = static_cast<int64_t>(windowDimensions.size()); |
| SmallVector<AffineExpr, 2> srcExprs; |
| SmallVector<AffineExpr, 2> windowExprs; |
| SmallVector<AffineExpr, 2> dstExprs; |
| SmallVector<int64_t> filteredWindowDims; |
| |
| int windowDim = 0; |
| for (int64_t i = 0; i < rank; i++) { |
| AffineExpr srcExpr = mlir::getAffineDimExpr(i, ctx); |
| |
| if (windowStrides[i] != 1) srcExpr = srcExpr * windowStrides[i]; |
| |
| if (windowDimensions[i] != 1) { |
| filteredWindowDims.push_back(windowDimensions[i]); |
| AffineExpr windowExpr = mlir::getAffineDimExpr(rank + windowDim, ctx); |
| windowExprs.push_back(windowExpr); |
| |
| if (windowDilations[i] != 1) |
| windowExpr = windowExpr * windowDilations[i]; |
| |
| srcExpr = srcExpr + windowExpr; |
| windowDim++; |
| } |
| |
| srcExprs.push_back(srcExpr); |
| dstExprs.push_back(mlir::getAffineDimExpr(i, ctx)); |
| } |
| |
| SmallVector<AffineMap, 4> inferredMaps(3, AffineMap::get(ctx)); |
| if (rank > 0) |
| inferredMaps = |
| AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}); |
| |
| SmallVector<AffineMap, 4> indexingMaps; |
| |
| indexingMaps.append(numOperands, inferredMaps[0]); |
| indexingMaps.append(1, inferredMaps[1]); |
| indexingMaps.append(numOperands, inferredMaps[2]); |
| |
| // Setup the initial values. |
| llvm::SmallVector<Value> broadcastValues; |
| for (uint64_t i = 0, s = initValues.size(); i < s; i++) { |
| Value initValue = initValues[i]; |
| auto resultTy = resultTypes[i].cast<ShapedType>(); |
| if (!resultTy.hasStaticShape()) return failure(); |
| |
| auto broadcastSizes = rewriter.getI64TensorAttr(resultTy.getShape()); |
| broadcastValues.push_back(rewriter.create<mhlo::BroadcastOp>( |
| loc, resultTy, initValue, broadcastSizes)); |
| } |
| |
| llvm::SmallVector<Value> inputs = llvm::to_vector(adaptor.operands()); |
| |
| // Pad as necessary. |
| if (llvm::any_of(padding, [](int64_t v) { return v != 0; }) || |
| llvm::any_of(baseDilations, [](int64_t v) { return v != 1; })) { |
| llvm::SmallVector<int64_t> staticLows(rank, 0); |
| llvm::SmallVector<int64_t> staticHighs(rank, 0); |
| for (int i = 0; i < padding.size(); i += 2) { |
| staticLows[i / 2] = padding[i]; |
| staticHighs[i / 2] = padding[i + 1]; |
| } |
| // Translate base dilation into interior padding. |
| llvm::SmallVector<int64_t> staticInteriors(rank, 0); |
| for (const auto& dilation : llvm::enumerate(baseDilations)) { |
| staticInteriors[dilation.index()] = dilation.value() - 1; |
| } |
| |
| auto padAttrType = RankedTensorType::get({rank}, rewriter.getIndexType()); |
| auto padLows = DenseIntElementsAttr::get(padAttrType, staticLows); |
| auto padHighs = DenseIntElementsAttr::get(padAttrType, staticHighs); |
| auto padInteriors = |
| DenseIntElementsAttr::get(padAttrType, staticInteriors); |
| |
| for (auto values : llvm::zip(inputs, initValues)) { |
| auto& input = std::get<0>(values); |
| auto& initValue = std::get<1>(values); |
| input = rewriter.create<mhlo::PadOp>(loc, input, initValue, padLows, |
| padHighs, padInteriors); |
| } |
| } |
| |
| // Add the extra input for the reduction dimension. |
| inputs.push_back(rewriter.create<linalg::InitTensorOp>( |
| loc, filteredWindowDims, rewriter.getF32Type())); |
| |
| rewriter.setInsertionPoint(op); |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, /*resultTensors=*/resultTypes, |
| /*inputs=*/inputs, |
| /*outputs=*/broadcastValues, indexingMaps, |
| getParallelAndReductionIterators(rank + filteredWindowDims.size(), |
| filteredWindowDims.size()), |
| /*bodyBuild=*/nullptr, pruneAttributeList(op)); |
| |
| // Convert the signature of the body. This includes converting scalar |
| // tensors to their scalar values and inserting an additional block arg for |
| // the window arg. |
| Region& region = linalgOp.region(); |
| rewriter.cloneRegionBefore(op.body(), region, region.end()); |
| |
| TypeConverter::SignatureConversion signatureConverter( |
| inputs.size() + op->getNumResults() - 1); |
| |
| for (uint64_t i = 0, s = inputs.size(); i < s - 1; i++) { |
| signatureConverter.addInputs( |
| i, inputs[i].getType().cast<ShapedType>().getElementType()); |
| } |
| |
| signatureConverter.addInputs( |
| inputs.back().getType().cast<ShapedType>().getElementType()); |
| |
| for (uint64_t i = 0, s = resultTypes.size(); i < s; i++) { |
| auto idx = inputs.size() + i - 1; |
| signatureConverter.addInputs( |
| idx, resultTypes[i].cast<ShapedType>().getElementType()); |
| } |
| |
| rewriter.applySignatureConversion(®ion, signatureConverter); |
| rewriter.replaceOp(op, linalgOp.getResults()); |
| return success(); |
| } |
| }; |
| |
| struct ReduceWindowOpConversion |
| : public OpConversionPattern<mhlo::ReduceWindowOp> { |
| using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern; |
| |
| /// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of |
| /// the pooling is determined based on the body of the reduce window |
| /// operation. This class enumerates the different variants. |
| enum class PoolingType { |
| kInvalid, |
| k2DMin, |
| k3DMin, |
| k2DMax, |
| k3DMax, |
| k2DAdd, |
| k3DAdd, |
| }; |
| |
| static PoolingType getPoolingType(mhlo::ReduceWindowOp reduceOp, |
| int resultIndex) { |
| auto rank = |
| reduceOp.getResultTypes()[resultIndex].cast<ShapedType>().getRank(); |
| if (Operation* op = reduceOp.getReductionOp(resultIndex)) { |
| if (isa<mhlo::MinOp>(*op) && rank == 4) return PoolingType::k2DMin; |
| if (isa<mhlo::MinOp>(*op) && rank == 5) return PoolingType::k3DMin; |
| if (isa<mhlo::MaxOp>(*op) && rank == 4) return PoolingType::k2DMax; |
| if (isa<mhlo::MaxOp>(*op) && rank == 5) return PoolingType::k3DMax; |
| if (isa<mhlo::AddOp>(*op) && rank == 4) return PoolingType::k2DAdd; |
| if (isa<mhlo::AddOp>(*op) && rank == 5) return PoolingType::k3DAdd; |
| } |
| return PoolingType::kInvalid; |
| } |
| |
| LogicalResult matchAndRewrite( |
| mhlo::ReduceWindowOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const override { |
| auto loc = op.getLoc(); |
| int rank = op.getResultTypes()[0].cast<ShapedType>().getRank(); |
| if (rank != 4 && rank != 5) { |
| return rewriter.notifyMatchFailure( |
| op, "expected NHWC/NDHWC pooling-based op"); |
| } |
| |
| if (op.padding() && !isSplatValue(*op.padding(), 0)) { |
| return rewriter.notifyMatchFailure(op, "require paddings are all zero"); |
| } |
| |
| int lastDim = rank - 1; |
| SmallVector<int64_t, 2> fakeWindowShapes; |
| for (int i = 1; i < lastDim; ++i) { |
| fakeWindowShapes.push_back( |
| op.window_dimensions().getValues<int64_t>()[i]); |
| } |
| |
| if (op.window_strides() && |
| (op.window_strides().getValue().getValues<int64_t>()[0] != 1 || |
| op.window_strides().getValue().getValues<int64_t>()[lastDim] != 1)) { |
| return rewriter.notifyMatchFailure( |
| op, "expected window_strides to be [1,x,y,(z),1]"); |
| } |
| if (op.window_dimensions() && |
| (op.window_dimensions().getValues<int64_t>()[0] != 1 || |
| op.window_dimensions().getValues<int64_t>()[lastDim] != 1)) { |
| return rewriter.notifyMatchFailure( |
| op, "expected window_dimensions to be [1,x,y,(z),1]"); |
| } |
| |
| Attribute strides; |
| SmallVector<int64_t> vec; |
| if (op.window_stridesAttr()) { |
| for (int i = 1; i < lastDim; ++i) { |
| vec.push_back(op.window_strides().getValue().getValues<int64_t>()[i]); |
| } |
| } else { |
| vec.assign(rank - 2, 1); |
| } |
| strides = rewriter.getI64VectorAttr(vec); |
| |
| Attribute dilations; |
| vec.clear(); |
| if (op.window_dilations()) { |
| for (int i = 1; i < lastDim; ++i) { |
| vec.push_back(op.window_dilations().getValue().getValues<int64_t>()[i]); |
| } |
| } else { |
| vec.assign(rank - 2, 1); |
| } |
| dilations = rewriter.getI64VectorAttr(vec); |
| |
| SmallVector<Value> poolingOps; |
| |
| ValueRange operands = adaptor.operands(); |
| ValueRange initValues = adaptor.init_values(); |
| for (auto it : llvm::zip(op.getResults(), operands, initValues)) { |
| OpResult result = std::get<0>(it); |
| Value input = std::get<1>(it); |
| Value initValue = std::get<2>(it); |
| auto resultType = result.getType().cast<ShapedType>(); |
| if (!input.getType().cast<ShapedType>().getElementType().isF32()) { |
| return rewriter.notifyMatchFailure(op, |
| "expected element type to be f32"); |
| } |
| |
| // Create a fake window dimension. |
| auto fakeWindowDims = rewriter.create<linalg::InitTensorOp>( |
| loc, fakeWindowShapes, resultType.getElementType()); |
| |
| SmallVector<Value> resultDynamicDims; |
| for (auto& en : llvm::enumerate(resultType.getShape())) { |
| if (en.value() != ShapedType::kDynamicSize) continue; |
| Value dimSize = rewriter.create<tensor::DimOp>(loc, input, en.index()); |
| if (en.index() == 0 || en.index() == rank - 1) { |
| // batch dims and channel dims can be derived from input dims |
| // directly. |
| resultDynamicDims.push_back(dimSize); |
| } else { |
| auto i = en.index() - 1; |
| auto stride = |
| strides.cast<DenseIntElementsAttr>().getValues<int64_t>()[i]; |
| auto dilation = |
| dilations.cast<DenseIntElementsAttr>().getValues<int64_t>()[i]; |
| // let j = i * stride |
| // output[i] = reduce( input[j, j + window_size * dilation) ) |
| Value offset = rewriter.create<arith::ConstantIndexOp>( |
| loc, fakeWindowShapes[i] * dilation); |
| dimSize = rewriter.create<arith::SubIOp>(loc, dimSize, offset); |
| dimSize = rewriter.create<arith::DivUIOp>( |
| loc, dimSize, |
| rewriter.create<arith::ConstantIndexOp>(loc, stride)); |
| dimSize = rewriter.create<arith::AddIOp>( |
| loc, dimSize, rewriter.create<arith::ConstantIndexOp>(loc, 1)); |
| resultDynamicDims.push_back(dimSize); |
| } |
| } |
| Value initTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, resultDynamicDims, resultType.getShape(), |
| resultType.getElementType()); |
| |
| initValue = rewriter.create<tensor::ExtractOp>(loc, initValue); |
| Value filledInitTensor = |
| rewriter.create<linalg::FillOp>(loc, initValue, initTensor) |
| .getResult(0); |
| auto createOp = [&](auto* typePtr) -> linalg::LinalgOp { |
| return cast<linalg::LinalgOp>( |
| rewriter |
| .create<std::remove_pointer_t<decltype(typePtr)>>( |
| loc, ArrayRef<Type>{resultType}, |
| ValueRange{input, fakeWindowDims.getResult()}, |
| filledInitTensor, strides, dilations, |
| pruneAttributeList(op)) |
| .getOperation()); |
| }; |
| linalg::LinalgOp poolingOp; |
| PoolingType poolingType = getPoolingType(op, result.getResultNumber()); |
| switch (poolingType) { |
| case PoolingType::k2DMin: { |
| poolingOp = createOp(static_cast<linalg::PoolingNhwcMinOp*>(nullptr)); |
| break; |
| } |
| case PoolingType::k3DMin: { |
| poolingOp = |
| createOp(static_cast<linalg::PoolingNdhwcMinOp*>(nullptr)); |
| break; |
| } |
| case PoolingType::k2DMax: { |
| poolingOp = createOp(static_cast<linalg::PoolingNhwcMaxOp*>(nullptr)); |
| break; |
| } |
| case PoolingType::k3DMax: { |
| poolingOp = |
| createOp(static_cast<linalg::PoolingNdhwcMaxOp*>(nullptr)); |
| break; |
| } |
| case PoolingType::k2DAdd: { |
| poolingOp = createOp(static_cast<linalg::PoolingNhwcSumOp*>(nullptr)); |
| break; |
| } |
| case PoolingType::k3DAdd: { |
| poolingOp = |
| createOp(static_cast<linalg::PoolingNdhwcSumOp*>(nullptr)); |
| break; |
| } |
| case PoolingType::kInvalid: |
| return rewriter.notifyMatchFailure(op, "unknown reduction operation"); |
| } |
| poolingOps.push_back(poolingOp->getResult(0)); |
| } |
| rewriter.replaceOp(op, poolingOps); |
| return success(); |
| } |
| }; |
| |
| /// Converts xla-hlo.torch_index_select op to a linalg.generic op. |
| struct TorchIndexSelectOpConversion |
| : public OpConversionPattern<mhlo::TorchIndexSelectOp> { |
| using OpConversionPattern<mhlo::TorchIndexSelectOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::TorchIndexSelectOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| int axis = static_cast<int>(op.dim()); |
| int batch = static_cast<int>(op.batch_dims()); |
| auto indexShapedType = adaptor.index().getType().cast<ShapedType>(); |
| int numIndices = static_cast<int>(indexShapedType.getRank()); |
| auto operandShapedType = adaptor.operand().getType().cast<ShapedType>(); |
| if (axis < 0) axis += static_cast<int>(operandShapedType.getRank()); |
| if (batch < 0) batch += numIndices; |
| |
| Location loc = op.getLoc(); |
| auto resultType = this->typeConverter->convertType(op.getResult().getType()) |
| .cast<ShapedType>(); |
| int rank = static_cast<int>(resultType.getRank()); |
| |
| // The output shape is |
| // `params[:axis] + indices[batch_dims:] + params[axis + 1:]` |
| SmallVector<Value, 4> dynSizes; |
| for (int i = 0; i < rank; ++i) { |
| if (!resultType.isDynamicDim(i)) continue; |
| if (i < axis) { |
| dynSizes.push_back( |
| rewriter.create<tensor::DimOp>(loc, adaptor.operand(), i)); |
| } else if (i < (axis + numIndices - batch)) { |
| int idx = i - axis + batch; |
| dynSizes.push_back( |
| rewriter.create<tensor::DimOp>(loc, adaptor.index(), idx)); |
| } else { |
| int idx = i - (axis + numIndices - batch) + axis + 1; |
| dynSizes.push_back( |
| rewriter.create<tensor::DimOp>(loc, adaptor.operand(), idx)); |
| } |
| } |
| |
| // Generate dummy tensor to preserve slice shape information. |
| SmallVector<int64_t> sliceShape; |
| SmallVector<Value, 4> dynSliceSizes; |
| SmallVector<AffineExpr, 4> sliceExprs; |
| auto resultShape = resultType.getShape(); |
| for (int i = 0; i < axis; ++i) { |
| sliceExprs.push_back(rewriter.getAffineDimExpr(i)); |
| sliceShape.push_back(resultShape[i]); |
| if (!resultType.isDynamicDim(i)) continue; |
| dynSliceSizes.push_back( |
| rewriter.create<tensor::DimOp>(loc, adaptor.operand(), i)); |
| } |
| for (int i = axis + numIndices - batch; i < rank; ++i) { |
| sliceExprs.push_back(rewriter.getAffineDimExpr(i)); |
| sliceShape.push_back(resultShape[i]); |
| if (!resultType.isDynamicDim(i)) continue; |
| int idx = i - (axis + numIndices - batch) + axis + 1; |
| dynSliceSizes.push_back( |
| rewriter.create<tensor::DimOp>(loc, adaptor.operand(), idx)); |
| } |
| |
| // Setup AffineMap for operand tensor. |
| SmallVector<AffineExpr, 4> exprs; |
| for (int i = 0; i < batch; ++i) { |
| exprs.push_back(rewriter.getAffineDimExpr(i)); |
| } |
| for (int i = 0, e = numIndices - batch; i < e; ++i) { |
| exprs.push_back(rewriter.getAffineDimExpr(axis + i)); |
| } |
| |
| SmallVector<AffineMap, 2> indexingMaps; |
| indexingMaps.emplace_back( |
| AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext())); |
| indexingMaps.emplace_back(AffineMap::get( |
| rank, /*symbolCount=*/0, sliceExprs, rewriter.getContext())); |
| indexingMaps.emplace_back(rewriter.getMultiDimIdentityMap(rank)); |
| |
| Value sliceOp = rewriter.create<linalg::InitTensorOp>( |
| loc, dynSliceSizes, sliceShape, resultType.getElementType()); |
| |
| Value initOp = rewriter.create<linalg::InitTensorOp>( |
| loc, dynSizes, resultType.getShape(), resultType.getElementType()); |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, /*resultTensors=*/ArrayRef<Type>{resultType}, |
| /*inputs=*/ValueRange{adaptor.index(), sliceOp}, |
| /*outputs=*/initOp, indexingMaps, getNParallelLoopsAttrs(rank), |
| /*bodyBuild=*/nullptr, pruneAttributeList(op)); |
| |
| SmallVector<Type, 4> bodyArgTypes; |
| SmallVector<Value, 2> linalgOpArgs = {adaptor.index(), sliceOp}; |
| // Add a block to the region. |
| auto* region = &linalgOp.region(); |
| auto* block = rewriter.createBlock(region, region->end()); |
| for (auto blockArgs : linalgOpArgs) { |
| bodyArgTypes.push_back( |
| blockArgs.getType().cast<ShapedType>().getElementType()); |
| } |
| block->addArguments(bodyArgTypes, |
| SmallVector<Location>(bodyArgTypes.size(), loc)); |
| block->addArguments(resultType.getElementType(), loc); |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPointToEnd(block); |
| |
| Value castedValue = rewriter.create<arith::IndexCastOp>( |
| loc, rewriter.getIndexType(), block->getArgument(0)); |
| |
| SmallVector<Value, 4> indices; |
| for (int i = 0; i < axis; ++i) { |
| indices.push_back(rewriter.create<linalg::IndexOp>(loc, i)); |
| } |
| indices.push_back(castedValue); |
| for (int i = axis + numIndices - batch; i < rank; ++i) { |
| indices.push_back(rewriter.create<linalg::IndexOp>(loc, i)); |
| } |
| Value res = |
| rewriter.create<tensor::ExtractOp>(loc, adaptor.operand(), indices); |
| rewriter.create<linalg::YieldOp>(loc, res); |
| |
| rewriter.replaceOp(op, linalgOp.getResults()); |
| return success(); |
| } |
| }; |
| |
| /// This lowering encompasses the full range of the Gather operation and |
| /// therefore is very general and just loops over the output and calculate the |
| /// corresponding input index. It follows the explanation at |
| /// https://www.tensorflow.org/xla/operation_semantics#gather. The compiler |
| /// should be able to optimize that a bit, but in order to get efficient |
| /// lowerings, special-cases of gather should be extracted in separate |
| /// lowerings, and ideally encapsulated as separate ops or canonicalization |
| /// patterns. |
| struct GatherConversion : public OpConversionPattern<mhlo::GatherOp> { |
| using OpConversionPattern<mhlo::GatherOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::GatherOp gatherOp, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| Location loc = gatherOp.getLoc(); |
| |
| Value startIndices = adaptor.start_indices(); |
| Value operand = adaptor.operand(); |
| |
| auto resultType = typeConverter->convertType(gatherOp.getType()) |
| .dyn_cast<RankedTensorType>(); |
| RankedTensorType startIndicesType = |
| startIndices.getType().dyn_cast<RankedTensorType>(); |
| // We could actually deal with an unranked result by inferring the result |
| // rank, but the current reifyReturnTypes doesn't support unranked either. |
| if (!resultType || !startIndicesType) |
| return rewriter.notifyMatchFailure(gatherOp, |
| "unranked start indices or result"); |
| |
| int resultRank = resultType.getRank(); |
| // slice_sizes has to have the same size as operand.rank, and doing it this |
| // way permits an unranked operand. |
| int operandRank = gatherOp.slice_sizes().getNumElements(); |
| |
| int64_t indexVectorDim = gatherOp.dimension_numbers().getIndexVectorDim(); |
| |
| ArrayRef<int64_t> offsetDims = gatherOp.dimension_numbers().getOffsetDims(); |
| ArrayRef<int64_t> collapsedSliceDims = |
| gatherOp.dimension_numbers().getCollapsedSliceDims(); |
| ArrayRef<int64_t> startIndexMap = |
| gatherOp.dimension_numbers().getStartIndexMap(); |
| |
| auto extractAsIndex = [&](Value input, ArrayRef<Value> index) -> Value { |
| return rewriter.create<arith::IndexCastOp>( |
| loc, rewriter.getIndexType(), |
| rewriter.create<tensor::ExtractOp>(loc, input, index)); |
| }; |
| |
| // We'll need these later and creating them on demand we end up with |
| // duplicates, which also makes lit tests really hard to write. |
| SmallVector<Value> constants; |
| for (unsigned i = 0; i < std::max({resultRank, operandRank, 2}); ++i) { |
| constants.push_back( |
| rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i))); |
| } |
| |
| // Create ops to calculate the dynamic dimensions of the return shape, which |
| // are needed for the init tensor. |
| SmallVector<Value> dynDimSizes; |
| if (!resultType.hasStaticShape()) { |
| SmallVector<Value> returnShapes; |
| if (failed(gatherOp.reifyReturnTypeShapes(rewriter, adaptor.getOperands(), |
| returnShapes))) |
| return rewriter.notifyMatchFailure(gatherOp, |
| "could not reify return shape"); |
| assert(returnShapes.size() == 1); |
| Value returnShape = returnShapes[0]; |
| |
| for (int i = 0; i < resultRank; ++i) |
| if (resultType.isDynamicDim(i)) |
| dynDimSizes.push_back(extractAsIndex(returnShape, constants[i])); |
| } |
| |
| Value initOp = rewriter.create<linalg::InitTensorOp>( |
| loc, dynDimSizes, resultType.getShape(), resultType.getElementType()); |
| |
| ValueRange ins; |
| SmallVector<AffineMap, 1> indexingMaps( |
| {rewriter.getMultiDimIdentityMap(resultRank)}); |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, /*resultTensorTypes=*/resultType, |
| /*inputs=*/ins, |
| /*outputs=*/initOp, indexingMaps, getNParallelLoopsAttrs(resultRank), |
| /*bodyBuild=*/nullptr, pruneAttributeList(gatherOp)); |
| |
| // Now populate the linalg generic region |
| auto* region = &linalgOp.region(); |
| auto* block = rewriter.createBlock(region, region->end()); |
| block->addArguments(resultType.getElementType(), loc); |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPointToEnd(block); |
| |
| // Dimensions in the result that aren't offset dimensions are called batch. |
| SmallVector<int64_t> batchDims; |
| for (int dim = 0; dim < resultRank; ++dim) |
| if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim); |
| |
| // Same as with the constants. Creating these all up front is easier than |
| // potentially getting duplicates later. |
| SmallVector<Value> linalgIndices; |
| for (unsigned i = 0; i < resultRank; ++i) |
| linalgIndices.push_back(rewriter.create<linalg::IndexOp>(loc, i)); |
| |
| // Now the complicated part. For a given output dimension we build up an |
| // index into the input. It's composed of two parts: the index coming from |
| // start_indices, and the offset from that index along the offset |
| // dimensions. Everything includes dimension shuffling and remapping as well |
| // because of the way gather is defined to allow for any-layout input by |
| // adding more attributes. |
| |
| // The base gather index (`G` in the documentation) points to a place in |
| // start_indices along the batch dimensions. |
| SmallVector<Value> gatherIndex; |
| for (auto dim : batchDims) gatherIndex.push_back(linalgIndices[dim]); |
| |
| SmallVector<Value> indexFromStartIndices; |
| for (unsigned i = 0; i < startIndexMap.size(); ++i) { |
| // The index along the index_vector dimension of start_indices varies. |
| // Basically indexFromStartIndices indexes into a "row" along |
| // index_vector_dim, where the row is selected by the current output |
| // index. |
| // But if index_vector_dim is equal to start_indices.rank, then |
| // start_indices gets a trailing 1 dimension added. So the row we're |
| // extracting always has length 1 and the index into it is always 0, so we |
| // just use the gather index directly |
| SmallVector<Value> gCombine(gatherIndex); |
| if (indexVectorDim != startIndicesType.getRank()) { |
| assert(indexVectorDim <= gCombine.size()); |
| gCombine.insert(gCombine.begin() + indexVectorDim, constants[i]); |
| } |
| |
| indexFromStartIndices.push_back(extractAsIndex(startIndices, gCombine)); |
| } |
| |
| // But then start indices are shuffled by the start index map. To make a |
| // full index into the operand, all missing indices are zeroes. |
| SmallVector<Value> remappedIndexFromIndices(operandRank, constants[0]); |
| for (auto& it : llvm::enumerate(startIndexMap)) |
| remappedIndexFromIndices[it.value()] = indexFromStartIndices[it.index()]; |
| |
| // Now we construct the index based on the offset. First we need to remap |
| // the offset dimensions by dropping the collapsed indices. |
| SmallVector<unsigned> remappedOffsetDims; |
| for (unsigned i = 0; i < operandRank; ++i) |
| if (!llvm::is_contained(collapsedSliceDims, i)) |
| remappedOffsetDims.push_back(i); |
| |
| assert(remappedOffsetDims.size() == offsetDims.size()); |
| |
| // Clamp out of bounds indices. |
| for (unsigned i = 0, operandIndexDim = 0; i < operandRank; ++i) { |
| // Compute the size of the output shape dimension corresponding to this |
| // index dimension. If it's collapsed set it to 1. |
| Value outputDimSize = constants[1]; |
| if (!llvm::is_contained(collapsedSliceDims, i)) { |
| outputDimSize = rewriter.createOrFold<tensor::DimOp>( |
| loc, initOp, offsetDims[operandIndexDim++]); |
| } |
| |
| // If this is a skipped dimension, we're done and don't have to clamp. |
| if (remappedIndexFromIndices[i] == constants[0]) continue; |
| |
| Value operandDimSize = |
| rewriter.createOrFold<tensor::DimOp>(loc, operand, i); |
| Value largestValidIndex = rewriter.createOrFold<arith::SubIOp>( |
| loc, operandDimSize, outputDimSize); |
| |
| // Clamp indices to [0, i, operand_dim-output_dim]. |
| Value clamp = rewriter.create<arith::MinSIOp>( |
| loc, |
| rewriter.create<arith::MaxSIOp>(loc, constants[0], |
| remappedIndexFromIndices[i]), |
| largestValidIndex); |
| remappedIndexFromIndices[i] = clamp; |
| } |
| |
| // For the (remapped) offset dimensions, the index is the current index in |
| // the output. As before this is expanded to a full index into the operand |
| // by using zeroe for the missing indices. |
| SmallVector<Value> indexFromOffset(operandRank, constants[0]); |
| for (unsigned k = 0; k < offsetDims.size(); ++k) |
| indexFromOffset[remappedOffsetDims[k]] = linalgIndices[offsetDims[k]]; |
| |
| // Now we add together our two indices to get the final index into the |
| // operand. |
| SmallVector<Value> combinedIndex; |
| for (unsigned i = 0; i < operandRank; ++i) |
| combinedIndex.push_back(rewriter.createOrFold<arith::AddIOp>( |
| loc, rewriter.getIndexType(), remappedIndexFromIndices[i], |
| indexFromOffset[i])); |
| |
| Value element = |
| rewriter.create<tensor::ExtractOp>(loc, operand, combinedIndex); |
| rewriter.create<linalg::YieldOp>(loc, element); |
| |
| rewriter.replaceOp(gatherOp, linalgOp.getResults()); |
| |
| return success(); |
| } |
| }; |
| |
| struct ScatterUpdateConversion : public OpConversionPattern<mhlo::ScatterOp> { |
| using OpConversionPattern<mhlo::ScatterOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mhlo::ScatterOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| // Variadic Scatter support not yet implemented |
| if (op.operands().size() != 1 || op.updates().size() != 1) return failure(); |
| |
| // Check if it is a tensor_scatter_nd_update-like op. |
| auto& bodyOps = op.getRegion().front().getOperations(); |
| if (bodyOps.size() != 1) return failure(); |
| auto retArg = bodyOps.front().getOperand(0).dyn_cast<BlockArgument>(); |
| if (!retArg || retArg.getArgNumber() != 1) return failure(); |
| |
| auto operandTy = |
| adaptor.operands()[0].getType().dyn_cast<RankedTensorType>(); |
| auto indicesTy = |
| adaptor.scatter_indices().getType().dyn_cast<RankedTensorType>(); |
| if (!operandTy || !indicesTy) return failure(); |
| |
| // Linalg operations put all the computation to the innermost loop. Since we |
| // also iterate over scatter_indices() with some loops, we can only check |
| // one scatter index in one iteration. If there are multiple indices (ie, |
| // the index depth is greater than 1), we don't have a way to keep the |
| // comparison state. E.g., if the index_depth is 2, like indices = [[0, 1]], |
| // we should use the update value only if (i == 0 and j == 1). However, we |
| // can not get both indices in one iteration unless we pack them together. |
| auto indexVectorDim = op.scatter_dimension_numbers().getIndexVectorDim(); |
| if (indicesTy.getDimSize(indexVectorDim) != 1) |
| return rewriter.notifyMatchFailure(op, "require index depth to be 1"); |
| if (indexVectorDim != indicesTy.getRank() - 1) { |
| return rewriter.notifyMatchFailure( |
| op, "require index_vector_dim to be the last dim"); |
| } |
| |
| // One of indices dims is index depth vector. |
| int64_t nloops = operandTy.getRank() + indicesTy.getRank() - 1; |
| SmallVector<AffineMap, 3> indexingMaps; |
| { |
| SmallVector<AffineExpr> exprs; |
| for (int64_t i = 0, e = operandTy.getRank(); i < e; ++i) |
| exprs.push_back(rewriter.getAffineDimExpr(i)); |
| indexingMaps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, |
| rewriter.getContext())); |
| } |
| { |
| SmallVector<AffineExpr> exprs; |
| for (int64_t i = operandTy.getRank(); i < nloops; ++i) |
| exprs.push_back(rewriter.getAffineDimExpr(i)); |
| // The index depth is 1. |
| exprs.push_back(rewriter.getAffineConstantExpr(0)); |
| indexingMaps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, |
| rewriter.getContext())); |
| |
| exprs.pop_back(); |
| auto updateWindowDims = |
| op.scatter_dimension_numbers().getUpdateWindowDims(); |
| for (auto d : updateWindowDims) |
| exprs.push_back(rewriter.getAffineDimExpr(d)); |
| indexingMaps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, |
| rewriter.getContext())); |
| } |
| indexingMaps.push_back(indexingMaps.front()); |
| |
| auto resultTy = |
| this->typeConverter->convertType(op.getResults()[0].getType()) |
| .cast<ShapedType>(); |
| auto scatterDimsToOperandDims = |
| op.scatter_dimension_numbers().getScatterDimsToOperandDims(); |
| assert(scatterDimsToOperandDims.size() == 1); |
| // Do not need init_tensor because we'd like to initialize the output as |
| // operand. |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| op.getLoc(), /*resultTensors=*/ArrayRef<Type>{resultTy}, |
| /*inputs=*/ |
| ValueRange{adaptor.operands()[0], adaptor.scatter_indices(), |
| adaptor.updates()[0]}, |
| /*outputs=*/adaptor.operands()[0], indexingMaps, |
| getNParallelLoopsAttrs(nloops), |
| [scatterDimsToOperandDims](OpBuilder& b, Location loc, |
| ValueRange args) { |
| Value cmpIdx = |
| b.create<linalg::IndexOp>(loc, scatterDimsToOperandDims[0]); |
| Value idx = |
| b.create<arith::IndexCastOp>(loc, b.getIndexType(), args[1]); |
| Value pred = b.create<arith::CmpIOp>( |
| loc, b.getI1Type(), arith::CmpIPredicate::eq, cmpIdx, idx); |
| // Use the output arg, so some update values won't be init value |
| // again. |
| Value res = b.create<arith::SelectOp>(loc, args[2].getType(), pred, |
| args[2], args[3]); |
| b.create<linalg::YieldOp>(loc, res); |
| }, |
| pruneAttributeList(op)); |
| rewriter.replaceOp(op, linalgOp.getResults()); |
| return success(); |
| } |
| }; |
| |
| class DotGeneralOpConversion : public OpConversionPattern<mhlo::DotGeneralOp> { |
| public: |
| using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| mhlo::DotGeneralOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter& rewriter) const final { |
| if (!verifyHloOpBufferOrTensorSemantics(op)) { |
| return failure(); |
| } |
| |
| // Get various dimension iterator information |
| mhlo::DotDimensionNumbersAttr dimNumbers = op.dot_dimension_numbers(); |
| auto lhsBatchingDims = dimNumbers.getLhsBatchingDimensions(); |
| auto rhsBatchingDims = dimNumbers.getRhsBatchingDimensions(); |
| auto lhsContractingDims = dimNumbers.getLhsContractingDimensions(); |
| auto rhsContractingDims = dimNumbers.getRhsContractingDimensions(); |
| |
| // Get shape information and initialize output |
| assert(lhsContractingDims.size() == rhsContractingDims.size() && |
| "number of contracting dims must be equal"); |
| auto numContracting = lhsContractingDims.size(); |
| // Convert unsigned to signed. This works because signed and unsigned |
| // integer matmul is the same operation in two's complement. |
| auto outputType = |
| typeConverter->convertType(op.getType()).cast<ShapedType>(); |
| auto targetRank = outputType.getRank(); |
| auto totalLoopCount = numContracting + targetRank; |
| |
| auto lhsRank = adaptor.lhs().getType().cast<ShapedType>().getRank(); |
| auto lhsExtraDims = |
| lhsRank - lhsBatchingDims.size() - lhsContractingDims.size(); |
| auto rhsRank = adaptor.rhs().getType().cast<ShapedType>().getRank(); |
| |
| Location loc = op.getLoc(); |
| auto initTensor = |
| getInitTensorFor(rewriter, loc, outputType, op, adaptor.getOperands()); |
| Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor); |
| SmallVector<AffineMap, 3> indexingMaps; |
| |
| auto getMap = [&](int64_t rank, ArrayRef<int64_t> batchingDims, |
| ArrayRef<int64_t> contractingDims, size_t extraDims) { |
| llvm::SmallVector<AffineExpr> indices(rank); |
| for (const auto& i : llvm::enumerate(batchingDims)) { |
| indices[i.value()] = rewriter.getAffineDimExpr(i.index()); |
| } |
| for (const auto& i : llvm::enumerate(contractingDims)) { |
| indices[i.value()] = rewriter.getAffineDimExpr(i.index() + targetRank); |
| } |
| for (int i = 0; i < rank; ++i) { |
| if (!indices[i]) { |
| indices[i] = rewriter.getAffineDimExpr(extraDims++); |
| } |
| } |
| indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount, |
| /*symbolCount=*/0, indices, |
| op->getContext())); |
| }; |
| getMap(lhsRank, lhsBatchingDims, lhsContractingDims, |
| lhsBatchingDims.size()); |
| getMap(rhsRank, rhsBatchingDims, rhsContractingDims, |
| rhsBatchingDims.size() + lhsExtraDims); |
| |
| { |
| SmallVector<AffineExpr, 4> dimExprs; |
| dimExprs.reserve(targetRank); |
| for (unsigned i = 0; i < targetRank; ++i) |
| dimExprs.push_back(rewriter.getAffineDimExpr(i)); |
| indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount, |
| /*symbolCount=*/0, dimExprs, |
| op.getContext())); |
| } |
| |
| Operation* linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, /*resultTensorTypes=*/TypeRange{outputType}, |
| /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()}, |
| /*outputBuffers=*/ValueRange{zeroTensor}, indexingMaps, |
| getParallelAndReductionIterators( |
| /*nLoops=*/totalLoopCount, |
| /*nReduction=*/numContracting), |
| [](OpBuilder& b, Location loc, ValueRange) { |
| ImplicitLocOpBuilder builder(loc, b); |
| linalg::MatmulOp::regionBuilder(builder, *b.getInsertionBlock(), {}); |
| }, |
| pruneAttributeList(op)); |
| |
| rewriter.replaceOp(op, linalgOp->getResults()); |
| return success(); |
| } |
| }; |
| |
| struct HloLegalizeToLinalgPass |
| : public mhlo::HloLegalizeToLinalgPassBase<HloLegalizeToLinalgPass> { |
| void getDependentDialects(DialectRegistry& registry) const override { |
| registry.insert<bufferization::BufferizationDialect, linalg::LinalgDialect, |
| scf::SCFDialect, complex::ComplexDialect, math::MathDialect, |
| memref::MemRefDialect, shape::ShapeDialect>(); |
| } |
| |
| void runOnOperation() override { |
| MLIRContext& ctx = getContext(); |
| RewritePatternSet patterns(&ctx); |
| ConversionTarget target(ctx); |
| target.addLegalDialect< |
| bufferization::BufferizationDialect, arith::ArithmeticDialect, |
| complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect, |
| tensor::TensorDialect, sparse_tensor::SparseTensorDialect, |
| scf::SCFDialect, shape::ShapeDialect>(); |
| |
| target.addLegalOp<UnrealizedConversionCastOp>(); |
| |
| mhlo::RemoveSignTypeConverter typeConverter; |
| auto func = getOperation(); |
| mhlo::populateHloToLinalgConversionPattern(&ctx, typeConverter, &patterns); |
| if (failed(applyPartialConversion(func, target, std::move(patterns)))) { |
| signalPassFailure(); |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| namespace mhlo { |
| |
| void populateHloToLinalgConversionPattern(MLIRContext* context, |
| TypeConverter& typeConverter, |
| RewritePatternSet* patterns) { |
| // clang-format off |
| patterns->add< |
| BroadcastConverter<mhlo::BroadcastOp>, ConcatenateConverter, |
| ConstConverterTensor, HloDynamicBroadcastInDimConverter, |
| HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp>, |
| EinsumToLinalgConverter, |
| IotaConverter<mhlo::DynamicIotaOp>, |
| MapOpConverter, |
| PointwiseToLinalgConverter<mhlo::AbsOp>, |
| PointwiseToLinalgConverter<mhlo::AddOp>, |
| PointwiseToLinalgConverter<mhlo::AndOp>, |
| PointwiseToLinalgConverter<mhlo::Atan2Op>, |
| PointwiseToLinalgConverter<mhlo::BitcastConvertOp>, |
| PointwiseToLinalgConverter<mhlo::CbrtOp>, |
| PointwiseToLinalgConverter<mhlo::CeilOp>, |
| PointwiseToLinalgConverter<mhlo::ClampOp>, |
| PointwiseToLinalgConverter<mhlo::ClzOp>, |
| PointwiseToLinalgConverter<mhlo::CompareOp>, |
| PointwiseToLinalgConverter<mhlo::ComplexOp>, |
| PointwiseToLinalgConverter<mhlo::ConvertOp>, |
| PointwiseToLinalgConverter<mhlo::CopyOp>, |
| PointwiseToLinalgConverter<mhlo::CosOp>, |
| PointwiseToLinalgConverter<mhlo::DivOp>, |
| PointwiseToLinalgConverter<mhlo::ExpOp>, |
| PointwiseToLinalgConverter<mhlo::Expm1Op>, |
| PointwiseToLinalgConverter<mhlo::FloorOp>, |
| PointwiseToLinalgConverter<mhlo::ImagOp>, |
| PointwiseToLinalgConverter<mhlo::IsFiniteOp>, |
| PointwiseToLinalgConverter<mhlo::LogOp>, |
| PointwiseToLinalgConverter<mhlo::LogisticOp>, |
| PointwiseToLinalgConverter<mhlo::Log1pOp>, |
| PointwiseToLinalgConverter<mhlo::MaxOp>, |
| PointwiseToLinalgConverter<mhlo::MinOp>, |
| PointwiseToLinalgConverter<mhlo::MulOp>, |
| PointwiseToLinalgConverter<mhlo::NegOp>, |
| PointwiseToLinalgConverter<mhlo::NotOp>, |
| PointwiseToLinalgConverter<mhlo::OrOp>, |
| PointwiseToLinalgConverter<mhlo::PopulationCountOp>, |
| PointwiseToLinalgConverter<mhlo::PowOp>, |
| PointwiseToLinalgConverter<mhlo::RealOp>, |
| PointwiseToLinalgConverter<mhlo::RemOp>, |
| PointwiseToLinalgConverter<mhlo::RoundOp>, |
| PointwiseToLinalgConverter<mhlo::RsqrtOp>, |
| PointwiseToLinalgConverter<mhlo::SelectOp>, |
| PointwiseToLinalgConverter<mhlo::ShiftLeftOp>, |
| PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp>, |
| PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp>, |
| PointwiseToLinalgConverter<mhlo::SignOp>, |
| PointwiseToLinalgConverter<mhlo::SinOp>, |
| PointwiseToLinalgConverter<mhlo::SqrtOp>, |
| PointwiseToLinalgConverter<mhlo::SubOp>, |
| PointwiseToLinalgConverter<mhlo::TanhOp>, |
| PointwiseToLinalgConverter<mhlo::XorOp>, |
| PointwiseToLinalgConverter<mhlo::ReducePrecisionOp>, |
| RealDynamicSliceConverter, |
| ReshapeOpConverter, |
| ReverseConverter, |
| SliceConverter, |
| DynamicSliceConverter, |
| DynamicUpdateSliceConverter, |
| TransposeConverter<mhlo::TransposeOp>, |
| NormalConvolutionOpConversion, |
| DepthwiseConvolutionOpConversion, |
| GatherConversion, |
| PadOpConversion, |
| PadOpNegativePaddingConversion, |
| ReduceConversion, |
| ReduceWindowOpOnTensorsGenericConversion, |
| ReduceWindowOpConversion, |
| RngUniformConversion, |
| ScatterUpdateConversion, |
| TorchIndexSelectOpConversion>(typeConverter, context); |
| patterns->add< |
| DotOpConversion<DotOperationType::kMatrixMatrix, linalg::MatmulOp>, |
| DotOpConversion<DotOperationType::kMatrixVector, linalg::MatvecOp>, |
| DotOpConversion<DotOperationType::kVectorMatrix, linalg::VecmatOp>, |
| DotOpConversion<DotOperationType::kVectorDot, linalg::DotOp>, |
| DotGeneralBatchMatMulOpConversion>(typeConverter, context, |
| PatternBenefit(2)); |
| // clang-format on |
| patterns->add<DotGeneralOpConversion>(typeConverter, context, |
| PatternBenefit(1)); |
| patterns->add<ReduceRegionXLAOpConversion<mhlo::AddOp>, |
| ReduceRegionXLAOpConversion<mhlo::AndOp>, |
| ReduceRegionXLAOpConversion<mhlo::CompareOp>, |
| ReduceRegionXLAOpConversion<mhlo::ConvertOp>, |
| ReduceRegionXLAOpConversion<mhlo::ImagOp>, |
| ReduceRegionXLAOpConversion<mhlo::MaxOp>, |
| ReduceRegionXLAOpConversion<mhlo::MinOp>, |
| ReduceRegionXLAOpConversion<mhlo::MulOp>, |
| ReduceRegionXLAOpConversion<mhlo::OrOp>, |
| ReduceRegionXLAOpConversion<mhlo::RealOp>, |
| ReduceRegionXLAOpConversion<mhlo::SelectOp>, |
| ReduceRegionXLAOpConversion<mhlo::XorOp>, |
| ReduceRegionReturnOpConversion>(context, PatternBenefit(1000)); |
| } |
| |
| std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeHloToLinalgPass() { |
| return std::make_unique<HloLegalizeToLinalgPass>(); |
| } |
| |
| std::unique_ptr<TypeConverter> createHloToLinalgSignedIntegerConverter() { |
| return std::make_unique<RemoveSignTypeConverter>(); |
| } |
| |
| } // namespace mhlo |
| } // namespace mlir |