Implement support for keepdims=True on ragged reduce operations.
PiperOrigin-RevId: 335365296
Change-Id: I210e4ecc54cda1954828a59dd725122fd2277fad
diff --git a/tensorflow/python/ops/ragged/ragged_math_ops.py b/tensorflow/python/ops/ragged/ragged_math_ops.py
index 60d608b..73a5358 100644
--- a/tensorflow/python/ops/ragged/ragged_math_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_math_ops.py
@@ -40,12 +40,8 @@
# pylint: disable=redefined-builtin
@tf_export('ragged.range')
@dispatch.add_dispatch_support
-def range(starts,
- limits=None,
- deltas=1,
- dtype=None,
- name=None,
- row_splits_dtype=dtypes.int64):
+def range(starts, limits=None, deltas=1, dtype=None,
+ name=None, row_splits_dtype=dtypes.int64):
"""Returns a `RaggedTensor` containing the specified sequences of numbers.
Each row of the returned `RaggedTensor` contains a single sequence:
@@ -108,8 +104,9 @@
result = gen_ragged_math_ops.ragged_range(
starts, limits, deltas, Tsplits=row_splits_dtype, name=name)
- return ragged_tensor.RaggedTensor.from_row_splits(
- result.rt_dense_values, result.rt_nested_splits, validate=False)
+ return ragged_tensor.RaggedTensor.from_row_splits(result.rt_dense_values,
+ result.rt_nested_splits,
+ validate=False)
def _infer_matching_dtype(tensors, dtype_hierarchy):
@@ -121,6 +118,7 @@
ops.no_gradient('RaggedRange')
+
#===============================================================================
# ragged_segment_<AGGREGATE>
#===============================================================================
@@ -183,8 +181,8 @@
`int32`. `segment_ids.shape` must be a prefix of `data.shape`.
`segment_ids` is not required to be sorted.
num_segments: An `int32` or `int64` scalar.
- separator: An optional string. Defaults to None. The separator to use when
- joining. Only used for string types.
+ separator: An optional string. Defaults to None. The separator to
+ use when joining. Only used for string types.
name: A name prefix for the returned tensor (optional).
Returns:
@@ -263,42 +261,38 @@
def segment_sum(data, segment_ids, num_segments, name=None):
# For docs, see: _RAGGED_SEGMENT_DOCSTRING
- return _ragged_segment_aggregate(
- math_ops.unsorted_segment_sum,
- data=data,
- segment_ids=segment_ids,
- num_segments=num_segments,
- name=(name or 'RaggedSegmentSum'))
+ return _ragged_segment_aggregate(math_ops.unsorted_segment_sum,
+ data=data,
+ segment_ids=segment_ids,
+ num_segments=num_segments,
+ name=(name or'RaggedSegmentSum'))
def segment_prod(data, segment_ids, num_segments, name=None):
# For docs, see: _RAGGED_SEGMENT_DOCSTRING
- return _ragged_segment_aggregate(
- math_ops.unsorted_segment_prod,
- data=data,
- segment_ids=segment_ids,
- num_segments=num_segments,
- name=(name or 'RaggedSegmentProd'))
+ return _ragged_segment_aggregate(math_ops.unsorted_segment_prod,
+ data=data,
+ segment_ids=segment_ids,
+ num_segments=num_segments,
+ name=(name or 'RaggedSegmentProd'))
def segment_min(data, segment_ids, num_segments, name=None):
# For docs, see: _RAGGED_SEGMENT_DOCSTRING
- return _ragged_segment_aggregate(
- math_ops.unsorted_segment_min,
- data=data,
- segment_ids=segment_ids,
- num_segments=num_segments,
- name=(name or 'RaggedSegmentMin'))
+ return _ragged_segment_aggregate(math_ops.unsorted_segment_min,
+ data=data,
+ segment_ids=segment_ids,
+ num_segments=num_segments,
+ name=(name or 'RaggedSegmentMin'))
def segment_max(data, segment_ids, num_segments, name=None):
# For docs, see: _RAGGED_SEGMENT_DOCSTRING
- return _ragged_segment_aggregate(
- math_ops.unsorted_segment_max,
- data=data,
- segment_ids=segment_ids,
- num_segments=num_segments,
- name=(name or 'RaggedSegmentMax'))
+ return _ragged_segment_aggregate(math_ops.unsorted_segment_max,
+ data=data,
+ segment_ids=segment_ids,
+ num_segments=num_segments,
+ name=(name or 'RaggedSegmentMax'))
def segment_mean(data, segment_ids, num_segments, name=None):
@@ -307,8 +301,7 @@
[data, segment_ids, num_segments]):
total = segment_sum(data, segment_ids, num_segments)
ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
- array_ops.ones_like(data.flat_values),
- data.nested_row_splits,
+ array_ops.ones_like(data.flat_values), data.nested_row_splits,
validate=False)
count = segment_sum(ones, segment_ids, num_segments)
if ragged_tensor.is_ragged(total):
@@ -323,13 +316,12 @@
[data, segment_ids, num_segments]):
total = segment_sum(data, segment_ids, num_segments)
ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
- array_ops.ones_like(data.flat_values),
- data.nested_row_splits,
+ array_ops.ones_like(data.flat_values), data.nested_row_splits,
validate=False)
count = segment_sum(ones, segment_ids, num_segments)
if ragged_tensor.is_ragged(total):
- return total.with_flat_values(total.flat_values /
- math_ops.sqrt(count.flat_values))
+ return total.with_flat_values(
+ total.flat_values / math_ops.sqrt(count.flat_values))
else:
return total / math_ops.sqrt(count)
@@ -463,8 +455,8 @@
range `[0, rt_input.rank)`.
keepdims: If true, retains reduced dimensions with length 1.
separator: An optional string. Defaults to None. The separator to use when
- joining. The separator must not be set for non-string data types. (i.e. if
- separator is not None then it uses string ops)
+ joining. The separator must not be set for non-string data types. (i.e.
+ if separator is not None then it uses string ops)
name: A name prefix for the returned tensor (optional).
Returns:
@@ -478,12 +470,14 @@
"""
if not ragged_tensor.is_ragged(rt_input):
if separator is None:
- return reduce_op(rt_input, axis, keepdims=keepdims, name=name)
+ return reduce_op(rt_input, axis, name=name)
else:
# When separator is not None, We infer that dtype is string and
# reduce_join will be called.
- return reduce_op(
- rt_input, axis, keepdims=keepdims, name=name, separator=separator)
+ return reduce_op(rt_input, axis, name=name, separator=separator)
+
+ if keepdims:
+ raise ValueError('keepdims=True is not supported for RaggedTensors.')
if isinstance(axis, ops.Tensor):
axis = tensor_util.constant_value(axis)
@@ -494,12 +488,7 @@
# When reducing all axes, just ignore splits & reduce the inner values.
if axis is None:
- result = reduce_op(rt_input.flat_values, None, keepdims=keepdims, name=name)
- if keepdims:
- # Expand the result to the input number of dimensions.
- for _ in rt_input.shape[1:]:
- result = array_ops.expand_dims(result, axis=0)
- return result
+ return reduce_op(rt_input.flat_values, None, name=name)
with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]):
if isinstance(axis, (tuple, list)):
@@ -540,21 +529,15 @@
row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1]
num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0)
segment_ids = range(row_lengths).values
- result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
- segment_ids, num_segments, separator)
- if keepdims:
- result = array_ops.expand_dims(result, axis=0)
- return result
+ return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
+ segment_ids, num_segments, separator)
elif axis == 1:
# out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N]
num_segments = array_ops.shape(rt_input.row_splits)[0] - 1
segment_ids = segment_id_ops.row_splits_to_segment_ids(
rt_input.row_splits)
- result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
- segment_ids, num_segments, separator)
- if keepdims:
- result = array_ops.expand_dims(result, axis=1)
- return result
+ return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
+ segment_ids, num_segments, separator)
else:
# out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] =
# sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
@@ -571,8 +554,7 @@
reduce_op=math_ops.reduce_sum,
unsorted_segment_op=math_ops.unsorted_segment_sum,
rt_input=input_tensor,
- axis=axis,
- keepdims=keepdims,
+ axis=axis, keepdims=keepdims,
name=(name or 'RaggedReduceSum'))
@@ -616,15 +598,13 @@
if ragged_tensor.is_ragged(input_tensor):
ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
array_ops.ones_like(input_tensor.flat_values),
- input_tensor.nested_row_splits,
- validate=False)
+ input_tensor.nested_row_splits, validate=False)
else:
ones = array_ops.ones_like(input_tensor)
count = reduce_sum(ones, axis, keepdims)
if ragged_tensor.is_ragged(total):
return ragged_tensor.RaggedTensor.from_nested_row_splits(
- total.flat_values / count.flat_values,
- total.nested_row_splits,
+ total.flat_values / count.flat_values, total.nested_row_splits,
validate=False)
else:
return total / count
diff --git a/tensorflow/python/ops/ragged/ragged_reduce_op_test.py b/tensorflow/python/ops/ragged/ragged_reduce_op_test.py
index 2afe086..a39090f 100644
--- a/tensorflow/python/ops/ragged/ragged_reduce_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_reduce_op_test.py
@@ -40,7 +40,8 @@
@test_util.run_all_in_graph_and_eager_modes
-class RaggedReduceOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
+class RaggedReduceOpsTest(test_util.TensorFlowTestCase,
+ parameterized.TestCase):
@parameterized.parameters(
#=========================================================================
@@ -50,212 +51,92 @@
# [9, ],
# [2, 6 ]]
#=========================================================================
- # keepdims=True
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
axis=0,
- keepdims=False,
expected=[15, 12, 4] # = [3+1+9+2, 1+5+6, 4]
),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
axis=-2,
- keepdims=False,
expected=[15, 12, 4] # = [3+1+9+2, 1+5+6, 4]
),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
axis=1,
- keepdims=False,
expected=[8, 6, 9, 8] # = [3+1+4, 1+5, 9, 2+6]
),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
axis=-1,
- keepdims=False,
expected=[8, 6, 9, 8] # = [3+1+4, 1+5, 9, 2+6]
),
dict(
ragged_reduce_op=ragged_math_ops.reduce_prod,
rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
axis=0,
- keepdims=False,
expected=[54, 30, 4] # = [3*1*9*2, 1*5*6, 4]
),
dict(
ragged_reduce_op=ragged_math_ops.reduce_prod,
rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
axis=1,
- keepdims=False,
expected=[12, 5, 9, 12] # = [3*1*4, 1*5, 9, 2*6]
),
dict(
ragged_reduce_op=ragged_math_ops.reduce_min,
rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
axis=0,
- keepdims=False,
expected=[1, 1, 4] # = [min(3, 1, 9, 2), min(1, 5, 6), 4]
),
dict(
ragged_reduce_op=ragged_math_ops.reduce_min,
rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
axis=1,
- keepdims=False,
expected=[1, 1, 9, 2] # = [min(3, 1, 4), min(1, 5), 9, min(2, 6)]
),
dict(
ragged_reduce_op=ragged_math_ops.reduce_max,
rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
axis=0,
- keepdims=False,
expected=[9, 6, 4] # = [max(3, 1, 9, 2), max(1, 5, 6), 4]
),
dict(
ragged_reduce_op=ragged_math_ops.reduce_max,
rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
axis=1,
- keepdims=False,
expected=[4, 5, 9, 6] # = [max(3, 1, 4), max(1, 5), 9, max(2, 6)]
),
dict(
ragged_reduce_op=ragged_math_ops.reduce_mean,
rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
axis=0,
- keepdims=False,
expected=[3.75, 4, 4] # = [mean(3, 1, 9, 2), mean(1, 5, 6), 4]
),
dict(
ragged_reduce_op=ragged_math_ops.reduce_any,
rt_input=[[True, True], [True, True, False, True], [False, True]],
axis=0,
- keepdims=False,
expected=[True, True, False, True]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_any,
rt_input=[[True, True], [True, True, False, True], [False, True]],
axis=1,
- keepdims=False,
expected=[True, True, True]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_all,
rt_input=[[True, True], [True, True, False, True], [False, True]],
axis=0,
- keepdims=False,
expected=[False, True, False, True]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_all,
rt_input=[[True, True], [True, True, False, True], [False, True]],
axis=1,
- keepdims=False,
expected=[True, False, False]),
- # keepdims=True
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
- axis=0,
- keepdims=True,
- expected=[[15, 12, 4]] # = [[3+1+9+2, 1+5+6, 4]]
- ),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
- axis=-2,
- keepdims=True,
- expected=[[15, 12, 4]] # = [[3+1+9+2, 1+5+6, 4]]
- ),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
- axis=1,
- keepdims=True,
- expected=[[8], [6], [9], [8]] # = [[3+1+4], [1+5], [9], [2+6]]
- ),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
- axis=-1,
- keepdims=True,
- expected=[[8], [6], [9], [8]] # = [[3+1+4], [1+5], [9], [2+6]]
- ),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_prod,
- rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
- axis=0,
- keepdims=True,
- expected=[[54, 30, 4]] # = [[3*1*9*2, 1*5*6, 4]]
- ),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_prod,
- rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
- axis=1,
- keepdims=True,
- expected=[[12], [5], [9], [12]] # = [[3*1*4], [1*5], [9], [2*6]]
- ),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_min,
- rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
- axis=0,
- keepdims=True,
- expected=[[1, 1, 4]] # = [[min(3, 1, 9, 2), min(1, 5, 6), 4]]
- ),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_min,
- rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
- axis=1,
- keepdims=True,
- expected=[[1], [1], [9],
- [2]] # = [[min(3, 1, 4)], [min(1, 5)], [9], [min(2, 6)]]
- ),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_max,
- rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
- axis=0,
- keepdims=True,
- expected=[[9, 6, 4]] # = [[max(3, 1, 9, 2), max(1, 5, 6), 4]]
- ),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_max,
- rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
- axis=1,
- keepdims=True,
- expected=[[4], [5], [9],
- [6]] # = [[max(3, 1, 4)], [max(1, 5)], [9], [max(2, 6)]]
- ),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_mean,
- rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]],
- axis=0,
- keepdims=True,
- expected=[[3.75, 4, 4]] # = [[mean(3, 1, 9, 2), mean(1, 5, 6), 4]]
- ),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_any,
- rt_input=[[True, True], [True, True, False, True], [False, True]],
- axis=0,
- keepdims=True,
- expected=[[True, True, False, True]]),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_any,
- rt_input=[[True, True], [True, True, False, True], [False, True]],
- axis=1,
- keepdims=True,
- expected=[[True], [True], [True]]),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_all,
- rt_input=[[True, True], [True, True, False, True], [False, True]],
- axis=0,
- keepdims=True,
- expected=[[False, True, False, True]]),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_all,
- rt_input=[[True, True], [True, True, False, True], [False, True]],
- axis=1,
- keepdims=True,
- expected=[[True], [False], [False]]),
#=========================================================================
# Examples with the following RaggedTensor (ragged_rank=1):
@@ -272,62 +153,52 @@
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=None,
- keepdims=False,
expected=0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9),
dict(
ragged_reduce_op=ragged_math_ops.reduce_prod,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=None,
- keepdims=False,
expected=0 * 1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9),
dict(
ragged_reduce_op=ragged_math_ops.reduce_min,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=None,
- keepdims=False,
expected=min(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)),
dict(
ragged_reduce_op=ragged_math_ops.reduce_max,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=None,
- keepdims=False,
expected=max(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)),
dict(
ragged_reduce_op=ragged_math_ops.reduce_mean,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=None,
- keepdims=False,
expected=mean(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)),
# axis=0
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=0,
- keepdims=False,
expected=[0 + 4 + 5 + 7 + 8, 1 + 6 + 9, 2, 3]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_prod,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=0,
- keepdims=False,
expected=[0 * 4 * 5 * 7 * 8, 1 * 6 * 9, 2, 3]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_min,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=0,
- keepdims=False,
expected=[min(0, 4, 5, 7, 8), min(1, 6, 9), 2, 3]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_max,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=0,
- keepdims=False,
expected=[max(0, 4, 5, 7, 8), max(1, 6, 9), 2, 3]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_mean,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=0,
- keepdims=False,
expected=[mean(0, 4, 5, 7, 8),
mean(1, 6, 9), 2, 3]),
# axis=1
@@ -337,19 +208,16 @@
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=1,
- keepdims=False,
expected=[0 + 1 + 2 + 3, 4, 0, 5 + 6, 7, 8 + 9]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_prod,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=1,
- keepdims=False,
expected=[0 * 1 * 2 * 3, 4, 1, 5 * 6, 7, 8 * 9]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_min,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=1,
- keepdims=False,
expected=[min(0, 1, 2, 3), 4, _MAX_INT32,
min(5, 6), 7,
min(8, 9)]),
@@ -357,7 +225,6 @@
ragged_reduce_op=ragged_math_ops.reduce_max,
rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]],
axis=1,
- keepdims=False,
expected=[max(0, 1, 2, 3), 4, _MIN_INT32,
max(5, 6), 7,
max(8, 9)]),
@@ -369,117 +236,51 @@
# [ ],
# [[9 ] ]]
#=========================================================================
- # keepdims=False
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
axis=[],
- keepdims=False,
expected=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
axis=None,
- keepdims=False,
expected=sum([1, 2, 3, 4, 5, 6, 7, 8, 9])),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
axis=0,
- keepdims=False,
expected=[[1 + 6 + 9, 2 + 7], [], [3 + 8, 4, 5]]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
axis=1,
- keepdims=False,
expected=[[1 + 3, 2 + 4, 5], [6 + 8, 7], [], [9]]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
axis=2,
- keepdims=False,
expected=[[1 + 2, 0, 3 + 4 + 5], [6 + 7, 0, 8], [], [9]]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
axis=[0, 1],
- keepdims=False,
expected=[1 + 3 + 6 + 8 + 9, 2 + 4 + 7, 5]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
axis=[0, 2],
- keepdims=False,
expected=[1 + 6 + 9 + 2 + 7, 0, 3 + 8 + 4 + 5]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
axis=[1, 2],
- keepdims=False,
expected=[1 + 2 + 3 + 4 + 5, 6 + 7 + 8, 0, 9]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
axis=[0, 1, 2],
- keepdims=False,
expected=sum([1, 2, 3, 4, 5, 6, 7, 8, 9])),
- # keepdims=True
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
- axis=[],
- keepdims=True,
- expected=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]]),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
- axis=None,
- keepdims=True,
- expected=[[[sum([1, 2, 3, 4, 5, 6, 7, 8, 9])]]]),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
- axis=0,
- keepdims=True,
- expected=[[[1 + 6 + 9, 2 + 7], [], [3 + 8, 4, 5]]]),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
- axis=1,
- keepdims=True,
- expected=[[[1 + 3, 2 + 4, 5]], [[6 + 8, 7]], [[]], [[9]]]),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
- axis=2,
- keepdims=True,
- expected=[[[1 + 2], [0], [3 + 4 + 5]], [[6 + 7], [0], [8]], [],
- [[9]]]),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
- axis=[0, 1],
- keepdims=True,
- expected=[[[1 + 3 + 6 + 8 + 9, 2 + 4 + 7, 5]]]),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
- axis=[0, 2],
- keepdims=True,
- expected=[[[1 + 6 + 9 + 2 + 7], [0], [3 + 8 + 4 + 5]]]),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
- axis=[1, 2],
- keepdims=True,
- expected=[[[1 + 2 + 3 + 4 + 5]], [[6 + 7 + 8]], [[0]], [[9]]]),
- dict(
- ragged_reduce_op=ragged_math_ops.reduce_sum,
- rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
- axis=[0, 1, 2],
- keepdims=True,
- expected=[[[sum([1, 2, 3, 4, 5, 6, 7, 8, 9])]]]),
#=========================================================================
# Examples for ragged_reduce_mean ragged_rank=2:
@@ -491,19 +292,16 @@
ragged_reduce_op=ragged_math_ops.reduce_mean,
rt_input=[[[1, 2], [3, 4, 5]], [[6, 7], [8]], [[9]]],
axis=0,
- keepdims=False,
expected=[[mean(1, 6, 9), mean(2, 7)], [mean(3, 8), 4, 5]]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_mean,
rt_input=[[[1, 2], [3, 4, 5]], [[6, 7], [8]], [[9]]],
axis=1,
- keepdims=False,
expected=[[mean(1, 3), mean(2, 4), 5], [mean(6, 8), 7], [9]]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_mean,
rt_input=[[[1, 2], [3, 4, 5]], [[6, 7], [8]], [[9]]],
axis=2,
- keepdims=False,
expected=[[mean(1, 2), mean(3, 4, 5)], [mean(6, 7), 8], [9]]),
# Test case for GitHub issue 27497, multiple negative axes.
@@ -511,18 +309,16 @@
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
axis=[-2, -1],
- keepdims=False,
expected=[1 + 2 + 3 + 4 + 5, 6 + 7 + 8, 0, 9]),
dict(
ragged_reduce_op=ragged_math_ops.reduce_sum,
rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]],
axis=[-3, -2, -1],
- keepdims=False,
expected=sum([1, 2, 3, 4, 5, 6, 7, 8, 9])),
)
- def testReduce(self, ragged_reduce_op, rt_input, axis, keepdims, expected):
+ def testReduce(self, ragged_reduce_op, rt_input, axis, expected):
rt_input = ragged_factory_ops.constant(rt_input)
- reduced = ragged_reduce_op(rt_input, axis, keepdims=keepdims)
+ reduced = ragged_reduce_op(rt_input, axis)
self.assertAllEqual(reduced, expected)
def testReduceKeepsInnerDimensionShape(self):
@@ -540,8 +336,8 @@
def testMeanNan(self):
rt_as_list = [[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]]
expected = (
- np.array([0 + 1 + 2 + 3, 4, 0, 5 + 6, 7, 8 + 9]) /
- np.array([4, 1, 0, 2, 1, 2]))
+ np.array([0 + 1 + 2 + 3, 4, 0, 5 + 6, 7, 8 + 9]) / np.array(
+ [4, 1, 0, 2, 1, 2]))
rt_input = ragged_factory_ops.constant(rt_as_list)
reduced = ragged_math_ops.reduce_mean(rt_input, axis=1)
self.assertEqualWithNan(self.evaluate(reduced), expected)