Optimize FusedBatchNormGrad on CPU device
PiperOrigin-RevId: 275572837
Change-Id: I6f7e19a80eb1c7822ddc549b8d18307d7c2b3133
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 381d879..83602cd 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -490,9 +490,9 @@
auto x_mean_rest_by_depth =
mean.reshape(one_by_depth).broadcast(bcast_spec);
auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth);
- auto coef0 = (variance + epsilon).rsqrt();
- auto coef0_rest_by_depth =
- coef0.reshape(one_by_depth).broadcast(bcast_spec);
+ auto coef0_one_by_depth =
+ (variance.reshape(one_by_depth) + epsilon).rsqrt();
+ auto coef0_rest_by_depth = coef0_one_by_depth.broadcast(bcast_spec);
auto x_scaled = x_centered * coef0_rest_by_depth;
auto y_backprop_rest_by_depth =
@@ -524,12 +524,12 @@
// (y_backprop_rest_by_depth * x_centered).mean(reduce_dims)
scratch_tensor.device(d) = y_backprop_rest_by_depth * x_centered;
redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, &scratch_one_by_depth);
- auto y_backprop_centered_mean = scratch_vector / static_cast<U>(rest_size);
+ auto y_backprop_centered_mean =
+ scratch_vector.reshape(one_by_depth) / static_cast<U>(rest_size);
- auto coef1 = (scale * coef0).reshape(one_by_depth).broadcast(bcast_spec);
- auto coef2 = (coef0.square() * y_backprop_centered_mean)
- .reshape(one_by_depth)
- .eval()
+ auto coef1 = (scale.reshape(one_by_depth) * coef0_one_by_depth)
+ .broadcast(bcast_spec);
+ auto coef2 = (coef0_one_by_depth.square() * y_backprop_centered_mean)
.broadcast(bcast_spec);
x_backprop.reshape(rest_by_depth).device(d) =
@@ -634,7 +634,8 @@
x_backprop.reshape(rest_by_depth).device(d) =
(y_backprop_rest_by_depth *
- ((scratch1 * scale).reshape(one_by_depth).broadcast(rest_by_one)))
+ ((scratch1.reshape(one_by_depth) * scale.reshape(one_by_depth))
+ .broadcast(rest_by_one)))
.template cast<T>();
scale_backprop = scratch2 * scratch1; // DEFAULT DEVICE
}