Fix data race in GatherOp::Compute().

Don't modify a member variable (batch_dims_) inside Compute().

PiperOrigin-RevId: 347665956
Change-Id: I77dcf4be15e3961d2808b2d0ca323ab1a8d7a3a8
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index e9e6a93..e0b909a 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -88,29 +88,31 @@
       axis = params.dims() + axis;
     }
 
-    if (batch_dims_ != 0) {
-      OP_REQUIRES(
-          c, batch_dims_ >= -indices.dims() && batch_dims_ <= indices.dims(),
-          errors::InvalidArgument("Expected batch_dims in the range [",
-                                  -indices.dims(), ", ", indices.dims(),
-                                  "], but got ", batch_dims_));
+    // Modify only a local copy of batch_dims_.
+    int32 batch_dims = batch_dims_;
+    if (batch_dims != 0) {
+      OP_REQUIRES(c,
+                  batch_dims >= -indices.dims() && batch_dims <= indices.dims(),
+                  errors::InvalidArgument("Expected batch_dims in the range [",
+                                          -indices.dims(), ", ", indices.dims(),
+                                          "], but got ", batch_dims));
 
-      if (batch_dims_ < 0) {
-        batch_dims_ = indices.dims() + batch_dims_;
+      if (batch_dims < 0) {
+        batch_dims = indices.dims() + batch_dims;
       }
 
-      if (!axis_is_set) axis = batch_dims_;
+      if (!axis_is_set) axis = batch_dims;
 
-      OP_REQUIRES(c, batch_dims_ < params.dims(),
-                  errors::InvalidArgument("batch_dims (", batch_dims_,
+      OP_REQUIRES(c, batch_dims < params.dims(),
+                  errors::InvalidArgument("batch_dims (", batch_dims,
                                           ") must be less than rank(params) (",
                                           params.dims(), ")."));
 
-      OP_REQUIRES(c, axis >= batch_dims_,
-                  errors::InvalidArgument("batch_dims (", batch_dims_,
+      OP_REQUIRES(c, axis >= batch_dims,
+                  errors::InvalidArgument("batch_dims (", batch_dims,
                                           ") must be less than or equal to ",
                                           "axis (", axis, ")."));
-      for (int i = 0; i < batch_dims_; ++i) {
+      for (int i = 0; i < batch_dims; ++i) {
         OP_REQUIRES(c, params.dim_size(i) == indices.dim_size(i),
                     errors::InvalidArgument(
                         "params.shape[", i, "]: ", params.dim_size(i),
@@ -136,15 +138,15 @@
     int64 outer_size = 1;
     int64 inner_size = 1;
 
-    for (int i = 0; i < batch_dims_; ++i) {
+    for (int i = 0; i < batch_dims; ++i) {
       result_shape.AddDim(params.dim_size(i));
       batch_size *= params.dim_size(i);
     }
-    for (int i = batch_dims_; i < axis; ++i) {
+    for (int i = batch_dims; i < axis; ++i) {
       result_shape.AddDim(params.dim_size(i));
       outer_size *= params.dim_size(i);
     }
-    for (int i = batch_dims_; i < indices.dims(); ++i) {
+    for (int i = batch_dims; i < indices.dims(); ++i) {
       result_shape.AddDim(indices.dim_size(i));
     }
     for (int i = axis + 1; i < params.dims(); ++i) {
@@ -159,7 +161,7 @@
 
     int64 bad_i = -1;
     auto indices_flat = indices.flat<Index>();
-    if (batch_dims_ > 0) {
+    if (batch_dims > 0) {
       auto params_flat = params.shaped<T, 4>(
           {batch_size, outer_size, gather_dim_size, inner_size});
       auto out_flat = out->shaped<T, 4>(