blob: cfb8be29816c2bb807f069f29e2764d2248a2578 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <vector>
#include "llvm/ADT/BitVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Analysis/DataFlowAnalysis.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/savedmodel_passes_detail.h"
namespace mlir {
namespace tf_saved_model {
namespace {
// The value of our lattice represents the GlobalTensorOp matching the value.
struct ResourceLatticeValue {
explicit ResourceLatticeValue(GlobalTensorOp op = nullptr) {
if (op) ops.insert(op);
}
static ResourceLatticeValue getPessimisticValueState(MLIRContext *context) {
return ResourceLatticeValue();
}
static ResourceLatticeValue getPessimisticValueState(Value value) {
if (auto barg = value.dyn_cast<BlockArgument>()) {
if (func::FuncOp func =
dyn_cast<func::FuncOp>(barg.getOwner()->getParentOp())) {
SymbolTable symbol_table(func->getParentOfType<ModuleOp>());
auto global_tensor = LookupBoundInputOfType<GlobalTensorOp>(
func, barg.getArgNumber(), symbol_table);
return ResourceLatticeValue(global_tensor);
}
}
return ResourceLatticeValue();
}
bool operator==(const ResourceLatticeValue &rhs) const {
return ops == rhs.ops;
}
static ResourceLatticeValue join(const ResourceLatticeValue &lhs,
const ResourceLatticeValue &rhs) {
// Take union of both sets of possible GlobalTensorOp values that can be
// referenced here.
ResourceLatticeValue ret;
ret.ops.insert(lhs.ops.begin(), lhs.ops.end());
ret.ops.insert(rhs.ops.begin(), rhs.ops.end());
return ret;
}
// The location which originated the int value.
DenseSet<GlobalTensorOp> ops;
};
class ResourceAnalysis : public ForwardDataFlowAnalysis<ResourceLatticeValue> {
public:
using LatticeElementT = LatticeElement<ResourceLatticeValue>;
using ForwardDataFlowAnalysis<ResourceLatticeValue>::ForwardDataFlowAnalysis;
~ResourceAnalysis() override = default;
ChangeResult visitOperation(Operation *op,
ArrayRef<LatticeElementT *> operands) override {
return markAllPessimisticFixpoint(op->getResults());
}
};
struct FreezeGlobalTensorsPass
: public FreezeGlobalTensorsPassBase<FreezeGlobalTensorsPass> {
explicit FreezeGlobalTensorsPass(bool allow_mutable_tensors) {
this->allow_mutable_tensors = allow_mutable_tensors;
}
void runOnOperation() override;
};
void FreezeGlobalTensorsPass::runOnOperation() {
auto module = getOperation();
if (!tf_saved_model::HasTfSavedModelSemantics(module)) return;
ResourceAnalysis analysis(&getContext());
analysis.run(module);
DenseSet<GlobalTensorOp> remaining_global_tensor_ops;
{
auto ops = module.getOps<GlobalTensorOp>();
remaining_global_tensor_ops.insert(ops.begin(), ops.end());
}
for (auto global_tensor : remaining_global_tensor_ops) {
// This pass assumes that all global tensors as immutable (e.g. by a
// previous optimize global tensors pass). If not, this pass has to fail
// since it cannot perform one of its goals.
if (global_tensor.is_mutable()) {
if (allow_mutable_tensors) continue;
global_tensor.emitError()
<< "is not immutable, try removing mutable variables in your model "
"since mutable variables are currently not supported through "
"this converter";
return signalPassFailure();
}
}
// Collect all those freezable. This is an extra scan but allows for the
// partial behavior from `allow_mutable_tensor`.
DenseMap<BlockArgument, bool> freezeable;
for (auto func : module.getOps<func::FuncOp>()) {
for (BlockArgument val : func.getArguments()) {
if (!getElementTypeOrSelf(val.getType()).isa<TF::ResourceType>())
continue;
// Check that there is only a single global tensor associated with arg.
LatticeElement<ResourceLatticeValue> *latticeElement =
analysis.lookupLatticeElement(val);
if (!latticeElement || latticeElement->getValue().ops.size() != 1)
continue;
// Don't freeze mutable tensors.
if (latticeElement->getValue().ops.begin()->is_mutable()) {
freezeable[val] = false;
continue;
}
freezeable[val] = true;
// Verify users are supported kind.
for (Operation *user : val.getUsers()) {
if (!(isa<TF::ReadVariableOp>(user) || isa<CallOpInterface>(user))) {
freezeable[val] = false;
// Error out early if possible.
if (!allow_mutable_tensors) {
user->emitError()
<< "could not rewrite use of immutable bound input";
return signalPassFailure();
}
}
}
}
}
DenseSet<GlobalTensorOp> frozen_global_tensors;
for (auto func : module.getOps<func::FuncOp>()) {
llvm::BitVector args_to_erase(func.getNumArguments());
DenseMap<Operation *, llvm::BitVector> remove_operands;
OpBuilder builder(func.getBody());
for (BlockArgument val : func.getArguments()) {
if (!freezeable[val]) continue;
LatticeElement<ResourceLatticeValue> *latticeElement =
analysis.lookupLatticeElement(val);
GlobalTensorOp global_tensor = *latticeElement->getValue().ops.begin();
SmallVector<TF::ReadVariableOp, 4> read_variable_ops_to_erase;
frozen_global_tensors.insert(global_tensor);
for (Operation *user : val.getUsers()) {
if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
// Collect all read variable ops so that all its uses can be replaced
// with the tf.constant corresponding to the global tensor op.
read_variable_ops_to_erase.push_back(read_op);
} else {
llvm::BitVector &bvector = remove_operands[user];
bvector.resize(user->getNumOperands());
for (OpOperand &use : user->getOpOperands())
bvector.set(use.getOperandNumber());
}
}
// Replace the arg with a tf.Const op in the function body.
builder.setInsertionPointToStart(&func.getBody().front());
auto const_op = builder.create<TF::ConstOp>(global_tensor.getLoc(),
global_tensor.value());
args_to_erase.set(val.getArgNumber());
for (auto read_op : read_variable_ops_to_erase) {
read_op.getResult().replaceAllUsesWith(const_op.getResult());
read_op.erase();
}
}
// As the other uses are call operations, we simply remove the arguments
// as the function arguments will be removed below once that function is
// processed.
for (auto it : remove_operands) {
it.first->eraseOperands(it.second);
}
func.eraseArguments(args_to_erase);
}
// Erase all global tensors that were frozen.
for (auto global_tensor : frozen_global_tensors) {
remaining_global_tensor_ops.erase(global_tensor);
global_tensor->erase();
}
// Verify that there are no remaining global tensors.
if (!allow_mutable_tensors && !remaining_global_tensor_ops.empty()) {
module.emitError() << "could not freeze all global tensors in the module";
return signalPassFailure();
}
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass(
bool allow_mutable_tensors) {
return std::make_unique<FreezeGlobalTensorsPass>(allow_mutable_tensors);
}
} // namespace tf_saved_model
} // namespace mlir