blob: 35893f2ac8d15306d398234e69c1c6155299ccfb [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass optimizes tf_saved_model.global_tensor ops.
#include <cstddef>
#include <map>
#include <set>
#include "llvm/ADT/DenseMap.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/savedmodel_passes_detail.h"
namespace mlir {
namespace tf_saved_model {
namespace {
struct OptimizeGlobalTensorsPass
: public OptimizeGlobalTensorsPassBase<OptimizeGlobalTensorsPass> {
void runOnOperation() override;
};
// A global tensor is bound to arguments of multiple funcs.
// This struct tracks which funcs (and which argument to that func) the global
// tensor is bound to.
struct GlobalTensorUse {
mutable func::FuncOp func;
size_t arg_index;
};
using GlobalTensorUsesMap =
std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
bool IsImmutable(GlobalTensorOp global_tensor,
ArrayRef<GlobalTensorUse> global_tensor_uses,
const TF::ResourceAnalyzer& resource_analyzer) {
// Global tensor is already known to be immutable.
if (!global_tensor.is_mutable()) {
return false;
}
// An exported global tensor that is not already known to be immutable might
// be externally mutated.
if (IsExported(global_tensor)) {
return false;
}
// A global tensor is immutable if the resource analyzer deems it so.
for (auto& global_tensor_use : global_tensor_uses) {
auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
if (resource_analyzer.IsPotentiallyWritten(arg)) {
return false;
}
}
return true;
}
GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
GlobalTensorUsesMap global_tensor_uses;
SymbolTable symbol_table(module);
for (auto func : module.getOps<func::FuncOp>()) {
for (size_t i = 0, e = func.getNumArguments(); i < e; i++) {
auto sym =
func.getArgAttrOfType<SymbolRefAttr>(i, "tf_saved_model.bound_input");
if (!sym) {
continue;
}
auto global_tensor = symbol_table.lookup<GlobalTensorOp>(
sym.cast<FlatSymbolRefAttr>().getValue());
if (!global_tensor) {
continue;
}
global_tensor_uses[global_tensor].push_back({func, i});
}
}
return global_tensor_uses;
}
// Removes `is_mutable` attribute from tf_saved_model.global_tensor ops where we
// can prove it is safe to do so.
void MarkGlobalTensorsImmutable(
ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map,
const TF::ResourceAnalyzer& resource_analyzer) {
for (const auto& kv : global_tensor_uses_map) {
auto global_tensor = kv.first;
const auto& global_tensor_uses = kv.second;
if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
global_tensor->removeAttr("is_mutable");
}
}
}
void EraseUnusedGlobalTensors(ModuleOp module,
const GlobalTensorUsesMap& global_tensor_uses) {
for (auto global_tensor :
llvm::make_early_inc_range(module.getOps<GlobalTensorOp>())) {
// If the tensor is exported, then it is used.
if (IsExported(global_tensor)) {
continue;
}
// If the tensor is bound to an argument, then it is used.
if (global_tensor_uses.find(global_tensor) != global_tensor_uses.end()) {
continue;
}
// Erase it.
global_tensor.erase();
}
}
void EraseUnusedBoundInputs(ModuleOp module) {
for (auto func : module.getOps<func::FuncOp>()) {
llvm::BitVector args_to_erase(func.getNumArguments());
for (int i = 0, e = func.getNumArguments(); i < e; i++) {
if (func.getArgAttr(i, "tf_saved_model.bound_input") &&
func.getArgument(i).use_empty()) {
args_to_erase.set(i);
}
}
func.eraseArguments(args_to_erase);
}
}
void OptimizeGlobalTensorsPass::runOnOperation() {
auto module = getOperation();
if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
return;
}
EraseUnusedBoundInputs(module);
TF::ResourceAnalyzer resource_analyzer(module);
GlobalTensorUsesMap global_tensor_uses = CreateGlobalTensorUsesMap(module);
MarkGlobalTensorsImmutable(module, global_tensor_uses, resource_analyzer);
EraseUnusedGlobalTensors(module, global_tensor_uses);
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
return std::make_unique<OptimizeGlobalTensorsPass>();
}
} // namespace tf_saved_model
} // namespace mlir