blob: 56ee3f377821e64d49b34d75a0b4b6141d530d19 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
namespace {
// Attribute name to be added on the module to identify whether
// variables should be legalized to TFLite or not.
const char kLegalizeTflVariables[] = "tfl._legalize_tfl_variables";
// Returns true if 'op' is TF op that accepts resource type, but is
// supported by TFLite.
bool IsSupportedTFLiteResourceOp(Operation* op) {
return llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp, TF::VarHandleOp,
TF::LookupTableFindV2Op, TF::LookupTableImportV2Op,
TF::LookupTableSizeV2Op>(op);
}
// Returns true if 'op' is TF/TFLite control flow op that can accept resource
// type. Usually these ops are just pass through, they call another subgraph and
// pass the operands to.
bool IsSupportedTFLiteControlFlow(Operation* op) {
return llvm::isa<TFL::WhileOp, TFL::IfOp, TF::IfOp, TFL::CallOnceOp>(op);
}
} // namespace
// Pass which analyzes the variables in the graph and add an attribute whether
// variables should be legalized to TFLite native ones.
// This pass needs to run post TF->TFL legalization and before variable
// legalization.
class AnalyzeVariablesPass
: public PassWrapper<AnalyzeVariablesPass, OperationPass<ModuleOp>> {
public:
AnalyzeVariablesPass() = default;
AnalyzeVariablesPass(const AnalyzeVariablesPass&) {}
StringRef getArgument() const final {
// This is the argument used to refer to the pass in
// the textual format (on the commandline for example).
return "tfl-analyze-variables-pass";
}
StringRef getDescription() const final {
// This is a brief description of the pass.
return "Analyze variables in the graph.";
}
void runOnOperation() override {
auto* context = &getContext();
auto module = getOperation();
bool legalize_to_tfl = true;
module.walk([&](Operation* op) {
// Skip ops that are supported natively by TFLite.
if (IsSupportedTFLiteResourceOp(op)) return WalkResult::advance();
if (IsSupportedTFLiteControlFlow(op)) return WalkResult::advance();
// Check for ops that are legalized to TFLite.
if (op->getDialect()->getNamespace() == "tfl") {
return WalkResult::advance();
}
// Check for ops that are not legalized to TFLite.
// If any of the operands is a resource type, then we break
// and mark the module as not valid for TFLite legalization.
// Note: this might disable native variables in more than needed cases.
// TODO(b/189370197): Enhance variable analysis.
for (auto operand : op->getOperands()) {
if (getElementTypeOrSelf(operand.getType()).isa<TF::ResourceType>()) {
legalize_to_tfl = false;
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
module->setAttr(kLegalizeTflVariables,
BoolAttr::get(context, legalize_to_tfl));
}
};
std::unique_ptr<OperationPass<ModuleOp>> CreateAnalyzeVariablesPass() {
return std::make_unique<AnalyzeVariablesPass>();
}
static PassRegistration<AnalyzeVariablesPass> pass([] {
return CreateAnalyzeVariablesPass();
});
} // namespace TFL
} // namespace mlir