blob: 48bca52b4b9fcb07d88e388632766be7569fe429 [file] [log] [blame]
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fault tolerance test for parameter server training in TF2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
import os
import threading
import time
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator as thread_coordinator
from tensorflow.python.training import server_lib
_RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker"
_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
_WORKER_PREEMPTION_THREAD_NAME = "WorkerPreemptionHandler"
class Model(object):
def __init__(self, coordinator):
self.cluster_coord = coordinator
self.strategy = self.cluster_coord.strategy
with self.cluster_coord.strategy.scope():
self.build()
def build(self):
self.w = variables.Variable(
initial_value=random_ops.random_uniform((10, 10)), dtype=dtypes.float32)
self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32)
# Allow external control to make the model run its train_fn in an infinite
# loop. This allows us to reliably test worker preemption in the middle of
# function execution.
self.do_infinite_step = variables.Variable(False)
def dataset_fn():
data = random_ops.random_uniform((10, 10))
dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat()
return dataset
self.iterator = iter(
self.cluster_coord.create_per_worker_dataset(dataset_fn))
def _train_fn_internal(self, iterator):
x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w)
x = math_ops.matmul(random_ops.random_uniform((10, 10)), x)
self.w.assign_add(x)
@def_function.function
def train_fn(self, iterator):
self._train_fn_internal(iterator)
while self.do_infinite_step:
self._train_fn_internal(iterator)
self.iterations.assign_add(1)
def schedule_training_functions(self, num_steps):
with self.strategy.scope():
for _ in range(num_steps):
self.cluster_coord.schedule(self.train_fn, args=(self.iterator,))
def join_training_functions(self):
self.do_infinite_step.assign(False)
self.cluster_coord.join()
class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring
def setUp(self, num_workers, num_ps):
super(BaseFaultToleranceTest, self).setUp()
# Set the environment variable to prevent hanging upon job failure and
# restart. Note that it defaults to 'use_caller' at Google, but defaults
# to False in OSS.
os.environ["GRPC_FAIL_FAST"] = "use_caller"
self._cluster = multi_worker_test_base.create_multi_process_cluster(
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
self._cluster_def["chief"] = [
"localhost:%d" % multi_worker_test_base.pick_unused_port()
]
cluster_resolver = SimpleClusterResolver(
server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc")
# The strategy's constructor would connect to the cluster.
self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
cluster_resolver)
self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy)
self.thread_coord = thread_coordinator.Coordinator(
clean_stop_exception_types=[])
self.num_workers = num_workers
def tearDown(self):
super(BaseFaultToleranceTest, self).tearDown()
self._cluster.stop()
def _restart(self, downtime_secs, job):
"""Kills `job` (index: 0) and restarts it after `downtime_secs`.
Args:
downtime_secs: secs before restarting the job.
job: a string specifying the job to restart.
"""
self._cluster.kill_task(job, 0)
time.sleep(downtime_secs)
self.assertFalse(context.check_alive("/job:%s/replica:0/task:0" % job))
self._cluster.start_task(job, 0)
while not context.check_alive("/job:%s/replica:0/task:0" % job):
time.sleep(1)
def _restart_in_thread(self, downtime_secs, restart_job):
def _restart_fn():
with self.thread_coord.stop_on_exception():
self._restart(downtime_secs, restart_job)
restart_thread = threading.Thread(target=_restart_fn)
restart_thread.start()
return restart_thread
def _ensure_threads_closed(self):
"""Ensure worker and preemption threads are closed."""
# Wait for threads to close.
self.cluster_coord = None
gc.collect()
time.sleep(1)
# Verify thread names.
running_threads = set()
for thread in threading.enumerate():
logging.info("Running thread name:%s", thread.name)
if thread.name is not None:
running_threads.add(thread.name)
# TODO(xingliu): Verify worker threads are closed.
self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
def testClusterCoordinatorDestroyed(self):
self._ensure_threads_closed()
def testWorkerPreemptionBetweenFunctions(self):
model = Model(self.cluster_coord)
model.schedule_training_functions(2)
model.join_training_functions()
self.assertEqual(model.iterations.numpy(), 2)
self._restart(downtime_secs=2, job="worker")
model.schedule_training_functions(2)
model.join_training_functions()
self.assertEqual(model.iterations.numpy(), 4)
def testWorkerPreemptionMidstFunction(self):
model = Model(self.cluster_coord)
model.do_infinite_step.assign(True)
model.schedule_training_functions(4)
# Model does infinite training step, so at this moment, we expect to have
# `self.num_workers` infinite closures inflight, and `4-self.num_workers`
# closures in the queue.
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
< self.num_workers):
time.sleep(0.1)
self.assertFalse(self.cluster_coord.done())
self._restart(downtime_secs=2, job="worker")
model.join_training_functions()
self.assertGreaterEqual(model.iterations.numpy(), 4)
def testOneWorkerPreemptionWithCancellation(self):
@def_function.function
def normal_function():
x = random_ops.random_uniform((2, 10))
y = random_ops.random_uniform((10, 2))
return math_ops.reduce_mean(math_ops.matmul(x, y))
@def_function.function
def error_function():
x = random_ops.random_uniform((2, 10))
y = random_ops.random_uniform((10, 2))
check_ops.assert_non_positive_v2(
math_ops.reduce_sum(math_ops.matmul(x, y)))
return x
@def_function.function
def long_function():
x = random_ops.random_uniform((1000, 1000))
for _ in math_ops.range(10000):
a = random_ops.random_uniform((1000, 1000))
b = random_ops.random_uniform((1000, 1000))
x += math_ops.matmul(a, b)
return x
for _ in range(3):
self.cluster_coord.schedule(normal_function)
long_function_result = self.cluster_coord.schedule(long_function)
self.cluster_coord.schedule(error_function)
time.sleep(1) # Let it run a couple steps.
self._restart(1, "worker")
with self.assertRaises(errors.InvalidArgumentError):
self.cluster_coord.join()
with self.assertRaises(errors.CancelledError):
long_function_result.fetch()
for _ in range(3):
self.cluster_coord.schedule(normal_function)
self.cluster_coord.join()
def testHandleDatasetCreationFailure(self):
model = Model(self.cluster_coord)
restart_thread = self._restart_in_thread(5, "worker")
model.schedule_training_functions(3)
model.join_training_functions()
self.thread_coord.join([restart_thread])
self.assertGreaterEqual(model.iterations.numpy(), 3)
def testWorkerPreemptionErrorType(self):
@def_function.function
def worker_train_fn():
x = random_ops.random_uniform((2, 10))
y = random_ops.random_uniform((10, 2))
return math_ops.reduce_mean(math_ops.matmul(x, y))
def run_fn():
with self.thread_coord.stop_on_exception():
with ops.device("/job:worker/replica:0/task:0"):
for _ in range(3):
for _ in range(3):
worker_train_fn()
time.sleep(5)
run_thread = threading.Thread(target=run_fn)
run_thread.start()
time.sleep(1) # Let it run a couple steps.
self._restart(2, "worker")
try:
self.thread_coord.join([run_thread])
except errors.UnavailableError as e:
logging.info("Got exception %r, error message is %s", e, e)
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
self.assertTrue("failed to connect to all addresses" in str(e) or
"Unable to find a context_id" in str(e) or
"Socket closed" in str(e) or
"Connection reset by peer" in str(e) or
"Transport closed" in str(e))
def testWorkerPreemptionErrorTypeWithPythonFunction(self):
def worker_train_fn():
x = random_ops.random_uniform((2, 10))
y = random_ops.random_uniform((10, 2))
return math_ops.reduce_mean(math_ops.matmul(x, y))
def run_fn():
with self.thread_coord.stop_on_exception():
with ops.device("/job:worker/replica:0/task:0"):
for _ in range(3):
for _ in range(3):
worker_train_fn()
time.sleep(5)
run_thread = threading.Thread(target=run_fn)
run_thread.start()
time.sleep(1) # Let it run a couple steps.
self._restart(2, "worker")
try:
self.thread_coord.join([run_thread])
except errors.UnavailableError as e:
logging.info("Got exception %r, error message is %s", e, e)
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
self.assertTrue("failed to connect to all addresses" in str(e) or
"Unable to find a context_id" in str(e) or
"Socket closed" in str(e) or
"Connection reset by peer" in str(e) or
"Transport closed" in str(e))
def testPSPreemptionErrorType(self):
with ops.device("/job:ps/replica:0/task:0"):
v = variables.Variable(
initial_value=random_ops.random_uniform((2, 10)),
dtype=dtypes.float32)
@def_function.function
def worker_train_fn():
y = random_ops.random_uniform((10, 2))
return math_ops.reduce_mean(math_ops.matmul(v, y))
def run_fn():
with self.thread_coord.stop_on_exception():
with ops.device("/job:worker/replica:0/task:0"):
for _ in range(3):
for _ in range(3):
worker_train_fn()
time.sleep(5)
run_thread = threading.Thread(target=run_fn)
run_thread.start()
time.sleep(1) # Let it run a couple steps.
# Use a short restart delay to cover the case that RPC channel is reused
self._restart(1, "ps")
try:
self.thread_coord.join([run_thread])
except (errors.UnavailableError, errors.AbortedError) as e:
logging.info("Got exception %r, error message is %s", e, e)
self.assertIn(_RPC_ERROR_FROM_PS, str(e)) # pylint: disable=g-assert-in-except
if isinstance(e, errors.UnavailableError):
self.assertTrue("failed to connect to all addresses" in str(e) or
"Unable to find a context_id" in str(e) or
"Socket closed" in str(e) or
"Connection reset by peer" in str(e) or
"Transport closed" in str(e))
if isinstance(e, errors.AbortedError):
self.assertIn("RecvTensor expects a different device incarnation",
str(e))
self._ensure_threads_closed()
def testTwoWorkersPreempted(self):
if self.num_workers < 2:
self.skipTest("Worker number is less than 2.")
model = Model(self.cluster_coord)
model.do_infinite_step.assign(True)
model.schedule_training_functions(10)
# Model does infinite training step, so at this moment, we expect to have 2
# infinite closures inflight, and 8 closures in the queue.
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
< 2):
time.sleep(0.1)
self.assertFalse(self.cluster_coord.done())
self._cluster.kill_task("worker", 0)
self._cluster.kill_task("worker", 1)
time.sleep(2)
self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
self.assertFalse(context.check_alive("/job:worker/replica:0/task:1"))
self._cluster.start_task("worker", 0)
self._cluster.start_task("worker", 1)
time.sleep(2)
self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
self.assertTrue(context.check_alive("/job:worker/replica:0/task:1"))
model.join_training_functions()
self.assertGreaterEqual(model.iterations.numpy(), 10)
def testWorkerContinuousFailure(self):
model = Model(self.cluster_coord)
model.do_infinite_step.assign(True)
model.schedule_training_functions(10)
# Model does infinite training step, so at this moment, we expect to have
# `self.num_workers` infinite closures inflight, and `10-self.num_workers`
# closures in the queue.
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
< self.num_workers):
time.sleep(0.1)
self.assertFalse(self.cluster_coord.done())
self._cluster.kill_task("worker", 0)
time.sleep(2)
self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
self._cluster.start_task("worker", 0)
time.sleep(2)
self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
self._cluster.kill_task("worker", 0)
time.sleep(2)
self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
self._cluster.start_task("worker", 0)
time.sleep(2)
self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
model.join_training_functions()
self.assertGreaterEqual(model.iterations.numpy(), 10)
def testNumpyFetchedAfterWorkerFailure(self):
with self.strategy.scope():
v = variables.Variable(initial_value=0, dtype=dtypes.int32)
@def_function.function
def worker_fn():
return v + 1, v - 1
remote_value = self.cluster_coord.schedule(worker_fn)
# Attempt to fetch before killing worker task should succeed.
self.assertEqual((1, -1), remote_value.fetch())
self._cluster.kill_task("worker", 0)
# So should attempt to fetch after killing worker task.
self.assertEqual((1, -1), remote_value.fetch())
def testClusterStateNotDisrupted(self):
# This test has side effects and can disrupt other tests, even if the
# resource created by it will not be used in following tests.
# TODO(b/155209534): enable this test.
# self.testPSPreemptionErrorType()
self.thread_coord = thread_coordinator.Coordinator(
clean_stop_exception_types=[])
self.testWorkerPreemptionMidstFunction()
self.thread_coord = thread_coordinator.Coordinator(
clean_stop_exception_types=[])
self.testWorkerPreemptionErrorType()
# In previous tests, workers may fail after training is done. But the
# following tests start with creating resources where failure is not
# handled.
# TODO(b/153888707): enable the following two tests.
# self.testTwoWorkersPreempted()
# self.testWorkerContinuousFailure()
class MultiWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
"""Multi worker fault tolerance tests.
This covers the ordinary cases where multiple workers and PS are used.
"""
def setUp(self):
super(MultiWorkerFaultToleranceTest, self).setUp(2, 2)
class SingleWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
"""Single worker fault tolerance tests.
This covers the cases that ensure training can continue in a single-worker
cluster, even if the only worker can become unavailable at some point and
recovered (if there are multiple workers, it is possible that the training
succeeds with the workers that did not fail). Realistically single worker
is very rarely used, but the tests are important to ensure the correct
behaviors.
"""
def setUp(self):
super(SingleWorkerFaultToleranceTest, self).setUp(1, 1)
if __name__ == "__main__":
v2_compat.enable_v2_behavior()
multi_process_runner.test_main()