Optimize FusedBatchNormGrad on CPU device

PiperOrigin-RevId: 275572837
Change-Id: I6f7e19a80eb1c7822ddc549b8d18307d7c2b3133
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 381d879..83602cd 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -490,9 +490,9 @@
     auto x_mean_rest_by_depth =
         mean.reshape(one_by_depth).broadcast(bcast_spec);
     auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth);
-    auto coef0 = (variance + epsilon).rsqrt();
-    auto coef0_rest_by_depth =
-        coef0.reshape(one_by_depth).broadcast(bcast_spec);
+    auto coef0_one_by_depth =
+        (variance.reshape(one_by_depth) + epsilon).rsqrt();
+    auto coef0_rest_by_depth = coef0_one_by_depth.broadcast(bcast_spec);
     auto x_scaled = x_centered * coef0_rest_by_depth;
 
     auto y_backprop_rest_by_depth =
@@ -524,12 +524,12 @@
     //     (y_backprop_rest_by_depth * x_centered).mean(reduce_dims)
     scratch_tensor.device(d) = y_backprop_rest_by_depth * x_centered;
     redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, &scratch_one_by_depth);
-    auto y_backprop_centered_mean = scratch_vector / static_cast<U>(rest_size);
+    auto y_backprop_centered_mean =
+        scratch_vector.reshape(one_by_depth) / static_cast<U>(rest_size);
 
-    auto coef1 = (scale * coef0).reshape(one_by_depth).broadcast(bcast_spec);
-    auto coef2 = (coef0.square() * y_backprop_centered_mean)
-                     .reshape(one_by_depth)
-                     .eval()
+    auto coef1 = (scale.reshape(one_by_depth) * coef0_one_by_depth)
+                     .broadcast(bcast_spec);
+    auto coef2 = (coef0_one_by_depth.square() * y_backprop_centered_mean)
                      .broadcast(bcast_spec);
 
     x_backprop.reshape(rest_by_depth).device(d) =
@@ -634,7 +634,8 @@
 
     x_backprop.reshape(rest_by_depth).device(d) =
         (y_backprop_rest_by_depth *
-         ((scratch1 * scale).reshape(one_by_depth).broadcast(rest_by_one)))
+         ((scratch1.reshape(one_by_depth) * scale.reshape(one_by_depth))
+              .broadcast(rest_by_one)))
             .template cast<T>();
     scale_backprop = scratch2 * scratch1;  // DEFAULT DEVICE
   }