blob: f4fb2a3f4526099df82726e5a2b4e18eb75b30db [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <set>
#include <string>
#include <utility>
#include "absl/base/attributes.h"
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
namespace mlir {
namespace TF {
namespace internal {
// The name prefix of Flex ops.
constexpr absl::string_view kFlexOpNamePrefix = "Flex";
// Don't fallback to Flex op if this attribute is set. This attribute is
// transient and is only used inside this pass. First, the pass looks for
// predefined patterns and set this attribute to ops in the patterns. Then,
// when parsing the function, if find ops with this attribute, the pass
// remove the attribute and skip further processing on those ops.
constexpr char kNoFallbackAttr[] = "no_fallback";
// TF Quantization modes. These constants are defined as char arrays so they
// can parsed by the pass option.
constexpr char kDefaultMode[] = "DEFAULT";
constexpr char kLegacyIntegerMode[] = "LEGACY_INTEGER";
// Checks if the operation is TF FakeQuant ops.
bool IsTfFakeQuantOp(Operation *op) {
return llvm::isa<
// clang-format off
TF::FakeQuantWithMinMaxArgsOp,
TF::FakeQuantWithMinMaxVarsOp,
TF::FakeQuantWithMinMaxVarsPerChannelOp
// clang-format on
>(op);
}
// Checks if the operation is allowlisted in both modes. These ops are not
// quantizable but is necessary to make the conversion successful.
bool IsAlwaysAllowlistedOp(Operation *op) {
return llvm::isa<
// clang-format off
// go/keep-sorted start
TF::ConstOp,
TF::IdentityOp,
TF::PartitionedCallOp,
TF::StatefulPartitionedCallOp
// go/keep-sorted end
// clang-format on
>(op);
}
// LINT.IfChange
// The list of quantizable ops in the Legacy Integer mode.
ABSL_ATTRIBUTE_NOINLINE const std::set<std::string>
&QuantizableOpsInLegacyMode() {
static const std::set<std::string> *legacy_op_list =
new std::set<std::string>({
// clang-format off
// go/keep-sorted start
TF::AbsOp::getOperationName().str(),
TF::AddOp::getOperationName().str(),
TF::AddV2Op::getOperationName().str(),
TF::ArgMaxOp::getOperationName().str(),
TF::AvgPoolOp::getOperationName().str(),
TF::BiasAddOp::getOperationName().str(),
TF::BucketizeOp::getOperationName().str(),
TF::ConcatV2Op::getOperationName().str(),
TF::Conv2DBackpropInputOp::getOperationName().str(),
TF::Conv2DOp::getOperationName().str(),
TF::DepthwiseConv2dNativeOp::getOperationName().str(),
TF::FusedBatchNormV3Op::getOperationName().str(),
TF::GatherV2Op::getOperationName().str(),
TF::MatMulOp::getOperationName().str(),
TF::MaxPoolOp::getOperationName().str(),
TF::MaximumOp::getOperationName().str(),
TF::MeanOp::getOperationName().str(),
TF::MinimumOp::getOperationName().str(),
TF::MulOp::getOperationName().str(),
TF::PadOp::getOperationName().str(),
TF::PadV2Op::getOperationName().str(),
TF::Relu6Op::getOperationName().str(),
TF::ReluOp::getOperationName().str(),
TF::ReshapeOp::getOperationName().str(),
TF::SoftmaxOp::getOperationName().str(),
TF::SubOp::getOperationName().str(),
TF::TransposeOp::getOperationName().str(),
// go/keep-sorted end
// clang-format on
});
return *legacy_op_list;
}
// The list of quantizable ops in the Default mode.
ABSL_ATTRIBUTE_NOINLINE const std::set<std::string>
&QuantizableOpsInDefaultMode() {
static const std::set<std::string> *default_op_list =
new std::set<std::string>({
// clang-format off
// go/keep-sorted start
TF::BiasAddOp::getOperationName().str(),
TF::Conv2DBackpropInputOp::getOperationName().str(),
TF::Conv2DOp::getOperationName().str(),
TF::DepthwiseConv2dNativeOp::getOperationName().str(),
TF::FusedBatchNormV3Op::getOperationName().str(),
TF::MatMulOp::getOperationName().str(),
TF::Relu6Op::getOperationName().str(),
TF::ReluOp::getOperationName().str(),
// go/keep-sorted end
// clang-format on
});
return *default_op_list;
}
// LINT.ThenChange(Google-internal path)
// Checks if the operation can be fused with bias.
inline bool IsFusibleWithBiasOp(Operation *op) {
return llvm::isa<
// clang-format off
TF::MatMulOp,
TF::Conv2DOp,
TF::DepthwiseConv2dNativeOp,
TF::Conv2DBackpropInputOp,
TF::Conv3DOp,
TF::Conv3DBackpropInputV2Op
// clang-format on
>(op);
}
// Creates the custom option of the Flex ops.
inline void CreateFlexOpCustomOptions(const std::string &op_name,
const std::string &node_def_str,
std::string &custom_option_buffer) {
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
flex_builder->Vector([&]() {
flex_builder->String(op_name);
flex_builder->String(node_def_str);
});
flex_builder->Finish();
custom_option_buffer.assign(flex_builder->GetBuffer().begin(),
flex_builder->GetBuffer().end());
}
// Creates ElementsAttr for custom option.
inline OpaqueElementsAttr CustomOptionForFlexOp(OpBuilder *builder,
const std::string &content) {
ShapedType type = RankedTensorType::get(
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
type,
StringRef(content.data(), content.size()));
}
// Fallbacks ops that are not supported by TF Quantization to TFLite Flex ops.
class FallbackToFlexOps
: public PassWrapper<FallbackToFlexOps, OperationPass<func::FuncOp>> {
public:
FallbackToFlexOps() {}
explicit FallbackToFlexOps(const std::string &mode) { mode_ = mode; }
FallbackToFlexOps(const FallbackToFlexOps &other) { mode_ = other.mode_; }
void runOnOperation() override;
StringRef getArgument() const final { return "quant-raise-flex-fallback"; }
StringRef getDescription() const final {
return "Fallback TF-Quantization-unsupported ops to TFLite Flex ops.";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TFL::TensorFlowLiteDialect>();
}
private:
// The mode of TF Quantization, might indicate different users/devices.
Option<std::string> mode_{*this, "mode",
llvm::cl::desc("The mode of TF Quantization."),
llvm::cl::init("")};
// Checks if the operation is allowlisted in the current mode.
bool IsAllowListedOp(Operation *op) {
std::string op_name = op->getName().getStringRef().str();
if (IsAlwaysAllowlistedOp(op) || IsTfFakeQuantOp(op)) {
return true;
} else if (mode_ == kDefaultMode) {
return QuantizableOpsInDefaultMode().count(op_name) > 0;
} else if (mode_ == kLegacyIntegerMode) {
return QuantizableOpsInLegacyMode().count(op_name) > 0;
} else {
mlir::emitError(getOperation().getLoc(), "Unregconized mode: " + mode_);
signalPassFailure();
return true;
}
}
// Converts the operation to a TFLite Flex op.
bool ConvertToFlexOp(Operation *op);
};
bool FallbackToFlexOps::ConvertToFlexOp(Operation *op) {
tensorflow::StatusOr<std::unique_ptr<tensorflow::NodeDef>> node_def =
tensorflow::ConvertTFDialectOpToNodeDef(
op, /*name=*/"", /*ignore_unregistered_attrs=*/true);
if (!node_def.ok()) {
op->emitError("Failed to obtain TensorFlow NodeDef: " +
node_def.status().ToString());
return false;
}
std::string node_def_str;
if (!(*node_def)->SerializeToString(&node_def_str)) {
op->emitError("Failed to serialize tensorflow NodeDef");
return false;
}
std::string op_name = (*node_def)->op();
OpBuilder builder(op);
std::string flex_op_name = std::string(kFlexOpNamePrefix) + op_name;
std::string custom_option_buffer;
CreateFlexOpCustomOptions(op_name, node_def_str, custom_option_buffer);
auto flex_op = builder.create<TFL::CustomOp>(
op->getLoc(), op->getResultTypes(), op->getOperands(), flex_op_name,
CustomOptionForFlexOp(&builder, custom_option_buffer));
op->replaceAllUsesWith(flex_op);
op->erase();
return true;
}
// Sets the "no_fallback" attribute.
Value SetNoFallbackAttr(PatternRewriter &rewriter, Value val) {
val.getDefiningOp()->setAttr(kNoFallbackAttr, rewriter.getUnitAttr());
return val;
}
// Returns true if the attr is a float attribute and be equal to value.
static bool FloatValueEquals(const Attribute &attr, double value) {
auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
if (fp_attr == nullptr) return false;
if (fp_attr.isSplat()) {
return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
}
return llvm::all_of(fp_attr.getValues<APFloat>(), [value](const APFloat &f) {
return f.isExactlyValue(value);
});
}
// Returns true if the rank of the value equals to the given rank.
bool RankEquals(Value value, int rank) {
auto rank_type = value.getType().template dyn_cast<RankedTensorType>();
return (rank_type && rank_type.getRank() == rank);
}
#include "tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_patterns.inc"
void FallbackToFlexOps::runOnOperation() {
if (mode_.empty()) return;
func::FuncOp func = getOperation();
MLIRContext *ctx = &getContext();
// Convert binary ops to BiasAdd ops if possible.
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
// Convert unsupported ops to Flex ops.
auto tf_dialect = ctx->getLoadedDialect<TF::TensorFlowDialect>();
func.walk([&](Operation *op) {
if (op->getDialect() != tf_dialect) return;
if (IsAllowListedOp(op)) return;
if (op->hasAttr(kNoFallbackAttr)) {
op->removeAttr(kNoFallbackAttr);
return;
}
if (!ConvertToFlexOp(op)) signalPassFailure();
});
}
} // namespace internal
std::unique_ptr<OperationPass<func::FuncOp>> CreateFallbackToFlexOpsPass(
const std::string &mode) {
return std::make_unique<internal::FallbackToFlexOps>(mode);
}
static PassRegistration<internal::FallbackToFlexOps> pass([] {
return CreateFallbackToFlexOpsPass(/*mode=*/internal::kDefaultMode);
});
} // namespace TF
} // namespace mlir