blob: 929586ff6578c456765c9932be6f20a53bcbbb72 [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.data service ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests.service import test_base as data_service_test_base
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import data_service_ops
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import def_function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR
NO_WORK_DIR = data_service_test_base.NO_WORK_DIR
class DataServiceOpsTest(data_service_test_base.TestBase,
parameterized.TestCase):
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
data_service_test_base.all_cluster_configurations()))
def testDistributeBasic(self, work_dir, fault_tolerant_mode):
cluster = data_service_test_base.TestCluster(
num_workers=1,
work_dir=work_dir,
fault_tolerant_mode=fault_tolerant_mode)
num_elements = 10
ds = self.make_distributed_range_dataset(num_elements, cluster)
self.assertDatasetProduces(ds, list(range(num_elements)))
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(compression=[None, "AUTO"])))
def testDistributeCompression(self, compression):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 10
ds = self.make_distributed_range_dataset(
num_elements, cluster, compression=compression)
self.assertDatasetProduces(ds, list(range(num_elements)))
@combinations.generate(test_base.default_test_combinations())
def testDistributeInvalidCompression(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
with self.assertRaisesRegex(ValueError, "Invalid compression argument"):
self.make_distributed_range_dataset(10, cluster, compression="foo")
@combinations.generate(test_base.eager_only_combinations())
def testDistributeSparse(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
element = sparse_tensor.SparseTensor(
indices=[[0]],
values=constant_op.constant([0], dtype=dtypes.int32),
dense_shape=[1])
ds = dataset_ops.Dataset.from_tensors(element)
ds = self.make_distributed_dataset(ds, cluster)
results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
self.assertAllEqual(results, [[0]])
@combinations.generate(test_base.eager_only_combinations())
def testDistributeRagged(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
ds = ds.map(math_ops.range)
ds = ds.apply(batching.dense_to_ragged_batch(2))
ds = self.make_distributed_dataset(ds, cluster)
results = [elem.to_tensor() for elem in ds]
self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]])
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
init_source=["textfile", "keyvaluetensor", "dataset"])))
def testDistributeLookupTable(self, init_source):
cluster = data_service_test_base.TestCluster(num_workers=1)
initializer = self.lookupTableInitializer(init_source, [10, 11])
table = lookup_ops.StaticHashTable(initializer, -1)
ds = dataset_ops.Dataset.range(3)
ds = ds.map(table.lookup)
ds = self.make_distributed_dataset(ds, cluster)
self.evaluate(lookup_ops.tables_initializer())
self.assertDatasetProduces(ds, [10, 11, -1], requires_initialization=True)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(value_rank=[0, 1])))
def testDistributeMutableHashTable(self, value_rank):
def value(v):
for _ in range(value_rank):
v = [v, v]
return v
v1 = value(10)
v2 = value(11)
default_value = value(-1)
cluster = data_service_test_base.TestCluster(num_workers=1)
table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64,
default_value)
self.evaluate(table.insert([0, 1], [v1, v2]))
ds = dataset_ops.Dataset.range(3)
ds = ds.map(table.lookup)
ds = self.make_distributed_dataset(ds, cluster)
self.assertDatasetProduces(
ds, [v1, v2, default_value], requires_initialization=True)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(shuffle_seed=[None, 10])))
def testShuffleOrder(self, shuffle_seed):
random_seed.set_random_seed(None)
num_elements = 100
cluster = data_service_test_base.TestCluster(num_workers=2)
ds = dataset_ops.Dataset.range(num_elements)
ds = ds.shuffle(num_elements, seed=shuffle_seed)
ds = self.make_distributed_dataset(ds, cluster)
output = self.getDatasetOutput(ds)
# The output will be two sequences of range(num_elements)
# non-deterministically interleaved together. If the orders of the elements
# were the same, first_order and second_order computed below will be equal.
first_order = {}
second_order = {}
for element in output:
if element in first_order:
second_order[element] = len(second_order)
else:
first_order[element] = len(first_order)
if shuffle_seed is None:
self.assertNotEqual(first_order, second_order)
else:
self.assertEqual(first_order, second_order)
@combinations.generate(test_base.default_test_combinations())
def testMultipleEpochs(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 3
ds = self.make_distributed_range_dataset(num_elements, cluster)
for _ in range(10):
self.assertDatasetProduces(ds, list(range(num_elements)))
@combinations.generate(test_base.default_test_combinations())
def testRepeatedDataset(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 10
num_repetitions = 5
ds = self.make_distributed_range_dataset(num_elements, cluster)
ds = ds.repeat(num_repetitions)
self.assertDatasetProduces(
ds, expected_output=num_repetitions * list(range(num_elements)))
@combinations.generate(test_base.default_test_combinations())
def testConcurrentEpoch(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 10
num_datasets = 3
get_nexts = []
results = []
for _ in range(num_datasets):
ds = self.make_distributed_range_dataset(num_elements, cluster)
get_nexts.append(self.getNext(ds))
results.append([])
for _ in range(num_elements):
for dataset_ind in range(num_datasets):
result = self.evaluate(get_nexts[dataset_ind]())
results[dataset_ind].append(result)
for result in results:
self.assertEqual(list(range(num_elements)), result)
@combinations.generate(test_base.default_test_combinations())
def testMultiWorker(self):
num_workers = 3
cluster = data_service_test_base.TestCluster(num_workers=num_workers)
num_elements = 10
ds = self.make_distributed_range_dataset(num_elements, cluster)
self.assertDatasetProduces(
ds, num_workers * list(range(num_elements)), assert_items_equal=True)
@combinations.generate(test_base.default_test_combinations())
def testMaxOutstandingRequests(self):
num_workers = 3
cluster = data_service_test_base.TestCluster(num_workers=num_workers)
num_elements = 10
ds = self.make_distributed_range_dataset(
num_elements, cluster, max_outstanding_requests=1)
self.assertDatasetProduces(
ds, num_workers * list(range(num_elements)), assert_items_equal=True)
@combinations.generate(test_base.eager_only_combinations())
def testInsideFunction(self):
num_workers = 3
cluster = data_service_test_base.TestCluster(num_workers=num_workers)
num_elements = 10
@def_function.function
def f():
ds = self.make_distributed_range_dataset(num_elements, cluster)
result = tensor_array_ops.TensorArray(
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
i = 0
for elem in ds:
result = result.write(i, elem)
i += 1
return result.stack()
result = list(f().numpy())
self.assertCountEqual(num_workers * list(range(num_elements)), result)
@combinations.generate(test_base.default_test_combinations())
def testSharedJobName(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 1000
def make_ds():
return dataset_ops.Dataset.range(num_elements).shuffle(num_elements)
ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name")
ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name")
get_next_1 = self.getNext(ds1)
get_next_2 = self.getNext(ds2)
results = []
for _ in range(num_elements // 5):
results.append(self.evaluate(get_next_1()))
results.append(self.evaluate(get_next_2()))
results += self.getIteratorOutput(get_next_1)
results += self.getIteratorOutput(get_next_2)
self.assertCountEqual(list(range(num_elements)), results)
@combinations.generate(test_base.default_test_combinations())
def testDifferentJobNames(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 10
ds1 = self.make_distributed_range_dataset(
num_elements, cluster, job_name="job_name1")
ds2 = self.make_distributed_range_dataset(
num_elements, cluster, job_name="job_name2")
self.assertDatasetProduces(ds1, list(range(num_elements)))
self.assertDatasetProduces(ds2, list(range(num_elements)))
@combinations.generate(test_base.eager_only_combinations())
def testSharedJobNameMultiIteration(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 10
ds1 = self.make_distributed_range_dataset(
num_elements, cluster, job_name="job_name")
ds2 = self.make_distributed_range_dataset(
num_elements, cluster, job_name="job_name")
# iteration 1
self.assertDatasetProduces(ds1, list(range(num_elements)))
self.assertDatasetProduces(ds2, [])
# iteration 2
self.assertDatasetProduces(ds2, list(range(num_elements)))
self.assertDatasetProduces(ds1, [])
@combinations.generate(test_base.default_test_combinations())
def testSharedJobNameRepeat(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 100
num_repetitions = 3
ds1 = self.make_distributed_range_dataset(
num_elements, cluster, job_name="job_name")
ds1 = ds1.repeat(num_repetitions)
ds2 = self.make_distributed_range_dataset(
num_elements, cluster, job_name="job_name")
ds2 = ds2.repeat(num_repetitions)
results = []
get_next_1 = self.getNext(ds1)
get_next_2 = self.getNext(ds2)
for _ in range((num_elements * num_repetitions) // 5):
results.append(self.evaluate(get_next_1()))
for _ in range((num_elements * num_repetitions) // 5):
results.append(self.evaluate(get_next_2()))
results += self.getIteratorOutput(get_next_1)
results += self.getIteratorOutput(get_next_2)
self.assertCountEqual(num_repetitions * list(range(num_elements)), results)
@combinations.generate(
combinations.times(test_base.eager_only_combinations(),
combinations.combine(job_name=[None, "test"])))
def testGcUnusedJob(self, job_name):
cluster = data_service_test_base.TestCluster(
num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
num_elements = 100
ds = self.make_distributed_range_dataset(
num_elements, cluster, job_name=job_name)
it = iter(ds)
self.assertEqual(next(it).numpy(), 0)
self.assertEqual(cluster.workers[0].num_tasks(), 1)
del it
while cluster.workers[0].num_tasks() > 0:
time.sleep(0.1)
@combinations.generate(test_base.eager_only_combinations())
def testDontGcUsedJob(self):
cluster = data_service_test_base.TestCluster(
num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
num_elements = 10
it1 = iter(
self.make_distributed_range_dataset(
num_elements, cluster, job_name="test1"))
it2 = iter(
self.make_distributed_range_dataset(
num_elements, cluster, job_name="test2"))
it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable
self.make_distributed_range_dataset(
num_elements, cluster, job_name="test2"))
self.assertEqual(cluster.workers[0].num_tasks(), 2)
del it1
del it2
# Check that only the first job is gced. The second job will not be gced
# because there is still an outstanding iterator for it.
while cluster.workers[0].num_tasks() > 1:
time.sleep(0.1)
self.assertEqual(cluster.workers[0].num_tasks(), 1)
@combinations.generate(test_base.eager_only_combinations())
def testGcErrorMessage(self):
cluster = data_service_test_base.TestCluster(
num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
num_elements = 100
ds = self.make_distributed_range_dataset(
num_elements, cluster, job_name="test")
it = iter(ds)
self.assertEqual(next(it).numpy(), 0)
self.assertEqual(cluster.workers[0].num_tasks(), 1)
del it
while cluster.workers[0].num_tasks() > 0:
time.sleep(0.1)
ds = self.make_distributed_range_dataset(
num_elements, cluster, job_name="test")
with self.assertRaisesRegex(
errors.FailedPreconditionError,
"The requested job has been garbage collected due to inactivity"):
list(ds)
@combinations.generate(test_base.default_test_combinations())
def testApplyDeterminismOption(self):
elements = list(range(10))
cluster = data_service_test_base.TestCluster(num_workers=1)
def dataset_fn(delay_ms):
def interleave_fn(x):
ds = dataset_ops.Dataset.from_tensors(x)
if math_ops.equal(x, 0):
ds = ds.apply(testing.sleep(delay_ms * 1000))
else:
ds = ds.apply(testing.sleep(0))
return ds
ds = dataset_ops.Dataset.from_tensor_slices(elements)
ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10)
opts = dataset_ops.Options()
opts.experimental_deterministic = False
ds = ds.with_options(opts)
ds = self.make_distributed_dataset(ds, cluster)
return ds
self.checkDeterminism(
dataset_fn=dataset_fn,
expect_determinism=False,
expected_elements=elements)
def run_stateful(self, external_state_policy):
num_elements = 10
ds = dataset_ops.Dataset.range(num_elements).map(
lambda _: random_ops.random_uniform(()))
options = dataset_ops.Options()
options.experimental_external_state_policy = external_state_policy
ds = ds.with_options(options)
cluster = data_service_test_base.TestCluster(num_workers=3)
ds = self.make_distributed_dataset(ds, cluster)
self.getDatasetOutput(ds)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(external_state_policy=[
distribute_options.ExternalStatePolicy.IGNORE,
distribute_options.ExternalStatePolicy.WARN
])))
def testStatefulNoError(self, external_state_policy):
self.run_stateful(external_state_policy)
@combinations.generate(test_base.default_test_combinations())
def testStatefulError(self):
with self.assertRaises(errors.FailedPreconditionError):
self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)
@combinations.generate(test_base.default_test_combinations())
def testDistributeFromInterleave(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
ds = dataset_ops.Dataset.range(2)
def interleave_fn(_):
dataset = dataset_ops.Dataset.range(2)
dataset = self.make_distributed_dataset(dataset, cluster)
return dataset
ds = ds.interleave(interleave_fn, cycle_length=2)
self.assertDatasetProduces(ds, [0, 0, 1, 1])
@combinations.generate(test_base.default_test_combinations())
def testDistributeNonStringAddresses(self):
ds = dataset_ops.Dataset.range(10)
with self.assertRaisesRegex(ValueError, "service must be a string"):
ds = ds.apply(
data_service_ops.distribute(
processing_mode="parallel_epochs", service=1))
@combinations.generate(test_base.default_test_combinations())
def testDistributeEmptyAddress(self):
ds = dataset_ops.Dataset.range(10)
with self.assertRaisesWithLiteralMatch(ValueError,
"service must not be empty"):
ds = ds.apply(
data_service_ops.distribute(
processing_mode="parallel_epochs", service=""))
@combinations.generate(test_base.default_test_combinations())
def testDistributeExplicitProtocol(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
ds = dataset_ops.Dataset.range(10)
ds = ds.apply(
data_service_ops.distribute(
processing_mode="parallel_epochs",
service="grpc://" + cluster.dispatcher_address()))
self.assertDatasetProduces(ds, list(range(10)))
@combinations.generate(test_base.default_test_combinations())
def testDistributeInvalidProtocol(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
ds = dataset_ops.Dataset.range(10)
with self.assertRaisesRegex(
errors.NotFoundError,
"No credentials factory has been registered for protocol grp"):
ds = ds.apply(
data_service_ops.distribute(
processing_mode="parallel_epochs",
service="grp://" + cluster.dispatcher_address()))
self.getDatasetOutput(ds)
@combinations.generate(test_base.eager_only_combinations())
def testDistributeInvalidProcessingMode(self):
ds = dataset_ops.Dataset.range(10)
with self.assertRaisesRegex(ValueError,
"invalid is not a valid processing mode"):
ds = ds.apply(
data_service_ops.distribute(
processing_mode="invalid", service="grpc://localhost:5000"))
@combinations.generate(test_base.default_test_combinations())
def testZipDifferentProcessingModesDatasets(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 100
ds1 = dataset_ops.Dataset.range(num_elements)
ds1 = self.make_distributed_dataset(
ds1, cluster, processing_mode="distributed_epoch")
ds2 = dataset_ops.Dataset.range(num_elements)
ds2 = self.make_distributed_dataset(
ds2, cluster, processing_mode="parallel_epochs")
ds = dataset_ops.Dataset.zip((ds1, ds2))
self.assertDatasetProduces(
ds,
list(zip(range(num_elements), range(num_elements))),
assert_items_equal=True)
@combinations.generate(test_base.default_test_combinations())
def testZipDifferentProcessingModesDatasetsSharedJobName(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 100
ds1 = dataset_ops.Dataset.range(num_elements)
ds1 = self.make_distributed_dataset(
ds1, cluster, processing_mode="distributed_epoch", job_name="job_name")
ds2 = dataset_ops.Dataset.range(num_elements)
ds2 = self.make_distributed_dataset(
ds2, cluster, processing_mode="parallel_epochs", job_name="job_name")
ds = dataset_ops.Dataset.zip((ds1, ds2))
with self.assertRaisesRegex(errors.FailedPreconditionError,
"but there is already an existing job"):
self.getDatasetOutput(ds)
@combinations.generate(test_base.default_test_combinations())
def testFromDatasetId(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 10
ds = dataset_ops.Dataset.range(num_elements)
dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
ds)
from_dataset_id_ds = data_service_ops.from_dataset_id(
"parallel_epochs", cluster.dispatcher_address(), dataset_id,
ds.element_spec)
self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
@combinations.generate(test_base.default_test_combinations())
def testRegisteringDatasetAsTfFunction(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 10
ds = dataset_ops.Dataset.range(num_elements)
register_func = def_function.function(data_service_ops.register_dataset)
dataset_id = register_func(
(constant_op.constant("grpc"),
constant_op.constant(cluster.dispatcher_address())), ds)
from_dataset_id_ds = data_service_ops.from_dataset_id(
"parallel_epochs", cluster.dispatcher_address(), dataset_id,
ds.element_spec)
self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
@combinations.generate(test_base.default_test_combinations())
def testFromDatasetIdMultipleComponents(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 10
ds = dataset_ops.Dataset.range(num_elements)
ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
ds)
from_dataset_id_ds = data_service_ops.from_dataset_id(
"parallel_epochs", cluster.dispatcher_address(), dataset_id,
ds.element_spec)
output = self.getDatasetOutput(from_dataset_id_ds)
for i in range(num_elements):
self.assertEqual(i, output[i]["a"][0])
self.assertEqual(i, output[i]["a"][1])
self.assertEqual(i, output[i]["b"])
@combinations.generate(test_base.default_test_combinations())
def testFromDatasetIdWrongElementSpec(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
num_elements = 10
ds = dataset_ops.Dataset.range(num_elements)
dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
ds)
wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
from_dataset_id_ds = data_service_ops.from_dataset_id(
"parallel_epochs", cluster.dispatcher_address(), dataset_id, wrong_spec)
with self.assertRaisesRegex(errors.FailedPreconditionError,
"Expected a tensor of type variant"):
self.evaluate(self.getNext(from_dataset_id_ds)())
@combinations.generate(test_base.default_test_combinations())
def testFromDatasetIdNotRegistered(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
dataset_id = 0
element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
from_dataset_id_ds = data_service_ops.from_dataset_id(
"parallel_epochs", cluster.dispatcher_address(), dataset_id,
element_spec)
with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
self.evaluate(self.getNext(from_dataset_id_ds)())
@combinations.generate(test_base.default_test_combinations())
def testCancellation(self):
self.skipTest("b/162521601")
sleep_microseconds = int(1e6) * 1000
cluster = data_service_test_base.TestCluster(num_workers=1)
# Create a dataset which produces the first element quickly, and the second
# element slowly. Fetching the first element triggers prefetching of the
# second element, which we should be able to cancel.
slow = dataset_ops.Dataset.range(1)
slow = slow.apply(testing.sleep(sleep_microseconds))
ds = dataset_ops.Dataset.range(1).concatenate(slow)
ds = self.make_distributed_dataset(ds, cluster)
ds = ds.prefetch(1)
get_next = self.getNext(ds)
self.assertEqual(0, self.evaluate(get_next()))
# Without properly implemented cancellation, we will hang here while trying
# to garbage collect the dataset iterator.
@combinations.generate(test_base.default_test_combinations())
def testRegisterEquivalentDatasets(self):
ds_1 = dataset_ops.Dataset.range(10)
ds_2 = dataset_ops.Dataset.range(10)
cluster = data_service_test_base.TestCluster(num_workers=1)
id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1)
id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2)
self.assertEqual(self.evaluate(id_1), self.evaluate(id_2))
@combinations.generate(test_base.default_test_combinations())
def testRegisterDifferentDatasets(self):
ds_1 = dataset_ops.Dataset.range(10)
ds_2 = dataset_ops.Dataset.range(20)
cluster = data_service_test_base.TestCluster(num_workers=1)
id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1)
id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2)
self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2))
@combinations.generate(test_base.default_test_combinations())
def testTwoLevelDistribute(self):
cluster_1_size = 3
cluster_1 = data_service_test_base.TestCluster(num_workers=cluster_1_size)
cluster_2 = data_service_test_base.TestCluster(num_workers=1)
num_sizes = 10
size_repeats = 5
strings = ["a" * i for i in range(num_sizes)] * size_repeats
ds = dataset_ops.Dataset.from_tensor_slices(strings)
ds = ds.shuffle(len(strings))
ds = self.make_distributed_dataset(ds, cluster_1)
# Large enough so that all strings of the same size are windowed together.
window_size = cluster_1_size * size_repeats
batch_size = size_repeats
def key_func(x):
return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64)
ds = ds.apply(
grouping.group_by_window(
key_func=key_func,
reduce_func=lambda _, x: x.batch(batch_size),
window_size=window_size))
ds = self.make_distributed_dataset(ds, cluster_2)
get_next = self.getNext(ds)
for _ in range(num_sizes):
element = self.evaluate(get_next())
for _ in range(1, cluster_1_size):
self.assertAllEqual(self.evaluate(get_next()), element)
self.assertEmpty(self.getIteratorOutput(get_next))
@combinations.generate(
combinations.times(test_base.default_test_combinations()))
def testDistributeLargeGraph(self):
cluster = data_service_test_base.TestCluster(
num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False)
# Larger than default OSS grpc message size limit of 4MB.
tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
ds = dataset_ops.Dataset.from_tensors(tensor)
ds = self.make_distributed_dataset(ds, cluster)
self.assertDatasetProduces(ds, [tensor])
@combinations.generate(
combinations.times(test_base.graph_only_combinations(),
combinations.combine(use_resource=False)) +
combinations.times(test_base.default_test_combinations(),
combinations.combine(use_resource=True)))
def testVariables(self, use_resource):
cluster = data_service_test_base.TestCluster(num_workers=1)
if not use_resource:
with variable_scope.variable_scope("foo", use_resource=False):
v = variables.VariableV1(10, dtype=dtypes.int64)
else:
v = variables.Variable(10, dtype=dtypes.int64)
ds = dataset_ops.Dataset.range(3)
ds = ds.map(lambda x: x + v)
ds = self.make_distributed_dataset(ds, cluster)
self.evaluate(v.initializer)
self.assertDatasetProduces(
ds, list(range(10, 13)), requires_initialization=True)
if __name__ == "__main__":
test.main()