blob: 2908346e4b7c682c31137747a54c1acccc277969 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset.padded_batch()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import checkpoint_test_base
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
count=[32, 34],
padded_shapes=[[None], [25]],
drop_remainder=[True, False])))
def testPaddedBatchDataset(self, count, padded_shapes, drop_remainder):
seq_lens = np.random.randint(20, size=(count,)).astype(np.int32)
batch_size = 4
dataset = dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
lambda x: array_ops.fill([x], x)).padded_batch(
batch_size=batch_size,
drop_remainder=drop_remainder,
padded_shapes=padded_shapes)
num_full_batches = len(seq_lens) // batch_size
get_next = self.getNext(dataset)
for i in range(num_full_batches):
result = self.evaluate(get_next())
padded_len = padded_shapes[0]
if padded_len is None or padded_len == -1:
padded_len = np.max(result) if result.size > 0 else 0
self.assertEqual((batch_size, padded_len), result.shape)
for j in range(batch_size):
seq_len = seq_lens[(i * batch_size) + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
if not drop_remainder and len(seq_lens) % batch_size > 0:
result = self.evaluate(get_next())
padded_len = padded_shapes[0]
if padded_len is None or padded_len == -1:
padded_len = np.max(result) if result.size > 0 else 0
self.assertEqual((len(seq_lens) % batch_size, padded_len), result.shape)
for j in range(len(seq_lens) % batch_size):
seq_len = seq_lens[num_full_batches * batch_size + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchShortPadding(self):
dataset = (
dataset_ops.Dataset.from_tensor_slices(
[6, 5, 5, 5, 5]).map(lambda x: array_ops.fill([x], x)).padded_batch(
batch_size=4, padded_shapes=[5]))
self.assertDatasetProduces(
dataset, expected_error=(errors.DataLossError, ''))
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchEmptyTensors(self):
dataset = (
dataset_ops.Dataset.from_tensor_slices(
[0, 0, 0, 0]).map(lambda x: array_ops.fill([x], x)).padded_batch(
batch_size=4, padded_shapes=[-1]))
self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]])
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchWithDifferetElementTypes(self):
dataset = dataset_ops.Dataset.from_tensor_slices(
([0, 1, 2, 3], ['a', 'b', 'c', 'd']))
dataset = dataset.padded_batch(2)
self.assertDatasetProduces(dataset, [([0, 1], ['a', 'b']),
([2, 3], ['c', 'd'])])
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchWithEmptyTuple(self):
dataset = dataset_ops.Dataset.from_tensor_slices(([0, 1, 2, 3], ()))
dataset = dataset.padded_batch(2)
self.assertDatasetProduces(dataset, [([0, 1], ()), ([2, 3], ())])
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchWithNoneElement(self):
dataset = dataset_ops.Dataset.from_tensor_slices(([0, 1, 2, 3], None))
with self.assertRaises(TypeError):
dataset = dataset.padded_batch(2)
@combinations.generate(test_base.default_test_combinations())
def testDefaultPaddedShapes(self):
def fill(x):
return array_ops.fill([x], x)
dataset = (
dataset_ops.Dataset.from_tensor_slices(
[1, 2, 3, 4]).map(fill).padded_batch(batch_size=2))
self.assertDatasetProduces(
dataset,
expected_output=[[[1, 0], [2, 2]], [[3, 3, 3, 0], [4, 4, 4, 4]]])
@combinations.generate(test_base.default_test_combinations())
def testNestedDefaultPaddedShapes(self):
def fill_tuple(x):
return (x, array_ops.fill([x], x))
dataset = (
dataset_ops.Dataset.from_tensor_slices(
[1, 2, 3, 4]).map(fill_tuple).padded_batch(batch_size=2))
self.assertDatasetProduces(
dataset,
expected_output=[([1, 2], [[1, 0], [2, 2]]),
([3, 4], [[3, 3, 3, 0], [4, 4, 4, 4]])])
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
padding_values=[(-1, '<end>', {'structure': ''}),
(-1, '<end>', None)])))
def testPaddedBatchDatasetNonDefaultPadding(self, padding_values):
def fill_tuple(x):
filled = array_ops.fill([x], x)
return (filled, string_ops.as_string(filled), {
'structure': string_ops.as_string(filled)
})
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
dataset = (
dataset_ops.Dataset.from_tensor_slices(random_seq_lens).map(fill_tuple)
.padded_batch(
4, padded_shapes=([-1], [-1], {'structure': [-1]}),
padding_values=padding_values))
get_next = self.getNext(dataset)
for i in range(8):
result = self.evaluate(get_next())
padded_len = np.max(result[0])
self.assertEqual((4, padded_len), result[0].shape)
self.assertEqual((4, padded_len), result[1].shape)
self.assertEqual((4, padded_len), result[2]['structure'].shape)
for j in range(4):
seq_len = random_seq_lens[(i * 4) + j]
self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[0][j, seq_len:],
[-1] * (padded_len - seq_len))
self.assertAllEqual(result[1][j, :seq_len],
[compat.as_bytes(str(seq_len))] * seq_len)
self.assertAllEqual(result[1][j, seq_len:],
[b'<end>'] * (padded_len - seq_len))
self.assertAllEqual(result[2]['structure'][j, :seq_len],
[compat.as_bytes(str(seq_len))] * seq_len)
self.assertAllEqual(result[2]['structure'][j, seq_len:],
[b''] * (padded_len - seq_len))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchDatasetUnicode(self):
# See GitHub issue 16149
def generator():
data = [[u'Простой', u'тест', u'юникода'],
[u'никогда', u'не', u'бывает', u'простым']]
for seq in data:
yield seq, [0, 1, 2, 3]
dataset = dataset_ops.Dataset.from_generator(
generator, (dtypes.string, dtypes.int32),
(tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None])))
padded_dataset = dataset.padded_batch(
2, padded_shapes=([None], [None]), padding_values=('', 0))
next_element = self.getNext(padded_dataset)
self.evaluate(next_element())
@combinations.generate(test_base.graph_only_combinations())
def testPaddedBatchDatasetShapeSpecifications(self):
int_placeholder = array_ops.placeholder(dtypes.int32)
float_placeholder = array_ops.placeholder(dtypes.float32)
string_placeholder = array_ops.placeholder(dtypes.string)
input_dataset = dataset_ops.Dataset.from_tensors(
(int_placeholder, float_placeholder, string_placeholder))
# Test different ways of specifying the `padded_shapes` argument.
dynamic_padding_from_tensor_shapes = input_dataset.padded_batch(
32,
padded_shapes=(tensor_shape.TensorShape([None]),
tensor_shape.TensorShape([None, None]),
tensor_shape.TensorShape([37])))
dynamic_padding_from_lists = input_dataset.padded_batch(
32, padded_shapes=([None], [None, None], [37]))
dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch(
32, padded_shapes=([-1], [-1, -1], [37]))
dynamic_padding_from_tensors = input_dataset.padded_batch(
32,
padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64),
constant_op.constant([-1, -1], dtype=dtypes.int64),
constant_op.constant([37], dtype=dtypes.int64)))
for dataset in [
dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists,
dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors
]:
dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
self.assertEqual([None, None], dataset_output_shapes[0].as_list())
self.assertEqual([None, None, None], dataset_output_shapes[1].as_list())
self.assertEqual([None, 37], dataset_output_shapes[2].as_list())
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchSparseError(self):
st = sparse_tensor.SparseTensorValue(
indices=[[0, 0]], values=([42]), dense_shape=[1, 1])
with self.assertRaises(TypeError):
_ = dataset_ops.Dataset.from_tensors(st).repeat(10).padded_batch(10)
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchRaggedError(self):
rt = ragged_tensor_value.RaggedTensorValue(
np.array([0, 42]), np.array([0, 2], dtype=np.int64))
with self.assertRaises(TypeError):
_ = dataset_ops.Dataset.from_tensors(rt).repeat(10).padded_batch(10)
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchShapeErrorWrongRank(self):
with self.assertRaisesRegex(
ValueError, r'The padded shape \(1,\) is not compatible with the '
r'corresponding input component shape \(\).'):
_ = dataset_ops.Dataset.range(10).padded_batch(5, padded_shapes=[1])
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchShapeErrorTooSmall(self):
with self.assertRaisesRegex(
ValueError, r'The padded shape \(1,\) is not compatible with the '
r'corresponding input component shape \(3,\).'):
_ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
5, padded_shapes=[1])
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchShapeErrorShapeNotRank1(self):
with self.assertRaisesRegex(
ValueError, r'Padded shape .* must be a 1-D tensor '
r'of tf.int64 values, but its shape was \(2, 2\).'):
_ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
5, padded_shapes=[[1, 1], [1, 1]])
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchShapeErrorShapeNotInt(self):
with self.assertRaisesRegex(
TypeError, r'Padded shape .* must be a 1-D tensor '
r'of tf.int64 values, but its element type was float32.'):
_ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
5, padded_shapes=constant_op.constant([1.5, 2., 3.]))
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchShapeErrorWrongRankFromTensor(self):
with self.assertRaisesRegex(
ValueError, r'The padded shape \(1,\) is not compatible with the '
r'corresponding input component shape \(\).'):
shape_as_tensor = constant_op.constant([1], dtype=dtypes.int64)
_ = dataset_ops.Dataset.range(10).padded_batch(
5, padded_shapes=shape_as_tensor)
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchShapeErrorDefaultShapeWithUnknownRank(self):
with self.assertRaisesRegex(ValueError, r'`padded_shapes`.*unknown rank'):
ds = dataset_ops.Dataset.from_generator(
lambda: iter([1, 2, 3]), output_types=dtypes.int32)
ds.padded_batch(2)
@combinations.generate(test_base.graph_only_combinations())
def testPaddedBatchShapeErrorPlaceholder(self):
with self.assertRaisesRegex(
ValueError,
r'The padded shape \((\?|None), (\?|None)\) is not compatible with the '
r'corresponding input component shape \(\).'):
shape_as_tensor = array_ops.placeholder(dtypes.int64, shape=[2])
_ = dataset_ops.Dataset.range(10).padded_batch(
5, padded_shapes=shape_as_tensor)
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchBfloat16(self):
ds = dataset_ops.Dataset.range(5)
ds = ds.map(lambda x: math_ops.cast(x, dtypes.bfloat16))
ds = ds.padded_batch(10)
self.assertDatasetProduces(
ds, expected_output=[[0.0, 1.0, 2.0, 3.0, 4.0]])
@combinations.generate(test_base.default_test_combinations())
def testDefaultPaddedValueShapes(self):
def fill(x):
return array_ops.fill([x], x)
dataset = dataset_ops.Dataset.zip(
(dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]).map(fill),
dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]).map(fill)))
dataset = dataset.padded_batch(batch_size=2, padding_values=-1)
self.assertDatasetProduces(
dataset,
expected_output=[([[1, -1], [2, 2]], [[1, -1], [2, 2]]),
([[3, 3, 3, -1], [4, 4, 4, 4]], [[3, 3, 3, -1],
[4, 4, 4, 4]])])
class PaddedBatchCheckpointTest(checkpoint_test_base.CheckpointTestBase,
parameterized.TestCase):
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
checkpoint_test_base.default_test_combinations()))
def test(self, verify_fn):
def build_dataset(seq_lens):
return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
lambda x: array_ops.fill([x], x)).padded_batch(
batch_size=4, padded_shapes=[-1])
seq_lens = np.random.randint(1, 20, size=(32,)).astype(np.int32)
verify_fn(self, lambda: build_dataset(seq_lens), num_outputs=8)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
checkpoint_test_base.default_test_combinations()))
def testNonDefaultPadding(self, verify_fn):
def build_dataset(seq_lens):
def fill_tuple(x):
filled = array_ops.fill([x], x)
return (filled, string_ops.as_string(filled))
padded_shape = [-1]
return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
fill_tuple).padded_batch(
batch_size=4,
padded_shapes=(padded_shape, padded_shape),
padding_values=(-1, '<end>'))
seq_lens = np.random.randint(1, 20, size=(32,)).astype(np.int32)
verify_fn(self, lambda: build_dataset(seq_lens), num_outputs=8)
if __name__ == '__main__':
test.main()