blob: 096723aa3511517f422132e86a27b5439993defd [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base testing class for strategies that require multiple nodes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import copy
import json
import os
import subprocess
import sys
import threading
import unittest
import six
_portpicker_import_error = None
try:
import portpicker # pylint: disable=g-import-not-at-top
except ImportError as _error: # pylint: disable=invalid-name
_portpicker_import_error = _error
portpicker = None
# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator
from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
original_run_std_server = dc._run_std_server # pylint: disable=protected-access
ASSIGNED_PORTS = set()
lock = threading.Lock()
def pick_unused_port():
"""Returns an unused and unassigned local port."""
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
global ASSIGNED_PORTS
with lock:
while True:
try:
port = portpicker.pick_unused_port()
except portpicker.NoFreePortFoundError:
raise unittest.SkipTest('Flakes in portpicker library do not represent '
'TensorFlow errors.')
if port > 10000 and port not in ASSIGNED_PORTS:
ASSIGNED_PORTS.add(port)
logging.info('Using local port %r', port)
return port
def _create_cluster(num_workers,
num_ps,
has_chief=False,
has_eval=False,
protocol='grpc',
worker_config=None,
ps_config=None,
eval_config=None):
"""Creates and starts local servers and returns the cluster_spec dict."""
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
worker_ports = [pick_unused_port() for _ in range(num_workers)]
ps_ports = [pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
if num_workers > 0:
cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
if num_ps > 0:
cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
if has_eval:
cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
if has_chief:
cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()]
cs = server_lib.ClusterSpec(cluster_dict)
for i in range(num_workers):
server_lib.Server(
cs,
job_name='worker',
protocol=protocol,
task_index=i,
config=worker_config,
start=True)
for i in range(num_ps):
server_lib.Server(
cs,
job_name='ps',
protocol=protocol,
task_index=i,
config=ps_config,
start=True)
if has_chief:
server_lib.Server(
cs,
job_name='chief',
protocol=protocol,
task_index=0,
config=worker_config,
start=True)
if has_eval:
server_lib.Server(
cs,
job_name='evaluator',
protocol=protocol,
task_index=0,
config=eval_config,
start=True)
return cluster_dict
def create_in_process_cluster(num_workers,
num_ps,
has_chief=False,
has_eval=False,
rpc_layer='grpc'):
"""Create an in-process cluster that consists of only standard server."""
# Leave some memory for cuda runtime.
gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
worker_config = config_pb2.ConfigProto()
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
# Enable collective ops which has no impact on non-collective ops.
# TODO(yuefengz, tucker): removing this after we move the initialization of
# collective mgr to the session level.
if has_chief:
worker_config.experimental.collective_group_leader = (
'/job:chief/replica:0/task:0')
else:
worker_config.experimental.collective_group_leader = (
'/job:worker/replica:0/task:0')
ps_config = config_pb2.ConfigProto()
ps_config.device_count['GPU'] = 0
eval_config = config_pb2.ConfigProto()
eval_config.experimental.collective_group_leader = ''
# Create in-process servers. Once an in-process tensorflow server is created,
# there is no way to terminate it. So we create one cluster per test process.
# We could've started the server in another process, we could then kill that
# process to terminate the server. The reasons why we don't want multiple
# processes are
# 1) it is more difficult to manage these processes;
# 2) there is something global in CUDA such that if we initialize CUDA in the
# parent process, the child process cannot initialize it again and thus cannot
# use GPUs (https://stackoverflow.com/questions/22950047).
cluster = None
try:
cluster = _create_cluster(
num_workers,
num_ps=num_ps,
has_chief=has_chief,
has_eval=has_eval,
worker_config=worker_config,
ps_config=ps_config,
eval_config=eval_config,
protocol=rpc_layer)
except errors.UnknownError as e:
if 'Could not start gRPC server' in e.message:
test.TestCase.SkipTest('Cannot start std servers.')
else:
raise
return cluster
# TODO(rchao): Remove `test_obj` once estimator repo picks up the updated
# nightly TF.
def create_cluster_spec(has_chief=False,
num_workers=1,
num_ps=0,
has_eval=False,
test_obj=None):
"""Create a cluster spec with tasks with unused local ports."""
del test_obj
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
cluster_spec = {}
if has_chief:
cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
if num_workers:
cluster_spec['worker'] = [
'localhost:%s' % pick_unused_port() for _ in range(num_workers)
]
if num_ps:
cluster_spec['ps'] = [
'localhost:%s' % pick_unused_port() for _ in range(num_ps)
]
if has_eval:
cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()]
return cluster_spec
@contextlib.contextmanager
def skip_if_grpc_server_cant_be_started(test_obj):
try:
yield
except errors.UnknownError as e:
if 'Could not start gRPC server' in e.message:
reason = 'Cannot start std servers.'
test_obj.test_skipped_reason = reason
test_obj.skipTest(reason)
else:
raise
class MultiWorkerTestBase(test.TestCase):
"""Base class for testing multi node strategy and dataset."""
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
cls._cluster_spec = create_in_process_cluster(num_workers=2, num_ps=1)
cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0]
def setUp(self):
# We only cache the session in one test because another test may have a
# different session config or master target.
self._thread_local = threading.local()
self._thread_local.cached_session = None
self._coord = coordinator.Coordinator()
@contextlib.contextmanager
def session(self, graph=None, config=None, target=None):
"""Create a test session with master target set to the testing cluster.
Creates a test session that connects to the local testing cluster.
Args:
graph: Optional graph to use during the returned session.
config: An optional config_pb2.ConfigProto to use to configure the
session.
target: the target of session to connect to.
Yields:
A Session object that should be used as a context manager to surround
the graph building and execution code in a test case.
"""
config = self._create_config(config)
if target is None:
target = self._default_target
with session.Session(graph=graph, config=config, target=target) as sess:
yield sess
@contextlib.contextmanager
# TODO(b/117573461): Overwrite self.evaluate() to use this function.
def cached_session(self, graph=None, config=None, target=None):
"""Create a test session with master target set to the testing cluster.
Creates a test session that connects to the local testing cluster.
The session is only created once per test and then reused.
Args:
graph: Optional graph to use during the returned session.
config: An optional config_pb2.ConfigProto to use to configure the
session.
target: the target of session to connect to.
Yields:
A Session object that should be used as a context manager to surround
the graph building and execution code in a test case. Note that the
session will live until the end of the test.
"""
config = self._create_config(config)
if target is None:
target = self._default_target
if getattr(self._thread_local, 'cached_session', None) is None:
self._thread_local.cached_session = session.Session(
graph=None, config=config, target=target)
sess = self._thread_local.cached_session
with sess.graph.as_default(), sess.as_default():
yield sess
def _create_config(self, config):
if config is None:
config = config_pb2.ConfigProto(allow_soft_placement=True)
else:
config = copy.deepcopy(config)
# Don't perform optimizations for tests so we don't inadvertently run
# gpu ops on cpu
config.graph_options.optimizer_options.opt_level = -1
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
return config
def _run_client(self, client_fn, task_type, task_id, num_gpus, eager_mode,
*args, **kwargs):
def wrapped_client_fn():
with self._coord.stop_on_exception():
client_fn(task_type, task_id, num_gpus, *args, **kwargs)
if eager_mode:
with context.eager_mode():
wrapped_client_fn()
else:
with context.graph_mode():
wrapped_client_fn()
def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args,
**kwargs):
"""Runs several clients for between-graph replication.
Args:
client_fn: a function that needs to accept `task_type`, `task_id`,
`num_gpus`.
cluster_spec: a dict specifying jobs in a cluster.
num_gpus: number of GPUs per worker.
*args: will be passed to `client_fn`.
**kwargs: will be passed to `client_fn`.
"""
threads = []
for task_type in ['chief', 'worker']:
for task_id in range(len(cluster_spec.get(task_type, []))):
t = threading.Thread(
target=self._run_client,
args=(client_fn, task_type, task_id, num_gpus,
context.executing_eagerly()) + args,
kwargs=kwargs)
t.start()
threads.append(t)
self._coord.join(threads)
class MockOsEnv(collections_abc.Mapping):
"""A class that allows per-thread TF_CONFIG."""
def __init__(self, *args):
self._dict = dict()
self._thread_local = threading.local()
super(MockOsEnv, self).__init__(*args)
def get(self, key, default=None):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
if key == 'TF_CONFIG':
return dict.get(self._thread_local.dict, key, default)
else:
return dict.get(self._dict, key, default)
def __getitem__(self, key):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
if key == 'TF_CONFIG':
return dict.__getitem__(self._thread_local.dict, key)
else:
return dict.__getitem__(self._dict, key)
def __setitem__(self, key, val):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
if key == 'TF_CONFIG':
return dict.__setitem__(self._thread_local.dict, key, val)
else:
return dict.__setitem__(self._dict, key, val)
def __iter__(self):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
for x in self._thread_local.dict:
yield x
for x in self._dict:
yield x
def __len__(self):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
return self._thread_local.dict.__len__() + self._dict.__len__()
class IndependentWorkerTestBase(test.TestCase):
"""Testing infra for independent workers."""
def _make_mock_run_std_server(self):
def _mock_run_std_server(*args, **kwargs):
"""Returns the std server once all threads have started it."""
with skip_if_grpc_server_cant_be_started(self):
ret = original_run_std_server(*args, **kwargs)
# Wait for all std servers to be brought up in order to reduce the chance
# of remote sessions taking local ports that have been assigned to std
# servers. Only call this barrier the first time this function is run for
# each thread.
if not getattr(self._thread_local, 'server_started', False):
self._barrier.wait()
self._thread_local.server_started = True
return ret
return _mock_run_std_server
def setUp(self):
self._mock_os_env = MockOsEnv()
self._mock_context = test.mock.patch.object(os, 'environ',
self._mock_os_env)
self._coord = coordinator.Coordinator()
super(IndependentWorkerTestBase, self).setUp()
self._mock_context.__enter__()
# threading local object to be shared by all threads
self._thread_local = threading.local()
def tearDown(self):
self._mock_context.__exit__(None, None, None)
super(IndependentWorkerTestBase, self).tearDown()
def _task_thread(self, task_fn, tf_config, executing_eagerly, *args,
**kwargs):
with self._coord.stop_on_exception():
os.environ['TF_CONFIG'] = json.dumps(tf_config)
# Force the new thread simulating a worker to run in the same context
# mode as the parent thread does.
if executing_eagerly:
with context.eager_mode():
task_fn(*args, **kwargs)
else:
with ops.Graph().as_default(), context.graph_mode():
task_fn(*args, **kwargs)
def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id,
*args, **kwargs):
"""Run tasks in a thread.
If `tf_config` is provided, use it for the new thread; if not, construct one
from `cluster_spec`, `task_type`, and `task_id`, and provide it to the new
thread to be set as `TF_CONFIG` environment.
Arguments:
task_fn: The function to run in the new thread.
cluster_spec: The cluster spec.
task_type: The task type.
task_id: The task id.
*args: Additional positional arguments to provide to the thread's task_fn.
**kwargs: Additional keyword arguments to provide to the thread's task_fn.
If `tf_config` is provided, that dict will be used for the TF_CONFIG for
the new thread.
Returns:
The thread that has started.
"""
tf_config = kwargs.pop('tf_config', None)
if tf_config is None:
if task_type:
tf_config = {
'cluster': cluster_spec,
'task': {
'type': task_type,
'index': task_id
}
}
else:
tf_config = {
'cluster': cluster_spec,
}
t = threading.Thread(
target=self._task_thread,
args=(task_fn, tf_config, context.executing_eagerly()) + args,
kwargs=kwargs)
t.start()
return t
def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args,
**kwargs):
# The task_fn should create std_server by itself.
threads = {}
for task_type in cluster_spec.keys():
threads[task_type] = []
for task_id in range(len(cluster_spec[task_type])):
t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id,
*args, **kwargs)
threads[task_type].append(t)
return threads
def join_independent_workers(self, worker_threads):
with skip_if_grpc_server_cant_be_started(self):
self._coord.join(worker_threads)
class MultiWorkerMultiProcessTest(test.TestCase):
"""Testing infra for independent workers using multiple processes."""
def _run_task_in_process(self, cmd_args, cluster_spec, task_type, task_id):
env = os.environ.copy()
env['TF_CONFIG'] = json.dumps({
'cluster': cluster_spec,
'task': {
'type': task_type,
'index': task_id
}
})
return subprocess.Popen(
cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
def run_multiple_tasks_in_processes(self, cmd_args, cluster_spec):
"""Run `cmd_args` in a process for each task in `cluster_spec`."""
processes = {}
for task_type in cluster_spec.keys():
processes[task_type] = []
for task_id in range(len(cluster_spec[task_type])):
p = self._run_task_in_process(cmd_args, cluster_spec, task_type,
task_id)
processes[task_type].append(p)
return processes
def join_independent_workers(self, worker_processes):
return_codes = []
for p in nest.flatten(worker_processes):
try:
# Calling p.wait() will hang if we don't consume its output.
p.communicate()
except ValueError:
# The output of the process may have been consumed, in which case
# calling `p.communicate()` will raise a ValueError.
pass
finally:
return_codes.append(p.returncode)
for return_code in return_codes:
self.assertEqual(return_code, 0)
def stream_stderr(self, processes, print_only_first=False):
"""Consume stderr of all processes and print to stdout.
To reduce the amount of logging, caller can set print_only_first to True.
In that case, this function only prints stderr from the first process of
each type.
Arguments:
processes: A dictionary from process type string -> list of processes.
print_only_first: If true, only print output from first process of each
type.
"""
def _stream_stderr_single_process(process, type_string, index,
print_to_stdout):
"""Consume a single process's stderr and optionally print to stdout."""
while True:
output = process.stderr.readline()
if not output and process.poll() is not None:
break
if output and print_to_stdout:
print('{}{} {}'.format(type_string, index, output.strip()))
sys.stdout.flush()
stream_threads = []
for process_type, process_list in six.iteritems(processes):
for i in range(len(process_list)):
print_to_stdout = (not print_only_first) or (i == 0)
thread = threading.Thread(
target=_stream_stderr_single_process,
args=(process_list[i], process_type, i, print_to_stdout))
thread.start()
stream_threads.append(thread)
for thread in stream_threads:
thread.join()
def get_tf_config_task():
return json.loads(os.environ['TF_CONFIG'])['task']
def get_tf_config_cluster_spec():
return json.loads(os.environ['TF_CONFIG'])['cluster']
def get_task_type():
return get_tf_config_task()['type']
def get_task_index():
return get_tf_config_task()['index']
def is_chief():
return ('chief' not in get_tf_config_cluster_spec()
and get_task_type() == 'worker'
and get_task_index() == 0)