Change tfxla.variadic_reduce to point to XlaVariadicReduceV2.
PiperOrigin-RevId: 385492367
Change-Id: I668b618e7b18abed199b445641b469069f35de2f
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 946d201..fb832e1 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -367,9 +367,8 @@
args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
expected=np.array([0, 45, 120, 231], dtype=dtype))
- @parameterized.parameters(False, True)
@test_util.disable_mlir_bridge('Not supported yet')
- def testVariadicReduceKahanSum(self, use_v2):
+ def testVariadicReduceKahanSum(self):
for dtype in set(self.numeric_types).intersection(
set([np.float32, np.complex64])):
@@ -389,17 +388,10 @@
reducer = kahan_sum_reducer.get_concrete_function(
(arg, arg), (arg, arg))
- if use_v2:
- return xla.variadic_reduce_v2((x, array_ops.zeros_like(x)),
- init_values=(arg, arg),
- dimensions_to_reduce=dims,
- reducer=reducer)[output_idx]
- else:
- return xla.variadic_reduce((x, array_ops.zeros_like(x)),
- init_value=(arg, arg),
- dimensions_to_reduce=dims,
- reducer=reducer)[output_idx]
-
+ return xla.variadic_reduce((x, array_ops.zeros_like(x)),
+ init_values=(arg, arg),
+ dimensions_to_reduce=dims,
+ reducer=reducer)[output_idx]
return fn
xs = np.array([1e5, np.pi, -1e5, np.exp(1.)])
@@ -459,9 +451,9 @@
reducer_func = reducer_add.get_concrete_function(arg_spec, arg_spec)
def reduce(values, *, dimensions_to_reduce):
- return xla.variadic_reduce_v2((values,), (init_val,), # pylint: disable=cell-var-from-loop
- dimensions_to_reduce=dimensions_to_reduce,
- reducer=reducer_func)[0] # pylint: disable=cell-var-from-loop
+ return xla.variadic_reduce((values,), (init_val,), # pylint: disable=cell-var-from-loop
+ dimensions_to_reduce=dimensions_to_reduce,
+ reducer=reducer_func)[0] # pylint: disable=cell-var-from-loop
# Reduce dimension 0
self._assertOpOutputMatchesExpected(
functools.partial(reduce, dimensions_to_reduce=(0,)),
@@ -500,9 +492,9 @@
arg_spec_1, arg_spec_2) # pylint: disable=cell-var-from-loop
def reduce(*values, dimensions_to_reduce):
- return xla.variadic_reduce_v2(values, (init_val_1, init_val_2,), # pylint: disable=cell-var-from-loop
- dimensions_to_reduce=dimensions_to_reduce,
- reducer=reducer_func) # pylint: disable=cell-var-from-loop
+ return xla.variadic_reduce(values, (init_val_1, init_val_2,), # pylint: disable=cell-var-from-loop
+ dimensions_to_reduce=dimensions_to_reduce,
+ reducer=reducer_func) # pylint: disable=cell-var-from-loop
# Reduce dimension 0
self._assertOpOutputMatchesExpected(
@@ -898,7 +890,7 @@
arg_spec = array_ops.zeros([], dtype) # pylint: disable=cell-var-from-loop
reducer_func = reducer_add.get_concrete_function(arg_spec, arg_spec)
- res = xla.variadic_reduce_v2(
+ res = xla.variadic_reduce(
(array_ops.placeholder(np.float32, shape=(3, 4, 5)),),
(array_ops.placeholder(np.float32, shape=()),),
dimensions_to_reduce=(1,),
@@ -930,7 +922,7 @@
array_ops.placeholder(np.int32, shape=()),
array_ops.placeholder(np.int32, shape=()))
- return xla.variadic_reduce_v2(
+ return xla.variadic_reduce(
inputs,
init_values,
dimensions_to_reduce=dimensions_to_reduce,
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index db52308..1c13640 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -379,8 +379,7 @@
recv = gen_xla_ops.xla_recv
reduce = gen_xla_ops.xla_reduce
-variadic_reduce = gen_xla_ops.xla_variadic_reduce
-variadic_reduce_v2 = gen_xla_ops.xla_variadic_reduce_v2
+variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2
ops.no_gradient("XlaVariadicReduce")