Add type and shape constraints to TFLite Add op
PiperOrigin-RevId: 305618100
Change-Id: Ida8022e0b76e031905e3ab343aec9092e2dc62fd
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 789d06b..697b161 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -307,7 +307,7 @@
"transforms/optimize_functional_ops.cc",
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
- "transforms/runtime_type_verify.cc",
+ "transforms/runtime_verify.cc",
"transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc",
"transforms/while_loop_outline.cc",
diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
index 82d0589..2ed63fc 100644
--- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
+++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
@@ -36,7 +36,8 @@
form_clusters(false),
unfold_batch_matmul(true),
legalize_tf_while(true),
- shape_inference(true) {}
+ shape_inference(true),
+ runtime_verification(true) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops.
@@ -65,6 +66,8 @@
bool legalize_tf_while;
// Whether to do shape inference.
bool shape_inference;
+ // Whether to do TFLite runtime verification.
+ bool runtime_verification;
};
} // namespace TFL
diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc
index bc894d3..f1ed97c 100644
--- a/tensorflow/compiler/mlir/lite/converter_gen.cc
+++ b/tensorflow/compiler/mlir/lite/converter_gen.cc
@@ -441,7 +441,7 @@
mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName()
- << "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
+ << "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
"failure_on_operand_type_mismatch) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top");
@@ -466,6 +466,25 @@
"operand");
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
"result");
+
+ for (auto &trait : op.getTraits()) {
+ if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) {
+ continue;
+ }
+ if (trait.getDef().getValueAsString("trait") !=
+ "OpTrait::TFLRuntimeOpTrait") {
+ continue;
+ }
+
+ auto *val = trait.getDef().getValue("tflRuntimePredicate");
+ if (!val) continue;
+
+ mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
+ os << tgfmt(
+ " if (!($0)) {\n "
+ " return ::mlir::LogicalResult::Failure;\n }\n",
+ &verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
+ }
os << " return top.verify();\n}\n";
}
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
index b20e81a..ccad3cb 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
@@ -86,7 +86,7 @@
let methods = [
StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}],
- "LogicalResult", "VerifyTflRuntimeTypes",
+ "LogicalResult", "VerifyTflRuntimeConstraints",
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
>,
];
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 45efe8f..dc47c1e 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -46,6 +46,30 @@
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
namespace TFL {
+// Returns true when the given two types have the same shape or broadcastable
+// shape within the given rank. If any given shapes are non-static, this method
+// returns true.
+bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
+ int max_bcast_rank) {
+ // Ignore shape checking on the non-static shapes for model compatibility.
+ auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
+ if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
+ auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
+ if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
+
+ if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
+ return true;
+
+ SmallVector<int64_t, 4> result_shape;
+ if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
+ rhs_shaped_type.getShape(),
+ result_shape)) {
+ return false;
+ }
+ return lhs_shaped_type.getRank() <= max_bcast_rank &&
+ rhs_shaped_type.getRank() <= max_bcast_rank;
+}
+
//===----------------------------------------------------------------------===//
// TensorFlowLiteDialect
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 2bf6ca2..cb1f8c6 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -106,6 +106,22 @@
class DerivedTFLiteTypeAttr<code body> :
DerivedAttr<"tflite::TensorType", body>;
+// TFL Runtime op trait predicate.
+class TFL_RuntimePredOpTrait<string desc, Pred pred> :
+ GenInternalOpTrait<"TFLRuntimeOpTrait"> {
+ Pred tflRuntimePredicate = pred;
+ string tflRuntimeDescription = desc;
+}
+
+class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
+ int i, int j, int max_bcast_rank> :
+ TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
+ " have the same shape or broadcastable shapes within the rank " #
+ max_bcast_rank,
+ CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
+ "$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
+ ").getType(), " # max_bcast_rank # ")">>;
+
// These additional types/type constraints here are used to decouple the ops
// from runtime support for the ops. Prefer to use these types when defining
// new TF_Ops for uniformity.
@@ -360,10 +376,9 @@
let hasFolder = 1;
}
-def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
- NoSideEffect,
- Commutative,
- TFL_GpuTargetOp]> {
+def TFL_AddOp : TFL_Op<"add", [
+ TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
+ ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> {
let summary = "Addition operator";
let description = [{
@@ -371,11 +386,11 @@
}];
let arguments = (
- ins AnyTensor:$lhs,
- AnyTensor:$rhs,
+ ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
+ TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
TFL_AFAttr:$fused_activation_function);
- let results = (outs AnyTensor:$output);
+ let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
let hasFolder = 1;
diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
index a17cdda..6dd44e6 100644
--- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
+++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
@@ -285,7 +285,7 @@
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
- pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
+ pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
index 7e9b1bd..239f453 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
@@ -9,6 +9,20 @@
// CHECK: return
}
+// CHECK-LABEL: testAddHighDimsHaveSameShape
+func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
+ // CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
+ %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32>
+ return %0 : tensor<1x2x3x4x5x6x7x8xi32>
+}
+
+// CHECK-LABEL: testAddTooHighBroadcastableDims
+func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
+ // expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
+ %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
+ return %0 : tensor<1x2x3x4x5x6xi32>
+}
+
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
return %2: tensor<1xf32>
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
index bff289d..2cb269b 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
@@ -173,7 +173,8 @@
pass_manager->addPass(
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
- pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
+ pass_manager->addPass(
+ mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
// This pass operates on TensorFlow ops but is triggered after legalization
// so that it can target constants introduced once TensorFlow Identity ops
@@ -255,7 +256,8 @@
// TFLite dialect passes.
pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
- pm.addPass(mlir::TFL::CreateLegalizeTFPass());
+ pm.addPass(
+ mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
pm.addPass(mlir::TFL::CreateOptimizePass());
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
@@ -268,7 +270,7 @@
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
- pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
+ pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
}
// Registers a pass pipeline for the standard TFL passes.
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
index 038adeb..35f9b24 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
@@ -214,7 +214,7 @@
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
- pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
+ pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
std::string result;
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index 6a50ad4..90549ef 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -70,8 +70,21 @@
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
// Legalize operations in functions.
-struct LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
+class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
+ public:
+ LegalizeTF() = default;
+ LegalizeTF(const LegalizeTF&) {}
+ explicit LegalizeTF(bool run_tfl_runtime_verification) {
+ run_tfl_runtime_verification_ = run_tfl_runtime_verification;
+ }
+
+ /// Performs the lowering to TFLite dialect.
void runOnFunction() override;
+
+ private:
+ Option<bool> run_tfl_runtime_verification_{
+ *this, "run-tfl-runtime-verification",
+ llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)};
};
// Returns true if all tensor value in `values` has static shape and same shape.
@@ -741,13 +754,19 @@
// graph.
target.addLegalOp<mlir::ConstantOp>();
target.addLegalOp<ConstOp>();
- target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
- Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
- auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
- if (!tfl_op) return false;
- return succeeded(tfl_op.VerifyTflRuntimeTypes(
- tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
- }));
+ if (run_tfl_runtime_verification_) {
+ target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
+ Optional<ConversionTarget::DynamicLegalityCallbackFn>(
+ [](Operation* op) {
+ auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
+ if (!tfl_op) return false;
+ return succeeded(tfl_op.VerifyTflRuntimeConstraints(
+ tfl_op.getOperation(),
+ /*failure_on_operand_type_mismatch=*/false));
+ }));
+ } else {
+ target.addLegalDialect<TensorFlowLiteDialect>();
+ }
// Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Look if there is a function that tries until it converge.
@@ -763,8 +782,9 @@
} // namespace
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
-std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass() {
- return std::make_unique<LegalizeTF>();
+std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
+ bool run_tfl_runtime_verification) {
+ return std::make_unique<LegalizeTF>(run_tfl_runtime_verification);
}
static PassRegistration<LegalizeTF> pass(
diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h
index c86ac56..a744a57 100644
--- a/tensorflow/compiler/mlir/lite/transforms/passes.h
+++ b/tensorflow/compiler/mlir/lite/transforms/passes.h
@@ -30,7 +30,11 @@
class QuantizationSpecs;
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
-std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass();
+// When the given run_tfl_runtime_verification value is true, it will check each
+// TFL builtin op towards the TFL runtime capability and the incompatible TF ops
+// will be left in the graph without getting legalized.
+std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
+ bool run_tfl_runtime_verification);
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
@@ -91,8 +95,8 @@
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
-// Verifies runtime supports types used.
-std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass();
+// Verifies runtime constraints.
+std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
} // namespace TFL
diff --git a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc
similarity index 63%
rename from tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc
rename to tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc
index 3cb26a5..3268329 100644
--- a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc
@@ -22,34 +22,32 @@
namespace TFL {
namespace {
-// This pass verifies that the operands and results types are supported by
-// TFLite runtime.
-class RuntimeTypeVerifyPass
- : public mlir::PassWrapper<RuntimeTypeVerifyPass, FunctionPass> {
+// This pass verifies that the TFL ops meet the TFL runtime constraints.
+class RuntimeVerifyPass
+ : public mlir::PassWrapper<RuntimeVerifyPass, FunctionPass> {
public:
- explicit RuntimeTypeVerifyPass() {}
+ explicit RuntimeVerifyPass() {}
private:
void runOnFunction() override;
};
-void RuntimeTypeVerifyPass::runOnFunction() {
+void RuntimeVerifyPass::runOnFunction() {
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
- if (failed(op.VerifyTflRuntimeTypes(
- op.getOperation(),
- /*failure_on_operand_type_mismatch=*/true)))
+ if (failed(op.VerifyTflRuntimeConstraints(
+ op.getOperation(), /*failure_on_operand_type_mismatch=*/true)))
signalPassFailure();
});
}
} // namespace
-// Verifies runtime supports types used.
-std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass() {
- return std::make_unique<RuntimeTypeVerifyPass>();
+// Verifies TFL runtime constraints.
+std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass() {
+ return std::make_unique<RuntimeVerifyPass>();
}
-static PassRegistration<RuntimeTypeVerifyPass> pass(
- "tfl-runtime-verify", "TFLite runtime verification");
+static PassRegistration<RuntimeVerifyPass> pass("tfl-runtime-verify",
+ "TFLite runtime verification");
} // namespace TFL
} // namespace mlir
diff --git a/tensorflow/lite/testing/op_tests/hardswish.py b/tensorflow/lite/testing/op_tests/hardswish.py
index 2816fe5..97dad80 100644
--- a/tensorflow/lite/testing/op_tests/hardswish.py
+++ b/tensorflow/lite/testing/op_tests/hardswish.py
@@ -48,10 +48,17 @@
"""Make a set of tests to do hardswish."""
# Chose a set of parameters
- test_parameters = [{
- "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
- [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
- }]
+ if options.run_with_flex:
+ # Only Flex is able to execute on the data bigger than four dimension.
+ test_parameters = [{
+ "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
+ [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ }]
+ else:
+ test_parameters = [{
+ "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
+ [3, 15, 14, 3]],
+ }]
def build_graph(parameters):
inp = tf.compat.v1.placeholder(