blob: abc641b02ba3e65b19120c41abbe4e7665cd5009 [file] [log] [blame]
import multiprocessing
import multiprocessing.pool
import multiprocessing.util as util
import torch
from . import Queue
def clean_worker(*args, **kwargs):
import gc
multiprocessing.pool.worker(*args, **kwargs)
# Regular multiprocessing workers don't fully clean up after themselves,
# so we have to explicitly trigger garbage collection to make sure that all
# destructors are called...
gc.collect()
class Pool(multiprocessing.pool.Pool):
"""Pool implementation with customizable pickling reducers.
This is useful to control how data is shipped between processes
and makes it possible to use shared memory without useless
copies induces by the default pickling methods of the original
objects passed as arguments to dispatch.
`forward_reducers` and `backward_reducers` are expected to be
dictionaries with key/values being `(type, callable)` pairs where
`callable` is a function that, given an instance of `type`, will return a
tuple `(constructor, tuple_of_objects)` to rebuild an instance out of the
pickled `tuple_of_objects` as would return a `__reduce__` method.
See the standard library documentation about pickling for more details.
"""
def __init__(self, processes=None, forward_reducers=None,
backward_reducers=None, **kwargs):
if forward_reducers is None:
forward_reducers = dict()
if backward_reducers is None:
backward_reducers = dict()
self._forward_reducers = forward_reducers
self._backward_reducers = backward_reducers
poolargs = dict(processes=processes)
poolargs.update(kwargs)
super(Pool, self).__init__(**poolargs)
def _setup_queues(self):
context = getattr(self, '_ctx', multiprocessing)
self._inqueue = Queue(context, self._forward_reducers)
self._outqueue = Queue(context, self._backward_reducers)
self._quick_put = self._inqueue._send
self._quick_get = self._outqueue._recv
def _repopulate_pool(self):
"""Bring the number of pool processes up to the specified number,
for use after reaping workers which have exited.
"""
for i in range(self._processes - len(self._pool)):
# changed worker -> clean_worker
args = (self._inqueue, self._outqueue,
self._initializer,
self._initargs, self._maxtasksperchild)
if hasattr(self, '_wrap_exception'):
args += (self._wrap_exception,)
w = self.Process(target=clean_worker, args=args)
self._pool.append(w)
w.name = w.name.replace('Process', 'PoolWorker')
w.daemon = True
w.start()
util.debug('added worker')