blob: bab07e9c9db9e374ed5414c5504f3ec665a5dcce [file] [log] [blame]
#
# Copyright (C) 2016 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Defines WorkQueue for delegating asynchronous work to subprocesses."""
import collections
import logging
import multiprocessing
import os
import Queue
import signal
import sys
import traceback
def logger():
return logging.getLogger(__name__)
def worker_sigterm_handler(_signum, _frame):
"""Raises SystemExit so atexit/finally handlers can be executed."""
sys.exit()
def _flush_queue(queue):
"""Flushes all pending items from a Queue."""
try:
while True:
queue.get_nowait()
except Queue.Empty:
pass
class TaskError(Exception):
"""An error for an exception raised in a worker process.
Exceptions raised in the worker will not be printed by default, and will
also not halt execution. We catch these exceptions in the worker process
and pass them through the queue. Results are checked, and if the result is
a TaskError the TaskError is raised in the caller's process. The message
for the TaskError is the stack trace of the original exception, and will be
printed if the TaskError is not caught.
"""
def __init__(self, trace):
super(TaskError, self).__init__(trace)
def worker_main(task_queue, result_queue):
"""Main loop for worker processes.
Args:
task_queue: A multiprocessing.Queue of Tasks to retrieve work from.
result_queue: A multiprocessing.Queue to push results to.
"""
os.setpgrp()
signal.signal(signal.SIGTERM, worker_sigterm_handler)
try:
while True:
logger().debug('worker %d waiting for work', os.getpid())
task = task_queue.get()
logger().debug('worker %d running task', os.getpid())
result = task.run()
logger().debug('worker %d putting result', os.getpid())
result_queue.put(result)
except SystemExit:
pass
except: # pylint: disable=bare-except
logger().debug('worker %d raised exception', os.getpid())
trace = ''.join(traceback.format_exception(*sys.exc_info()))
result_queue.put(TaskError(trace))
finally:
# multiprocessing.Process.terminate() doesn't kill our descendents.
signal.signal(signal.SIGTERM, signal.SIG_IGN)
logger().debug('worker %d killing process group', os.getpid())
os.kill(0, signal.SIGTERM)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
logger().debug('worker %d exiting', os.getpid())
class Task(object):
"""A task to be executed by a worker process."""
def __init__(self, func, args, kwargs):
"""Creates a task.
Args:
func: An invocable object to be executed by a worker process.
args: Arguments to be passed to the task.
kwargs: Keyword arguments to be passed to the task.
"""
self.func = func
self.args = args
self.kwargs = kwargs
def run(self):
"""Invokes the task."""
return self.func(*self.args, **self.kwargs)
class ProcessPoolWorkQueue(object):
"""A pool of processes for executing work asynchronously."""
join_timeout = 8 # Timeout for join before trying SIGKILL.
def __init__(self, num_workers=multiprocessing.cpu_count()):
"""Creates a WorkQueue.
Worker threads are spawned immediately and remain live until both
terminate() and join() are called.
Args:
num_workers: Number of worker processes to spawn.
"""
if sys.platform == 'win32':
# TODO(danalbert): Port ProcessPoolWorkQueue to Windows.
# Our implementation of ProcessPoolWorkQueue depends on process
# groups, which are not supported on Windows.
raise NotImplementedError
self.task_queue = multiprocessing.Queue()
self.result_queue = multiprocessing.Queue()
self.workers = []
# multiprocessing.JoinableQueue's join isn't able to implement
# finished() because it doesn't come in a non-blocking flavor.
self.num_tasks = 0
self._spawn_workers(num_workers)
def add_task(self, func, *args, **kwargs):
"""Queues up a new task for execution.
Tasks are executed in order of insertion as worker processes become
available.
Args:
func: An invocable object to be executed by a worker process.
args: Arguments to be passed to the task.
kwargs: Keyword arguments to be passed to the task.
"""
self.task_queue.put(Task(func, args, kwargs))
self.num_tasks += 1
def get_result(self):
"""Gets a result from the queue, blocking until one is available."""
result = self.result_queue.get()
if type(result) == TaskError:
raise result
self.num_tasks -= 1
return result
def terminate(self):
"""Terminates all worker processes."""
for worker in self.workers:
logger().info('terminating %d', worker.pid)
worker.terminate()
self._flush()
def _flush(self):
"""Flushes all pending tasks and results.
If there are still items pending in the queues when terminate is
called, the subsequent join will hang waiting for the queues to be
emptied.
We call _flush after all workers have been terminated to ensure that we
can exit cleanly.
"""
_flush_queue(self.task_queue)
_flush_queue(self.result_queue)
def join(self):
"""Waits for all worker processes to exit."""
for worker in self.workers:
logger().info('joining %d', worker.pid)
worker.join(self.join_timeout)
if worker.is_alive():
logger().error(
'worker %d will not die; sending SIGKILL', worker.pid)
os.killpg(worker.pid, signal.SIGKILL)
worker.join()
self.workers = []
def finished(self):
"""Returns True if all tasks have completed execution."""
return self.num_tasks == 0
def _spawn_workers(self, num_workers):
"""Spawns the worker processes.
Args:
num_workers: Number of worker proceeses to spawn.
"""
for _ in range(num_workers):
worker = multiprocessing.Process(
target=worker_main, args=(self.task_queue, self.result_queue))
worker.start()
self.workers.append(worker)
class DummyWorkQueue(object):
"""A fake WorkQueue that does not parallelize.
Useful for debugging when trying to determine if an issue is being caused
by multiprocess specific behavior.
"""
def __init__(self):
"""Creates a SerialWorkQueue."""
self.task_queue = collections.deque()
def add_task(self, func, *args, **kwargs):
"""Queues up a new task for execution.
Tasks are executed when get_result is called.
Args:
func: An invocable object to be executed by a worker process.
args: Arguments to be passed to the task.
kwargs: Keyword arguments to be passed to the task.
"""
self.task_queue.append(Task(func, args, kwargs))
def get_result(self):
"""Executes a task and returns the result."""
task = self.task_queue.popleft()
try:
return task.run()
except:
trace = ''.join(traceback.format_exception(*sys.exc_info()))
raise TaskError(trace)
def terminate(self):
"""Does nothing."""
pass
def join(self):
"""Does nothing."""
pass
def finished(self):
"""Returns True if all tasks have completed execution."""
return len(self.task_queue) == 0
if sys.platform == 'win32':
WorkQueue = DummyWorkQueue
else:
WorkQueue = ProcessPoolWorkQueue