Update TensorFlow MLIR ODS tf.FusedBatchNormGradV3 and tf.FusedBatchNormV3 to support NDHWC and NCDHW data formats.
This is in preparation of https://github.com/tensorflow/tensorflow/pull/42970 where FusedBatchNormGradV3 and FusedBatchNormV3 op defs are updated to support 5D data formats.
PiperOrigin-RevId: 334835832
Change-Id: I04396665beacedd7358ab169606f01b996b00389
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 1b73c5f..eb60f47 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -3929,7 +3929,7 @@
TF_Float32Tensor:$reserve_space_3,
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
- DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
+ DefaultValuedAttr<TF_AnyStrAttrOf<["NHWC", "NCHW", "NDHWC", "NCDHW"]>, "NHWC">:$data_format,
DefaultValuedAttr<BoolAttr, "true">:$is_training
);
@@ -4014,7 +4014,7 @@
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
- DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
+ DefaultValuedAttr<TF_AnyStrAttrOf<["NHWC", "NCHW", "NDHWC", "NCDHW"]>, "NHWC">:$data_format,
DefaultValuedAttr<BoolAttr, "true">:$is_training
);