blob: 5564e6c02f9f77f555a105d7bda8876951bc062c [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DataAdapter tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class DataAdapterTestBase(test.TestCase):
def setUp(self):
super(DataAdapterTestBase, self).setUp()
self.batch_size = 5
self.numpy_input = np.zeros((50, 10))
self.numpy_target = np.ones(50)
self.tensor_input = constant_op.constant(2.0, shape=(50, 10))
self.tensor_target = array_ops.ones((50,))
self.dataset_input = dataset_ops.DatasetV2.from_tensor_slices(
(self.numpy_input, self.numpy_target)).shuffle(50).batch(
self.batch_size)
def generator():
yield (np.zeros((self.batch_size, 10)), np.ones(self.batch_size))
self.generator_input = generator()
self.sequence_input = TestSequence(batch_size=self.batch_size,
feature_shape=10)
self.model = keras.models.Sequential(
[keras.layers.Dense(8, input_shape=(10,), activation='softmax')])
class TestSequence(data_utils.Sequence):
def __init__(self, batch_size, feature_shape):
self.batch_size = batch_size
self.feature_shape = feature_shape
def __getitem__(self, item):
return (np.zeros((self.batch_size, self.feature_shape)),
np.ones((self.batch_size,)))
def __len__(self):
return 10
class TensorLikeDataAdapterTest(DataAdapterTestBase):
def setUp(self):
super(TensorLikeDataAdapterTest, self).setUp()
self.adapter_cls = data_adapter.TensorLikeDataAdapter
def test_can_handle_numpy(self):
self.assertTrue(self.adapter_cls.can_handle(self.numpy_input))
self.assertTrue(
self.adapter_cls.can_handle(self.numpy_input, self.numpy_target))
self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
def test_iterator_expect_batch_size_numpy(self):
with self.assertRaisesRegexp(ValueError, r'`batch_size` is required'):
self.adapter_cls(self.numpy_input, self.numpy_target)
def test_size_numpy(self):
adapter = self.adapter_cls(
self.numpy_input, self.numpy_target, batch_size=5)
self.assertEqual(adapter.get_size(), 10)
self.assertFalse(adapter.has_partial_batch())
def test_batch_size_numpy(self):
adapter = self.adapter_cls(
self.numpy_input, self.numpy_target, batch_size=5)
self.assertEqual(adapter.batch_size(), 5)
def test_partial_batch_numpy(self):
adapter = self.adapter_cls(
self.numpy_input, self.numpy_target, batch_size=4)
self.assertEqual(adapter.get_size(), 13) # 50/4
self.assertTrue(adapter.has_partial_batch())
self.assertEqual(adapter.partial_batch_size(), 2)
def test_training_numpy(self):
dataset = self.adapter_cls(
self.numpy_input, self.numpy_target, batch_size=5).get_dataset()
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
self.model.fit(dataset)
def test_can_handle(self):
self.assertTrue(self.adapter_cls.can_handle(self.tensor_input))
self.assertTrue(
self.adapter_cls.can_handle(self.tensor_input, self.tensor_target))
self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
def test_training(self):
dataset = self.adapter_cls(
self.tensor_input, self.tensor_target, batch_size=5).get_dataset()
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
self.model.fit(dataset)
def test_size(self):
adapter = self.adapter_cls(
self.tensor_input, self.tensor_target, batch_size=5)
self.assertEqual(adapter.get_size(), 10)
self.assertFalse(adapter.has_partial_batch())
def test_batch_size(self):
adapter = self.adapter_cls(
self.tensor_input, self.tensor_target, batch_size=5)
self.assertEqual(adapter.batch_size(), 5)
def test_partial_batch(self):
adapter = self.adapter_cls(
self.tensor_input, self.tensor_target, batch_size=4)
self.assertEqual(adapter.get_size(), 13) # 50/4
self.assertTrue(adapter.has_partial_batch())
self.assertEqual(adapter.partial_batch_size(), 2)
class DatasetAdapterTest(DataAdapterTestBase):
def setUp(self):
super(DatasetAdapterTest, self).setUp()
self.adapter_cls = data_adapter.DatasetAdapter
def test_can_handle(self):
self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
self.assertTrue(self.adapter_cls.can_handle(self.dataset_input))
self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
def test_training(self):
dataset = self.adapter_cls(self.dataset_input).get_dataset()
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
self.model.fit(dataset)
def test_size(self):
adapter = self.adapter_cls(self.dataset_input)
self.assertIsNone(adapter.get_size())
def test_batch_size(self):
adapter = self.adapter_cls(self.dataset_input)
self.assertIsNone(adapter.batch_size())
def test_partial_batch(self):
adapter = self.adapter_cls(self.dataset_input)
self.assertFalse(adapter.has_partial_batch())
self.assertIsNone(adapter.partial_batch_size())
class GeneratorDataAdapterTest(DataAdapterTestBase):
def setUp(self):
super(GeneratorDataAdapterTest, self).setUp()
self.adapter_cls = data_adapter.GeneratorDataAdapter
def test_can_handle(self):
self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
self.assertTrue(self.adapter_cls.can_handle(self.generator_input))
self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
def test_training(self):
dataset = self.adapter_cls(self.generator_input).get_dataset()
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
self.model.fit(dataset)
def test_size(self):
adapter = self.adapter_cls(self.generator_input)
self.assertIsNone(adapter.get_size())
def test_batch_size(self):
adapter = self.adapter_cls(self.generator_input)
self.assertEqual(adapter.batch_size(), 5)
def test_partial_batch(self):
adapter = self.adapter_cls(self.generator_input)
self.assertFalse(adapter.has_partial_batch())
self.assertIsNone(adapter.partial_batch_size())
class KerasSequenceAdapterTest(DataAdapterTestBase):
def setUp(self):
super(KerasSequenceAdapterTest, self).setUp()
self.adapter_cls = data_adapter.KerasSequenceAdapter
def test_can_handle(self):
self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
self.assertTrue(self.adapter_cls.can_handle(self.sequence_input))
def test_training(self):
dataset = self.adapter_cls(self.sequence_input).get_dataset()
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
self.model.fit(dataset)
def test_size(self):
adapter = self.adapter_cls(self.sequence_input)
self.assertEqual(adapter.get_size(), 10)
def test_batch_size(self):
adapter = self.adapter_cls(self.sequence_input)
self.assertEqual(adapter.batch_size(), 5)
def test_partial_batch(self):
adapter = self.adapter_cls(self.sequence_input)
self.assertFalse(adapter.has_partial_batch())
self.assertIsNone(adapter.partial_batch_size())
if __name__ == '__main__':
test.main()