Internal change
PiperOrigin-RevId: 334901840
Change-Id: I5b3673b270c3731b7644132cef8ee93baa70dbab
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 2d30d41..c54418b 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1121,17 +1121,8 @@
}
Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
- string data_format_str;
- TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
- TensorFormat data_format;
- if (!FormatFromString(data_format_str, &data_format)) {
- return errors::InvalidArgument("Invalid data format string: ",
- data_format_str);
- }
- const int rank =
- (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
ShapeHandle x;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
bool is_training;
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
@@ -1140,8 +1131,14 @@
exponential_avg_factor = 1.0f; // default value
}
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
-
- int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
+ string data_format_str;
+ TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
+ TensorFormat data_format;
+ if (!FormatFromString(data_format_str, &data_format)) {
+ return errors::InvalidArgument("Invalid data format string: ",
+ data_format_str);
+ }
+ int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
// covers scale, offset, and if is_training is false, mean, variance
@@ -1194,6 +1191,13 @@
}
Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
+ ShapeHandle y_backprop;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
+ ShapeHandle x;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
+
+ bool is_training;
+ TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
string data_format_str;
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
TensorFormat data_format;
@@ -1201,17 +1205,7 @@
return errors::InvalidArgument("Invalid data format string: ",
data_format_str);
}
- const int rank =
- (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
- ShapeHandle y_backprop;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
- ShapeHandle x;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
-
- bool is_training;
- TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
-
- int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
+ int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
TF_RETURN_IF_ERROR(
c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
index 3c466ed..10253f1 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
@@ -670,25 +670,7 @@
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
- const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
- const auto& shape = output_shape_attr->list().shape(0);
- const int rank = shape.dim_size();
- std::string src_format = context->src_format;
- std::string dst_format = context->dst_format;
- // Update the format from 4D to 5D layout if necessary.
- bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
- if (allow_5d) {
- std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
- std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
- context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
- dst_format_3d);
- }
- if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
- // Change back to the original layout due to early exit.
- if (allow_5d) {
- context->AssignDeviceAndDataFormats(context->target_device, src_format,
- dst_format);
- }
+ if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@@ -697,11 +679,6 @@
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
- // Change back the format from 5D to 4D layout.
- if (allow_5d) {
- context->AssignDeviceAndDataFormats(context->target_device, src_format,
- dst_format);
- }
return context->graph_view->GetMutationBuilder()->Apply();
}
@@ -904,26 +881,8 @@
Status FusedBatchNormGradTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsFusedBatchNormGrad(*node->node()));
- const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
- const auto& shape = output_shape_attr->list().shape(0);
- const int rank = shape.dim_size();
- std::string src_format = context->src_format;
- std::string dst_format = context->dst_format;
- // Update the format from 4D to 5D layout if necessary.
- bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
- if (allow_5d) {
- std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
- std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
- context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
- dst_format_3d);
- }
- if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) ||
+ if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsTraining(*node)) {
- // Change back to the original layout due to early exit.
- if (allow_5d) {
- context->AssignDeviceAndDataFormats(context->target_device, src_format,
- dst_format);
- }
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@@ -933,11 +892,6 @@
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
- // Change back the format from 5D to 4D layout.
- if (allow_5d) {
- context->AssignDeviceAndDataFormats(context->target_device, src_format,
- dst_format);
- }
return context->graph_view->GetMutationBuilder()->Apply();
}
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index db528da..115428f 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -1438,41 +1438,29 @@
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status;
- string x_format = fused_node.attr().at(kDataFormat).s();
- if (x_format == "NCHW" or x_format == "NCDHW") {
+ if (fused_node.attr().at(kDataFormat).s() == "NCHW") {
// Need to reshape the last 4 inputs
NodeDef new_shape;
const string new_shape_name =
- AddPrefixToNodeName(x_format + "Shape", fused_node.name());
+ AddPrefixToNodeName("NCHWShape", fused_node.name());
new_shape.set_name(new_shape_name);
new_shape.set_op("Const");
new_shape.set_device(fused_node.device());
*new_shape.add_input() = AsControlDependency(scale);
(*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
- if (x_format == "NCHW") {
- Tensor t(DT_INT32, {4});
- t.flat<int32>()(0) = 1;
- t.flat<int32>()(1) = -1;
- t.flat<int32>()(2) = 1;
- t.flat<int32>()(3) = 1;
- t.AsProtoTensorContent(
- (*new_shape.mutable_attr())["value"].mutable_tensor());
- } else {
- Tensor t(DT_INT32, {5});
- t.flat<int32>()(0) = 1;
- t.flat<int32>()(1) = -1;
- t.flat<int32>()(2) = 1;
- t.flat<int32>()(3) = 1;
- t.flat<int32>()(4) = 1;
- t.AsProtoTensorContent(
- (*new_shape.mutable_attr())["value"].mutable_tensor());
- }
+ Tensor t(DT_INT32, {4});
+ t.flat<int32>()(0) = 1;
+ t.flat<int32>()(1) = -1;
+ t.flat<int32>()(2) = 1;
+ t.flat<int32>()(3) = 1;
+ t.AsProtoTensorContent(
+ (*new_shape.mutable_attr())["value"].mutable_tensor());
mutation->AddNode(std::move(new_shape), &status);
TF_RETURN_IF_ERROR(status);
NodeDef reshaped_scale;
reshaped_scale.set_name(
- AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name()));
+ AddPrefixToNodeName("NCHWShapedScale", fused_node.name()));
reshaped_scale.set_op("Reshape");
reshaped_scale.set_device(fused_node.device());
*reshaped_scale.add_input() = scale;
@@ -1485,7 +1473,7 @@
NodeDef reshaped_offset;
reshaped_offset.set_name(
- AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name()));
+ AddPrefixToNodeName("NCHWShapedOffset", fused_node.name()));
reshaped_offset.set_op("Reshape");
reshaped_offset.set_device(fused_node.device());
*reshaped_offset.add_input() = offset;
@@ -1498,7 +1486,7 @@
NodeDef reshaped_mean;
reshaped_mean.set_name(
- AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name()));
+ AddPrefixToNodeName("NCHWShapedMean", fused_node.name()));
reshaped_mean.set_op("Reshape");
reshaped_mean.set_device(fused_node.device());
*reshaped_mean.add_input() = mean;
@@ -1511,7 +1499,7 @@
NodeDef reshaped_variance;
reshaped_variance.set_name(
- AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name()));
+ AddPrefixToNodeName("NCHWShapedVariance", fused_node.name()));
reshaped_variance.set_op("Reshape");
reshaped_variance.set_device(fused_node.device());
*reshaped_variance.add_input() = variance;
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index d8e5809..00ac9be 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -1241,15 +1241,15 @@
// If use_reserved_space is false, we don't have 5th output.
virtual void ComputeWithReservedSpace(OpKernelContext* context,
bool use_reserved_space) {
- Tensor x = context->input(0);
+ const Tensor& x = context->input(0);
const Tensor& scale = context->input(1);
const Tensor& offset = context->input(2);
const Tensor& estimated_mean = context->input(3);
const Tensor& estimated_variance = context->input(4);
const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
- OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
- errors::InvalidArgument("input must be 4 or 5-dimensional",
+ OP_REQUIRES(context, x.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
x.shape().DebugString()));
OP_REQUIRES(context, scale.dims() == 1,
errors::InvalidArgument("scale must be 1-dimensional",
@@ -1264,21 +1264,6 @@
context, estimated_variance.dims() == 1,
errors::InvalidArgument("estimated_variance must be 1-dimensional",
estimated_variance.shape().DebugString()));
- bool use_reshape = (x.dims() == 5);
- auto x_shape = x.shape();
- TensorShape dest_shape;
- if (use_reshape) {
- const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
- int64 in_planes = GetTensorDim(x, tensor_format_, '0');
- int64 in_rows = GetTensorDim(x, tensor_format_, '1');
- int64 in_cols = GetTensorDim(x, tensor_format_, '2');
- const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
- dest_shape = ShapeFromFormat(tensor_format_, in_batch,
- {{in_planes, in_rows * in_cols}}, in_depth);
- OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
- errors::InvalidArgument("Error during tensor copy."));
- }
-
if (has_side_input_) {
OP_REQUIRES(context, side_input->shape() == x.shape(),
errors::InvalidArgument(
@@ -1297,10 +1282,8 @@
}
Tensor* y = nullptr;
- auto alloc_shape = use_reshape ? dest_shape : x_shape;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {0}, 0, alloc_shape, &y));
-
+ {0}, 0, x.shape(), &y));
Tensor* batch_mean = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{3}, 1, scale.shape(), &batch_mean));
@@ -1327,10 +1310,6 @@
batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
tensor_format_, use_reserved_space);
}
- if (use_reshape) {
- OP_REQUIRES(context, y->CopyFrom(*y, x_shape),
- errors::InvalidArgument("Error during tensor copy."));
- }
}
private:
@@ -1396,8 +1375,8 @@
virtual void ComputeWithReservedSpace(OpKernelContext* context,
bool use_reserved_space) {
- Tensor y_backprop = context->input(0);
- Tensor x = context->input(1);
+ const Tensor& y_backprop = context->input(0);
+ const Tensor& x = context->input(1);
const Tensor& scale = context->input(2);
// When is_training=True, batch mean and variance/inverted variance are
// saved in the forward pass to be reused here. When is_training=False,
@@ -1408,11 +1387,11 @@
// saves inverted variance.
const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
- OP_REQUIRES(context, y_backprop.dims() == 4 or y_backprop.dims() == 5,
- errors::InvalidArgument("input must be 4 or 5-dimensional",
+ OP_REQUIRES(context, y_backprop.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
y_backprop.shape().DebugString()));
- OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
- errors::InvalidArgument("input must be 4 or 5-dimensional",
+ OP_REQUIRES(context, x.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
x.shape().DebugString()));
OP_REQUIRES(context, scale.dims() == 1,
errors::InvalidArgument("scale must be 1-dimensional",
@@ -1425,27 +1404,10 @@
errors::InvalidArgument(
"saved variance must be 1-dimensional",
saved_maybe_inv_var_or_pop_var.shape().DebugString()));
- bool use_reshape = (x.dims() == 5);
- auto x_shape = x.shape();
- TensorShape dest_shape;
- if (use_reshape) {
- const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
- int64 in_planes = GetTensorDim(x, tensor_format_, '0');
- int64 in_rows = GetTensorDim(x, tensor_format_, '1');
- int64 in_cols = GetTensorDim(x, tensor_format_, '2');
- const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
- dest_shape = ShapeFromFormat(tensor_format_, in_batch,
- {{in_planes, in_rows * in_cols}}, in_depth);
- OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
- errors::InvalidArgument("Error during tensor copy."));
- OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape),
- errors::InvalidArgument("Error during tensor copy."));
- }
Tensor* x_backprop = nullptr;
- auto alloc_shape = use_reshape ? dest_shape : x_shape;
OP_REQUIRES_OK(context,
- context->allocate_output(0, alloc_shape, &x_backprop));
+ context->allocate_output(0, x.shape(), &x_backprop));
const TensorShape& scale_offset_shape = scale.shape();
Tensor* scale_backprop = nullptr;
@@ -1479,20 +1441,15 @@
offset_backprop, use_reserved_space, tensor_format_);
} else {
// Necessary layout conversion is currently done in python.
- OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
- errors::InvalidArgument(
- "The implementation of "
- "FusedBatchNormGrad with is_training=False only support "
- "NHWC tensor format for now."));
+ CHECK(tensor_format_ == FORMAT_NHWC)
+ << "The implementation of FusedBatchNormGrad with is_training=False "
+ "only support "
+ << "NHWC tensor format for now.";
functor::FusedBatchNormFreezeGrad<Device, T, U>()(
context, y_backprop, x, scale, saved_mean_or_pop_mean,
saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
offset_backprop);
}
- if (use_reshape) {
- OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape),
- errors::InvalidArgument("Error during tensor copy."));
- }
}
private:
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 759bf0f..2b6330d 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -221,7 +221,7 @@
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr("exponential_avg_factor: float = 1.0")
- .Attr(GetConvnetDataFormat2D3DAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormV3Shape);
@@ -308,7 +308,7 @@
.Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr(GetConvnetDataFormat2D3DAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
// --------------------------------------------------------------------------
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 263b050..c80ab53 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -1275,94 +1275,6 @@
self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes)
self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes)
self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes)
-
- @test_util.deprecated_graph_mode_only
- def testBatchNorm3D(self):
- if test.is_gpu_available(cuda_only=True):
- random_seed.set_random_seed(0)
- x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
- filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
- strides_val = [1, 1, 1, 1, 1]
- scale = constant_op.constant(0.1, shape=[3])
- offset = constant_op.constant(0.3, shape=[3])
- conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME')
- y, _, _ = nn.fused_batch_norm(conv3d, scale, offset, data_format='NDHWC')
- output = array_ops.identity(y)
-
- with session.Session(config=_get_config(False)) as sess:
- output_val_ref = sess.run(output)
-
- with session.Session(config=_get_config()) as sess:
- metadata = config_pb2.RunMetadata()
- output_val = sess.run(output, run_metadata=metadata)
-
- nodes = []
- num_transposes = 0
- for node in metadata.cost_graph.node:
- if _is_transpose(node.name):
- num_transposes += 1
- nodes.append(node.name)
-
- expected_num_transposes = 2
- self.assertEqual(expected_num_transposes, num_transposes)
- self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
- self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormV3-0-0', nodes)
- self.assertAllClose(output_val_ref, output_val, atol=1e-3)
-
- @test_util.deprecated_graph_mode_only
- def testBatchNormGrad3D(self):
- if test.is_gpu_available(cuda_only=True):
- random_seed.set_random_seed(0)
- x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
- filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
- strides_val = [1, 1, 1, 1, 1]
- scale = constant_op.constant(0.1, shape=[3])
- offset = constant_op.constant(0.3, shape=[3])
- mean = constant_op.constant(0.1, shape=[3])
- variance = constant_op.constant(0.3, shape=[3])
- conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME')
- y, running_mean, running_var, r0, r1, r2 = gen_nn_ops.fused_batch_norm_v3(
- conv3d,
- scale,
- offset,
- mean,
- variance,
- epsilon=1.001e-5,
- exponential_avg_factor=1.0,
- data_format='NDHWC',
- is_training=True,
- name='batch_norm')
- dx, dscale, doffset, _, _ = gen_nn_ops.fused_batch_norm_grad_v3(
- y,
- x_3d,
- scale,
- r0,
- r1,
- r2,
- epsilon=1.001e-5,
- data_format='NDHWC',
- is_training=True)
- output = array_ops.identity(dx)
-
- with session.Session(config=_get_config(False)) as sess:
- output_val_ref = sess.run(output)
-
- with session.Session(config=_get_config()) as sess:
- metadata = config_pb2.RunMetadata()
- output_val = sess.run(output, run_metadata=metadata)
-
- nodes = []
- num_transposes = 0
- for node in metadata.cost_graph.node:
- if _is_transpose(node.name):
- num_transposes += 1
- nodes.append(node.name)
-
- expected_num_transposes = 3
- self.assertEqual(expected_num_transposes, num_transposes)
- self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
- self._assert_trans_ndhwc_to_ncdhw('FusedBatchNormGradV3-1', nodes)
- self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormGradV3-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 957f413..dc6eda6 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -330,13 +330,13 @@
# output back to its original shape accordingly.
if self._USE_V2_BEHAVIOR:
if self.fused is None:
- self.fused = ndims in (4, 5)
- elif self.fused and ndims not in (4, 5):
+ self.fused = (ndims == 4)
+ elif self.fused and ndims != 4:
raise ValueError('Batch normalization layers with fused=True only '
- 'support 4D or 5D input tensors.')
+ 'support 4D input tensors.')
else:
assert self.fused is not None
- self.fused = (ndims in (4, 5) and self._fused_can_be_used())
+ self.fused = (ndims == 4 and self._fused_can_be_used())
# TODO(chrisying): fused batch norm is currently not supported for
# multi-axis batch norm and by extension virtual batches. In some cases,
# it might be possible to use fused batch norm but would require reshaping
@@ -345,18 +345,13 @@
# common use case (turning 5D w/ virtual batch to NCHW)
if self.fused:
- if self.axis == [1] and ndims == 4:
+ if self.axis == [1]:
self._data_format = 'NCHW'
- elif self.axis == [1] and ndims == 5:
- self._data_format = 'NCDHW'
- elif self.axis == [3] and ndims == 4:
+ elif self.axis == [3]:
self._data_format = 'NHWC'
- elif self.axis == [4] and ndims == 5:
- self._data_format = 'NDHWC'
else:
raise ValueError('Unsupported axis, fused batch norm only supports '
- 'axis == [1] or axis == [3] for 4D input tensors or'
- 'axis == [1] or axis == [4] for 5D input tensors')
+ 'axis == [1] or axis == [3]')
axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
for x in axis_to_dim:
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py
index 79ecc3c..f89a615 100644
--- a/tensorflow/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/layers/normalization_test.py
@@ -66,15 +66,6 @@
kwargs={'scale': False,
'center': False},
input_shape=(3, 3))
- testing_utils.layer_test(
- keras.layers.BatchNormalization,
- kwargs={
- 'gamma_initializer': 'ones',
- 'beta_initializer': 'ones',
- 'moving_mean_initializer': 'zeros',
- 'moving_variance_initializer': 'ones'
- },
- input_shape=(3, 2, 4, 2))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_batchnorm_weights(self):
@@ -328,7 +319,7 @@
norm = normalization_v2.BatchNormalization(fused=True)
self.assertEqual(norm.fused, True)
inp = keras.layers.Input(shape=(4, 4))
- with self.assertRaisesRegex(ValueError, '4D or 5D input tensors'):
+ with self.assertRaisesRegex(ValueError, '4D input tensors'):
norm(inp)
def test_updates_in_wrap_function(self):
diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py
index 0421829..1742a91 100644
--- a/tensorflow/python/ops/nn_fused_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py
@@ -43,18 +43,14 @@
return math_ops.cast(y, x.dtype)
def _inference_ref(self, x, scale, offset, mean, var, epsilon, data_format):
- if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
- raise ValueError('data_format must be NCHW or NHWC for 4D tensors or'
- 'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
+ if data_format not in ['NHWC', 'NCHW']:
+ raise ValueError('data_format must be NCHW or NHWC, '
+ 'got %s.' % data_format)
if data_format == 'NCHW':
x = array_ops.transpose(x, [0, 2, 3, 1])
- elif data_format == 'NCDHW':
- x = array_ops.transpose(x, [0, 2, 3, 4, 1])
y = self._batch_norm(x, mean, var, offset, scale, epsilon)
if data_format == 'NCHW':
y = array_ops.transpose(y, [0, 3, 1, 2])
- elif data_format == 'NCDHW':
- y = array_ops.transpose(y, [0, 4, 1, 2, 3])
return self.evaluate(y)
def _test_inference(self,
@@ -106,24 +102,17 @@
def _training_ref(self, x, scale, offset, old_mean, old_var,
exponential_avg_factor, epsilon, data_format):
- if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
- raise ValueError('data_format must be NCHW or NHWC for 4D tensors or'
- 'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
- use_4d_tensor = (x.shape.ndims == 4)
+ if data_format not in ['NHWC', 'NCHW']:
+ raise ValueError('data_format must be NCHW or NHWC, '
+ 'got %s.' % data_format)
if data_format == 'NCHW':
x = array_ops.transpose(x, [0, 2, 3, 1])
- elif data_format == 'NCDHW':
- x = array_ops.transpose(x, [0, 2, 3, 4, 1])
-
- mean_axis = [0, 1, 2] if use_4d_tensor else [0, 1, 2, 3]
batch_mean, batch_var = nn_impl.moments(
- math_ops.cast(x, scale.dtype), mean_axis, keep_dims=False)
+ math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
if data_format == 'NCHW':
y = array_ops.transpose(y, [0, 3, 1, 2])
- elif data_format == 'NCDHW':
- y = array_ops.transpose(y, [0, 4, 1, 2, 3])
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
# the denominator in the formula to calculate variance, while
@@ -388,18 +377,14 @@
def _runtests(self, x_shape, is_training, gradient_test=False,
cpu_only=False):
- if len(x_shape) == 4:
- data_format_list = ['NHWC', 'NCHW']
- else:
- data_format_list = ['NCDHW', 'NDHWC']
use_gpu_vals = [False]
if test.is_gpu_available(cuda_only=True) and not cpu_only:
use_gpu_vals += [True]
factors = [1.0, 0.6]
for dtype in [np.float16, np.float32]:
for use_gpu in use_gpu_vals:
- for data_format in data_format_list:
- if data_format == 'NHWC' or data_format == 'NDHWC':
+ for data_format in ['NHWC', 'NCHW']:
+ if data_format == 'NHWC':
scale_shape = x_shape[-1:]
else:
scale_shape = x_shape[1:2]
@@ -459,10 +444,6 @@
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
self._runtests(x_shape, False, cpu_only=True)
- def testInferenceShape7(self):
- x_shape = [1, 2, 6, 1, 3]
- self._runtests(x_shape, False)
-
def testTrainingShape1(self):
x_shape = [1, 1, 6, 1]
self._runtests(x_shape, True)
@@ -484,16 +465,11 @@
x_shape = [0, 131, 127, 6]
self._runtests(x_shape, True)
- @test_util.run_deprecated_v1
def testTrainingShape6(self):
x_shape = [1, 1, 1, 1]
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
self._runtests(x_shape, True, cpu_only=True)
- def testTrainingShape7(self):
- x_shape = [1, 2, 6, 1, 3]
- self._runtests(x_shape, True)
-
@test_util.run_deprecated_v1
def testBatchNormGradInferenceShape1(self):
x_shape = [1, 1, 6, 1]
@@ -528,11 +504,6 @@
cpu_only=True)
@test_util.run_deprecated_v1
- def testBatchNormGradInferenceShape7(self):
- x_shape = [1, 2, 6, 1, 3]
- self._runtests(x_shape, is_training=False, gradient_test=True)
-
- @test_util.run_deprecated_v1
def testBatchNormGradTrainingShape1(self):
x_shape = [1, 1, 6, 1]
self._runtests(x_shape, is_training=True, gradient_test=True)
@@ -564,54 +535,42 @@
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
self._runtests(x_shape, is_training=True, gradient_test=True, cpu_only=True)
- @test_util.run_deprecated_v1
- def testBatchNormGradTrainingShape7(self):
- x_shape = [1, 2, 6, 1, 3]
- self._runtests(x_shape, is_training=True, gradient_test=True)
-
def _testBatchNormGradGrad(self, config):
shape = config['shape']
err_tolerance = config['err_tolerance']
dtype = config['dtype']
- rank = len(shape)
- if rank == 4:
- data_format_nhwc, features_nhwc = 'NHWC', shape[3]
- data_format_nchw, features_nchw = 'NCHW', shape[1]
- else:
- data_format_nhwc, features_nhwc = 'NDHWC', shape[4]
- data_format_nchw, features_nchw = 'NCDHW', shape[1]
for is_training in [True, False]:
if test.is_gpu_available(cuda_only=True):
self._test_grad_grad(
shape,
- dtype, [features_nhwc],
+ dtype, [shape[3]],
np.float32,
use_gpu=True,
- data_format=data_format_nhwc,
+ data_format='NHWC',
is_training=is_training,
err_tolerance=err_tolerance)
self._test_grad_grad(
shape,
- dtype, [features_nchw],
+ dtype, [shape[1]],
np.float32,
use_gpu=True,
- data_format=data_format_nchw,
+ data_format='NCHW',
is_training=is_training,
err_tolerance=err_tolerance)
self._test_grad_grad(
shape,
- dtype, [features_nhwc],
+ dtype, [shape[3]],
np.float32,
use_gpu=False,
- data_format=data_format_nhwc,
+ data_format='NHWC',
is_training=is_training,
err_tolerance=err_tolerance)
self._test_grad_grad(
shape,
- dtype, [features_nchw],
+ dtype, [shape[1]],
np.float32,
use_gpu=False,
- data_format=data_format_nchw,
+ data_format='NCHW',
is_training=is_training,
err_tolerance=err_tolerance)
@@ -651,24 +610,6 @@
}
self._testBatchNormGradGrad(config)
- @test_util.run_deprecated_v1
- def testBatchNormGradGradConfig5(self):
- config = {
- 'shape': [2, 3, 2, 2, 2],
- 'err_tolerance': 2e-3,
- 'dtype': np.float32,
- }
- self._testBatchNormGradGrad(config)
-
- @test_util.run_deprecated_v1
- def testBatchNormGradGradConfig6(self):
- config = {
- 'shape': [2, 3, 2, 2, 2],
- 'err_tolerance': 3e-3,
- 'dtype': np.float16,
- }
- self._testBatchNormGradGrad(config)
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index a02e31f..58dd185 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -897,11 +897,6 @@
if data_format == b"NCHW":
x = array_ops.transpose(x, [0, 2, 3, 1])
grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1])
- elif data_format == b"NCDHW":
- x = array_ops.transpose(x, [0, 2, 3, 4, 1])
- grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1])
- target_data_format = ("NHWC" if data_format in (b"NCHW",
- b"NHWC") else "NDHWC")
args = {
"y_backprop": grad_y,
"x": x,
@@ -909,7 +904,7 @@
"reserve_space_1": pop_mean,
"reserve_space_2": pop_var,
"epsilon": epsilon,
- "data_format": target_data_format,
+ "data_format": "NHWC",
"is_training": is_training
}
if version == 2:
@@ -917,8 +912,6 @@
dx, dscale, doffset, _, _ = grad_fun(**args)
if data_format == b"NCHW":
dx = array_ops.transpose(dx, [0, 3, 1, 2])
- elif data_format == b"NCDHW":
- dx = array_ops.transpose(dx, [0, 4, 1, 2, 3])
return dx, dscale, doffset, None, None
@@ -948,8 +941,8 @@
"""Returns the gradients for the 3 inputs of BatchNorm.
Args:
- grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y.
- x: A `Tensor` of 4 or 5 dimensions for x.
+ grad_y: A `Tensor` of 4 dimensions for gradient for y.
+ x: A `Tensor` of 4 dimensions for x.
scale: A `Tensor` of 1 dimension for scaling.
pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
is_training=False.
@@ -975,19 +968,11 @@
if data_format == b"NHWC":
keepdims = False
reduce_axis = [0, 1, 2]
- elif data_format == b"NDHWC":
- keepdims = False
- reduce_axis = [0, 1, 2, 3]
- elif data_format == b"NCHW":
+ else:
keepdims = True
reduce_axis = [0, 2, 3]
shape = [1, array_ops.size(scale), 1, 1]
scale = array_ops.reshape(scale, shape)
- else:
- keepdims = True
- reduce_axis = [0, 2, 3, 4]
- shape = [1, array_ops.size(scale), 1, 1, 1]
- scale = array_ops.reshape(scale, shape)
mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
var_x = math_ops.reduce_mean(
@@ -1002,27 +987,19 @@
grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
- if data_format == b"NCHW" or data_format == b"NCDHW":
+ if data_format == b"NCHW":
grad_scale = array_ops.squeeze(grad_scale)
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
else:
if data_format == b"NHWC":
reduce_axis = [0, 1, 2]
- elif data_format == b"NDHWC":
- reduce_axis = [0, 1, 2, 3]
- elif data_format == b"NCHW":
+ else:
reduce_axis = [0, 2, 3]
shape = [1, array_ops.size(pop_mean), 1, 1]
pop_mean = array_ops.reshape(pop_mean, shape)
pop_var = array_ops.reshape(pop_var, shape)
scale = array_ops.reshape(scale, shape)
- else:
- reduce_axis = [0, 2, 3, 4]
- shape = [1, array_ops.size(pop_mean), 1, 1, 1]
- pop_mean = array_ops.reshape(pop_mean, shape)
- pop_var = array_ops.reshape(pop_var, shape)
- scale = array_ops.reshape(scale, shape)
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index d22fbf3..89174b2 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -1585,7 +1585,7 @@
(http://arxiv.org/abs/1502.03167).
Args:
- x: Input `Tensor` of 4 or 5 dimensions.
+ x: Input `Tensor` of 4 dimensions.
scale: A `Tensor` of 1 dimension for scaling.
offset: A `Tensor` of 1 dimension for bias.
mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
@@ -1611,8 +1611,7 @@
Variance must be a `Tensor` of the same shape as scale containing
the exponential running variance.
epsilon: A small float number added to the variance of x.
- data_format: The data format for x. Support "NHWC" (default) or "NCHW" for
- 4D tenors and "NDHWC" or "NCDHW" for 5D tensors.
+ data_format: The data format for x. Either "NHWC" (default) or "NCHW".
is_training: A bool value to specify if the operation is used for
training or inference.
name: A name for this operation (optional).
@@ -1623,7 +1622,7 @@
returned.
Returns:
- y: A 4D or 5D Tensor for the normalized, scaled, offsetted x.
+ y: A 4D Tensor for the normalized, scaled, offsetted x.
running_mean: A 1D Tensor for the exponential running mean of x.
The output value is (1 - exponential_avg_factor) * mean +
exponential_avg_factor * batch_mean), where batch_mean