blob: 120ad59cd4f4b8e066f2db97ad2d68b795ce0fe1 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Eager mode tests for the experimental `replicate` transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
@test_util.run_all_in_graph_and_eager_modes
class LocalReplicateTest(test_base.DatasetTestBase):
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
super(LocalReplicateTest, self).__init__(methodName)
self._device0 = "/device:CPU:0"
self._device1 = "/device:CPU:1"
self._device2 = "/device:CPU:2"
@test_util.run_v1_only("V2 doesnt support multiple devices")
def testBasic(self):
with ops.device(self._device0):
dataset0 = dataset_ops.Dataset.range(100)
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
dataset1 = replicated_ds[self._device1]
dataset2 = replicated_ds[self._device2]
logging.info("Producing 0")
with ops.device(self._device0):
self.assertDatasetProduces(dataset0, range(100))
logging.info("Producing 1")
with ops.device(self._device1):
self.assertDatasetProduces(dataset1, range(100))
logging.info("Producing 2")
with ops.device(self._device2):
self.assertDatasetProduces(dataset2, range(100))
@test_util.run_v1_only("V2 doesnt support multiple devices")
def testVariableInput(self):
with ops.device(self._device0):
counter_var = variable_scope.get_variable(
"counter", (), dtypes.int32, use_resource=True)
dataset0 = dataset_ops.Dataset.range(100).map(
lambda _: counter_var.assign_add(1))
# We don't support stateful ops in functions as of now.
with self.assertRaises(errors.FailedPreconditionError):
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
self.evaluate(replicated_ds[self._device1]._variant_tensor)
JOB_NAME = "remote_device"
def _get_server_def(job_name, local_server_port, remote_server_addresses,
task_index):
"""Returns a server def with a single job + multiple tasks."""
cluster_def = cluster_pb2.ClusterDef()
job_def = cluster_def.job.add()
job_def.name = job_name
job_def.tasks[0] = "localhost:%d" % local_server_port
for i, remote_server_address in enumerate(remote_server_addresses, start=1):
job_def.tasks[i] = remote_server_address
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_def,
job_name=job_name,
task_index=task_index,
protocol="grpc")
return server_def
# Pure eager mode test that sets up a cluster of processes.
class RemoteReplicateTest(test_base.DatasetTestBase):
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
super(RemoteReplicateTest, self).__init__(methodName)
self._cached_server1 = server_lib.Server.create_local_server()
self._cached_server2 = server_lib.Server.create_local_server()
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
self._device0 = "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME
self._device1 = "/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME
self._device2 = "/job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME
def setUp(self):
super(RemoteReplicateTest, self).setUp()
# Start the local server.
local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
context.set_server_def(
server_def=_get_server_def(
JOB_NAME,
local_server_port=local_port,
remote_server_addresses=[
self._cached_server1_target, self._cached_server2_target
],
task_index=0))
@test_util.run_v2_only
def testBasic(self):
with ops.device(self._device0):
dataset0 = dataset_ops.Dataset.range(100)
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
dataset1 = replicated_ds[self._device1]
dataset2 = replicated_ds[self._device2]
with ops.device(self._device0):
self.assertDatasetProduces(dataset0, range(100))
with ops.device(self._device1):
self.assertDatasetProduces(dataset1, range(100))
with ops.device(self._device2):
self.assertDatasetProduces(dataset2, range(100))
@test_util.run_v2_only
def testMap(self):
with ops.device(self._device0):
dataset0 = dataset_ops.Dataset.range(100).map(lambda x: x * 2)
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
dataset1 = replicated_ds[self._device1]
dataset2 = replicated_ds[self._device2]
with ops.device(self._device0):
self.assertDatasetProduces(dataset0, range(0, 200, 2))
with ops.device(self._device1):
self.assertDatasetProduces(dataset1, range(0, 200, 2))
with ops.device(self._device2):
self.assertDatasetProduces(dataset2, range(0, 200, 2))
@test_util.run_v2_only
def testVariableInput(self):
with ops.device(self._device0):
counter_var = variable_scope.get_variable(
"counter", (), dtypes.int32, use_resource=True)
dataset0 = dataset_ops.Dataset.range(100).map(
lambda _: counter_var.assign_add(1))
# We don't support stateful ops in functions as of now.
with self.assertRaises(errors.FailedPreconditionError):
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
self.evaluate(replicated_ds[self._device1]._variant_tensor)
if __name__ == "__main__":
ops.enable_eager_execution(
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
test.main()