[TF2XLA] Re-enable single pass argmax lowering for TF
Rolling forward: the underlying issue with mhlo converter was fixed.
PiperOrigin-RevId: 404719818
Change-Id: I5f55249a3e6f05daf81f62177a7bfac6f95a099c
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index e2e0fc8..79d6026 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -36,9 +36,7 @@
class CategoricalOp : public XlaOpKernel {
public:
- explicit CategoricalOp(OpKernelConstruction* ctx)
- : XlaOpKernel(ctx),
- is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
+ explicit CategoricalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
// Get the logits
@@ -113,14 +111,8 @@
xla::PrimitiveType xla_output_type;
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(output_type(0), &xla_output_type));
- xla::XlaOp argmax;
- if (is_gpu_) {
- argmax = xla::ArgMaxTwoPass(softmax_entries, xla_output_type,
- /*axis=*/class_dimension);
- } else {
- argmax = xla::ArgMax(softmax_entries, xla_output_type,
- /*axis=*/class_dimension, /*stable=*/true);
- }
+ xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type,
+ /*axis=*/class_dimension, /*stable=*/true);
if (num_samples == 1 && !num_samples_is_dynamic) {
argmax = xla::Reshape(argmax, {batch_size, 1});
@@ -145,7 +137,6 @@
}
private:
- bool is_gpu_;
TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp);
};
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
index 63c30d5..76abbfb 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
@@ -31,9 +31,7 @@
namespace tensorflow {
XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min)
- : XlaOpKernel(ctx),
- is_min_(is_min),
- is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
+ : XlaOpKernel(ctx), is_min_(is_min) {}
void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
const TensorShape input_shape = ctx->InputShape(0);
@@ -68,17 +66,9 @@
xla::XlaOp output;
// One pass ArgMin/ArgMax is slow on GPUs.
if (is_min_) {
- if (is_gpu_) {
- output = xla::ArgMinTwoPass(input, index_xla_type, axis);
- } else {
- output = xla::ArgMin(input, index_xla_type, axis, /*stable=*/true);
- }
+ output = xla::ArgMin(input, index_xla_type, axis, /*stable=*/true);
} else {
- if (is_gpu_) {
- output = xla::ArgMaxTwoPass(input, index_xla_type, axis);
- } else {
- output = xla::ArgMax(input, index_xla_type, axis, /*stable=*/true);
- }
+ output = xla::ArgMax(input, index_xla_type, axis, /*stable=*/true);
}
ctx->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.h b/tensorflow/compiler/tf2xla/kernels/index_ops.h
index 4089a20..ef2b9e6 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.h
@@ -30,7 +30,6 @@
private:
const bool is_min_; // Are we computing ArgMin (true) or ArgMax (false)?
- const bool is_gpu_;
};
class XlaArgMaxOp : public XlaArgMinMaxOp {
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 09178a3..d83b887 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -172,49 +172,6 @@
});
}
-XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
- bool is_min, bool tie_low) {
- XlaBuilder* builder = input.builder();
- return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
- XlaOp init_value;
- XlaComputation reducer;
- if (is_min) {
- init_value = MaxValue(builder, input_shape.element_type());
- reducer = CreateScalarMinComputation(input_shape.element_type(), builder);
- } else {
- init_value = MinValue(builder, input_shape.element_type());
- reducer = CreateScalarMaxComputation(input_shape.element_type(), builder);
- }
-
- XlaOp iota = Iota(
- builder, ShapeUtil::ChangeElementType(input_shape, output_type), axis);
- XlaOp reduced_input = Reduce(input, init_value, reducer,
- /*dimensions_to_reduce=*/{axis});
- std::vector<int64_t> broadcast_dims(input_shape.rank() - 1);
- std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
- std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
- if (tie_low) {
- XlaOp max_idx = MaxValue(builder, output_type);
- XlaOp select_mask = Select(Eq(input, reduced_input, broadcast_dims),
- /*on_true=*/iota,
- /*on_false=*/
- max_idx);
- return Reduce(select_mask, max_idx,
- CreateScalarMinComputation(output_type, builder),
- /*dimensions_to_reduce=*/{axis});
- } else {
- XlaOp min_idx = MinValue(builder, output_type);
- XlaOp select_mask = Select(Eq(input, reduced_input, broadcast_dims),
- /*on_true=*/iota,
- /*on_false=*/
- min_idx);
- return Reduce(select_mask, min_idx,
- CreateScalarMaxComputation(output_type, builder),
- /*dimensions_to_reduce=*/{axis});
- }
- });
-}
} // namespace
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis, bool stable,
@@ -227,13 +184,4 @@
return ArgMinMax(input, output_type, axis, /*is_min=*/true, stable, tie_low);
}
-XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
- bool tie_low) {
- return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false, tie_low);
-}
-
-XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis,
- bool tie_low) {
- return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true, tie_low);
-}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index 2712b2a..231b518 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -80,21 +80,17 @@
// use for the output. The `tie_low` argument drives the index selection is case
// of same values. If `true` (default behavior) the lowest index will be
// returned, otherwise the higher. The tie_low argument only applies if `stable`
-// is true or using the ArgMaxTwoPass.
+// is true.
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis,
bool stable = false, bool tie_low = true);
-XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
- bool tie_low = true);
// Returns the argmin of `input` along `axis`. `output_type` is the type to
// use for the output. The `tie_low` argument drives the index selection is case
// of same values. If `true` (default behavior) the lowest index will be
// returned, otherwise the higher. The tie_low argument only applies if `stable`
-// is true or using the ArgMinTwoPass.
+// is true.
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis,
bool stable = false, bool tie_low = true);
-XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis,
- bool tie_low = true);
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc
index 842b063..f9f2321 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc
@@ -57,19 +57,11 @@
input, expected_output, [=](XlaOp op, PrimitiveType type) {
return ArgMin(op, type, axis, /*stable=*/true, tie_low);
});
- TestArgMinMaxImpl(input, expected_output,
- [=](XlaOp op, PrimitiveType type) {
- return ArgMinTwoPass(op, type, axis, tie_low);
- });
} else {
TestArgMinMaxImpl(
input, expected_output, [=](XlaOp op, PrimitiveType type) {
return ArgMax(op, type, axis, /*stable=*/true, tie_low);
});
- TestArgMinMaxImpl(input, expected_output,
- [=](XlaOp op, PrimitiveType type) {
- return ArgMaxTwoPass(op, type, axis, tie_low);
- });
}
}
diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py
index 39021b1..7394e0b 100644
--- a/tensorflow/python/eager/def_function_xla_jit_test.py
+++ b/tensorflow/python/eager/def_function_xla_jit_test.py
@@ -1168,6 +1168,20 @@
'Trying to access resource .*'):
my_func_temp()
+ def testSinglePassArgmax(self):
+ with ops.device('device:{}:0'.format(self.device)):
+
+ @def_function.function(jit_compile=True)
+ def f(x):
+ return math_ops.argmax(x)
+
+ hlo = f.experimental_get_compiler_ir(
+ array_ops.ones([10], dtype=dtypes.float32))(
+ stage='hlo')
+
+ # Test that reduction occurs only once.
+ self.assertTrue(hlo.count('reduce'), 1)
+
if __name__ == '__main__':
ops.enable_eager_execution()