| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This file implements logic for lowering LHLO dialect to Affine dialect. |
| |
| #include <utility> |
| |
| #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" |
| #include "mlir-hlo/Dialect/lhlo/transforms/PassDetail.h" |
| #include "mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h" |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Location.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir { |
| namespace lmhlo { |
| namespace { |
| |
| // Builds an affine loop nest iterating from zeros to "upper_bounds" with unit |
| // steps, and populates the body of the innermost loop using "body_builder". |
| static void buildBoundedAffineLoopNest( |
| OpBuilder& builder, Location location, ArrayRef<int64_t> upperBounds, |
| function_ref<void(OpBuilder&, Location, ValueRange)> bodyBuilder) { |
| SmallVector<int64_t, 3> lowerBounds(upperBounds.size(), /*Value=*/0); |
| SmallVector<int64_t, 3> steps(upperBounds.size(), /*Value=*/1); |
| buildAffineLoopNest(builder, location, lowerBounds, upperBounds, steps, |
| bodyBuilder); |
| } |
| |
| struct DotOpConverter : public OpRewritePattern<DotOp> { |
| using OpRewritePattern<DotOp>::OpRewritePattern; |
| |
| // Supports only rank-2 tensors for LHS and RHS. |
| LogicalResult matchAndRewrite(DotOp op, |
| PatternRewriter& rewriter) const override { |
| Value lhs = op.lhs(); |
| Value rhs = op.rhs(); |
| MemRefType lhsType = lhs.getType().cast<MemRefType>(); |
| MemRefType rhsType = rhs.getType().cast<MemRefType>(); |
| Type elementType = lhsType.getElementType(); |
| ArrayRef<int64_t> shapeLhs = lhsType.getShape(); |
| ArrayRef<int64_t> shapeRhs = rhsType.getShape(); |
| |
| if ((lhsType.getRank() != 2) || (rhsType.getRank() != 2)) { |
| return failure(); |
| } |
| |
| // We don't currently support batching dimensions, or multiple contraction |
| // dimensions. |
| mhlo::DotDimensionNumbersAttr dotDimensionNumbers = |
| op.dot_dimension_numbers(); |
| if (!dotDimensionNumbers.getLhsBatchingDimensions().empty() || |
| !dotDimensionNumbers.getRhsBatchingDimensions().empty()) |
| return failure(); |
| if (dotDimensionNumbers.getLhsContractingDimensions().size() != 1 || |
| *dotDimensionNumbers.getLhsContractingDimensions().begin() != 1 || |
| dotDimensionNumbers.getRhsContractingDimensions().size() != 1 || |
| *dotDimensionNumbers.getRhsContractingDimensions().begin() != 0) { |
| return failure(); |
| } |
| |
| LogicalResult mapStatus = success(); |
| auto bodyBuilder = [&](OpBuilder& builder, Location loc, ValueRange ivs) { |
| SmallVector<Value, 2> lhsIndices{ivs[0], ivs[2]}, |
| rhsIndices{ivs[2], ivs[1]}, resultIndices{ivs[0], ivs[1]}; |
| |
| auto l = builder.create<AffineLoadOp>(loc, lhs, lhsIndices); |
| auto r = builder.create<AffineLoadOp>(loc, rhs, rhsIndices); |
| auto result = |
| rewriter.create<AffineLoadOp>(loc, op.output(), resultIndices); |
| Value opResult = lmhlo::LhloOpToStdScalarOp::map<DotOp>( |
| op, elementType, {l, r, result}, &builder); |
| mapStatus = success(opResult != nullptr); |
| if (failed(mapStatus)) return; |
| builder.create<AffineStoreOp>(loc, opResult, op.output(), resultIndices); |
| }; |
| |
| buildBoundedAffineLoopNest(rewriter, op.getLoc(), |
| {shapeLhs[0], shapeRhs[1], shapeRhs[0]}, |
| bodyBuilder); |
| if (failed(mapStatus)) return failure(); |
| |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| /// Concat Operation Example (2D): |
| /// Given inpA[2][1], inpB[2][2], concat_dimension = 1. |
| /// Compute output[x1][x2]. |
| /// Implementation Pseudocode: |
| /// s = 0 |
| /// for a in range(0, 2): |
| /// for b in range(0, 1): |
| /// output[a][b] = inpA[a][b - s] |
| /// s = 1 |
| /// for a in range(0, 2): |
| /// for b in range(1, 3): |
| /// output[a][b] = inpB[a][b - s] |
| /// |
| /// Concatenate composes an array from multiple array operands. The array is of |
| /// the same rank as each of the input array operands (which must be of the same |
| /// rank as each other) and contains the arguments in the order that they were |
| /// specified. |
| struct ConcatOpConverter : public OpRewritePattern<ConcatenateOp> { |
| using OpRewritePattern<ConcatenateOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ConcatenateOp op, |
| PatternRewriter& rewriter) const override { |
| Location loc = op.getLoc(); |
| Value output = op.output(); |
| MemRefType outputType = output.getType().cast<MemRefType>(); |
| unsigned outputRank = outputType.getRank(); |
| ArrayRef<int64_t> outputShape = outputType.getShape(); |
| |
| ValueRange operands = op.val(); |
| uint64_t concatDim = op.dimension(); |
| int64_t prevBound = 0; |
| |
| for (Value operand : operands) { |
| MemRefType operandType = operand.getType().cast<MemRefType>(); |
| ArrayRef<int64_t> operandShape = operandType.getShape(); |
| |
| // TODO(pashu123): Extend support for dynamic dimensions. |
| if (!operandType.hasStaticShape()) return failure(); |
| |
| // Only for the concatenation dimension, the value is dimension - |
| // prevBound. |
| SmallVector<AffineExpr, 4> expr; |
| for (unsigned i = 0; i < outputRank; i++) { |
| AffineExpr d0 = (i == concatDim) |
| ? rewriter.getAffineDimExpr(concatDim) - prevBound |
| : rewriter.getAffineDimExpr(i); |
| |
| expr.push_back(d0); |
| } |
| AffineMap map = |
| AffineMap::get(outputRank, 0, expr, rewriter.getContext()); |
| |
| // Create multiple for loop nests iterating along the concatenation |
| // dimension. |
| OpBuilder::InsertionGuard guard(rewriter); |
| SmallVector<Value, 3> indices; |
| AffineForOp forOp; |
| for (unsigned i = 0; i < outputRank; i++) { |
| if (i == concatDim) { |
| forOp = rewriter.create<AffineForOp>(loc, prevBound, |
| prevBound + operandShape[i]); |
| prevBound += operandShape[i]; |
| indices.push_back(forOp.getInductionVar()); |
| } else { |
| forOp = rewriter.create<AffineForOp>(loc, 0, outputShape[i]); |
| indices.push_back(forOp.getInductionVar()); |
| } |
| rewriter.setInsertionPointToStart(forOp.getBody()); |
| } |
| Value storeVal = |
| rewriter.create<AffineLoadOp>(loc, operand, map, indices); |
| rewriter.create<AffineStoreOp>(loc, storeVal, output, indices); |
| } |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| /// Returns a zero value of type `type`. `type` is expected to be either |
| /// int or float. |
| static Value getZeroValue(Type type, Location loc, PatternRewriter& rewriter) { |
| assert(type.isIntOrFloat() && "Expected int or float"); |
| |
| if (IntegerType intType = type.dyn_cast<IntegerType>()) |
| return rewriter.create<mlir::arith::ConstantIntOp>(loc, 0, |
| intType.getWidth()); |
| |
| FloatType floatType = type.cast<FloatType>(); |
| return rewriter.create<mlir::arith::ConstantFloatOp>( |
| loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); |
| } |
| |
| /// Emits a nest to fill the given `buffer` of memref type with `fillValue`. |
| static void fillBuffer(Location loc, Value buffer, Value fillValue, |
| PatternRewriter& builder) { |
| OpBuilder::InsertionGuard guard(builder); |
| MemRefType bufType = buffer.getType().cast<MemRefType>(); |
| unsigned rank = bufType.getRank(); |
| SmallVector<Value, 4> dimSizes; |
| dimSizes.reserve(rank); |
| for (unsigned i = 0; i < rank; ++i) |
| dimSizes.push_back(builder.create<mlir::memref::DimOp>(loc, buffer, i)); |
| |
| AffineMap idSymMap = builder.getSymbolIdentityMap(); |
| AffineMap lbMap = builder.getConstantAffineMap(0); |
| SmallVector<Value, 4> ivs(rank); |
| AffineForOp forOp; |
| for (unsigned i = 0; i < rank; ++i) { |
| forOp = builder.create<AffineForOp>(loc, llvm::None, lbMap, dimSizes[i], |
| idSymMap); |
| builder.setInsertionPointToStart(forOp.getBody()); |
| ivs[i] = forOp.getInductionVar(); |
| } |
| Type fillValueType = fillValue.getType(); |
| auto fillMemRefType = fillValueType.dyn_cast<MemRefType>(); |
| assert(((fillMemRefType && fillMemRefType.getRank() == 0) || |
| fillValueType.isIntOrFloat()) && |
| "init value has to be a 0-d memref or int or fp"); |
| Value initVal = fillMemRefType ? builder.create<AffineLoadOp>( |
| loc, fillValue, /*indices=*/llvm::None) |
| : fillValue; |
| builder.create<AffineStoreOp>(loc, initVal, buffer, ivs); |
| } |
| |
| /// Converts GatherOp to Affine nest form. |
| /// Pseudocode: |
| /// 1. Fill a temporary output tensor with 0. |
| /// 2. Repeat the following for each batch dimension :- |
| /// 1. For each indices in 'operand' :- |
| /// 1. Get hold of start indices from 'start_indices'. |
| /// 2. Add offset to the start indices to get the final indices. |
| /// 3. Load value from 'operand' tensor : 'operand_val'. |
| /// 4. Load value from temporary output : 'prev_val'. |
| /// 5. If the final indices match current indices of 'operand' : |
| /// 'prev_val' = 'prev_val' + 'operand_val' |
| /// 6. Store 'prev_val' back to the temporary output. |
| class GatherOpConverter : public OpRewritePattern<GatherOp> { |
| public: |
| using OpRewritePattern<GatherOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(GatherOp op, |
| PatternRewriter& rewriter) const final { |
| Location loc = op.getLoc(); |
| |
| // Operand array. |
| Value operand = op.operand(); |
| MemRefType operandType = operand.getType().cast<MemRefType>(); |
| unsigned operandRank = operandType.getRank(); |
| ArrayRef<int64_t> operandShape = operandType.getShape(); |
| |
| // Start_indices array. |
| Value startIndices = op.start_indices(); |
| MemRefType startIndicesType = startIndices.getType().cast<MemRefType>(); |
| unsigned startIndicesRank = startIndicesType.getRank(); |
| ArrayRef<int64_t> startIndicesShape = startIndicesType.getShape(); |
| |
| // Output array. |
| Value output = op.output(); |
| MemRefType outputType = output.getType().cast<MemRefType>(); |
| ArrayRef<int64_t> outputShape = outputType.getShape(); |
| |
| if (!operandType.hasStaticShape() || !startIndicesType.hasStaticShape() || |
| !outputType.hasStaticShape()) |
| return rewriter.notifyMatchFailure(op, "only static shaped type allowed"); |
| |
| mhlo::GatherDimensionNumbersAttr gatherDim = op.dimension_numbers(); |
| |
| auto collapsedSliceDims = gatherDim.getCollapsedSliceDims(); |
| auto offsetDims = gatherDim.getOffsetDims(); |
| auto startIndexMap = gatherDim.getStartIndexMap(); |
| int64_t indexVectorDim = gatherDim.getIndexVectorDim(); |
| |
| // Slice_sizes. |
| DenseIntElementsAttr sliceSizesAttr = op.slice_sizesAttr(); |
| SmallVector<int64_t, 4> sliceSizes; |
| for (const APInt& dim : sliceSizesAttr.getValues<APInt>()) |
| sliceSizes.push_back(dim.getSExtValue()); |
| |
| // Creating constants with 0 value. We need the Integer type constant value |
| // because the indices type will be Integer. |
| Value zeroIntVal = rewriter.create<mlir::arith::ConstantIntOp>( |
| loc, 0, startIndicesType.getElementType()); |
| Type elementType = outputType.getElementType(); |
| Value zeroLoadValue = getZeroValue(elementType, loc, rewriter); |
| // Initializing the output buffer with 0. |
| fillBuffer(loc, output, zeroLoadValue, rewriter); |
| |
| // We fetch the shape of start_indices at index_vector_dim. In case |
| // index_vector_dim is equal to the rank of start_indices, we implicitly |
| // consider start_indices to have a trailing 1 dimension. |
| unsigned startIndicesNumbers = (indexVectorDim == startIndicesRank) |
| ? 1 |
| : startIndicesShape[indexVectorDim]; |
| // We create integer constants till start_incides_index which help us to |
| // fetch start_indices in affine transformation. |
| SmallVector<Value, 4> startIndicesIndex; |
| for (unsigned i = 0; i < startIndicesNumbers; i++) { |
| Value iVal = rewriter.create<mlir::arith::ConstantIntOp>( |
| loc, i, startIndicesType.getElementType()); |
| iVal = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), |
| iVal); |
| startIndicesIndex.push_back(iVal); |
| } |
| |
| // S_in contains the multiple indices that form a starting index used in the |
| // input/operand tensor. O_in contains the multiple offsets of corresponding |
| // starting index used in the input/operand tensor. We initialize both of |
| // them with 0. |
| SmallVector<Value, 4> sIn; |
| SmallVector<Value, 4> oIn; |
| Value zeroIndexVal = rewriter.create<arith::IndexCastOp>( |
| loc, rewriter.getIndexType(), zeroIntVal); |
| for (unsigned i = 0; i < operandRank; i++) { |
| sIn.push_back(zeroIndexVal); |
| oIn.push_back(zeroIndexVal); |
| } |
| |
| // batch_induction_vars stores the loop induction variables pertaining to |
| // the batches of start indices. |
| SmallVector<Value, 4> batchInductionVars; |
| // output_induction_vars stores the loop induction variables pertaining to |
| // both batches and offsets within the output tensor. |
| SmallVector<Value, 4> outputInductionVars; |
| // Create loops to iterate over each batch of starting index. |
| for (unsigned i = 0; i < startIndicesRank; i++) { |
| // ith dimension of start_indices doesn't form a batch if it is equal to |
| // index_vector_dim. |
| if (i == indexVectorDim) continue; |
| AffineForOp forOp = |
| rewriter.create<AffineForOp>(loc, 0, startIndicesShape[i]); |
| batchInductionVars.push_back(forOp.getInductionVar()); |
| outputInductionVars.push_back(forOp.getInductionVar()); |
| rewriter.setInsertionPointToStart(forOp.getBody()); |
| } |
| |
| // Create loops to iterate over each offset dimension within the output |
| // tensor. |
| for (unsigned i = 0, k = 0, e = offsetDims.size(); i < e; i++) { |
| AffineForOp forOp = |
| rewriter.create<AffineForOp>(loc, 0, outputShape[offsetDims[i]]); |
| rewriter.setInsertionPointToStart(forOp.getBody()); |
| // We try to fetch the first non-collapsed dimension. |
| while (k < collapsedSliceDims.size() && collapsedSliceDims[k] == i) k++; |
| // Remapping the offset_dim[i] to the non-collapsed dimension. |
| oIn[k++] = forOp.getInductionVar(); |
| // We assume offset_dims to be sorted. So when inserted to |
| // output_induction_vars the loop induction variable gets inserted at the |
| // correct position. |
| outputInductionVars.insert(outputInductionVars.begin() + offsetDims[i], |
| forOp.getInductionVar()); |
| } |
| |
| // Create loops to iterate over all dimensions within the operand tensor. |
| SmallVector<Value, 4> operandIndex; |
| for (unsigned i = 0, k = 0; i < operandRank; i++) { |
| // We assume start_index_map to have sorted dimensions. We only include |
| // those dimensions of operand tensor which are present in |
| // start_index_map. |
| if (k < startIndexMap.size() && i == startIndexMap[k++]) { |
| AffineForOp forOp = |
| rewriter.create<AffineForOp>(loc, 0, operandShape[i]); |
| operandIndex.push_back(forOp.getInductionVar()); |
| rewriter.setInsertionPointToStart(forOp.getBody()); |
| } else { |
| operandIndex.push_back(oIn[i]); |
| } |
| } |
| |
| // In case index_vector_dim is not equal to start_indices shape then we |
| // create another loop to iterate over starting index and update |
| // batch_induction_vars. |
| if (indexVectorDim != startIndicesRank) { |
| for (unsigned i = 0; i < startIndicesNumbers; i++) { |
| batchInductionVars.insert(batchInductionVars.begin() + indexVectorDim, |
| startIndicesIndex[i]); |
| Value startIndex = rewriter.create<AffineLoadOp>(loc, startIndices, |
| batchInductionVars); |
| startIndex = rewriter.create<arith::IndexCastOp>( |
| loc, rewriter.getIndexType(), startIndex); |
| sIn[startIndexMap[i]] = startIndex; |
| batchInductionVars.erase(batchInductionVars.begin() + indexVectorDim); |
| } |
| } else { |
| // Since index_vector_dim is equal to start_indicesRank we can directly |
| // fetch the start_index from batch_induction_vars. |
| Value startIndex = |
| rewriter.create<AffineLoadOp>(loc, startIndices, batchInductionVars); |
| startIndex = rewriter.create<arith::IndexCastOp>( |
| loc, rewriter.getIndexType(), startIndex); |
| sIn[0] = startIndex; |
| } |
| |
| // We load value at a particular operand index and populate the output |
| // tensor if the index constraints match. |
| Value loadValue = rewriter.create<AffineLoadOp>(loc, operand, operandIndex); |
| SmallVector<Value, 4> predicates; |
| // Adding offsets to the corresponding starting index and comparing it with |
| // the corresponding operand index. |
| for (unsigned k = 0, i = 0; k < startIndexMap.size(); k++) { |
| i = startIndexMap[k]; |
| Value addStartIndexOffset = rewriter.create<mlir::arith::AddIOp>( |
| loc, rewriter.getIndexType(), sIn[i], oIn[i]); |
| Value predicate = rewriter.create<mlir::arith::CmpIOp>( |
| loc, arith::CmpIPredicate::eq, addStartIndexOffset, operandIndex[i]); |
| predicates.push_back(predicate); |
| } |
| |
| // Since the no. of predicates is equal to start_index_map.size() we |
| // iterate over pairs of predicates and join them with arith::AndIOp. |
| // We store the final predicate formed by joining other predicates with |
| // arith::AndIOp in result_predicate. |
| Value resultPredicate = nullptr; |
| for (unsigned i = 0; i < predicates.size() - 1; i += 2) { |
| Value predicateA = predicates[i]; |
| Value predicateB = predicates[i + 1]; |
| Value andPredicate = |
| rewriter.create<mlir::arith::AndIOp>(loc, predicateA, predicateB); |
| resultPredicate = (i == 0) ? andPredicate |
| : rewriter.create<mlir::arith::AndIOp>( |
| loc, resultPredicate, andPredicate); |
| } |
| // We fetch the last predicate value. In case this is the only predicate |
| // we let result_predicate be equal to this predicate value. Else if there |
| // are odd number of predicates we join it to other predicates using |
| // arith::AndIOp. |
| Value predicate = predicates.back(); |
| if (!resultPredicate) resultPredicate = predicate; |
| // In case there are odd number of predicates we join the last predicate |
| // to the result_predicate using arith::AndIOp. |
| else if (startIndexMap.size() % 2 == 1) |
| resultPredicate = |
| rewriter.create<mlir::arith::AndIOp>(loc, resultPredicate, predicate); |
| |
| // We use the loaded value if the index computed by adding offsets to |
| // starting index is equal to the current operand index. We use 0 as a value |
| // otherwise. |
| Value selectLoad = rewriter.create<mlir::arith::SelectOp>( |
| loc, resultPredicate, loadValue, zeroLoadValue); |
| // We load value at output array. |
| Value outputValue = |
| rewriter.create<AffineLoadOp>(loc, output, outputInductionVars); |
| |
| // The selected value is added to the previous value stored in output array. |
| if (elementType.isa<FloatType>()) |
| outputValue = rewriter.create<arith::AddFOp>(loc, elementType, selectLoad, |
| outputValue); |
| else |
| outputValue = rewriter.create<arith::AddIOp>(loc, elementType, selectLoad, |
| outputValue); |
| rewriter.create<AffineStoreOp>(loc, outputValue, output, |
| outputInductionVars); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| /// Converts PadOp to Affine nest form. |
| /// Pseudocode: |
| /// 1. Fill `output` tensor with `padding_value`. |
| /// 2. Compute AffineMap for store into `output`. |
| /// out_idx = edge_padding_low + |
| /// operand_idx * (1 + interior_padding) |
| /// 3. Create nested loop from `operand` shape. |
| /// 3.1 load from `operand`. |
| /// 3.2 store into `output`. |
| /// NOTE: The lowering handles only ranked shapes and bails out in case any of |
| /// output type/edge_padding_low/edge_padding_high/interior_padding size |
| /// doesn't match that of the operand's rank. |
| /// Limitation for now: |
| /// interior_padding == 0 && edge_padding_* >= 0 |
| struct PadOpConverter : public OpRewritePattern<PadOp> { |
| using OpRewritePattern<PadOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(PadOp op, |
| PatternRewriter& rewriter) const override { |
| Value operand = op.operand(); |
| Value paddingValue = op.padding_value(); |
| Value output = op.output(); |
| |
| auto operandType = operand.getType().dyn_cast<ShapedType>(); |
| auto outputType = output.getType().dyn_cast<ShapedType>(); |
| // We allow lowering for only ranked input/output. |
| if (!(operandType && outputType && operandType.hasRank() && |
| outputType.hasRank())) |
| return failure(); |
| unsigned rank = operandType.getRank(); |
| |
| auto edgePadLowRanges = op.edge_padding_low().getValues<int64_t>(); |
| auto edgePadHighRanges = op.edge_padding_high().getValues<int64_t>(); |
| auto interiorPadRanges = op.interior_padding().getValues<int64_t>(); |
| // Check whether the constraints for the lowering are satisfied :- |
| // 1. interior_padding[i] == 0 |
| // 2. edge_padding_*[i] >= 0 |
| for (auto paddings : |
| llvm::zip(edgePadLowRanges, edgePadHighRanges, interiorPadRanges)) { |
| // Only handle non-negative edge padding. |
| if (std::get<0>(paddings) < 0 || std::get<1>(paddings) < 0) |
| return rewriter.notifyMatchFailure( |
| op, "expected non-negative edge padding"); |
| // Only handle interior padding being zero for now. |
| if (std::get<2>(paddings) != 0) |
| return rewriter.notifyMatchFailure(op, |
| "expected zero interior padding"); |
| } |
| |
| SmallVector<int64_t> edgePaddingLow(edgePadLowRanges.begin(), |
| edgePadLowRanges.end()); |
| SmallVector<int64_t> edgePaddingHigh(edgePadHighRanges.begin(), |
| edgePadHighRanges.end()); |
| SmallVector<int64_t> interiorPadding(interiorPadRanges.begin(), |
| interiorPadRanges.end()); |
| |
| ArrayRef<int64_t> operandShape = operandType.getShape(); |
| ArrayRef<int64_t> outputShape = outputType.getShape(); |
| |
| // Mapping the `operand` index to the `output` index. |
| SmallVector<AffineExpr, 4> expr; |
| for (unsigned i = 0; i < rank; i++) { |
| AffineExpr dimExpr = rewriter.getAffineDimExpr(i); |
| expr.push_back(dimExpr + edgePaddingLow[i]); |
| } |
| AffineMap map = |
| AffineMap::get(rank, /*symbolCount=*/0, expr, rewriter.getContext()); |
| |
| SmallVector<Value, 4> indices; |
| |
| Location loc = op.getLoc(); |
| // Set padding_value to output. |
| { |
| OpBuilder::InsertionGuard regionGuard(rewriter); |
| Value scalarPaddingValue = rewriter.create<AffineLoadOp>( |
| loc, paddingValue, SmallVector<Value, 4>()); |
| AffineForOp initForOp; |
| for (unsigned i = 0; i < rank; i++) { |
| initForOp = rewriter.create<AffineForOp>(loc, 0, outputShape[i]); |
| indices.push_back(initForOp.getInductionVar()); |
| rewriter.setInsertionPointToStart(initForOp.getBody()); |
| } |
| rewriter.create<AffineStoreOp>(loc, scalarPaddingValue, output, indices); |
| } |
| |
| // Store `operand` into `output`, loop upper bounds from `operand` shape. |
| indices.clear(); |
| AffineForOp padForOp; |
| for (unsigned i = 0; i < rank; i++) { |
| padForOp = rewriter.create<AffineForOp>(loc, 0, operandShape[i]); |
| indices.push_back(padForOp.getInductionVar()); |
| rewriter.setInsertionPointToStart(padForOp.getBody()); |
| } |
| Value storeVal = rewriter.create<AffineLoadOp>(loc, operand, indices); |
| rewriter.create<AffineStoreOp>(loc, storeVal, output, map, indices); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| template <typename LhloOpTy> |
| struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> { |
| using OpRewritePattern<LhloOpTy>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(LhloOpTy op, |
| PatternRewriter& rewriter) const override { |
| const auto& lhs = op.lhs(); |
| const auto& rhs = op.rhs(); |
| const auto& lhsType = lhs.getType().template cast<MemRefType>(); |
| const auto& rhsType = rhs.getType().template cast<MemRefType>(); |
| const auto& elementType = lhsType.getElementType(); |
| |
| if (lhsType.getShape() != rhsType.getShape()) { |
| return failure(); |
| } |
| |
| LogicalResult mapStatus = success(); |
| auto bodyBuilder = [&](OpBuilder& builder, Location loc, |
| ValueRange inductionVars) { |
| auto l = builder.create<AffineLoadOp>(loc, lhs, inductionVars); |
| auto r = builder.create<AffineLoadOp>(loc, rhs, inductionVars); |
| Value opResult = lmhlo::LhloOpToStdScalarOp::map<LhloOpTy>( |
| op, elementType, {l, r}, &builder); |
| mapStatus = success(opResult != nullptr); |
| if (failed(mapStatus)) return; |
| rewriter.create<AffineStoreOp>(loc, opResult, op.out(), inductionVars); |
| }; |
| |
| buildBoundedAffineLoopNest(rewriter, op.getLoc(), lhsType.getShape(), |
| bodyBuilder); |
| if (failed(mapStatus)) return failure(); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| /// Conversion for unary operations i.e. tanh sin cos log log1p etc. |
| template <typename LhloOpTy> |
| struct UnaryOpConverter : public OpRewritePattern<LhloOpTy> { |
| using OpRewritePattern<LhloOpTy>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(LhloOpTy op, |
| PatternRewriter& rewriter) const override { |
| Value input = op.input(); |
| auto inputType = input.getType().cast<MemRefType>(); |
| auto elementType = inputType.getElementType(); |
| ArrayRef<int64_t> shape = inputType.getShape(); |
| |
| SmallVector<Value, 4> inductionVars; |
| |
| LogicalResult mapStatus = success(); |
| auto bodyBuilder = [&](OpBuilder& builder, Location loc, |
| ValueRange inductionVars) { |
| Value loadInput = builder.create<AffineLoadOp>(loc, input, inductionVars); |
| Value opResult = lmhlo::LhloOpToStdScalarOp::map<LhloOpTy>( |
| op, elementType, {loadInput}, &builder); |
| mapStatus = success(opResult != nullptr); |
| if (failed(mapStatus)) return; |
| rewriter.create<AffineStoreOp>(loc, opResult, op.output(), inductionVars); |
| }; |
| buildBoundedAffineLoopNest(rewriter, op.getLoc(), shape, bodyBuilder); |
| if (failed(mapStatus)) return failure(); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| void populateLHLOToAffineConversionPattern(MLIRContext* context, |
| RewritePatternSet* patterns) { |
| // clang-format off |
| patterns->add< |
| BinaryOpConverter<lmhlo::AddOp>, |
| BinaryOpConverter<lmhlo::AndOp>, |
| BinaryOpConverter<lmhlo::DivOp>, |
| BinaryOpConverter<lmhlo::MaxOp>, |
| BinaryOpConverter<lmhlo::MinOp>, |
| BinaryOpConverter<lmhlo::MulOp>, |
| BinaryOpConverter<lmhlo::SubOp>, |
| ConcatOpConverter, |
| DotOpConverter, |
| GatherOpConverter, |
| PadOpConverter, |
| UnaryOpConverter<lmhlo::LogOp>>(context); |
| // clang-format on |
| } |
| |
| struct LhloLegalizeToAffinePass |
| : public LhloLegalizeToAffinePassBase<LhloLegalizeToAffinePass> { |
| void getDependentDialects(DialectRegistry& registry) const override { |
| registry.insert<AffineDialect, math::MathDialect>(); |
| } |
| void runOnOperation() override { |
| auto func = getOperation(); |
| RewritePatternSet patterns(&getContext()); |
| populateLHLOToAffineConversionPattern(&getContext(), &patterns); |
| if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) |
| return signalPassFailure(); |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<func::FuncOp>> createLhloLegalizeToAffinePass() { |
| return std::make_unique<LhloLegalizeToAffinePass>(); |
| } |
| |
| } // namespace lmhlo |
| } // namespace mlir |