blob: 298a76a0bf3d71811acfcad4f210716769fd26c9 [file] [log] [blame]
#
# Copyright (C) 2016 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Defines WorkQueue for delegating asynchronous work to subprocesses."""
from __future__ import absolute_import
import collections
import logging
import multiprocessing
import os
import signal
import sys
import traceback
def logger():
return logging.getLogger(__name__)
def worker_sigterm_handler(_signum, _frame):
"""Raises SystemExit so atexit/finally handlers can be executed."""
sys.exit()
class TaskError(Exception):
"""An error for an exception raised in a worker process.
Exceptions raised in the worker will not be printed by default, and will
also not halt execution. We catch these exceptions in the worker process
and pass them through the queue. Results are checked, and if the result is
a TaskError the TaskError is raised in the caller's process. The message
for the TaskError is the stack trace of the original exception, and will be
printed if the TaskError is not caught.
"""
def __init__(self, trace):
super(TaskError, self).__init__(trace)
def create_windows_process_group():
import ndk.win32
job = ndk.win32.CreateJobObject()
limit_info = ndk.win32.JOBOBJECT_EXTENDED_LIMIT_INFORMATION(
BasicLimitInformation=ndk.win32.JOBOBJECT_BASIC_LIMIT_INFORMATION(
LimitFlags=ndk.win32.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE))
ndk.win32.SetInformationJobObject(
job, ndk.win32.JobObjectExtendedLimitInformation, limit_info)
ndk.win32.AssignProcessToJobObject(job, ndk.win32.GetCurrentProcess())
def assign_self_to_new_process_group():
if sys.platform == 'win32':
return create_windows_process_group()
else:
return os.setpgrp()
def kill_process_group(group):
if sys.platform == 'win32':
import ndk.win32
ndk.win32.CloseHandle(group)
else:
os.kill(0, signal.SIGTERM)
class Worker(object):
IDLE_STATUS = 'IDLE'
EXCEPTION_STATUS = 'EXCEPTION'
def __init__(self, data, task_queue, result_queue, manager):
"""Creates a Worker object.
Args:
task_queue: A multiprocessing.Queue of Tasks to retrieve work from.
result_queue: A multiprocessing.Queue to push results to.
"""
self.data = data
self.task_queue = task_queue
self.result_queue = result_queue
# For multiprocess.Manager.Value, the type is actually ignored.
# https://stackoverflow.com/a/21290961/632035
self._status = manager.Value('', self.IDLE_STATUS)
self._status_lock = manager.Lock()
self.process = multiprocessing.Process(target=self.main)
@property
def status(self):
with self._status_lock:
return self._status.value
@status.setter
def status(self, value):
with self._status_lock:
self._status.value = value
def put_result(self, result, status):
with self._status_lock:
self._status.value = status
self.result_queue.put(result)
@property
def pid(self):
return self.process.pid
def is_alive(self):
return self.process.is_alive()
def start(self):
self.process.start()
def terminate(self):
self.process.terminate()
def join(self, timeout=None):
self.process.join(timeout)
def main(self):
"""Main loop for worker processes."""
group = assign_self_to_new_process_group()
signal.signal(signal.SIGTERM, worker_sigterm_handler)
try:
while True:
logger().debug('worker %d waiting for work', os.getpid())
task = self.task_queue.get()
logger().debug('worker %d running task', os.getpid())
result = task.run(self)
logger().debug('worker %d putting result', os.getpid())
self.put_result(result, self.IDLE_STATUS)
except SystemExit:
pass
except: # pylint: disable=bare-except
logger().debug('worker %d raised exception', os.getpid())
trace = ''.join(traceback.format_exception(*sys.exc_info()))
self.put_result(TaskError(trace), self.EXCEPTION_STATUS)
finally:
# multiprocessing.Process.terminate() doesn't kill our descendents.
signal.signal(signal.SIGTERM, signal.SIG_IGN)
logger().debug('worker %d killing process group', os.getpid())
kill_process_group(group)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
logger().debug('worker %d exiting', os.getpid())
class Task(object):
"""A task to be executed by a worker process."""
def __init__(self, func, args, kwargs):
"""Creates a task.
Args:
func: An invocable object to be executed by a worker process.
args: Arguments to be passed to the task.
kwargs: Keyword arguments to be passed to the task.
"""
self.func = func
self.args = args
self.kwargs = kwargs
def run(self, worker_data):
"""Invokes the task."""
return self.func(worker_data, *self.args, **self.kwargs)
class ProcessPoolWorkQueue(object):
"""A pool of processes for executing work asynchronously."""
join_timeout = 8 # Timeout for join before trying SIGKILL.
def __init__(self, num_workers=multiprocessing.cpu_count(),
task_queue=None, result_queue=None, worker_data=None):
"""Creates a WorkQueue.
Worker threads are spawned immediately and remain live until both
terminate() and join() are called.
Args:
num_workers: Number of worker processes to spawn.
task_queue: multiprocessing.Queue for tasks. Allows multiple work
queues to share a single task queue. If None, the work queue
creates its own.
result_queue: multiprocessing.Queue for results. Allows multiple
work queues to share a single result queue. If None, the work
queue creates its own.
worker_data: Data to be passed to every task run by this work
queue.
"""
self.manager = multiprocessing.Manager()
self.task_queue = task_queue
self.owns_task_queue = False
if task_queue is None:
self.task_queue = self.manager.Queue()
self.owns_task_queue = True
self.result_queue = result_queue
self.owns_result_queue = False
if result_queue is None:
self.result_queue = self.manager.Queue()
self.owns_result_queue = True
self.worker_data = worker_data
self.workers = []
# multiprocessing.JoinableQueue's join isn't able to implement
# finished() because it doesn't come in a non-blocking flavor.
self.num_tasks = 0
self._spawn_workers(num_workers)
def add_task(self, func, *args, **kwargs):
"""Queues up a new task for execution.
Tasks are executed in order of insertion as worker processes become
available.
Args:
func: An invocable object to be executed by a worker process.
args: Arguments to be passed to the task.
kwargs: Keyword arguments to be passed to the task.
"""
self.task_queue.put(Task(func, args, kwargs))
self.num_tasks += 1
def get_result(self):
"""Gets a result from the queue, blocking until one is available."""
result = self.result_queue.get()
if type(result) == TaskError:
raise result
self.num_tasks -= 1
return result
def terminate(self):
"""Terminates all worker processes."""
for worker in self.workers:
logger().debug('terminating %d', worker.pid)
worker.terminate()
def join(self):
"""Waits for all worker processes to exit."""
for worker in self.workers:
logger().debug('joining %d', worker.pid)
worker.join(self.join_timeout)
if worker.is_alive():
logger().error(
'worker %d will not die; sending SIGKILL', worker.pid)
os.killpg(worker.pid, signal.SIGKILL)
worker.join()
self.workers = []
def finished(self):
"""Returns True if all tasks have completed execution."""
return self.num_tasks == 0
def _spawn_workers(self, num_workers):
"""Spawns the worker processes.
Args:
num_workers: Number of worker proceeses to spawn.
"""
for _ in range(num_workers):
worker = Worker(
self.worker_data, self.task_queue, self.result_queue,
self.manager)
worker.start()
self.workers.append(worker)
class DummyWorker(object):
def __init__(self, data):
self.data = data
class DummyWorkQueue(object):
"""A fake WorkQueue that does not parallelize.
Useful for debugging when trying to determine if an issue is being caused
by multiprocess specific behavior.
"""
# pylint: disable=unused-argument
def __init__(self, num_workers=None, task_queue=None, result_queue=None,
worker_data=None):
"""Creates a SerialWorkQueue."""
self.task_queue = collections.deque()
self.worker_data = worker_data
# pylint: enable=unused-argument
def add_task(self, func, *args, **kwargs):
"""Queues up a new task for execution.
Tasks are executed when get_result is called.
Args:
func: An invocable object to be executed by a worker process.
args: Arguments to be passed to the task.
kwargs: Keyword arguments to be passed to the task.
"""
self.task_queue.append(Task(func, args, kwargs))
def get_result(self):
"""Executes a task and returns the result."""
task = self.task_queue.popleft()
try:
return task.run(DummyWorker(self.worker_data))
except:
trace = ''.join(traceback.format_exception(*sys.exc_info()))
raise TaskError(trace)
def terminate(self):
"""Does nothing."""
pass
def join(self):
"""Does nothing."""
pass
@property
def num_tasks(self):
return len(self.task_queue)
@property
def workers(self):
return []
def finished(self):
"""Returns True if all tasks have completed execution."""
return self.num_tasks == 0
WorkQueue = ProcessPoolWorkQueue