blob: 970fa7ae959438b49c6fd8c887a1004f52a40177 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
namespace tensorflow {
namespace tfrt_compiler {
namespace {
using ::mlir::ArrayRef;
using ::mlir::ModuleOp;
using ::mlir::Operation;
using ::mlir::SymbolTable;
using ::mlir::SymbolTableCollection;
using ::mlir::SymbolUserMap;
// This only includes some preliminary checks as this is a short term solution.
bool AreEquivalent(mlir::func::FuncOp& lhs, mlir::func::FuncOp& rhs) {
if (lhs.getFunctionType() != rhs.getFunctionType()) return false;
for (auto arg_pair : llvm::zip(lhs.getArguments(), rhs.getArguments())) {
auto& lhs_arg = std::get<0>(arg_pair);
auto& rhs_arg = std::get<1>(arg_pair);
if (lhs_arg.getType() != rhs_arg.getType()) return false;
}
auto lhs_ops = lhs.getBody().getOps();
auto rhs_ops = rhs.getBody().getOps();
if (std::distance(lhs_ops.begin(), lhs_ops.end()) !=
std::distance(rhs_ops.begin(), rhs_ops.end()))
return false;
for (auto op_pair : llvm::zip(lhs_ops, rhs_ops)) {
auto& lhs_op = std::get<0>(op_pair);
auto& rhs_op = std::get<1>(op_pair);
if (lhs_op.getName() != rhs_op.getName()) return false;
if (lhs_op.getNumRegions() != rhs_op.getNumRegions()) return false;
if (lhs_op.getNumSuccessors() != rhs_op.getNumSuccessors()) return false;
if (!std::equal(lhs_op.getOperandTypes().begin(),
lhs_op.getOperandTypes().end(),
rhs_op.getOperandTypes().begin()))
return false;
if (!std::equal(lhs_op.getResultTypes().begin(),
lhs_op.getResultTypes().end(),
rhs_op.getResultTypes().begin()))
return false;
}
return true;
}
// Deduplicate the functions if all users are BatchFunctionOp and have the same
// shared_name.
//
// TODO(b/192463730): this is the short term solution and not needed anymore
// after the shape inference pass is revamped with ideal solution
// (b/192463730#comment11).
class DeduplicateFunctionsInovkedByBatchFunction
: public mlir::PassWrapper<DeduplicateFunctionsInovkedByBatchFunction,
mlir::OperationPass<mlir::ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
DeduplicateFunctionsInovkedByBatchFunction)
private:
llvm::StringRef getArgument() const final {
return "tfrt-deduplicate-functions-invoked-by-batch-function";
}
llvm::StringRef getDescription() const final {
return "Deduplicate the functions invoked by tf.BatchFunction with the "
"same shared_name";
}
void runOnOperation() override {
if (failed(Run())) {
signalPassFailure();
}
}
mlir::LogicalResult Run();
};
mlir::LogicalResult DeduplicateFunctionsInovkedByBatchFunction::Run() {
ModuleOp module = getOperation();
SymbolTableCollection symbol_table_collection;
SymbolTable& symbol_table = symbol_table_collection.getSymbolTable(module);
SymbolUserMap symbol_users(symbol_table_collection, module);
// Categorize the functions invoked by BatchFunctionOp by its shared_name.
llvm::StringMap<llvm::SmallVector<mlir::func::FuncOp, 2>>
shared_name_to_func_ops;
for (auto func :
llvm::make_early_inc_range(module.getOps<mlir::func::FuncOp>())) {
ArrayRef<Operation*> users = symbol_users.getUsers(func);
llvm::StringRef shared_name;
// Deduplicate the function only if all users are BatchFunctionOp and have
// the same shared_name
if (!users.empty() && llvm::all_of(users, [&shared_name](Operation* user) {
auto op = llvm::dyn_cast_or_null<mlir::TF::BatchFunctionOp>(user);
// User is not a BatchFunctionOp
if (!op) return false;
if (shared_name.empty()) {
shared_name = op.shared_name();
return true;
}
return shared_name == op.shared_name();
})) {
shared_name_to_func_ops[shared_name].push_back(func);
}
}
for (auto& it : shared_name_to_func_ops) {
auto& func_ops = it.second;
mlir::func::FuncOp& func_op_to_keep = func_ops.front();
for (mlir::func::FuncOp& func_op_to_remove : llvm::drop_begin(func_ops)) {
if (!AreEquivalent(func_op_to_keep, func_op_to_remove)) {
return func_op_to_remove.emitError(
"func_ops for BatchFunctionOp with the same shared name are "
"different");
}
if (failed(SymbolTable::replaceAllSymbolUses(
func_op_to_remove, func_op_to_keep.getSymNameAttr(), module))) {
return func_op_to_remove.emitError("unable to replace the symbol use");
}
symbol_table.erase(func_op_to_remove);
}
}
return mlir::success();
}
} // namespace
std::unique_ptr<mlir::OperationPass<ModuleOp>>
CreateDeduplicateFunctionsInovkedByBatchFunctionPass() {
return std::make_unique<DeduplicateFunctionsInovkedByBatchFunction>();
}
static mlir::PassRegistration<DeduplicateFunctionsInovkedByBatchFunction>
register_pass;
} // namespace tfrt_compiler
} // namespace tensorflow