[TF:MLIR] Implement optimal layout assignment for FusedBatchNormV3
PiperOrigin-RevId: 304046758
Change-Id: I2f60d09a8308a9453df3b4a031972bcac4300ebf
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 02dcdcb..d6eb3b7 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -2648,7 +2648,7 @@
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
}
-def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
+def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
let summary = "Batch normalization.";
let description = [{
@@ -2685,6 +2685,10 @@
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
+
+ // TF_LayoutSensitiveInterface:
+ StringRef GetOptimalLayout(const RuntimeDevices& devices);
+ LogicalResult UpdateDataFormat(StringRef data_format);
}];
}
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index 3622a63..9e9ba63 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -1547,9 +1547,36 @@
LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
ArrayRef<int64_t> permutation) {
+ // FusedBatchNorm in training mode is a layout sentitive operation, and should
+ // have already assigned an optimal data format.
+ if (is_training()) return failure();
+
return ::mlir::TF::FoldOperandsPermutation(permutation, this);
}
+LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) {
+ return ::mlir::TF::UpdateDataFormat(data_format, this);
+}
+
+StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) {
+ // In inference mode FusedBatchNorm is not sensitive to data layout.
+ if (!is_training()) return data_format();
+
+ // Keep current data format if no GPUs are available or if explicit placement
+ // does not allow to use GPU for this operation.
+ if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
+ return data_format();
+
+ // For f16 data type on devices with Tensor Cores support NHWC data format
+ // is up to ~2x faster.
+ auto x_ty = x().getType().cast<TensorType>();
+ const bool is_f16 = x_ty.getElementType().isF16();
+ if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
+
+ // For all other data types prefer NCHW.
+ return "NCHW";
+}
+
//===----------------------------------------------------------------------===//
// GatherV2Op
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir
index 3839b00..852ecfa 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir
@@ -62,4 +62,54 @@
return %0 : tensor<1x28x28x64xf16>
}
+// CHECK-LABEL: func @transposeFusedBatchNormV3_f32
+func @transposeFusedBatchNormV3_f32(
+ %arg0: tensor<1x28x28x64xf32>,
+ %arg1: tensor<64xf32>
+) -> tensor<1x28x28x64xf32> {
+
+ // CHECK: "tf.FusedBatchNormV3"
+ // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
+ // CHECK-SAME: data_format = "NCHW"
+ %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
+ = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
+ {
+ data_format = "NHWC",
+ epsilon = 1.001 : f32,
+ exponential_avg_factor = 1.0 : f32,
+ is_training = true
+ }
+ : (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>)
+ -> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
+
+ return %y : tensor<1x28x28x64xf32>
+}
+
+// CHECK-LABEL: func @transposeFusedBatchNormV3_f16
+func @transposeFusedBatchNormV3_f16(
+ %arg0: tensor<1x28x28x64xf16>,
+ %arg1: tensor<64xf32>
+) -> tensor<1x28x28x64xf16> {
+
+ // CHECK: "tf.FusedBatchNormV3"
+ // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
+ // CHECK-SAME: data_format = "NCHW"
+ %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
+ = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
+ {
+ data_format = "NHWC",
+ epsilon = 1.001 : f32,
+ exponential_avg_factor = 1.0 : f32,
+ is_training = true
+ }
+ : (tensor<1x28x28x64xf16>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>)
+ -> (tensor<1x28x28x64xf16>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
+
+ return %y : tensor<1x28x28x64xf16>
+}
+
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir
index b52ef1c..5358438 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir
@@ -143,4 +143,54 @@
return %0 : tensor<1x64x28x28xf16>
}
+// CHECK-LABEL: func @transposeFusedBatchNormV3_f32
+func @transposeFusedBatchNormV3_f32(
+ %arg0: tensor<1x28x28x64xf32>,
+ %arg1: tensor<64xf32>
+) -> tensor<1x28x28x64xf32> {
+
+ // CHECK: "tf.FusedBatchNormV3"
+ // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
+ // CHECK-SAME: data_format = "NCHW"
+ %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
+ = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
+ {
+ data_format = "NHWC",
+ epsilon = 1.001 : f32,
+ exponential_avg_factor = 1.0 : f32,
+ is_training = true
+ }
+ : (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>)
+ -> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
+
+ return %y : tensor<1x28x28x64xf32>
+}
+
+// CHECK-LABEL: func @transposeFusedBatchNormV3_f16
+func @transposeFusedBatchNormV3_f16(
+ %arg0: tensor<1x64x28x28xf16>,
+ %arg1: tensor<64xf32>
+) -> tensor<1x64x28x28xf16> {
+
+ // CHECK: "tf.FusedBatchNormV3"
+ // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
+ // CHECK-SAME: data_format = "NHWC"
+ %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
+ = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
+ {
+ data_format = "NCHW",
+ epsilon = 1.001 : f32,
+ exponential_avg_factor = 1.0 : f32,
+ is_training = true
+ }
+ : (tensor<1x64x28x28xf16>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>)
+ -> (tensor<1x64x28x28xf16>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
+
+ return %y : tensor<1x64x28x28xf16>
+}
+
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir
index 22be653..e8f0c60 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir
@@ -146,3 +146,40 @@
return %0 : tensor<1x32x32x3xf32>
}
+
+// CHECK-LABEL: func @transposeFusedBatchNormV3
+func @transposeFusedBatchNormV3(
+ %arg0: tensor<1x28x28x64xf32>,
+ %arg1: tensor<64xf32>
+) -> tensor<1x28x28x64xf32> {
+
+ // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"()
+ // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
+ // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
+
+ // CHECK: "tf.FusedBatchNormV3"
+ // CHECK-SAME: (%[[ARG_TRANSPOSE]], %arg1, %arg1, %arg1, %arg1)
+ // CHECK-SAME: data_format = "NCHW"
+ // CHECK-SAME: (tensor<1x64x28x28xf32>, tensor<64xf32>,
+ // CHECK-SAME: -> (tensor<1x64x28x28xf32>, tensor<64xf32>,
+
+ // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"()
+ // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
+ // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
+ // CHECK: return %[[RES_TRANSPOSE]]
+
+ %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
+ = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
+ {
+ data_format = "NHWC",
+ epsilon = 1.001 : f32,
+ exponential_avg_factor = 1.0 : f32,
+ is_training = true
+ }
+ : (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>)
+ -> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
+
+ return %y : tensor<1x28x28x64xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
index e27448e..e6b3bf0 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
@@ -33,3 +33,40 @@
return %0 : tensor<1x8x32x32xf32>
}
+
+// CHECK-LABEL: func @transposeFusedBatchNormV3
+func @transposeFusedBatchNormV3(
+ %arg0: tensor<1x64x28x28xf32>,
+ %arg1: tensor<64xf32>
+) -> tensor<1x64x28x28xf32> {
+
+ // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"()
+ // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
+ // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
+
+ // CHECK: "tf.FusedBatchNormV3"
+ // CHECK-SAME: (%[[ARG_TRANSPOSE]], %arg1, %arg1, %arg1, %arg1)
+ // CHECK-SAME: data_format = "NHWC"
+ // CHECK-SAME: (tensor<1x28x28x64xf32>, tensor<64xf32>,
+ // CHECK-SAME: -> (tensor<1x28x28x64xf32>, tensor<64xf32>,
+
+ // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"()
+ // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
+ // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
+ // CHECK: return %[[RES_TRANSPOSE]]
+
+ %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
+ = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
+ {
+ data_format = "NCHW",
+ epsilon = 1.001 : f32,
+ exponential_avg_factor = 1.0 : f32,
+ is_training = true
+ }
+ : (tensor<1x64x28x28xf32>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>)
+ -> (tensor<1x64x28x28xf32>, tensor<64xf32>, tensor<64xf32>,
+ tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
+
+ return %y : tensor<1x64x28x28xf32>
+}