[TF:XLA] Support batch_dims in ResourceGatherOp.
PiperOrigin-RevId: 266950621
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index 489ffd3..84a0e78 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -25,8 +25,10 @@
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -150,6 +152,85 @@
return Status::OK();
}
+Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
+ const xla::XlaOp input,
+ const TensorShape& input_shape,
+ int batch_dims, xla::XlaOp* gather_output) {
+ auto indices = context->Input(1);
+ auto indices_shape = context->InputShape(1);
+
+ absl::optional<int64> axis;
+ if (context->num_inputs() == 3) {
+ const TensorShape axis_shape = context->InputShape(2);
+ if (!TensorShapeUtils::IsScalar(axis_shape)) {
+ return errors::InvalidArgument("axis must be scalar");
+ }
+ DataType axis_type = context->input_type(2);
+ if (axis_type != DT_INT32 && axis_type != DT_INT64) {
+ return errors::InvalidArgument("axis must be int32 or int64");
+ }
+
+ int64 axis_input;
+ TF_RETURN_IF_ERROR(context->ConstantInputAsIntScalar(2, &axis_input));
+
+ const auto params_dims = input_shape.dims();
+ if (-params_dims > axis_input || axis_input >= params_dims) {
+ return errors::InvalidArgument("Expected axis in the range [",
+ -params_dims, ", ", params_dims,
+ "), but got ", axis_input);
+ }
+ if (axis_input < 0) {
+ axis_input += params_dims;
+ }
+ axis = axis_input;
+ }
+
+ if (batch_dims != 0) {
+ if (batch_dims < 0) {
+ batch_dims = indices_shape.dims() + batch_dims;
+ }
+
+ axis = axis.value_or(batch_dims);
+
+ if (batch_dims < -indices_shape.dims() ||
+ batch_dims >= indices_shape.dims()) {
+ return errors::InvalidArgument(
+ "Expected batch_dims in the range [", -indices_shape.dims(), ", ",
+ indices_shape.dims(), "), but got ", batch_dims);
+ }
+
+ if (batch_dims >= input_shape.dims()) {
+ return errors::InvalidArgument("batch_dims (", batch_dims,
+ ") must be less than rank(input) (",
+ input_shape.dims(), ").");
+ }
+
+ if (*axis < batch_dims) {
+ return errors::InvalidArgument("batch_dims (", batch_dims,
+ ") must be less than or equal to ",
+ "axis (", *axis, ").");
+ }
+ }
+
+ axis = axis.value_or(0);
+ DataType index_type = context->input_type(1);
+ if (index_type != DT_INT32 && index_type != DT_INT64) {
+ return errors::InvalidArgument("indices must be int32 or int64");
+ }
+
+ xla::XlaOp gather;
+ if (batch_dims > 0) {
+ *gather_output = xla::TorchIndexSelect(input, indices, *axis, batch_dims);
+ } else {
+ // XlaGather() manages degenerate cases, like empty-indices, which are
+ // error conditions and caught above if batch_dims is not 0.
+ TF_RETURN_IF_ERROR(
+ XlaGather(input, input_shape, indices, indices_shape, *axis,
+ /*indices_are_nd=*/false, context->input_type(0), index_type,
+ context->builder(), gather_output));
+ }
+ return Status::OK();
+}
class GatherOp : public XlaOpKernel {
public:
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
@@ -164,76 +245,11 @@
void Compile(XlaOpKernelContext* context) override {
auto input = context->Input(0);
auto input_shape = context->InputShape(0);
- auto indices = context->Input(1);
- auto indices_shape = context->InputShape(1);
-
- absl::optional<int64> axis;
- if (context->num_inputs() == 3) {
- const TensorShape axis_shape = context->InputShape(2);
- OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
- errors::InvalidArgument("axis must be scalar"));
- DataType axis_type = input_type(2);
- OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
- errors::InvalidArgument("axis must be int32 or int64"));
-
- int64 axis_input;
- OP_REQUIRES_OK(context,
- context->ConstantInputAsIntScalar(2, &axis_input));
-
- const auto params_dims = input_shape.dims();
- OP_REQUIRES(context,
- -params_dims <= axis_input && axis_input < params_dims,
- errors::InvalidArgument("Expected axis in the range [",
- -params_dims, ", ", params_dims,
- "), but got ", axis_input));
- if (axis_input < 0) {
- axis_input += params_dims;
- }
- axis = axis_input;
- }
-
- if (batch_dims_ != 0) {
- if (batch_dims_ < 0) {
- batch_dims_ = indices_shape.dims() + batch_dims_;
- }
-
- axis = axis.value_or(batch_dims_);
-
- OP_REQUIRES(context,
- batch_dims_ >= -indices_shape.dims() &&
- batch_dims_ < indices_shape.dims(),
- errors::InvalidArgument("Expected batch_dims in the range [",
- -indices_shape.dims(), ", ",
- indices_shape.dims(), "), but got ",
- batch_dims_));
-
- OP_REQUIRES(context, batch_dims_ < input_shape.dims(),
- errors::InvalidArgument("batch_dims (", batch_dims_,
- ") must be less than rank(input) (",
- input_shape.dims(), ")."));
-
- OP_REQUIRES(context, *axis >= batch_dims_,
- errors::InvalidArgument("batch_dims (", batch_dims_,
- ") must be less than or equal to ",
- "axis (", *axis, ")."));
- }
-
- axis = axis.value_or(0);
- DataType index_type = input_type(1);
- OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
- errors::InvalidArgument("indices must be int32 or int64"));
xla::XlaOp gather;
- if (batch_dims_ > 0) {
- gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_);
- } else {
- // XlaGather() manages degenerate cases, like empty-indices, which are
- // error conditions and caught above if batch_dims is not 0.
- OP_REQUIRES_OK(
- context, XlaGather(input, input_shape, indices, indices_shape, *axis,
- /*indices_are_nd=*/false, input_type(0),
- index_type, context->builder(), &gather));
- }
+ OP_REQUIRES_OK(context,
+ XlaGatherWithBatchDimsOpImpl(context, input, input_shape,
+ batch_dims_, &gather));
context->SetOutput(0, gather);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
index 9234628..7bd2523 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
@@ -39,6 +39,13 @@
DataType index_type, xla::XlaBuilder* builder,
xla::XlaOp* gather_output);
+// The implementation of Gather and ResourceGather through XLA. Uses `input` as
+// the input instead of context->input(0) in order to allow ResourceGather to
+// handle obtaining the data from the ResourceVariable.
+Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
+ const xla::XlaOp input,
+ const TensorShape& input_shape,
+ int batch_dims, xla::XlaOp* gather_output);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index 7b4125a..60424f8 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -19,6 +19,7 @@
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -122,27 +123,24 @@
class ResourceGatherOp : public XlaOpKernel {
public:
- explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_dims", &batch_dims_));
+ }
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
-
DataType type = ctx->expected_output_dtype(0);
- TensorShape resource_shape;
- xla::XlaOp resource_handle;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
- &resource_handle));
+ TensorShape input_shape;
+ xla::XlaOp input;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &input_shape, &input));
- auto indices = ctx->Input(1);
- auto indices_shape = ctx->InputShape(1);
- DataType index_type = ctx->input_type(1);
xla::XlaOp gather;
- OP_REQUIRES_OK(
- ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape,
- /*axis=*/0, /*indices_are_nd=*/false, type, index_type,
- builder, &gather));
+ OP_REQUIRES_OK(ctx, XlaGatherWithBatchDimsOpImpl(ctx, input, input_shape,
+ batch_dims_, &gather));
ctx->SetOutput(0, gather);
}
+
+ private:
+ int32 batch_dims_;
};
REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp);
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index e5b741b..18682e0 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -861,7 +861,6 @@
],
# TODO(b/128347673): Re-enable.
tags = ["no_windows"],
- xla_enable_strict_auto_jit = True,
)
tf_py_test(
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 70c6c7e..14a4c53 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -986,7 +986,9 @@
x = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
container=ops.get_default_graph()._container)
- with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
+ with self.assertRaisesOpError(
+ "(Resource .*/var5/.* does not exist|Read of uninitialized variable)"
+ ):
resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
@test_util.run_deprecated_v1