Cleans up a newly added XlaVariadicSort op to take Input instead of Attr.

`dimension` and `is_stable` were passed as attributes before but are now
allowed to be inputs. However, they are constrained to be compile-time
constants for now.

Also add proper kernel registrations for CPU and GP (which was missing in the
previous change).

PiperOrigin-RevId: 350341893
Change-Id: Iab0554ed9ba12343f388f9963e4624c9823a227e
diff --git a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
index 6c6c490..7ddb1a6 100644
--- a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
+++ b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
@@ -74,6 +74,9 @@
                           XlaCompileOnDemandOp);                               \
   REGISTER_KERNEL_BUILDER(Name("XlaKeyValueSort").Device(DEVICE),              \
                           XlaCompileOnDemandOp);                               \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("XlaVariadicSort").HostMemory("dimension").Device(DEVICE),          \
+      XlaCompileOnDemandOp);                                                   \
   REGISTER_KERNEL_BUILDER(Name("XlaWhile").Device(DEVICE),                     \
                           XlaCompileOnDemandOp);                               \
   REGISTER_KERNEL_BUILDER(Name("XlaDequantize").Device(DEVICE),                \
diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
index 7b2acae..bc7ef63 100644
--- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
@@ -61,17 +61,22 @@
       : XlaOpKernel(context) {
     OP_REQUIRES_OK(context, context->GetAttr("T", &input_types_));
     OP_REQUIRES_OK(context, context->GetAttr("comparator", &comparator_));
-    OP_REQUIRES_OK(context, context->GetAttr("dimension", &dimension_));
     OP_REQUIRES_OK(context, context->GetAttr("is_stable", &is_stable_));
   }
 
   void Compile(XlaOpKernelContext* context) override {
-    std::vector<xla::XlaOp> inputs(input_types_.size());
+    std::vector<xla::XlaOp> inputs;
+    std::vector<TensorShape> input_shapes;
+    OP_REQUIRES_OK(context,
+                   context->InputList("inputs", &inputs, &input_shapes));
+    int64 dimension;
+    OP_REQUIRES_OK(context,
+                   context->ConstantInputAsIntScalar("dimension", &dimension));
+
     std::vector<xla::PrimitiveType> input_xla_types(input_types_.size());
     std::vector<XlaCompiler::Argument> comparator_args(2 * input_types_.size());
 
-    for (int i = 0; i < input_types_.size(); ++i) {
-      inputs[i] = context->Input(i);
+    for (int i = 0; i < inputs.size(); ++i) {
       OP_REQUIRES_OK(context, DataTypeToPrimitiveType(input_types_[i],
                                                       &input_xla_types[i]));
       XlaCompiler::Argument comparator_arg;
@@ -101,12 +106,12 @@
         xla::ShapeUtil::Compatible(comparator.xla_output_shape,
                                    expected_comparator_output_shape),
         errors::InvalidArgument(
-            "Invalid output shape of XlaReduce reducer. Expected ",
+            "Invalid output shape of XlaVariadicSort comparator. Expected ",
             xla::ShapeUtil::HumanString(expected_comparator_output_shape),
             " got ", xla::ShapeUtil::HumanString(comparator.xla_output_shape)));
 
     xla::XlaOp outputs =
-        xla::Sort(inputs, *comparator.computation, dimension_, is_stable_);
+        xla::Sort(inputs, *comparator.computation, dimension, is_stable_);
 
     for (int i = 0; i < input_types_.size(); ++i) {
       xla::XlaOp output_handle =
@@ -119,12 +124,12 @@
  private:
   DataTypeVector input_types_;
   const NameAttrList* comparator_;
-  int64 dimension_;
   bool is_stable_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(XlaVariadicSortOp);
 };
 
-REGISTER_XLA_OP(Name("XlaVariadicSort"), XlaVariadicSortOp);
+REGISTER_XLA_OP(Name("XlaVariadicSort").CompileTimeConstantInput("dimension"),
+                XlaVariadicSortOp);
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 00d1fef..c2e449d 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -730,16 +730,16 @@
 )doc");
 
 REGISTER_OP("XlaVariadicSort")
-    .Input("input: T")
-    .Output("output: T")
+    .Input("inputs: T")
+    .Input("dimension: int32")
+    .Output("outputs: T")
     .Attr("T: list(type) >= 1")
     .Attr("comparator: func")
-    .Attr("dimension: int")
     .Attr("is_stable: bool")
     .SetShapeFn([](shape_inference::InferenceContext* c) {
-      for (int i = 0; i < c->num_inputs(); ++i) {
-        c->set_output(i, c->input(i));
-      }
+      std::vector<shape_inference::ShapeHandle> input_shapes;
+      TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
+      TF_RETURN_IF_ERROR(c->set_output("outputs", input_shapes));
       return Status::OK();
     })
     .Doc(R"doc(
@@ -750,13 +750,13 @@
 Sorts one or more tensors, with support for custom comparator, dimension, and
 is_stable attributes.
 
-input: A list of `Tensor` of identical shape by possibly different types.
+inputs: A list of `Tensor` of identical shape but possibly different types.
+dimension: The dimension along which to sort. Must be a compile-time constant.
+is_stable: Whether to use stable sort.
 comparator: A comparator function to apply to 2*N scalars and returning a
   boolean. N is the number of sort inputs. If you want to sort in ascending
   order then the comparator should perform a less-than comparison.
-output: A list of `Tensor` of type T.
-dimension: The dimension along which to sort.
-is_stable: Whether to use stable sort.
+outputs: A list of `Tensor` of same shape and types as the `input`.
 )doc");
 
 // TODO(b/37549631) setting the While Op to always be stateful is too