blob: a2f0539896ee68f8448e34f5d08a2e0fdf6f6a9b [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <utility>
#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
namespace tensorflow {
namespace {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
using llvm::makeArrayRef;
using mlir::BlockAndValueMapping;
using mlir::BlockArgument;
using mlir::dyn_cast;
using mlir::failure;
using mlir::Location;
using mlir::LogicalResult;
using mlir::MLIRContext;
using mlir::OpBuilder;
using mlir::Operation;
using mlir::OpFoldResult;
using mlir::OpRewritePattern;
using mlir::PatternRewriter;
using mlir::SmallVector;
using mlir::success;
using mlir::Value;
using mlir::ValueRange;
using mlir::gml_st::LoopOp;
using mlir::linalg::FillOp;
using mlir::linalg::GenericOp;
using mlir::linalg::InitTensorOp;
using mlir::linalg::LinalgOp;
using mlir::linalg::YieldOp;
using mlir::tensor::ExtractSliceOp;
using mlir::tensor::InsertSliceOp;
SmallVector<OpFoldResult> GetParallelDimStep(LoopOp tiled_loop) {
assert(tiled_loop.getNumLoops() == 2 && "Expected a 2D loop");
Value step = tiled_loop.isParallelDimension(0) ? tiled_loop.step().front()
: tiled_loop.step().back();
if (auto constant = step.getDefiningOp<mlir::arith::ConstantOp>()) {
return {constant.getValue()};
}
return {step};
}
// Fuses `linalg.fill` into a loop with a tiled reduction.
// Currently, only 2D case is supported. Fusion into a tiled 1D reduction is
// also possible.
struct FuseFillIntoTiledReductionPattern : public OpRewritePattern<GenericOp> {
explicit FuseFillIntoTiledReductionPattern(MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit) {}
LogicalResult matchAndRewrite(GenericOp linalg_op,
PatternRewriter &rewriter) const override {
if (linalg_op.getNumOutputs() != 1) return failure();
if (linalg_op.getNumLoops() != 2) return failure();
// Get immediate parent.
auto tiled_loop_op =
dyn_cast<LoopOp>(linalg_op->getParentRegion()->getParentOp());
if (!tiled_loop_op) return failure();
if (tiled_loop_op.getNumLoops() != 2) return failure();
return RewriteTiledReduction(rewriter, tiled_loop_op, linalg_op);
}
private:
// Add a new output argument to the `tiled_loop`. It will be produced by
// `init_tensor` op with the same shape of the tiled output argument.
//
// Rewrite
//
// %init = linalg.init_tensor
// %fill = linalg.fill(%cst, %init)
// linalg.tiled_loop outs(%fill)
//
// into
//
// %init = linalg.init_tensor
//** %init_tile = linalg.init_tensor [%stride]
// %fill = linalg.fill(%cst, %init)
//** linalg.tiled_loop outs(%fill, %init_tile)
BlockArgument CloneAndAppendInitTensorToTiledLoop(PatternRewriter &rewriter,
FillOp fill,
LoopOp tiled_loop) const {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(fill);
auto init = fill.output().getDefiningOp<InitTensorOp>();
Value init_clone = rewriter.create<InitTensorOp>(
init.getLoc(), GetParallelDimStep(tiled_loop),
init.getType().cast<mlir::RankedTensorType>().getElementType());
mlir::OpOperand *init_clone_output_operand;
rewriter.updateRootInPlace(tiled_loop, [&]() {
init_clone_output_operand =
&tiled_loop.appendOutputOperand(rewriter, init_clone);
});
return tiled_loop.getTiedBlockArgument(*init_clone_output_operand);
}
// Fuse `fill` operation into the `tiled_loop`, rewire the `linalg.generic` to
// use it as the output for the reduced tile. Also create an additional
// `insert_slice` that updates the new output.
//
// Rewrite
//
// %init = linalg.init_tensor
// %init_tile = linalg.init_tensor [%stride]
// %fill = linalg.fill(%cst, %init)
// linalg.tiled_loop outs(%fill, %init_tile) {
// %extract_output_slice = tensor.extract_slice %fill
// %reduce = linalg.generic outs (%extract_output_slice)
// %insert_output_slice = tensor.insert_slice %reduce into %fill
// linalg.yield %insert_output_slice
// }
//
// into
//
// %init = linalg.init_tensor
// %init_tile = linalg.init_tensor
// %fill = linalg.fill(%cst, %init)
// linalg.tiled_loop outs(%fill, %init_tile) {
// %extract_output_slice = tensor.extract_slice %fill
//
//** %slice_of_output_tile = tensor.extract_slice %init
//** %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile)
//** %reduce = linalg.generic outs (%fill_of_output_tile)
//** %update_output_tile = tensor.insert_slice %reduce into %init_tile
//
// %insert_output_slice = tensor.insert_slice %reduce into %fill
// linalg.yield %insert_output_slice, %update_output_tile
// }
void FuseFill(PatternRewriter &rewriter, LinalgOp tiled_op, FillOp fill,
BlockArgument loop_output_bb_arg,
BlockArgument output_tile_bb_arg,
ExtractSliceOp extract_output_slice,
InsertSliceOp insert_output_slice) const {
Location loc = tiled_op.getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(tiled_op);
SmallVector<OpFoldResult> offset{rewriter.getIndexAttr(0)};
Value slice_of_output_tile = rewriter.create<ExtractSliceOp>(
loc, output_tile_bb_arg, offset, extract_output_slice.getMixedSizes(),
extract_output_slice.getMixedStrides());
auto fused_fill =
rewriter.create<FillOp>(loc, fill.value(), slice_of_output_tile);
rewriter.updateRootInPlace(tiled_op, [&]() {
tiled_op.getOutputOperand(0)->set(fused_fill.result());
});
rewriter.setInsertionPointAfter(tiled_op);
Value cloned_insert = rewriter.create<mlir::tensor::InsertSliceOp>(
loc, fused_fill.getResult(0), output_tile_bb_arg, offset,
extract_output_slice.getMixedSizes(),
extract_output_slice.getMixedStrides());
auto yield = tiled_op.getOperation()->getBlock()->getTerminator();
rewriter.updateRootInPlace(
yield, [&]() { yield->insertOperands(1, cloned_insert); });
}
// Add an operation that combines the partial result with the output.
//
// Rewrite
//
// %init = linalg.init_tensor
// %init_tile = linalg.init_tensor
// %fill = linalg.fill(%cst, %init)
// linalg.tiled_loop outs(%fill, %init_tile) {
// %extract_output_slice = tensor.extract_slice %fill
//
// %slice_of_output_tile = tensor.extract_slice %init
// %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile)
// %reduce = linalg.generic outs (%fill_of_output_tile)
// %update_output_tile = tensor.insert_slice %reduce into %init_tile
//
// %insert_output_slice = tensor.insert_slice %reduce into %fill
// linalg.yield %insert_output_slice, %update_output_tile
// }
//
// into
//
// %init = linalg.init_tensor
// %init_tile = linalg.init_tensor
// %fill = linalg.fill(%cst, %init)
// linalg.tiled_loop outs(%fill, %init_tile) {
// %extract_output_slice = tensor.extract_slice %fill
//
// %slice_of_output_tile = tensor.extract_slice %init
// %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile)
// %reduce = linalg.generic outs (%fill_of_output_tile)
// %update_output_tile = tensor.insert_slice %reduce into %init_tile
//
//** %combine = linalg.generic ins (%reduce) outs (%extract_output_slice)
//** %insert_output_slice = tensor.insert_slice %combine into %fill
//
// linalg.yield %insert_output_slice, %update_output_tile
// }
LogicalResult CombineReducedTileWithOutput(
PatternRewriter &rewriter, LinalgOp tiled_op, Value partial_result,
ExtractSliceOp extract_output_slice,
InsertSliceOp insert_output_slice) const {
rewriter.setInsertionPointAfter(tiled_op);
auto num_parallel_loops = tiled_op.getNumParallelLoops();
SmallVector<mlir::StringRef, 3> parallel_iter_types(
num_parallel_loops, mlir::getParallelIteratorTypeName());
auto id_map = rewriter.getMultiDimIdentityMap(num_parallel_loops);
auto combiner_or = DetectCombiner(tiled_op);
if (failed(combiner_or)) return failure();
Operation *combiner = combiner_or.getValue();
auto accumulator = rewriter.create<GenericOp>(
tiled_op.getLoc(), partial_result.getType(),
makeArrayRef(partial_result),
makeArrayRef(extract_output_slice.result()),
makeArrayRef({id_map, id_map}), parallel_iter_types,
[&](OpBuilder &b, Location nested_loc, ValueRange args) {
BlockAndValueMapping bvm;
bvm.map(combiner->getOperands(), args);
Value result_val = b.clone(*combiner, bvm)->getResult(0);
b.create<YieldOp>(nested_loc, result_val);
});
rewriter.updateRootInPlace(insert_output_slice, [&]() {
insert_output_slice.sourceMutable().assign(accumulator.getResult(0));
});
return success();
}
// Unfortunaly, there is no way to modify the results of the loop inplace. So
// we have to replace it with a clone.
LoopOp CreateLoopWithUpdatedResults(PatternRewriter &rewriter,
LoopOp tiled_loop) const {
auto loc = tiled_loop.getLoc();
rewriter.setInsertionPoint(tiled_loop);
auto new_loop = rewriter.create<LoopOp>(
loc, mlir::TypeRange(tiled_loop.outputs()), tiled_loop.getOperands(),
tiled_loop->getAttrs());
rewriter.inlineRegionBefore(tiled_loop.region(), new_loop.region(),
new_loop.region().begin());
rewriter.replaceOp(tiled_loop, new_loop.getResult(0));
return new_loop;
}
// Fuses FillOp producer of the output argument of the LoopOp and inserts
// an operation that accumulates the partial result, i.e. reduced tile, and
// the current value of the output tile.
LogicalResult RewriteTiledReduction(PatternRewriter &rewriter,
LoopOp tiled_loop,
LinalgOp tiled_op) const {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(tiled_op);
// Find tiled loop output operand and the corresponding block argument.
mlir::OpOperand *loop_output_operand =
tiled_loop.findOutputOperand(tiled_loop.outputs().front());
BlockArgument loop_output_bb_arg =
tiled_loop.getTiedBlockArgument(*loop_output_operand);
// Find `linalg.fill` producer of the output.
auto fill = loop_output_operand->get().getDefiningOp<FillOp>();
if (!fill) return failure();
// Find extract_slice/insert_slice pair used to RMW output.
auto extract_output_slice =
tiled_op.getOutputOperand(0)->get().getDefiningOp<ExtractSliceOp>();
if (!extract_output_slice) return failure();
Value tiled_op_result = tiled_op->getResult(0);
auto insert_output_slice =
dyn_cast<InsertSliceOp>(*tiled_op_result.getUsers().begin());
if (!insert_output_slice) return failure();
// Fuse the output.
BlockArgument output_tile_bb_arg =
CloneAndAppendInitTensorToTiledLoop(rewriter, fill, tiled_loop);
FuseFill(rewriter, tiled_op, fill, loop_output_bb_arg, output_tile_bb_arg,
extract_output_slice, insert_output_slice);
// We have already modified the loop above, so we need to update the
// results.
CreateLoopWithUpdatedResults(rewriter, tiled_loop);
return CombineReducedTileWithOutput(rewriter, tiled_op, tiled_op_result,
extract_output_slice,
insert_output_slice);
}
};
struct FuseFillIntoTiledReductionPass
: public FuseFillIntoTiledReductionBase<FuseFillIntoTiledReductionPass> {
void runOnOperation() override {
auto func = getOperation();
auto context = func.getContext();
mlir::RewritePatternSet patterns(context);
patterns.add<FuseFillIntoTiledReductionPattern>(context);
(void)mlir::applyPatternsAndFoldGreedily(func, std::move(patterns));
}
};
} // namespace
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateFuseFillIntoTiledReductionPass() {
return std::make_unique<FuseFillIntoTiledReductionPass>();
}
} // namespace tensorflow