Add type and shape constraints to TFLite Add op

PiperOrigin-RevId: 305618100
Change-Id: Ida8022e0b76e031905e3ab343aec9092e2dc62fd
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 789d06b..697b161 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -307,7 +307,7 @@
         "transforms/optimize_functional_ops.cc",
         "transforms/prepare_composite_functions_tf.cc",
         "transforms/prepare_tf.cc",
-        "transforms/runtime_type_verify.cc",
+        "transforms/runtime_verify.cc",
         "transforms/split_merged_operands.cc",
         "transforms/trim_functions_tf.cc",
         "transforms/while_loop_outline.cc",
diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
index 82d0589..2ed63fc 100644
--- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
+++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
@@ -36,7 +36,8 @@
         form_clusters(false),
         unfold_batch_matmul(true),
         legalize_tf_while(true),
-        shape_inference(true) {}
+        shape_inference(true),
+        runtime_verification(true) {}
 
   // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
   // added, which produces TF Lite ops.
@@ -65,6 +66,8 @@
   bool legalize_tf_while;
   // Whether to do shape inference.
   bool shape_inference;
+  // Whether to do TFLite runtime verification.
+  bool runtime_verification;
 };
 
 }  // namespace TFL
diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc
index bc894d3..f1ed97c 100644
--- a/tensorflow/compiler/mlir/lite/converter_gen.cc
+++ b/tensorflow/compiler/mlir/lite/converter_gen.cc
@@ -441,7 +441,7 @@
 
     mlir::tblgen::FmtContext verify_ctx;
     os << "::mlir::LogicalResult " << op.getCppClassName()
-       << "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
+       << "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
           "failure_on_operand_type_mismatch) {\n";
     os << "  auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
     verify_ctx.withOp("top");
@@ -466,6 +466,25 @@
                              "operand");
     GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
                              "result");
+
+    for (auto &trait : op.getTraits()) {
+      if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) {
+        continue;
+      }
+      if (trait.getDef().getValueAsString("trait") !=
+          "OpTrait::TFLRuntimeOpTrait") {
+        continue;
+      }
+
+      auto *val = trait.getDef().getValue("tflRuntimePredicate");
+      if (!val) continue;
+
+      mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
+      os << tgfmt(
+          "  if (!($0)) {\n    "
+          "    return ::mlir::LogicalResult::Failure;\n  }\n",
+          &verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
+    }
     os << "  return top.verify();\n}\n";
   }
 
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
index b20e81a..ccad3cb 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
@@ -86,7 +86,7 @@
   let methods = [
     StaticInterfaceMethod<
       [{Returns whether the op's operands/results are supported by runtime.}],
-      "LogicalResult", "VerifyTflRuntimeTypes",
+      "LogicalResult", "VerifyTflRuntimeConstraints",
       (ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
     >,
   ];
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 45efe8f..dc47c1e 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -46,6 +46,30 @@
 #include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
 namespace TFL {
 
+// Returns true when the given two types have the same shape or broadcastable
+// shape within the given rank. If any given shapes are non-static, this method
+// returns true.
+bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
+                                                        int max_bcast_rank) {
+  // Ignore shape checking on the non-static shapes for model compatibility.
+  auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
+  if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
+  auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
+  if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
+
+  if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
+    return true;
+
+  SmallVector<int64_t, 4> result_shape;
+  if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
+                                          rhs_shaped_type.getShape(),
+                                          result_shape)) {
+    return false;
+  }
+  return lhs_shaped_type.getRank() <= max_bcast_rank &&
+         rhs_shaped_type.getRank() <= max_bcast_rank;
+}
+
 //===----------------------------------------------------------------------===//
 // TensorFlowLiteDialect
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 2bf6ca2..cb1f8c6 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -106,6 +106,22 @@
 class DerivedTFLiteTypeAttr<code body> :
   DerivedAttr<"tflite::TensorType", body>;
 
+// TFL Runtime op trait predicate.
+class TFL_RuntimePredOpTrait<string desc, Pred pred> :
+    GenInternalOpTrait<"TFLRuntimeOpTrait"> {
+  Pred tflRuntimePredicate = pred;
+  string tflRuntimeDescription = desc;
+}
+
+class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
+    int i, int j, int max_bcast_rank> :
+  TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
+      " have the same shape or broadcastable shapes within the rank " #
+      max_bcast_rank,
+    CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
+              "$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
+              ").getType(), " # max_bcast_rank # ")">>;
+
 // These additional types/type constraints here are used to decouple the ops
 // from runtime support for the ops. Prefer to use these types when defining
 // new TF_Ops for uniformity.
@@ -360,10 +376,9 @@
   let hasFolder = 1;
 }
 
-def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
-                               NoSideEffect,
-                               Commutative,
-                               TFL_GpuTargetOp]> {
+def TFL_AddOp : TFL_Op<"add", [
+    TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
+    ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> {
   let summary = "Addition operator";
 
   let description = [{
@@ -371,11 +386,11 @@
   }];
 
   let arguments = (
-    ins AnyTensor:$lhs,
-    AnyTensor:$rhs,
+    ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
+    TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
     TFL_AFAttr:$fused_activation_function);
 
-  let results = (outs AnyTensor:$output);
+  let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
 
   let hasFolder = 1;
 
diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
index a17cdda..6dd44e6 100644
--- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
+++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
@@ -285,7 +285,7 @@
   if (pass_config.legalize_tf_while) {
     pm.addPass(mlir::TFL::CreateWhileOutlinePass());
   }
-  pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
+  pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
 
   auto status = ConvertTFExecutorToTFLOrFlatbuffer(
       module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
index 7e9b1bd..239f453 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
@@ -9,6 +9,20 @@
 // CHECK:  return
 }
 
+// CHECK-LABEL: testAddHighDimsHaveSameShape
+func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
+  // CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
+  %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32>
+  return %0 : tensor<1x2x3x4x5x6x7x8xi32>
+}
+
+// CHECK-LABEL: testAddTooHighBroadcastableDims
+func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
+  // expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
+  %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
+  return %0 : tensor<1x2x3x4x5x6xi32>
+}
+
 func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
   %2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
   return %2: tensor<1xf32>
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
index bff289d..2cb269b 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
@@ -173,7 +173,8 @@
     pass_manager->addPass(
         mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
     pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
-    pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
+    pass_manager->addPass(
+        mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
     pass_manager->addPass(mlir::TFL::CreateOptimizePass());
     // This pass operates on TensorFlow ops but is triggered after legalization
     // so that it can target constants introduced once TensorFlow Identity ops
@@ -255,7 +256,8 @@
   // TFLite dialect passes.
   pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
-  pm.addPass(mlir::TFL::CreateLegalizeTFPass());
+  pm.addPass(
+      mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
   pm.addPass(mlir::TFL::CreateOptimizePass());
   pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
 
@@ -268,7 +270,7 @@
 
   pm.addPass(mlir::TFL::CreateWhileOutlinePass());
 
-  pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
+  pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
 }
 
 // Registers a pass pipeline for the standard TFL passes.
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
index 038adeb..35f9b24 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
@@ -214,7 +214,7 @@
   if (pass_config.legalize_tf_while) {
     pm.addPass(mlir::TFL::CreateWhileOutlinePass());
   }
-  pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
+  pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
 
   std::string result;
   auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index 6a50ad4..90549ef 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -70,8 +70,21 @@
 constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
 
 // Legalize operations in functions.
-struct LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
+class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
+ public:
+  LegalizeTF() = default;
+  LegalizeTF(const LegalizeTF&) {}
+  explicit LegalizeTF(bool run_tfl_runtime_verification) {
+    run_tfl_runtime_verification_ = run_tfl_runtime_verification;
+  }
+
+  /// Performs the lowering to TFLite dialect.
   void runOnFunction() override;
+
+ private:
+  Option<bool> run_tfl_runtime_verification_{
+      *this, "run-tfl-runtime-verification",
+      llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)};
 };
 
 // Returns true if all tensor value in `values` has static shape and same shape.
@@ -741,13 +754,19 @@
   // graph.
   target.addLegalOp<mlir::ConstantOp>();
   target.addLegalOp<ConstOp>();
-  target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
-      Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
-        auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
-        if (!tfl_op) return false;
-        return succeeded(tfl_op.VerifyTflRuntimeTypes(
-            tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
-      }));
+  if (run_tfl_runtime_verification_) {
+    target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
+        Optional<ConversionTarget::DynamicLegalityCallbackFn>(
+            [](Operation* op) {
+              auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
+              if (!tfl_op) return false;
+              return succeeded(tfl_op.VerifyTflRuntimeConstraints(
+                  tfl_op.getOperation(),
+                  /*failure_on_operand_type_mismatch=*/false));
+            }));
+  } else {
+    target.addLegalDialect<TensorFlowLiteDialect>();
+  }
   // Keep trying to convert.
   // TODO(karimnosseir): This is similar to what apply greedy patterns does.
   // Look if there is a function that tries until it converge.
@@ -763,8 +782,9 @@
 }  // namespace
 
 // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
-std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass() {
-  return std::make_unique<LegalizeTF>();
+std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
+    bool run_tfl_runtime_verification) {
+  return std::make_unique<LegalizeTF>(run_tfl_runtime_verification);
 }
 
 static PassRegistration<LegalizeTF> pass(
diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h
index c86ac56..a744a57 100644
--- a/tensorflow/compiler/mlir/lite/transforms/passes.h
+++ b/tensorflow/compiler/mlir/lite/transforms/passes.h
@@ -30,7 +30,11 @@
 class QuantizationSpecs;
 
 // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
-std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass();
+// When the given run_tfl_runtime_verification value is true, it will check each
+// TFL builtin op towards the TFL runtime capability and the incompatible TF ops
+// will be left in the graph without getting legalized.
+std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
+    bool run_tfl_runtime_verification);
 
 // Creates an instance of the TensorFlow Lite dialect Optimize pass.
 std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
@@ -91,8 +95,8 @@
 // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
 std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
 
-// Verifies runtime supports types used.
-std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass();
+// Verifies runtime constraints.
+std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
 
 }  // namespace TFL
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc
similarity index 63%
rename from tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc
rename to tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc
index 3cb26a5..3268329 100644
--- a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc
@@ -22,34 +22,32 @@
 namespace TFL {
 namespace {
 
-// This pass verifies that the operands and results types are supported by
-// TFLite runtime.
-class RuntimeTypeVerifyPass
-    : public mlir::PassWrapper<RuntimeTypeVerifyPass, FunctionPass> {
+// This pass verifies that the TFL ops meet the TFL runtime constraints.
+class RuntimeVerifyPass
+    : public mlir::PassWrapper<RuntimeVerifyPass, FunctionPass> {
  public:
-  explicit RuntimeTypeVerifyPass() {}
+  explicit RuntimeVerifyPass() {}
 
  private:
   void runOnFunction() override;
 };
 
-void RuntimeTypeVerifyPass::runOnFunction() {
+void RuntimeVerifyPass::runOnFunction() {
   getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
-    if (failed(op.VerifyTflRuntimeTypes(
-            op.getOperation(),
-            /*failure_on_operand_type_mismatch=*/true)))
+    if (failed(op.VerifyTflRuntimeConstraints(
+            op.getOperation(), /*failure_on_operand_type_mismatch=*/true)))
       signalPassFailure();
   });
 }
 }  // namespace
 
-// Verifies runtime supports types used.
-std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass() {
-  return std::make_unique<RuntimeTypeVerifyPass>();
+// Verifies TFL runtime constraints.
+std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass() {
+  return std::make_unique<RuntimeVerifyPass>();
 }
 
-static PassRegistration<RuntimeTypeVerifyPass> pass(
-    "tfl-runtime-verify", "TFLite runtime verification");
+static PassRegistration<RuntimeVerifyPass> pass("tfl-runtime-verify",
+                                                "TFLite runtime verification");
 
 }  // namespace TFL
 }  // namespace mlir
diff --git a/tensorflow/lite/testing/op_tests/hardswish.py b/tensorflow/lite/testing/op_tests/hardswish.py
index 2816fe5..97dad80 100644
--- a/tensorflow/lite/testing/op_tests/hardswish.py
+++ b/tensorflow/lite/testing/op_tests/hardswish.py
@@ -48,10 +48,17 @@
   """Make a set of tests to do hardswish."""
 
   # Chose a set of parameters
-  test_parameters = [{
-      "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
-                      [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
-  }]
+  if options.run_with_flex:
+    # Only Flex is able to execute on the data bigger than four dimension.
+    test_parameters = [{
+        "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
+                        [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+    }]
+  else:
+    test_parameters = [{
+        "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
+                        [3, 15, 14, 3]],
+    }]
 
   def build_graph(parameters):
     inp = tf.compat.v1.placeholder(