blob: 964634400b96fc35d90349229711f13796e12d08 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dataset_creator."""
from absl.testing import parameterized
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator
from tensorflow.python.keras import combinations
from tensorflow.python.keras.distribute import multi_worker_testing_utils
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import core as core_layers
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.keras.utils import dataset_creator
from tensorflow.python.platform import test
from tensorflow.python.training.server_lib import ClusterSpec
class DatasetCreatorTest(test.TestCase, parameterized.TestCase):
def test_dataset_creator(self):
with self.assertRaisesRegex(
TypeError, "`dataset_fn` for `DatasetCreator` must be a `callable`."):
dataset_creator.DatasetCreator(2)
dataset_fn = lambda: 3
with self.assertRaisesRegex(
TypeError, "The `callable` provided to `DatasetCreator` must return "
"a Dataset."):
dataset_creator.DatasetCreator(dataset_fn)()
dataset_fn = lambda: dataset_ops.DatasetV2.from_tensor_slices([1, 1])
got = dataset_creator.DatasetCreator(dataset_fn)()
self.assertEqual(
next(iter(got)),
next(iter(dataset_ops.DatasetV2.from_tensor_slices([1, 1]))))
def _get_dataset_fn(self):
def dataset_fn(input_context):
global_batch_size = 64
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = dataset_ops.DatasetV2.from_tensors(([1.], [1.])).repeat()
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
return dataset
return dataset_fn
@combinations.generate(combinations.combine(use_input_options=[True, False]))
def test_dataset_creator_model_fit_without_strategy(self, use_input_options):
model = sequential.Sequential([core_layers.Dense(10)])
model.compile(gradient_descent.SGD(), loss="mse")
input_options = distribute_lib.InputOptions() if use_input_options else None
history = model.fit(
dataset_creator.DatasetCreator(self._get_dataset_fn(), input_options),
epochs=10,
steps_per_epoch=10,
verbose=0)
self.assertLen(history.history["loss"], 10)
def _get_parameter_server_strategy(self):
cluster_def = multi_worker_testing_utils.create_in_process_cluster(
num_workers=2, num_ps=1, rpc_layer="grpc")
return parameter_server_strategy_v2.ParameterServerStrategyV2(
SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc"))
@combinations.generate(combinations.combine(use_input_options=[True, False]))
def test_dataset_creator_usage_in_parameter_server_model_fit(
self, use_input_options):
strategy = self._get_parameter_server_strategy()
with strategy.scope():
model = sequential.Sequential([core_layers.Dense(10)])
model.compile(gradient_descent.SGD(), loss="mse")
input_options = distribute_lib.InputOptions() if use_input_options else None
history = model.fit(
dataset_creator.DatasetCreator(self._get_dataset_fn(), input_options),
epochs=10,
steps_per_epoch=10,
verbose=0)
self.assertLen(history.history["loss"], 10)
def test_dataset_creator_input_options(self):
dataset_fn = lambda _: dataset_ops.DatasetV2.from_tensor_slices([1, 1])
input_options = distribute_lib.InputOptions(
experimental_fetch_to_device=True,
experimental_per_replica_buffer_size=2)
x = dataset_creator.DatasetCreator(dataset_fn, input_options=input_options)
with collective_all_reduce_strategy.CollectiveAllReduceStrategy().scope():
data_handler = data_adapter.get_data_handler(
x,
steps_per_epoch=2,
model=sequential.Sequential([core_layers.Dense(10)]))
# Ensuring the resulting `DistributedDatasetsFromFunction` has the right
# options.
self.assertTrue(data_handler._dataset._options.experimental_fetch_to_device)
self.assertEqual(
data_handler._dataset._options.experimental_per_replica_buffer_size, 2)
def test_dataset_creator_input_options_with_cluster_coordinator(self):
dataset_fn = lambda _: dataset_ops.DatasetV2.from_tensor_slices([1, 1])
input_options = distribute_lib.InputOptions(
experimental_fetch_to_device=True,
experimental_per_replica_buffer_size=2)
x = dataset_creator.DatasetCreator(dataset_fn, input_options=input_options)
strategy = self._get_parameter_server_strategy()
with strategy.scope():
model = sequential.Sequential([core_layers.Dense(10)])
model._cluster_coordinator = cluster_coordinator.ClusterCoordinator(
strategy)
data_handler = data_adapter.get_data_handler(
x, steps_per_epoch=2, model=model)
iter_rv = iter(data_handler._dataset)._values[0]
iter_rv._rebuild_on(model._cluster_coordinator._cluster.workers[0])
distributed_iterator = iter_rv._get_values()
# Ensuring the resulting `DistributedIterator` has the right options.
self.assertTrue(distributed_iterator._options.experimental_fetch_to_device)
self.assertEqual(
distributed_iterator._options.experimental_per_replica_buffer_size, 2)
if __name__ == "__main__":
v2_compat.enable_v2_behavior()
test.main()