[TF2XLA] Re-enable single pass argmax lowering for TF

Rolling forward: the underlying issue with mhlo converter was fixed.

PiperOrigin-RevId: 404719818
Change-Id: I5f55249a3e6f05daf81f62177a7bfac6f95a099c
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index e2e0fc8..79d6026 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -36,9 +36,7 @@
 
 class CategoricalOp : public XlaOpKernel {
  public:
-  explicit CategoricalOp(OpKernelConstruction* ctx)
-      : XlaOpKernel(ctx),
-        is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
+  explicit CategoricalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
 
   void Compile(XlaOpKernelContext* ctx) override {
     // Get the logits
@@ -113,14 +111,8 @@
     xla::PrimitiveType xla_output_type;
     OP_REQUIRES_OK(ctx,
                    DataTypeToPrimitiveType(output_type(0), &xla_output_type));
-    xla::XlaOp argmax;
-    if (is_gpu_) {
-      argmax = xla::ArgMaxTwoPass(softmax_entries, xla_output_type,
-                                  /*axis=*/class_dimension);
-    } else {
-      argmax = xla::ArgMax(softmax_entries, xla_output_type,
-                           /*axis=*/class_dimension, /*stable=*/true);
-    }
+    xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type,
+                                    /*axis=*/class_dimension, /*stable=*/true);
 
     if (num_samples == 1 && !num_samples_is_dynamic) {
       argmax = xla::Reshape(argmax, {batch_size, 1});
@@ -145,7 +137,6 @@
   }
 
  private:
-  bool is_gpu_;
   TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp);
 };
 
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
index 63c30d5..76abbfb 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
@@ -31,9 +31,7 @@
 
 namespace tensorflow {
 XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min)
-    : XlaOpKernel(ctx),
-      is_min_(is_min),
-      is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
+    : XlaOpKernel(ctx), is_min_(is_min) {}
 
 void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
   const TensorShape input_shape = ctx->InputShape(0);
@@ -68,17 +66,9 @@
   xla::XlaOp output;
   // One pass ArgMin/ArgMax is slow on GPUs.
   if (is_min_) {
-    if (is_gpu_) {
-      output = xla::ArgMinTwoPass(input, index_xla_type, axis);
-    } else {
-      output = xla::ArgMin(input, index_xla_type, axis, /*stable=*/true);
-    }
+    output = xla::ArgMin(input, index_xla_type, axis, /*stable=*/true);
   } else {
-    if (is_gpu_) {
-      output = xla::ArgMaxTwoPass(input, index_xla_type, axis);
-    } else {
-      output = xla::ArgMax(input, index_xla_type, axis, /*stable=*/true);
-    }
+    output = xla::ArgMax(input, index_xla_type, axis, /*stable=*/true);
   }
 
   ctx->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.h b/tensorflow/compiler/tf2xla/kernels/index_ops.h
index 4089a20..ef2b9e6 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.h
@@ -30,7 +30,6 @@
 
  private:
   const bool is_min_;  // Are we computing ArgMin (true) or ArgMax (false)?
-  const bool is_gpu_;
 };
 
 class XlaArgMaxOp : public XlaArgMinMaxOp {
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 09178a3..d83b887 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -172,49 +172,6 @@
   });
 }
 
-XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
-                       bool is_min, bool tie_low) {
-  XlaBuilder* builder = input.builder();
-  return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
-    XlaOp init_value;
-    XlaComputation reducer;
-    if (is_min) {
-      init_value = MaxValue(builder, input_shape.element_type());
-      reducer = CreateScalarMinComputation(input_shape.element_type(), builder);
-    } else {
-      init_value = MinValue(builder, input_shape.element_type());
-      reducer = CreateScalarMaxComputation(input_shape.element_type(), builder);
-    }
-
-    XlaOp iota = Iota(
-        builder, ShapeUtil::ChangeElementType(input_shape, output_type), axis);
-    XlaOp reduced_input = Reduce(input, init_value, reducer,
-                                 /*dimensions_to_reduce=*/{axis});
-    std::vector<int64_t> broadcast_dims(input_shape.rank() - 1);
-    std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
-    std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
-    if (tie_low) {
-      XlaOp max_idx = MaxValue(builder, output_type);
-      XlaOp select_mask = Select(Eq(input, reduced_input, broadcast_dims),
-                                 /*on_true=*/iota,
-                                 /*on_false=*/
-                                 max_idx);
-      return Reduce(select_mask, max_idx,
-                    CreateScalarMinComputation(output_type, builder),
-                    /*dimensions_to_reduce=*/{axis});
-    } else {
-      XlaOp min_idx = MinValue(builder, output_type);
-      XlaOp select_mask = Select(Eq(input, reduced_input, broadcast_dims),
-                                 /*on_true=*/iota,
-                                 /*on_false=*/
-                                 min_idx);
-      return Reduce(select_mask, min_idx,
-                    CreateScalarMaxComputation(output_type, builder),
-                    /*dimensions_to_reduce=*/{axis});
-    }
-  });
-}
 }  // namespace
 
 XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis, bool stable,
@@ -227,13 +184,4 @@
   return ArgMinMax(input, output_type, axis, /*is_min=*/true, stable, tie_low);
 }
 
-XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
-                    bool tie_low) {
-  return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false, tie_low);
-}
-
-XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis,
-                    bool tie_low) {
-  return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true, tie_low);
-}
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index 2712b2a..231b518 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -80,21 +80,17 @@
 // use for the output. The `tie_low` argument drives the index selection is case
 // of same values. If `true` (default behavior) the lowest index will be
 // returned, otherwise the higher. The tie_low argument only applies if `stable`
-// is true or using the ArgMaxTwoPass.
+// is true.
 XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis,
              bool stable = false, bool tie_low = true);
-XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
-                    bool tie_low = true);
 
 // Returns the argmin of `input` along `axis`. `output_type` is the type to
 // use for the output. The `tie_low` argument drives the index selection is case
 // of same values. If `true` (default behavior) the lowest index will be
 // returned, otherwise the higher. The tie_low argument only applies if `stable`
-// is true or using the ArgMinTwoPass.
+// is true.
 XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis,
              bool stable = false, bool tie_low = true);
-XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis,
-                    bool tie_low = true);
 
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc
index 842b063..f9f2321 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc
@@ -57,19 +57,11 @@
           input, expected_output, [=](XlaOp op, PrimitiveType type) {
             return ArgMin(op, type, axis, /*stable=*/true, tie_low);
           });
-      TestArgMinMaxImpl(input, expected_output,
-                        [=](XlaOp op, PrimitiveType type) {
-                          return ArgMinTwoPass(op, type, axis, tie_low);
-                        });
     } else {
       TestArgMinMaxImpl(
           input, expected_output, [=](XlaOp op, PrimitiveType type) {
             return ArgMax(op, type, axis, /*stable=*/true, tie_low);
           });
-      TestArgMinMaxImpl(input, expected_output,
-                        [=](XlaOp op, PrimitiveType type) {
-                          return ArgMaxTwoPass(op, type, axis, tie_low);
-                        });
     }
   }
 
diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py
index 39021b1..7394e0b 100644
--- a/tensorflow/python/eager/def_function_xla_jit_test.py
+++ b/tensorflow/python/eager/def_function_xla_jit_test.py
@@ -1168,6 +1168,20 @@
                                   'Trying to access resource .*'):
         my_func_temp()
 
+  def testSinglePassArgmax(self):
+    with ops.device('device:{}:0'.format(self.device)):
+
+      @def_function.function(jit_compile=True)
+      def f(x):
+        return math_ops.argmax(x)
+
+      hlo = f.experimental_get_compiler_ir(
+          array_ops.ones([10], dtype=dtypes.float32))(
+              stage='hlo')
+
+      # Test that reduction occurs only once.
+      self.assertTrue(hlo.count('reduce'), 1)
+
 
 if __name__ == '__main__':
   ops.enable_eager_execution()