Cleans up a newly added XlaVariadicSort op to take Input instead of Attr.
`dimension` and `is_stable` were passed as attributes before but are now
allowed to be inputs. However, they are constrained to be compile-time
constants for now.
Also add proper kernel registrations for CPU and GP (which was missing in the
previous change).
PiperOrigin-RevId: 350341893
Change-Id: Iab0554ed9ba12343f388f9963e4624c9823a227e
diff --git a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
index 6c6c490..7ddb1a6 100644
--- a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
+++ b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
@@ -74,6 +74,9 @@
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaKeyValueSort").Device(DEVICE), \
XlaCompileOnDemandOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("XlaVariadicSort").HostMemory("dimension").Device(DEVICE), \
+ XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaWhile").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDequantize").Device(DEVICE), \
diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
index 7b2acae..bc7ef63 100644
--- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
@@ -61,17 +61,22 @@
: XlaOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("T", &input_types_));
OP_REQUIRES_OK(context, context->GetAttr("comparator", &comparator_));
- OP_REQUIRES_OK(context, context->GetAttr("dimension", &dimension_));
OP_REQUIRES_OK(context, context->GetAttr("is_stable", &is_stable_));
}
void Compile(XlaOpKernelContext* context) override {
- std::vector<xla::XlaOp> inputs(input_types_.size());
+ std::vector<xla::XlaOp> inputs;
+ std::vector<TensorShape> input_shapes;
+ OP_REQUIRES_OK(context,
+ context->InputList("inputs", &inputs, &input_shapes));
+ int64 dimension;
+ OP_REQUIRES_OK(context,
+ context->ConstantInputAsIntScalar("dimension", &dimension));
+
std::vector<xla::PrimitiveType> input_xla_types(input_types_.size());
std::vector<XlaCompiler::Argument> comparator_args(2 * input_types_.size());
- for (int i = 0; i < input_types_.size(); ++i) {
- inputs[i] = context->Input(i);
+ for (int i = 0; i < inputs.size(); ++i) {
OP_REQUIRES_OK(context, DataTypeToPrimitiveType(input_types_[i],
&input_xla_types[i]));
XlaCompiler::Argument comparator_arg;
@@ -101,12 +106,12 @@
xla::ShapeUtil::Compatible(comparator.xla_output_shape,
expected_comparator_output_shape),
errors::InvalidArgument(
- "Invalid output shape of XlaReduce reducer. Expected ",
+ "Invalid output shape of XlaVariadicSort comparator. Expected ",
xla::ShapeUtil::HumanString(expected_comparator_output_shape),
" got ", xla::ShapeUtil::HumanString(comparator.xla_output_shape)));
xla::XlaOp outputs =
- xla::Sort(inputs, *comparator.computation, dimension_, is_stable_);
+ xla::Sort(inputs, *comparator.computation, dimension, is_stable_);
for (int i = 0; i < input_types_.size(); ++i) {
xla::XlaOp output_handle =
@@ -119,12 +124,12 @@
private:
DataTypeVector input_types_;
const NameAttrList* comparator_;
- int64 dimension_;
bool is_stable_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaVariadicSortOp);
};
-REGISTER_XLA_OP(Name("XlaVariadicSort"), XlaVariadicSortOp);
+REGISTER_XLA_OP(Name("XlaVariadicSort").CompileTimeConstantInput("dimension"),
+ XlaVariadicSortOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 00d1fef..c2e449d 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -730,16 +730,16 @@
)doc");
REGISTER_OP("XlaVariadicSort")
- .Input("input: T")
- .Output("output: T")
+ .Input("inputs: T")
+ .Input("dimension: int32")
+ .Output("outputs: T")
.Attr("T: list(type) >= 1")
.Attr("comparator: func")
- .Attr("dimension: int")
.Attr("is_stable: bool")
.SetShapeFn([](shape_inference::InferenceContext* c) {
- for (int i = 0; i < c->num_inputs(); ++i) {
- c->set_output(i, c->input(i));
- }
+ std::vector<shape_inference::ShapeHandle> input_shapes;
+ TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
+ TF_RETURN_IF_ERROR(c->set_output("outputs", input_shapes));
return Status::OK();
})
.Doc(R"doc(
@@ -750,13 +750,13 @@
Sorts one or more tensors, with support for custom comparator, dimension, and
is_stable attributes.
-input: A list of `Tensor` of identical shape by possibly different types.
+inputs: A list of `Tensor` of identical shape but possibly different types.
+dimension: The dimension along which to sort. Must be a compile-time constant.
+is_stable: Whether to use stable sort.
comparator: A comparator function to apply to 2*N scalars and returning a
boolean. N is the number of sort inputs. If you want to sort in ascending
order then the comparator should perform a less-than comparison.
-output: A list of `Tensor` of type T.
-dimension: The dimension along which to sort.
-is_stable: Whether to use stable sort.
+outputs: A list of `Tensor` of same shape and types as the `input`.
)doc");
// TODO(b/37549631) setting the While Op to always be stateful is too