Change tfxla.variadic_reduce to point to XlaVariadicReduceV2.

PiperOrigin-RevId: 385492367
Change-Id: I668b618e7b18abed199b445641b469069f35de2f
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 946d201..fb832e1 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -367,9 +367,8 @@
           args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
           expected=np.array([0, 45, 120, 231], dtype=dtype))
 
-  @parameterized.parameters(False, True)
   @test_util.disable_mlir_bridge('Not supported yet')
-  def testVariadicReduceKahanSum(self, use_v2):
+  def testVariadicReduceKahanSum(self):
     for dtype in set(self.numeric_types).intersection(
         set([np.float32, np.complex64])):
 
@@ -389,17 +388,10 @@
           reducer = kahan_sum_reducer.get_concrete_function(
               (arg, arg), (arg, arg))
 
-          if use_v2:
-            return xla.variadic_reduce_v2((x, array_ops.zeros_like(x)),
-                                          init_values=(arg, arg),
-                                          dimensions_to_reduce=dims,
-                                          reducer=reducer)[output_idx]
-          else:
-            return xla.variadic_reduce((x, array_ops.zeros_like(x)),
-                                       init_value=(arg, arg),
-                                       dimensions_to_reduce=dims,
-                                       reducer=reducer)[output_idx]
-
+          return xla.variadic_reduce((x, array_ops.zeros_like(x)),
+                                     init_values=(arg, arg),
+                                     dimensions_to_reduce=dims,
+                                     reducer=reducer)[output_idx]
         return fn
 
       xs = np.array([1e5, np.pi, -1e5, np.exp(1.)])
@@ -459,9 +451,9 @@
       reducer_func = reducer_add.get_concrete_function(arg_spec, arg_spec)
 
       def reduce(values, *, dimensions_to_reduce):
-        return xla.variadic_reduce_v2((values,), (init_val,),  # pylint: disable=cell-var-from-loop
-                                      dimensions_to_reduce=dimensions_to_reduce,
-                                      reducer=reducer_func)[0]  # pylint: disable=cell-var-from-loop
+        return xla.variadic_reduce((values,), (init_val,),  # pylint: disable=cell-var-from-loop
+                                   dimensions_to_reduce=dimensions_to_reduce,
+                                   reducer=reducer_func)[0]  # pylint: disable=cell-var-from-loop
       # Reduce dimension 0
       self._assertOpOutputMatchesExpected(
           functools.partial(reduce, dimensions_to_reduce=(0,)),
@@ -500,9 +492,9 @@
                                                        arg_spec_1, arg_spec_2)  # pylint: disable=cell-var-from-loop
 
       def reduce(*values, dimensions_to_reduce):
-        return xla.variadic_reduce_v2(values, (init_val_1, init_val_2,),  # pylint: disable=cell-var-from-loop
-                                      dimensions_to_reduce=dimensions_to_reduce,
-                                      reducer=reducer_func)  # pylint: disable=cell-var-from-loop
+        return xla.variadic_reduce(values, (init_val_1, init_val_2,),  # pylint: disable=cell-var-from-loop
+                                   dimensions_to_reduce=dimensions_to_reduce,
+                                   reducer=reducer_func)  # pylint: disable=cell-var-from-loop
 
       # Reduce dimension 0
       self._assertOpOutputMatchesExpected(
@@ -898,7 +890,7 @@
     arg_spec = array_ops.zeros([], dtype)  # pylint: disable=cell-var-from-loop
     reducer_func = reducer_add.get_concrete_function(arg_spec, arg_spec)
 
-    res = xla.variadic_reduce_v2(
+    res = xla.variadic_reduce(
         (array_ops.placeholder(np.float32, shape=(3, 4, 5)),),
         (array_ops.placeholder(np.float32, shape=()),),
         dimensions_to_reduce=(1,),
@@ -930,7 +922,7 @@
                      array_ops.placeholder(np.int32, shape=()),
                      array_ops.placeholder(np.int32, shape=()))
 
-      return xla.variadic_reduce_v2(
+      return xla.variadic_reduce(
           inputs,
           init_values,
           dimensions_to_reduce=dimensions_to_reduce,
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index db52308..1c13640 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -379,8 +379,7 @@
 
 recv = gen_xla_ops.xla_recv
 reduce = gen_xla_ops.xla_reduce
-variadic_reduce = gen_xla_ops.xla_variadic_reduce
-variadic_reduce_v2 = gen_xla_ops.xla_variadic_reduce_v2
+variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2
 
 ops.no_gradient("XlaVariadicReduce")