Canonicalize FusedBatchNorm op to FusedBatchNormV3
PiperOrigin-RevId: 337394368
Change-Id: I7bd7f0513815e1b27584a1b6cba7c447a9d9c9a2
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 4e7c089..186c863 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -62,40 +62,6 @@
// LAYOUT: "tfl.conv_2d"
}
-
-func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
-^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
- // OK
- %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
- // Unsupported training
- %1:5 = "tf.FusedBatchNorm"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
- // Use other output
- %2:5 = "tf.FusedBatchNorm"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
-
- return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
-
-// CHECK-LABEL: fusedBatchNorm
-// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
-// variance + epsilon
-// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
-// rsqrt(variance + epsilon)
-// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]])
-// scale * rsqrt(variance + epsilon)
-// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]])
-// x * scale * rsqrt(variance + epsilon)
-// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]])
-// mean * scale * rsqrt(variance + epsilon)
-// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]])
-// offset - mean * scale * rsqrt(variance + epsilon)
-// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]])
-// x * scale * rsqrt(variance + epsilon) +
-// offset - mean * scale * rsqrt(variance + epsilon)
-// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])
-
-// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
-// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
-}
-
func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
// OK
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index ecca3d3..c4f30c2 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -740,31 +740,6 @@
}
};
-struct ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
- explicit ConvertFusedBatchNorm(MLIRContext *context)
- : OpRewritePattern<TF::FusedBatchNormOp>(context) {}
-
- LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
- PatternRewriter &rewriter) const override {
- auto new_result_types =
- llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
- // reserve_space_3
- new_result_types.push_back(
- UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
-
- OperationState new_state(tf_fused_batch_norm_op.getLoc(),
- TF::FusedBatchNormV3Op::getOperationName(),
- tf_fused_batch_norm_op.getOperands(),
- new_result_types,
- tf_fused_batch_norm_op.getAttrs());
- Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
-
- rewriter.replaceOp(tf_fused_batch_norm_op,
- tf_fused_batch_norm_op_v3->getResults().drop_back());
- return success();
- }
-};
-
// The below pattern is equivalent to the DRR rule below
// The checks are dependent on generated values, so we can't add
// the checks on intermediate values, ideally we should find equivalent
@@ -1202,7 +1177,6 @@
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat,
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
- patterns.insert<ConvertFusedBatchNorm>(ctx);
TFL::populateWithGenerated(ctx, patterns);
// TODO(karimnosseir): Split to separate pass probably after
// deciding on long term plan for this optimization.
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index aa087ec..aa1b7bb 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -3942,6 +3942,8 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+ let hasCanonicalizer = 1;
+
let verifier = [{
return Verify(*this);
}];
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index d7a2427..e9ccbed 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -2337,6 +2337,41 @@
}
//===----------------------------------------------------------------------===//
+// FusedBatchNormOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
+ using OpRewritePattern<FusedBatchNormOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
+ PatternRewriter &rewriter) const override {
+ auto new_result_types =
+ llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
+ // reserve_space_3
+ new_result_types.push_back(
+ UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
+
+ OperationState new_state(tf_fused_batch_norm_op.getLoc(),
+ TF::FusedBatchNormV3Op::getOperationName(),
+ tf_fused_batch_norm_op.getOperands(),
+ new_result_types,
+ tf_fused_batch_norm_op.getAttrs());
+ Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
+
+ rewriter.replaceOp(tf_fused_batch_norm_op,
+ tf_fused_batch_norm_op_v3->getResults().drop_back());
+ return success();
+ }
+};
+} // namespace.
+
+void FusedBatchNormOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ConvertFusedBatchNorm>(context);
+}
+
+//===----------------------------------------------------------------------===//
// UnpackOp
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index ea9820d..e77dd36 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -1284,3 +1284,10 @@
%0 = "tf.NonMaxSuppressionV3"(%arg0, %arg1, %max_size, %arg2, %arg3): (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>)
return %0 : tensor<2xi32>
}
+
+// CHECK-LABEL: testFusedBatchNormToBatchNormV3
+func @testFusedBatchNormToBatchNormV3(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
+ // CHECK: "tf.FusedBatchNormV3"
+ %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4): (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> )
+ return %0#0 : tensor<8x8x8x8xf32>
+}