Canonicalize FusedBatchNorm op to FusedBatchNormV3

PiperOrigin-RevId: 337394368
Change-Id: I7bd7f0513815e1b27584a1b6cba7c447a9d9c9a2
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 4e7c089..186c863 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -62,40 +62,6 @@
   // LAYOUT: "tfl.conv_2d"
 }
 
-
-func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
-^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
-  // OK
-  %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
-  // Unsupported training
-  %1:5 = "tf.FusedBatchNorm"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true}  : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
-  // Use other output
-  %2:5 = "tf.FusedBatchNorm"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
-
-  return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
-
-// CHECK-LABEL: fusedBatchNorm
-// CHECK:  %[[CONSTANT:.*]] = constant dense<1.000000e-03>
-//              variance + epsilon
-// CHECK:  %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
-//              rsqrt(variance + epsilon)
-// CHECK:  %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]])
-//              scale * rsqrt(variance + epsilon)
-// CHECK:  %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]])
-//              x * scale * rsqrt(variance + epsilon)
-// CHECK:  %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]])
-//              mean * scale * rsqrt(variance + epsilon)
-// CHECK:  %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]])
-//              offset - mean * scale * rsqrt(variance + epsilon)
-// CHECK:  %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]])
-//              x * scale * rsqrt(variance + epsilon) +
-//              offset - mean * scale * rsqrt(variance + epsilon)
-// CHECK:  %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])
-
-// CHECK:  %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
-// CHECK:  "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
-}
-
 func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
 ^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
   // OK
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index ecca3d3..c4f30c2 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -740,31 +740,6 @@
   }
 };
 
-struct ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
-  explicit ConvertFusedBatchNorm(MLIRContext *context)
-      : OpRewritePattern<TF::FusedBatchNormOp>(context) {}
-
-  LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
-                                PatternRewriter &rewriter) const override {
-    auto new_result_types =
-        llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
-    // reserve_space_3
-    new_result_types.push_back(
-        UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
-
-    OperationState new_state(tf_fused_batch_norm_op.getLoc(),
-                             TF::FusedBatchNormV3Op::getOperationName(),
-                             tf_fused_batch_norm_op.getOperands(),
-                             new_result_types,
-                             tf_fused_batch_norm_op.getAttrs());
-    Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
-
-    rewriter.replaceOp(tf_fused_batch_norm_op,
-                       tf_fused_batch_norm_op_v3->getResults().drop_back());
-    return success();
-  }
-};
-
 // The below pattern is equivalent to the DRR rule below
 // The checks are dependent on generated values, so we can't add
 // the checks on intermediate values, ideally we should find equivalent
@@ -1202,7 +1177,6 @@
   patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat,
                   ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
 
-  patterns.insert<ConvertFusedBatchNorm>(ctx);
   TFL::populateWithGenerated(ctx, patterns);
   // TODO(karimnosseir): Split to separate pass probably after
   // deciding on long term plan for this optimization.
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index aa087ec..aa1b7bb 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -3942,6 +3942,8 @@
 
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 
+  let hasCanonicalizer = 1;
+
   let verifier = [{
     return Verify(*this);
   }];
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index d7a2427..e9ccbed 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -2337,6 +2337,41 @@
 }
 
 //===----------------------------------------------------------------------===//
+// FusedBatchNormOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
+  using OpRewritePattern<FusedBatchNormOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
+                                PatternRewriter &rewriter) const override {
+    auto new_result_types =
+        llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
+    // reserve_space_3
+    new_result_types.push_back(
+        UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
+
+    OperationState new_state(tf_fused_batch_norm_op.getLoc(),
+                             TF::FusedBatchNormV3Op::getOperationName(),
+                             tf_fused_batch_norm_op.getOperands(),
+                             new_result_types,
+                             tf_fused_batch_norm_op.getAttrs());
+    Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
+
+    rewriter.replaceOp(tf_fused_batch_norm_op,
+                       tf_fused_batch_norm_op_v3->getResults().drop_back());
+    return success();
+  }
+};
+}  // namespace.
+
+void FusedBatchNormOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<ConvertFusedBatchNorm>(context);
+}
+
+//===----------------------------------------------------------------------===//
 // UnpackOp
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index ea9820d..e77dd36 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -1284,3 +1284,10 @@
   %0 = "tf.NonMaxSuppressionV3"(%arg0, %arg1, %max_size, %arg2, %arg3): (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>)
   return %0 : tensor<2xi32>
 }
+
+// CHECK-LABEL: testFusedBatchNormToBatchNormV3
+func @testFusedBatchNormToBatchNormV3(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
+  // CHECK: "tf.FusedBatchNormV3"
+  %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4): (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> )
+  return %0#0  : tensor<8x8x8x8xf32>
+}