Fix data race in GatherOp::Compute().
Don't modify a member variable (batch_dims_) inside Compute().
PiperOrigin-RevId: 347665956
Change-Id: I77dcf4be15e3961d2808b2d0ca323ab1a8d7a3a8
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index e9e6a93..e0b909a 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -88,29 +88,31 @@
axis = params.dims() + axis;
}
- if (batch_dims_ != 0) {
- OP_REQUIRES(
- c, batch_dims_ >= -indices.dims() && batch_dims_ <= indices.dims(),
- errors::InvalidArgument("Expected batch_dims in the range [",
- -indices.dims(), ", ", indices.dims(),
- "], but got ", batch_dims_));
+ // Modify only a local copy of batch_dims_.
+ int32 batch_dims = batch_dims_;
+ if (batch_dims != 0) {
+ OP_REQUIRES(c,
+ batch_dims >= -indices.dims() && batch_dims <= indices.dims(),
+ errors::InvalidArgument("Expected batch_dims in the range [",
+ -indices.dims(), ", ", indices.dims(),
+ "], but got ", batch_dims));
- if (batch_dims_ < 0) {
- batch_dims_ = indices.dims() + batch_dims_;
+ if (batch_dims < 0) {
+ batch_dims = indices.dims() + batch_dims;
}
- if (!axis_is_set) axis = batch_dims_;
+ if (!axis_is_set) axis = batch_dims;
- OP_REQUIRES(c, batch_dims_ < params.dims(),
- errors::InvalidArgument("batch_dims (", batch_dims_,
+ OP_REQUIRES(c, batch_dims < params.dims(),
+ errors::InvalidArgument("batch_dims (", batch_dims,
") must be less than rank(params) (",
params.dims(), ")."));
- OP_REQUIRES(c, axis >= batch_dims_,
- errors::InvalidArgument("batch_dims (", batch_dims_,
+ OP_REQUIRES(c, axis >= batch_dims,
+ errors::InvalidArgument("batch_dims (", batch_dims,
") must be less than or equal to ",
"axis (", axis, ")."));
- for (int i = 0; i < batch_dims_; ++i) {
+ for (int i = 0; i < batch_dims; ++i) {
OP_REQUIRES(c, params.dim_size(i) == indices.dim_size(i),
errors::InvalidArgument(
"params.shape[", i, "]: ", params.dim_size(i),
@@ -136,15 +138,15 @@
int64 outer_size = 1;
int64 inner_size = 1;
- for (int i = 0; i < batch_dims_; ++i) {
+ for (int i = 0; i < batch_dims; ++i) {
result_shape.AddDim(params.dim_size(i));
batch_size *= params.dim_size(i);
}
- for (int i = batch_dims_; i < axis; ++i) {
+ for (int i = batch_dims; i < axis; ++i) {
result_shape.AddDim(params.dim_size(i));
outer_size *= params.dim_size(i);
}
- for (int i = batch_dims_; i < indices.dims(); ++i) {
+ for (int i = batch_dims; i < indices.dims(); ++i) {
result_shape.AddDim(indices.dim_size(i));
}
for (int i = axis + 1; i < params.dims(); ++i) {
@@ -159,7 +161,7 @@
int64 bad_i = -1;
auto indices_flat = indices.flat<Index>();
- if (batch_dims_ > 0) {
+ if (batch_dims > 0) {
auto params_flat = params.shaped<T, 4>(
{batch_size, outer_size, gather_dim_size, inner_size});
auto out_flat = out->shaped<T, 4>(