blob: 01fb68e4478c499a603081394f6e4e36d40ca022 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering lmhlo.while to tfrt.while.
#include <utility>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/DialectRegistry.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime
#include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
#include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime
namespace tensorflow {
namespace {
class ConvertLmhloToTfrtWhilePass
: public PassWrapper<ConvertLmhloToTfrtWhilePass, OperationPass<ModuleOp>> {
public:
ConvertLmhloToTfrtWhilePass() = default;
ConvertLmhloToTfrtWhilePass(const ConvertLmhloToTfrtWhilePass&) {}
llvm::StringRef getArgument() const override { return "lmhlo-to-tfrt-while"; }
llvm::StringRef getDescription() const override {
return "Convert lmhlo.while to tfrt.while.";
}
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<memref::MemRefDialect, tfrt::compiler::TFRTDialect,
tfrt::gpu::GpuDialect>();
}
void runOnOperation() override;
};
// Replaces lmhlo ops within a module with tfrt.while ops.
//
// "lmhlo.while"(%cond) ({
// <cond_ops>
// return
// }, {
// <body_ops>
// return
// }) : (memref<i1>) -> ()
//
// is rewritten to:
//
// func.func @while_cond(%cond, <captures>) : (memref<i1>, ...) -> i1 {
// <cond_ops>
// %value = memref.load %cond : memref<i1>
// return %value : i1
// }
// func.func @while_body(%cond, <captures>)
// : (memref<i1>, ...) -> (memref<i1>, ..., i1) {
// <body_ops>
// %value = tfrt.call @while_cond(%cond, <captures>)
// return %cond, <captures>, %value
// }
// %value = tfrt.call @while_cond(%cond, <captures>)
// %results:N = tfrt.while %value @while_body(%cond, <captures>)
//
struct WhilePattern : OpRewritePattern<lmhlo::WhileOp> {
WhilePattern(MLIRContext* context, SymbolTable& symbol_table)
: OpRewritePattern(context), symbol_table(symbol_table) {}
LogicalResult matchAndRewrite(lmhlo::WhileOp while_op,
PatternRewriter& rewriter) const override;
SymbolTable& symbol_table;
};
LogicalResult WhilePattern::matchAndRewrite(lmhlo::WhileOp while_op,
PatternRewriter& rewriter) const {
if (while_op->getNumOperands() != 1)
return rewriter.notifyMatchFailure(while_op, "expected single condition");
if (while_op.trip_count())
return rewriter.notifyMatchFailure(while_op, "trip count not supported");
// Collect condition value and implicit captures.
llvm::SetVector<Value> while_args;
while_args.insert(while_op.cond_val().front());
getUsedValuesDefinedAbove(while_op.getOperation()->getRegions(), while_args);
auto return_types = llvm::to_vector<4>(TypeRange(while_args.getArrayRef()));
auto i1_type = rewriter.getI1Type();
return_types.push_back(i1_type);
auto argument_types = TypeRange(return_types).drop_back();
// Clones single-block 'region' into 'func' and returns the `func.return` op.
auto clone_region = [&](Region& region, func::FuncOp func) {
Block* block = func.addEntryBlock();
BlockAndValueMapping mapping;
for (auto pair : llvm::zip_first(while_args, block->getArguments()))
mapping.map(std::get<0>(pair), std::get<1>(pair));
rewriter.cloneRegionBefore(region, func.getRegion(), func.end(), mapping);
// Merge cloned block into entry block.
rewriter.mergeBlocks(&func.back(), block);
rewriter.setInsertionPointToEnd(block);
Operation* terminator = block->getTerminator();
return rewriter.replaceOpWithNewOp<func::ReturnOp>(terminator);
};
// Insert while_cond function.
rewriter.setInsertionPoint(while_op->getParentOfType<func::FuncOp>());
auto cond_func_type = rewriter.getFunctionType(argument_types, i1_type);
auto cond_func = rewriter.create<func::FuncOp>(while_op.cond().getLoc(),
"while_cond", cond_func_type);
symbol_table.insert(cond_func);
auto cond_return = clone_region(while_op.cond(), cond_func);
rewriter.setInsertionPoint(cond_return);
Value cond_result = rewriter.create<memref::LoadOp>(cond_return.getLoc(),
cond_func.getArgument(0));
cond_return->setOperands(cond_result);
// Insert while_body function.
rewriter.setInsertionPointAfter(cond_func);
auto body_func_type = rewriter.getFunctionType(argument_types, return_types);
auto body_func = rewriter.create<func::FuncOp>(while_op.body().getLoc(),
"while_body", body_func_type);
symbol_table.insert(body_func);
auto body_return = clone_region(while_op.body(), body_func);
rewriter.setInsertionPoint(body_return);
auto body_call = rewriter.create<tfrt::compiler::CallOp>(
body_return.getLoc(), i1_type, cond_func.getSymName(),
body_func.getArguments());
body_return->setOperands(body_call.getResults());
body_return->insertOperands(0, body_func.getArguments());
// Replace lmhlo.while with calls to cond and body functions.
rewriter.setInsertionPoint(while_op);
auto while_call = rewriter.create<tfrt::compiler::CallOp>(
while_op.getLoc(), i1_type, cond_func.getSymName(),
while_args.getArrayRef());
rewriter.create<tfrt::compiler::WhileOp>(
while_op.getLoc(), argument_types, while_call.getResult(0),
while_args.getArrayRef(), body_func.getSymName());
rewriter.eraseOp(while_op);
return success();
}
} // namespace
void ConvertLmhloToTfrtWhilePass::runOnOperation() {
RewritePatternSet patterns(&getContext());
SymbolTable symbol_table(getOperation());
patterns.add<WhilePattern>(&getContext(), symbol_table);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<Pass> createConvertLmhloToGpuWhilePass() {
return std::make_unique<ConvertLmhloToTfrtWhilePass>();
}
void registerConvertLmhloToGpuWhilePass() {
::registerPass([] { return createConvertLmhloToGpuWhilePass(); });
}
} // namespace tensorflow