blob: 18b4503eff4c7e83e8b98a6d71893dee15c19898 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base testing class for strategies that require multiple nodes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import copy
import threading
import numpy as np
_portpicker_import_error = None
try:
import portpicker # pylint: disable=g-import-not-at-top
except ImportError as _error: # pylint: disable=invalid-name
_portpicker_import_error = _error
portpicker = None
# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
def _create_cluster(num_workers,
num_ps,
has_chief=False,
has_eval=False,
protocol='grpc',
worker_config=None,
ps_config=None):
"""Creates and starts local servers and returns the cluster_spec dict."""
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
if num_workers > 0:
cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
if num_ps > 0:
cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
if has_eval:
cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()]
if has_chief:
cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()]
cs = server_lib.ClusterSpec(cluster_dict)
for i in range(num_workers):
server_lib.Server(
cs,
job_name='worker',
protocol=protocol,
task_index=i,
config=worker_config,
start=True)
for i in range(num_ps):
server_lib.Server(
cs,
job_name='ps',
protocol=protocol,
task_index=i,
config=ps_config,
start=True)
if has_chief:
server_lib.Server(
cs,
job_name='chief',
protocol=protocol,
task_index=0,
config=worker_config,
start=True)
if has_eval:
server_lib.Server(
cs,
job_name='evaluator',
protocol=protocol,
task_index=0,
config=worker_config,
start=True)
return cluster_dict
def create_in_process_cluster(num_workers,
num_ps,
has_chief=False,
has_eval=False):
"""Create an in-process cluster that consists of only standard server."""
# Leave some memory for cuda runtime.
gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
worker_config = config_pb2.ConfigProto()
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
# Enable collective ops which has no impact on non-collective ops.
# TODO(yuefengz, tucker): removing this after we move the initialization of
# collective mgr to the session level.
if has_chief:
worker_config.experimental.collective_group_leader = (
'/job:chief/replica:0/task:0')
else:
worker_config.experimental.collective_group_leader = (
'/job:worker/replica:0/task:0')
ps_config = config_pb2.ConfigProto()
ps_config.device_count['GPU'] = 0
# Create in-process servers. Once an in-process tensorflow server is created,
# there is no way to terminate it. So we create one cluster per test process.
# We could've started the server in another process, we could then kill that
# process to terminate the server. The reasons why we don't want multiple
# processes are
# 1) it is more difficult to manage these processes;
# 2) there is something global in CUDA such that if we initialize CUDA in the
# parent process, the child process cannot initialize it again and thus cannot
# use GPUs (https://stackoverflow.com/questions/22950047).
return _create_cluster(
num_workers,
num_ps=num_ps,
has_chief=has_chief,
worker_config=worker_config,
ps_config=ps_config,
protocol='grpc')
class MultiWorkerTestBase(test.TestCase):
"""Base class for testing multi node strategy and dataset."""
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
cls._cluster_spec = create_in_process_cluster(num_workers=2, num_ps=0)
cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0]
def setUp(self):
# We only cache the session in one test because another test may have a
# different session config or master target.
self._thread_local = threading.local()
self._thread_local.cached_session = None
self._result = 0
self._lock = threading.Lock()
@contextlib.contextmanager
def test_session(self, graph=None, config=None, target=None):
"""Create a test session with master target set to the testing cluster.
This overrides the base class' method, removes arguments that are not needed
by the multi-node case and creates a test session that connects to the local
testing cluster.
Args:
graph: Optional graph to use during the returned session.
config: An optional config_pb2.ConfigProto to use to configure the
session.
target: the target of session to connect to.
Yields:
A Session object that should be used as a context manager to surround
the graph building and execution code in a test case.
"""
if self.id().endswith('.test_session'):
self.skipTest('Not a test.')
if config is None:
config = config_pb2.ConfigProto(allow_soft_placement=True)
else:
config = copy.deepcopy(config)
# Don't perform optimizations for tests so we don't inadvertently run
# gpu ops on cpu
config.graph_options.optimizer_options.opt_level = -1
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
if target is None:
target = self._default_target
if graph is None:
if getattr(self._thread_local, 'cached_session', None) is None:
self._thread_local.cached_session = session.Session(
graph=None, config=config, target=target)
sess = self._thread_local.cached_session
with sess.graph.as_default(), sess.as_default():
yield sess
else:
with session.Session(graph=graph, config=config, target=target) as sess:
yield sess
def _run_client(self, client_fn, task_type, task_id, num_gpus, *args,
**kwargs):
result = client_fn(task_type, task_id, num_gpus, *args, **kwargs)
if np.all(result):
with self._lock:
self._result += 1
def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args,
**kwargs):
"""Runs several clients for between-graph replication.
Args:
client_fn: a function that needs to accept `task_type`, `task_id`,
`num_gpus` and returns True if it succeeds.
cluster_spec: a dict specifying jobs in a cluster.
num_gpus: number of GPUs per worker.
*args: will be passed to `client_fn`.
**kwargs: will be passed to `client_fn`.
"""
threads = []
for task_type in [run_config.TaskType.CHIEF, run_config.TaskType.WORKER]:
for task_id in range(len(cluster_spec.get(task_type, []))):
t = threading.Thread(
target=self._run_client,
args=(client_fn, task_type, task_id, num_gpus) + args,
kwargs=kwargs)
t.start()
threads.append(t)
for t in threads:
t.join()
self.assertEqual(self._result, len(threads))