Implement support for keepdims=True on ragged reduce operations.

PiperOrigin-RevId: 335365296
Change-Id: I210e4ecc54cda1954828a59dd725122fd2277fad
diff --git a/tensorflow/python/ops/ragged/ragged_math_ops.py b/tensorflow/python/ops/ragged/ragged_math_ops.py
index 60d608b..73a5358 100644
--- a/tensorflow/python/ops/ragged/ragged_math_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_math_ops.py
@@ -40,12 +40,8 @@
 # pylint: disable=redefined-builtin
 @tf_export('ragged.range')
 @dispatch.add_dispatch_support
-def range(starts,
-          limits=None,
-          deltas=1,
-          dtype=None,
-          name=None,
-          row_splits_dtype=dtypes.int64):
+def range(starts, limits=None, deltas=1, dtype=None,
+          name=None, row_splits_dtype=dtypes.int64):
   """Returns a `RaggedTensor` containing the specified sequences of numbers.
 
   Each row of the returned `RaggedTensor` contains a single sequence:
@@ -108,8 +104,9 @@
 
     result = gen_ragged_math_ops.ragged_range(
         starts, limits, deltas, Tsplits=row_splits_dtype, name=name)
-    return ragged_tensor.RaggedTensor.from_row_splits(
-        result.rt_dense_values, result.rt_nested_splits, validate=False)
+    return ragged_tensor.RaggedTensor.from_row_splits(result.rt_dense_values,
+                                                      result.rt_nested_splits,
+                                                      validate=False)
 
 
 def _infer_matching_dtype(tensors, dtype_hierarchy):
@@ -121,6 +118,7 @@
 
 ops.no_gradient('RaggedRange')
 
+
 #===============================================================================
 # ragged_segment_<AGGREGATE>
 #===============================================================================
@@ -183,8 +181,8 @@
       `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
       `segment_ids` is not required to be sorted.
     num_segments: An `int32` or `int64` scalar.
-    separator: An optional string. Defaults to None. The separator to use when
-      joining. Only used for string types.
+    separator: An optional string. Defaults to None. The separator to
+      use when joining. Only used for string types.
     name: A name prefix for the returned tensor (optional).
 
   Returns:
@@ -263,42 +261,38 @@
 
 def segment_sum(data, segment_ids, num_segments, name=None):
   # For docs, see: _RAGGED_SEGMENT_DOCSTRING
-  return _ragged_segment_aggregate(
-      math_ops.unsorted_segment_sum,
-      data=data,
-      segment_ids=segment_ids,
-      num_segments=num_segments,
-      name=(name or 'RaggedSegmentSum'))
+  return _ragged_segment_aggregate(math_ops.unsorted_segment_sum,
+                                   data=data,
+                                   segment_ids=segment_ids,
+                                   num_segments=num_segments,
+                                   name=(name or'RaggedSegmentSum'))
 
 
 def segment_prod(data, segment_ids, num_segments, name=None):
   # For docs, see: _RAGGED_SEGMENT_DOCSTRING
-  return _ragged_segment_aggregate(
-      math_ops.unsorted_segment_prod,
-      data=data,
-      segment_ids=segment_ids,
-      num_segments=num_segments,
-      name=(name or 'RaggedSegmentProd'))
+  return _ragged_segment_aggregate(math_ops.unsorted_segment_prod,
+                                   data=data,
+                                   segment_ids=segment_ids,
+                                   num_segments=num_segments,
+                                   name=(name or 'RaggedSegmentProd'))
 
 
 def segment_min(data, segment_ids, num_segments, name=None):
   # For docs, see: _RAGGED_SEGMENT_DOCSTRING
-  return _ragged_segment_aggregate(
-      math_ops.unsorted_segment_min,
-      data=data,
-      segment_ids=segment_ids,
-      num_segments=num_segments,
-      name=(name or 'RaggedSegmentMin'))
+  return _ragged_segment_aggregate(math_ops.unsorted_segment_min,
+                                   data=data,
+                                   segment_ids=segment_ids,
+                                   num_segments=num_segments,
+                                   name=(name or 'RaggedSegmentMin'))
 
 
 def segment_max(data, segment_ids, num_segments, name=None):
   # For docs, see: _RAGGED_SEGMENT_DOCSTRING
-  return _ragged_segment_aggregate(
-      math_ops.unsorted_segment_max,
-      data=data,
-      segment_ids=segment_ids,
-      num_segments=num_segments,
-      name=(name or 'RaggedSegmentMax'))
+  return _ragged_segment_aggregate(math_ops.unsorted_segment_max,
+                                   data=data,
+                                   segment_ids=segment_ids,
+                                   num_segments=num_segments,
+                                   name=(name or 'RaggedSegmentMax'))
 
 
 def segment_mean(data, segment_ids, num_segments, name=None):
@@ -307,8 +301,7 @@
                       [data, segment_ids, num_segments]):
     total = segment_sum(data, segment_ids, num_segments)
     ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
-        array_ops.ones_like(data.flat_values),
-        data.nested_row_splits,
+        array_ops.ones_like(data.flat_values), data.nested_row_splits,
         validate=False)
     count = segment_sum(ones, segment_ids, num_segments)
     if ragged_tensor.is_ragged(total):
@@ -323,13 +316,12 @@
                       [data, segment_ids, num_segments]):
     total = segment_sum(data, segment_ids, num_segments)
     ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
-        array_ops.ones_like(data.flat_values),
-        data.nested_row_splits,
+        array_ops.ones_like(data.flat_values), data.nested_row_splits,
         validate=False)
     count = segment_sum(ones, segment_ids, num_segments)
     if ragged_tensor.is_ragged(total):
-      return total.with_flat_values(total.flat_values /
-                                    math_ops.sqrt(count.flat_values))
+      return total.with_flat_values(
+          total.flat_values / math_ops.sqrt(count.flat_values))
     else:
       return total / math_ops.sqrt(count)
 
@@ -463,8 +455,8 @@
       range `[0, rt_input.rank)`.
     keepdims: If true, retains reduced dimensions with length 1.
     separator: An optional string. Defaults to None. The separator to use when
-      joining. The separator must not be set for non-string data types. (i.e. if
-      separator is not None then it uses string ops)
+      joining. The separator must not be set for non-string data types. (i.e.
+      if separator is not None then it uses string ops)
     name: A name prefix for the returned tensor (optional).
 
   Returns:
@@ -478,12 +470,14 @@
   """
   if not ragged_tensor.is_ragged(rt_input):
     if separator is None:
-      return reduce_op(rt_input, axis, keepdims=keepdims, name=name)
+      return reduce_op(rt_input, axis, name=name)
     else:
       # When separator is not None, We infer that dtype is string and
       # reduce_join will be called.
-      return reduce_op(
-          rt_input, axis, keepdims=keepdims, name=name, separator=separator)
+      return reduce_op(rt_input, axis, name=name, separator=separator)
+
+  if keepdims:
+    raise ValueError('keepdims=True is not supported for RaggedTensors.')
 
   if isinstance(axis, ops.Tensor):
     axis = tensor_util.constant_value(axis)
@@ -494,12 +488,7 @@
 
   # When reducing all axes, just ignore splits & reduce the inner values.
   if axis is None:
-    result = reduce_op(rt_input.flat_values, None, keepdims=keepdims, name=name)
-    if keepdims:
-      # Expand the result to the input number of dimensions.
-      for _ in rt_input.shape[1:]:
-        result = array_ops.expand_dims(result, axis=0)
-    return result
+    return reduce_op(rt_input.flat_values, None, name=name)
 
   with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]):
     if isinstance(axis, (tuple, list)):
@@ -540,21 +529,15 @@
       row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1]
       num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0)
       segment_ids = range(row_lengths).values
-      result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
-                                         segment_ids, num_segments, separator)
-      if keepdims:
-        result = array_ops.expand_dims(result, axis=0)
-      return result
+      return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
+                                       segment_ids, num_segments, separator)
     elif axis == 1:
       # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N]
       num_segments = array_ops.shape(rt_input.row_splits)[0] - 1
       segment_ids = segment_id_ops.row_splits_to_segment_ids(
           rt_input.row_splits)
-      result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
-                                         segment_ids, num_segments, separator)
-      if keepdims:
-        result = array_ops.expand_dims(result, axis=1)
-      return result
+      return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
+                                       segment_ids, num_segments, separator)
     else:
       # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] =
       #     sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
@@ -571,8 +554,7 @@
       reduce_op=math_ops.reduce_sum,
       unsorted_segment_op=math_ops.unsorted_segment_sum,
       rt_input=input_tensor,
-      axis=axis,
-      keepdims=keepdims,
+      axis=axis, keepdims=keepdims,
       name=(name or 'RaggedReduceSum'))
 
 
@@ -616,15 +598,13 @@
     if ragged_tensor.is_ragged(input_tensor):
       ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
           array_ops.ones_like(input_tensor.flat_values),
-          input_tensor.nested_row_splits,
-          validate=False)
+          input_tensor.nested_row_splits, validate=False)
     else:
       ones = array_ops.ones_like(input_tensor)
     count = reduce_sum(ones, axis, keepdims)
     if ragged_tensor.is_ragged(total):
       return ragged_tensor.RaggedTensor.from_nested_row_splits(
-          total.flat_values / count.flat_values,
-          total.nested_row_splits,
+          total.flat_values / count.flat_values, total.nested_row_splits,
           validate=False)
     else:
       return total / count
diff --git a/tensorflow/python/ops/ragged/ragged_reduce_op_test.py b/tensorflow/python/ops/ragged/ragged_reduce_op_test.py
index 2afe086..a39090f 100644
--- a/tensorflow/python/ops/ragged/ragged_reduce_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_reduce_op_test.py
@@ -40,7 +40,8 @@
 
 
 @test_util.run_all_in_graph_and_eager_modes
-class RaggedReduceOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
+class RaggedReduceOpsTest(test_util.TensorFlowTestCase,
+                          parameterized.TestCase):
 
   @parameterized.parameters(
       #=========================================================================
@@ -50,212 +51,92 @@
       #    [9,     ],
       #    [2, 6   ]]
       #=========================================================================
-      # keepdims=True
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
           axis=0,
-          keepdims=False,
           expected=[15, 12, 4]  # = [3+1+9+2, 1+5+6, 4]
       ),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
           axis=-2,
-          keepdims=False,
           expected=[15, 12, 4]  # = [3+1+9+2, 1+5+6, 4]
       ),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
           axis=1,
-          keepdims=False,
           expected=[8, 6, 9, 8]  # = [3+1+4, 1+5, 9, 2+6]
       ),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
           axis=-1,
-          keepdims=False,
           expected=[8, 6, 9, 8]  # = [3+1+4, 1+5, 9, 2+6]
       ),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_prod,
           rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
           axis=0,
-          keepdims=False,
           expected=[54, 30, 4]  # = [3*1*9*2, 1*5*6, 4]
       ),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_prod,
           rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
           axis=1,
-          keepdims=False,
           expected=[12, 5, 9, 12]  # = [3*1*4, 1*5, 9, 2*6]
       ),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_min,
           rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
           axis=0,
-          keepdims=False,
           expected=[1, 1, 4]  # = [min(3, 1, 9, 2), min(1, 5, 6), 4]
       ),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_min,
           rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
           axis=1,
-          keepdims=False,
           expected=[1, 1, 9, 2]  # = [min(3, 1, 4), min(1, 5), 9, min(2, 6)]
       ),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_max,
           rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
           axis=0,
-          keepdims=False,
           expected=[9, 6, 4]  # = [max(3, 1, 9, 2), max(1, 5, 6), 4]
       ),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_max,
           rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
           axis=1,
-          keepdims=False,
           expected=[4, 5, 9, 6]  # = [max(3, 1, 4), max(1, 5), 9, max(2, 6)]
       ),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_mean,
           rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
           axis=0,
-          keepdims=False,
           expected=[3.75, 4, 4]  # = [mean(3, 1, 9, 2), mean(1, 5, 6), 4]
       ),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_any,
           rt_input=[[True, True], [True, True, False, True], [False, True]],
           axis=0,
-          keepdims=False,
           expected=[True, True, False, True]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_any,
           rt_input=[[True, True], [True, True, False, True], [False, True]],
           axis=1,
-          keepdims=False,
           expected=[True, True, True]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_all,
           rt_input=[[True, True], [True, True, False, True], [False, True]],
           axis=0,
-          keepdims=False,
           expected=[False, True, False, True]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_all,
           rt_input=[[True, True], [True, True, False, True], [False, True]],
           axis=1,
-          keepdims=False,
           expected=[True, False, False]),
-      # keepdims=True
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
-          axis=0,
-          keepdims=True,
-          expected=[[15, 12, 4]]  # = [[3+1+9+2, 1+5+6, 4]]
-      ),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
-          axis=-2,
-          keepdims=True,
-          expected=[[15, 12, 4]]  # = [[3+1+9+2, 1+5+6, 4]]
-      ),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
-          axis=1,
-          keepdims=True,
-          expected=[[8], [6], [9], [8]]  # = [[3+1+4], [1+5], [9], [2+6]]
-      ),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
-          axis=-1,
-          keepdims=True,
-          expected=[[8], [6], [9], [8]]  # = [[3+1+4], [1+5], [9], [2+6]]
-      ),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_prod,
-          rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
-          axis=0,
-          keepdims=True,
-          expected=[[54, 30, 4]]  # = [[3*1*9*2, 1*5*6, 4]]
-      ),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_prod,
-          rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
-          axis=1,
-          keepdims=True,
-          expected=[[12], [5], [9], [12]]  # = [[3*1*4], [1*5], [9], [2*6]]
-      ),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_min,
-          rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
-          axis=0,
-          keepdims=True,
-          expected=[[1, 1, 4]]  # = [[min(3, 1, 9, 2), min(1, 5, 6), 4]]
-      ),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_min,
-          rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
-          axis=1,
-          keepdims=True,
-          expected=[[1], [1], [9],
-                    [2]]  # = [[min(3, 1, 4)], [min(1, 5)], [9], [min(2, 6)]]
-      ),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_max,
-          rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
-          axis=0,
-          keepdims=True,
-          expected=[[9, 6, 4]]  # = [[max(3, 1, 9, 2), max(1, 5, 6), 4]]
-      ),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_max,
-          rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
-          axis=1,
-          keepdims=True,
-          expected=[[4], [5], [9],
-                    [6]]  # = [[max(3, 1, 4)], [max(1, 5)], [9], [max(2, 6)]]
-      ),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_mean,
-          rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
-          axis=0,
-          keepdims=True,
-          expected=[[3.75, 4, 4]]  # = [[mean(3, 1, 9, 2), mean(1, 5, 6), 4]]
-      ),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_any,
-          rt_input=[[True, True], [True, True, False, True], [False, True]],
-          axis=0,
-          keepdims=True,
-          expected=[[True, True, False, True]]),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_any,
-          rt_input=[[True, True], [True, True, False, True], [False, True]],
-          axis=1,
-          keepdims=True,
-          expected=[[True], [True], [True]]),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_all,
-          rt_input=[[True, True], [True, True, False, True], [False, True]],
-          axis=0,
-          keepdims=True,
-          expected=[[False, True, False, True]]),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_all,
-          rt_input=[[True, True], [True, True, False, True], [False, True]],
-          axis=1,
-          keepdims=True,
-          expected=[[True], [False], [False]]),
 
       #=========================================================================
       # Examples with the following RaggedTensor (ragged_rank=1):
@@ -272,62 +153,52 @@
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=None,
-          keepdims=False,
           expected=0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_prod,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=None,
-          keepdims=False,
           expected=0 * 1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_min,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=None,
-          keepdims=False,
           expected=min(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_max,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=None,
-          keepdims=False,
           expected=max(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_mean,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=None,
-          keepdims=False,
           expected=mean(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)),
       # axis=0
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=0,
-          keepdims=False,
           expected=[0 + 4 + 5 + 7 + 8, 1 + 6 + 9, 2, 3]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_prod,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=0,
-          keepdims=False,
           expected=[0 * 4 * 5 * 7 * 8, 1 * 6 * 9, 2, 3]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_min,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=0,
-          keepdims=False,
           expected=[min(0, 4, 5, 7, 8), min(1, 6, 9), 2, 3]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_max,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=0,
-          keepdims=False,
           expected=[max(0, 4, 5, 7, 8), max(1, 6, 9), 2, 3]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_mean,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=0,
-          keepdims=False,
           expected=[mean(0, 4, 5, 7, 8),
                     mean(1, 6, 9), 2, 3]),
       # axis=1
@@ -337,19 +208,16 @@
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=1,
-          keepdims=False,
           expected=[0 + 1 + 2 + 3, 4, 0, 5 + 6, 7, 8 + 9]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_prod,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=1,
-          keepdims=False,
           expected=[0 * 1 * 2 * 3, 4, 1, 5 * 6, 7, 8 * 9]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_min,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=1,
-          keepdims=False,
           expected=[min(0, 1, 2, 3), 4, _MAX_INT32,
                     min(5, 6), 7,
                     min(8, 9)]),
@@ -357,7 +225,6 @@
           ragged_reduce_op=ragged_math_ops.reduce_max,
           rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
           axis=1,
-          keepdims=False,
           expected=[max(0, 1, 2, 3), 4, _MIN_INT32,
                     max(5, 6), 7,
                     max(8, 9)]),
@@ -369,117 +236,51 @@
       #  [                      ],
       #  [[9   ]                ]]
       #=========================================================================
-      # keepdims=False
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
           axis=[],
-          keepdims=False,
           expected=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
           axis=None,
-          keepdims=False,
           expected=sum([1, 2, 3, 4, 5, 6, 7, 8, 9])),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
           axis=0,
-          keepdims=False,
           expected=[[1 + 6 + 9, 2 + 7], [], [3 + 8, 4, 5]]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
           axis=1,
-          keepdims=False,
           expected=[[1 + 3, 2 + 4, 5], [6 + 8, 7], [], [9]]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
           axis=2,
-          keepdims=False,
           expected=[[1 + 2, 0, 3 + 4 + 5], [6 + 7, 0, 8], [], [9]]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
           axis=[0, 1],
-          keepdims=False,
           expected=[1 + 3 + 6 + 8 + 9, 2 + 4 + 7, 5]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
           axis=[0, 2],
-          keepdims=False,
           expected=[1 + 6 + 9 + 2 + 7, 0, 3 + 8 + 4 + 5]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
           axis=[1, 2],
-          keepdims=False,
           expected=[1 + 2 + 3 + 4 + 5, 6 + 7 + 8, 0, 9]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
           axis=[0, 1, 2],
-          keepdims=False,
           expected=sum([1, 2, 3, 4, 5, 6, 7, 8, 9])),
-      # keepdims=True
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
-          axis=[],
-          keepdims=True,
-          expected=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]]),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
-          axis=None,
-          keepdims=True,
-          expected=[[[sum([1, 2, 3, 4, 5, 6, 7, 8, 9])]]]),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
-          axis=0,
-          keepdims=True,
-          expected=[[[1 + 6 + 9, 2 + 7], [], [3 + 8, 4, 5]]]),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
-          axis=1,
-          keepdims=True,
-          expected=[[[1 + 3, 2 + 4, 5]], [[6 + 8, 7]], [[]], [[9]]]),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
-          axis=2,
-          keepdims=True,
-          expected=[[[1 + 2], [0], [3 + 4 + 5]], [[6 + 7], [0], [8]], [],
-                    [[9]]]),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
-          axis=[0, 1],
-          keepdims=True,
-          expected=[[[1 + 3 + 6 + 8 + 9, 2 + 4 + 7, 5]]]),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
-          axis=[0, 2],
-          keepdims=True,
-          expected=[[[1 + 6 + 9 + 2 + 7], [0], [3 + 8 + 4 + 5]]]),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
-          axis=[1, 2],
-          keepdims=True,
-          expected=[[[1 + 2 + 3 + 4 + 5]], [[6 + 7 + 8]], [[0]], [[9]]]),
-      dict(
-          ragged_reduce_op=ragged_math_ops.reduce_sum,
-          rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
-          axis=[0, 1, 2],
-          keepdims=True,
-          expected=[[[sum([1, 2, 3, 4, 5, 6, 7, 8, 9])]]]),
 
       #=========================================================================
       # Examples for ragged_reduce_mean ragged_rank=2:
@@ -491,19 +292,16 @@
           ragged_reduce_op=ragged_math_ops.reduce_mean,
           rt_input=[[[1, 2], [3, 4, 5]], [[6, 7], [8]], [[9]]],
           axis=0,
-          keepdims=False,
           expected=[[mean(1, 6, 9), mean(2, 7)], [mean(3, 8), 4, 5]]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_mean,
           rt_input=[[[1, 2], [3, 4, 5]], [[6, 7], [8]], [[9]]],
           axis=1,
-          keepdims=False,
           expected=[[mean(1, 3), mean(2, 4), 5], [mean(6, 8), 7], [9]]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_mean,
           rt_input=[[[1, 2], [3, 4, 5]], [[6, 7], [8]], [[9]]],
           axis=2,
-          keepdims=False,
           expected=[[mean(1, 2), mean(3, 4, 5)], [mean(6, 7), 8], [9]]),
 
       # Test case for GitHub issue 27497, multiple negative axes.
@@ -511,18 +309,16 @@
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
           axis=[-2, -1],
-          keepdims=False,
           expected=[1 + 2 + 3 + 4 + 5, 6 + 7 + 8, 0, 9]),
       dict(
           ragged_reduce_op=ragged_math_ops.reduce_sum,
           rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
           axis=[-3, -2, -1],
-          keepdims=False,
           expected=sum([1, 2, 3, 4, 5, 6, 7, 8, 9])),
   )
-  def testReduce(self, ragged_reduce_op, rt_input, axis, keepdims, expected):
+  def testReduce(self, ragged_reduce_op, rt_input, axis, expected):
     rt_input = ragged_factory_ops.constant(rt_input)
-    reduced = ragged_reduce_op(rt_input, axis, keepdims=keepdims)
+    reduced = ragged_reduce_op(rt_input, axis)
     self.assertAllEqual(reduced, expected)
 
   def testReduceKeepsInnerDimensionShape(self):
@@ -540,8 +336,8 @@
   def testMeanNan(self):
     rt_as_list = [[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]]
     expected = (
-        np.array([0 + 1 + 2 + 3, 4, 0, 5 + 6, 7, 8 + 9]) /
-        np.array([4, 1, 0, 2, 1, 2]))
+        np.array([0 + 1 + 2 + 3, 4, 0, 5 + 6, 7, 8 + 9]) / np.array(
+            [4, 1, 0, 2, 1, 2]))
     rt_input = ragged_factory_ops.constant(rt_as_list)
     reduced = ragged_math_ops.reduce_mean(rt_input, axis=1)
     self.assertEqualWithNan(self.evaluate(reduced), expected)