blob: 2f588ea4ef6800e2664637fe33159c6639310950 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.ragged.ragged_tensor_shape."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_shape
from tensorflow.python.ops.ragged.ragged_tensor_shape import RaggedTensorDynamicShape
from tensorflow.python.platform import googletest
# pylint: disable=g-long-lambda
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def assertShapeEq(self, x, y):
assert isinstance(x, RaggedTensorDynamicShape)
assert isinstance(y, RaggedTensorDynamicShape)
self.assertLen(x.partitioned_dim_sizes, len(y.partitioned_dim_sizes))
for x_dims, y_dims in zip(x.partitioned_dim_sizes, y.partitioned_dim_sizes):
self.assertAllEqual(x_dims, y_dims)
self.assertAllEqual(x.inner_dim_sizes, y.inner_dim_sizes)
@parameterized.parameters([
dict(value='x', expected_dim_sizes=[]),
dict(value=['a', 'b', 'c'], expected_dim_sizes=[3]),
dict(value=[['a', 'b', 'c'], ['d', 'e', 'f']], expected_dim_sizes=[2, 3]),
dict(
value=[[['a', 'b', 'c'], ['d', 'e', 'f']]],
expected_dim_sizes=[1, 2, 3]),
dict(
value=ragged_factory_ops.constant_value([['a', 'b', 'c'], ['d',
'e']]),
expected_dim_sizes=[2, [3, 2]]),
dict(
value=ragged_factory_ops.constant_value([[['a', 'b', 'c'], ['d',
'e']]]),
expected_dim_sizes=[1, [2], [3, 2]]),
dict(
value=ragged_factory_ops.constant_value(
[[['a', 'b', 'c'], ['d', 'e', 'f']]], ragged_rank=1),
expected_dim_sizes=[1, [2], 3]),
dict(
value=ragged_factory_ops.constant_value(
[[[[1], [2]], [[3], [4]]], [[[5], [6]]]], ragged_rank=1),
expected_dim_sizes=[2, [2, 1], 2, 1]),
dict(
value=ragged_factory_ops.constant_value([[10, 20], [30]]),
expected_dim_sizes=[2, [2, 1]]),
# Docstring examples:
dict(value=[[1, 2, 3], [4, 5, 6]], expected_dim_sizes=[2, 3]),
dict(
value=ragged_factory_ops.constant_value([[1, 2], [], [3, 4, 5]]),
expected_dim_sizes=[3, [2, 0, 3]]),
dict(
value=ragged_factory_ops.constant_value([[[1, 2], [3, 4]], [[5, 6]]],
ragged_rank=1),
expected_dim_sizes=[2, [2, 1], 2]),
dict(
value=ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
expected_dim_sizes=[2, [2, 1], [2, 1, 2]]),
dict(
value=lambda: ragged_tensor.RaggedTensor.from_uniform_row_length(
ragged_factory_ops.constant([[1, 2], [3, 4, 5], [], [6]]),
uniform_row_length=2),
expected_dim_sizes=[2, 2, [2, 3, 0, 1]]),
])
def testFromTensor(self, value, expected_dim_sizes):
if callable(value):
value = value()
shape = RaggedTensorDynamicShape.from_tensor(value)
expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
self.assertShapeEq(shape, expected)
@parameterized.parameters([
dict(dim_sizes=[], rank=0, expected_dim_sizes=[]),
dict(dim_sizes=[], rank=3, expected_dim_sizes=[1, 1, 1]),
dict(dim_sizes=[3], rank=1, expected_dim_sizes=[3]),
dict(dim_sizes=[3], rank=3, expected_dim_sizes=[1, 1, 3]),
dict(dim_sizes=[2, 3], rank=3, expected_dim_sizes=[1, 2, 3]),
dict(dim_sizes=[3, [3, 2, 4]], rank=2, expected_dim_sizes=[3, [3, 2, 4]]),
dict(
dim_sizes=[3, [3, 2, 4]],
rank=4,
expected_dim_sizes=[1, 1, 3, [3, 2, 4]]),
dict(
dim_sizes=[3, [3, 2, 4], 2, 3],
rank=5,
expected_dim_sizes=[1, 3, [3, 2, 4], 2, 3]),
])
def testBroadcastToRank(self, dim_sizes, rank, expected_dim_sizes):
shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
broadcasted_shape = shape.broadcast_to_rank(rank)
self.assertShapeEq(broadcasted_shape, expected)
self.assertEqual(broadcasted_shape.rank, rank)
@parameterized.parameters([
#=========================================================================
# dimension[axis] is uniform inner; and row_lengths is a scalar
#=========================================================================
# shape: [BROADCAST(UNIFORM), UNIFORM, UNIFORM]
dict(axis=0,
row_length=3,
original_dim_sizes=[1, 4, 5],
broadcast_dim_sizes=[3, 4, 5]),
# shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
dict(axis=2,
row_length=5,
original_dim_sizes=[3, 4, 1],
broadcast_dim_sizes=[3, 4, 5]),
# shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)]
dict(axis=2,
row_length=5,
original_dim_sizes=[3, [3, 2, 8], 1],
broadcast_dim_sizes=[3, [3, 2, 8], 5]),
# shape: [UNIFORM, RAGGED, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
dict(axis=5,
row_length=5,
original_dim_sizes=[2, [2, 1], [3, 2, 8], 3, 4, 1],
broadcast_dim_sizes=[2, [2, 1], [3, 2, 8], 3, 4, 5]),
#=========================================================================
# dimension[axis] is uniform inner; and row_lengths is a vector
#=========================================================================
# shape: [UNIFORM, BROADCAST(UNIFORM)]
dict(axis=1,
row_length=[2, 0, 1],
original_dim_sizes=[3, 1],
broadcast_dim_sizes=[3, [2, 0, 1]]),
# shape: [UNIFORM, BROADCAST(UNIFORM), UNIFORM]
dict(axis=1,
row_length=[2, 0, 1],
original_dim_sizes=[3, 1, 5],
broadcast_dim_sizes=[3, [2, 0, 1], 5]),
# shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
dict(axis=2,
row_length=[2, 0, 1, 3, 8, 2, 3, 4, 1, 8, 7, 0],
original_dim_sizes=[4, 3, 1],
broadcast_dim_sizes=[4, 3, [2, 0, 1, 3, 8, 2, 3, 4, 1, 8, 7, 0]]),
# shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)]
dict(axis=2,
row_length=[2, 5, 3],
original_dim_sizes=[2, [2, 1], 1],
broadcast_dim_sizes=[2, [2, 1], [2, 5, 3]]),
# shape: [UNIFORM, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM), UNIFORM]
dict(axis=4,
row_length=list(range(18)),
original_dim_sizes=[2, [2, 1], 3, 2, 1, 8],
broadcast_dim_sizes=[2, [2, 1], 3, 2, list(range(18)), 8]),
#=========================================================================
# dimension[axis] is uniform partitioned; and row_lengths is a scalar
#=========================================================================
# shape: [BROADCAST(UNIFORM), RAGGED]
dict(axis=0,
row_length=3,
original_dim_sizes=[1, [5]],
broadcast_dim_sizes=[3, [5, 5, 5]]),
# shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED]
dict(axis=0,
row_length=2,
original_dim_sizes=[1, 3, [3, 0, 2]],
broadcast_dim_sizes=[2, 3, [3, 0, 2, 3, 0, 2]]),
# shape: [BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM, UNIFORM]
dict(axis=0,
row_length=3,
original_dim_sizes=[1, [3], [3, 5, 2], 9, 4, 5],
broadcast_dim_sizes=[3, [3, 3, 3], [3, 5, 2, 3, 5, 2, 3, 5, 2],
9, 4, 5]),
# shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED, UNIFORM]
dict(axis=0,
row_length=2,
original_dim_sizes=[1, 2, [2, 1], [3, 5, 2], 2],
broadcast_dim_sizes=[2, 2, [2, 1, 2, 1], [3, 5, 2, 3, 5, 2], 2]),
# shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM]
dict(axis=1,
row_length=2,
original_dim_sizes=[3, 1, [4, 0, 2], 5],
broadcast_dim_sizes=[3, 2, [4, 0, 2, 4, 0, 2], 5]),
# shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED]
dict(axis=1,
row_length=1,
original_dim_sizes=[2, 3, (1, 2, 3, 4, 5, 6)],
broadcast_dim_sizes=[2, 3, (1, 2, 3, 4, 5, 6)]),
#=========================================================================
# dimension[axis] is uniform partitioned; and row_lengths is a vector
#=========================================================================
# shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM]
dict(axis=1,
row_length=[4, 1, 2],
original_dim_sizes=[
3, # axis=0
1, # axis=1 (broadcast)
[3, 1, 2], # axis=2
5], # axis=3
broadcast_dim_sizes=[
3, # axis=0
[4, 1, 2], # axis=1 (broadcast)
[3, 3, 3, 3, 1, 2, 2], # axis=2
5]), # axis=3
# shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, RAGGED]
dict(axis=1,
row_length=[2, 0, 3],
original_dim_sizes=[
3, # axis=0
1, # axis=1 (broadcast)
[3, 1, 2], # axis=2
[3, 1, 4, 1, 5, 9]], # axis=3
broadcast_dim_sizes=[
3, # axis=0
[2, 0, 3], # axis=1 (broadcast)
[3, 3, 2, 2, 2], # axis=2
[3, 1, 4, 3, 1, 4, 5, 9, 5, 9, 5, 9]]), # axis=3
# shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM]
dict(axis=2,
row_length=[4, 1, 2],
original_dim_sizes=[
3, # axis=0
[2, 0, 1], # axis=1
1, # axis=2 (broadcast)
[3, 2, 1], # axis=3
[1, 0, 1, 0, 2, 3], # axis=4
5], # axis=5
broadcast_dim_sizes=[
3, # axis=0
[2, 0, 1], # axis=2
[4, 1, 2], # axis=2 (broadcast)
[3, 3, 3, 3, 2, 1, 1], # axis=3
[1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, # axis=4
2, 3, 3],
5]), # axis=5
dict(axis=0,
row_length=2,
original_dim_sizes=[1, 1, 2, (2, 1)],
broadcast_dim_sizes=[2, 1, 2, (2, 1, 2, 1)]),
dict(axis=1,
row_length=(2, 1),
original_dim_sizes=[2, 1, 2, (2, 1, 2, 1)],
broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
dict(axis=2,
row_length=2,
original_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)],
broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
dict(axis=3,
row_length=(2, 1, 2, 1, 2, 1),
original_dim_sizes=[2, (2, 1), 2, 1],
broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
]) # pyformat: disable
def testBroadcastDimension(self, axis, row_length, original_dim_sizes,
broadcast_dim_sizes):
"""Tests for the broadcast_dimension method.
Verifies that:
* `original.broadcast_dimension(axis, row_length) == broadcast`
* `broadcast.broadcast_dimension(axis, row_length) == broadcast`
* `broadcast.broadcast_dimension(axis, 1) == broadcast`
Args:
axis: The axis to broadcast
row_length: The slice lengths to broadcast to.
original_dim_sizes: The dimension sizes before broadcasting.
original_dim_sizes[axis] should be equal to `1` or `row_length`.
broadcast_dim_sizes: THe dimension sizes after broadcasting.
"""
original_shape = RaggedTensorDynamicShape.from_dim_sizes(original_dim_sizes)
bcast_shape = RaggedTensorDynamicShape.from_dim_sizes(broadcast_dim_sizes)
self.assertEqual(original_shape.rank, bcast_shape.rank)
# shape[axis].value == 1 and row_length > 1:
bcast1 = original_shape.broadcast_dimension(axis, row_length)
# shape[axis].value > 1 and row_length == shape[axis].value:
bcast2 = bcast_shape.broadcast_dimension(axis, row_length)
# shape[axis].value > 1 and row_length == 1:
bcast3 = bcast_shape.broadcast_dimension(axis, 1)
self.assertShapeEq(bcast1, bcast_shape)
self.assertShapeEq(bcast2, bcast_shape)
self.assertShapeEq(bcast3, bcast_shape)
@parameterized.parameters(
[
# Broadcast scalar
dict(x_dims=[], y_dims=[], expected_dims=[]),
dict(x_dims=[], y_dims=[2], expected_dims=[2]),
dict(x_dims=[], y_dims=[2, 3], expected_dims=[2, 3]),
dict(
x_dims=[],
y_dims=[2, (2, 3), (5, 7, 2, 0, 9)],
expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]),
# Broadcast vector
dict(x_dims=[3], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]),
dict(x_dims=[1], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]),
dict(x_dims=[3], y_dims=[4, 2, 1], expected_dims=[4, 2, 3]),
dict(
x_dims=[3],
y_dims=[3, (2, 3, 1), 1],
expected_dims=[3, (2, 3, 1), 3]),
dict(x_dims=[1], y_dims=[3, (2, 1, 3)], expected_dims=[3, (2, 1, 3)]),
dict(
x_dims=[1],
y_dims=[3, (2, 1, 3), 8],
expected_dims=[3, (2, 1, 3), 8]),
dict(
x_dims=[1],
y_dims=[2, (2, 3), (5, 7, 2, 0, 9)],
expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]),
# Mixed broadcasting
dict(
x_dims=[
1, # axis=0
3, # axis=1
(3, 0, 2), # axis=2
1, # axis=3
2, # axis=4
],
y_dims=[
2, # axis=0
1, # axis=1
1, # axis=2
(7, 2), # axis=3
1, # axis=4
],
expected_dims=[
2, # axis=0
3, # axis=1
(3, 0, 2, 3, 0, 2), # axis=2
(7, 7, 7, 7, 7, 2, 2, 2, 2, 2), # axis=3
2, # axis=4
]),
dict(
x_dims=[2, (2, 1), 2, 1],
y_dims=[1, 1, 2, (2, 1)],
expected_dims=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
])
def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims):
x_shape = RaggedTensorDynamicShape.from_dim_sizes(x_dims)
y_shape = RaggedTensorDynamicShape.from_dim_sizes(y_dims)
expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dims)
result1 = ragged_tensor_shape.broadcast_dynamic_shape(x_shape, y_shape)
result2 = ragged_tensor_shape.broadcast_dynamic_shape(y_shape, x_shape)
self.assertShapeEq(expected, result1)
self.assertShapeEq(expected, result2)
def testRepr(self):
shape = RaggedTensorDynamicShape.from_dim_sizes([2, (2, 1), 2, 1])
self.assertRegex(
repr(shape), r'RaggedTensorDynamicShape\('
r'partitioned_dim_sizes=\(<[^>]+>, <[^>]+>\), '
r'inner_dim_sizes=<[^>]+>\)')
@parameterized.parameters([
dict(
x=[[10], [20], [30]], # shape=[3, 1]
dim_sizes=[3, 2],
expected=[[10, 10], [20, 20], [30, 30]]),
dict(
x=[[10], [20], [30]], # shape=[3, 1]
dim_sizes=[3, [3, 0, 2]],
expected=ragged_factory_ops.constant_value(
[[10, 10, 10], [], [30, 30]], dtype=np.int32)),
dict(
x=[[[1, 2, 3]], [[4, 5, 6]]], # shape = [2, 1, 3]
dim_sizes=[2, [2, 3], 3],
expected=ragged_factory_ops.constant_value(
[[[1, 2, 3], [1, 2, 3]], [[4, 5, 6], [4, 5, 6], [4, 5, 6]]],
dtype=np.int32,
ragged_rank=1)),
dict(
x=[[[1]], [[2]]], # shape = [2, 1, 1]
dim_sizes=[2, [2, 3], [0, 2, 1, 2, 0]],
expected=ragged_factory_ops.constant_value(
[[[], [1, 1]], [[2], [2, 2], []]], dtype=np.int32,
ragged_rank=2)),
dict(
x=10,
dim_sizes=[3, [3, 0, 2]],
expected=ragged_factory_ops.constant_value([[10, 10, 10], [],
[10, 10]])),
dict(
x=ragged_factory_ops.constant_value([[[1], [2]], [[3]]],
ragged_rank=1),
dim_sizes=[2, [2, 1], 2],
expected=ragged_factory_ops.constant_value(
[[[1, 1], [2, 2]], [[3, 3]]], ragged_rank=1)),
])
def testRaggedBroadcastTo(self, x, dim_sizes, expected):
shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
result = ragged_tensor_shape.broadcast_to(x, shape)
self.assertEqual(
getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0))
self.assertAllEqual(result, expected)
@parameterized.parameters(
[
dict(
doc='x.shape=[3, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]],
dtype=np.int32),
y=[[10], [20], [30]],
expected=ragged_factory_ops.constant_value([[11, 12, 13], [],
[34, 35]])),
dict(
doc='x.shape=[3, (D1)]; y.shape=[]; bcast.shape=[3, (D1)]',
x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]],
dtype=np.int32),
y=10,
expected=ragged_factory_ops.constant_value([[11, 12, 13], [],
[14, 15]])),
dict(
doc='x.shape=[1, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
x=ragged_factory_ops.constant_value([[1, 2, 3]], dtype=np.int32),
y=[[10], [20], [30]],
expected=ragged_factory_ops.constant_value(
[[11, 12, 13], [21, 22, 23], [31, 32, 33]], dtype=np.int32)),
dict(
doc=('x.shape=[2, (D1), 1]; y.shape=[1, (D2)]; '
'bcast.shape=[2, (D1), (D2)]'),
x=ragged_factory_ops.constant_value([[[1], [2], [3]], [[4]]],
ragged_rank=1),
y=ragged_factory_ops.constant_value([[10, 20, 30]]),
expected=ragged_factory_ops.constant_value([[[11, 21, 31],
[12, 22, 32],
[13, 23, 33]],
[[14, 24, 34]]])),
dict(
doc=('x.shape=[2, (D1), 1]; y.shape=[1, 1, 4]; '
'bcast.shape=[2, (D1), 4]'),
x=ragged_factory_ops.constant_value([[[10], [20]], [[30]]],
ragged_rank=1),
y=[[[1, 2, 3, 4]]],
expected=ragged_factory_ops.constant_value(
[[[11, 12, 13, 14], [21, 22, 23, 24]], [[31, 32, 33, 34]]],
ragged_rank=1)),
dict(
doc=('x.shape=[2, (D1), 2, 1]; y.shape=[2, (D2)]; '
'bcast.shape=[2, (D1), (2), (D2)'),
x=ragged_factory_ops.constant_value(
[[[[1], [2]], [[3], [4]]], [[[5], [6]]]], ragged_rank=1),
y=ragged_factory_ops.constant_value([[10, 20], [30]]),
expected=ragged_factory_ops.constant_value([[[[11, 21], [32]],
[[13, 23], [34]]],
[[[15, 25], [36]]]])),
])
def testRaggedAddWithBroadcasting(self, x, y, expected, doc):
expected_rrank = getattr(expected, 'ragged_rank', 0)
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
result = x + y
result_rrank = getattr(result, 'ragged_rank', 0)
self.assertEqual(expected_rrank, result_rrank)
if hasattr(expected, 'tolist'):
expected = expected.tolist()
self.assertAllEqual(result, expected)
if __name__ == '__main__':
googletest.main()