Source code for lenstronomy.Sampling.Pool.multiprocessing
"""
this file is taken from schwimmbad (https://github.com/adrn/schwimmbad) and an explicit fork by Aymeric Galan
to replace the multiprocessing with the multiprocess dependence as for multi-threading, multiprocessing is
not supporting dill (only pickle) which is required.
The class also extends with a ``is_master()`` definition
"""
# Standard library
import signal
import functools
import multiprocess
from multiprocess.pool import Pool
__all__ = ["MultiPool"]
def _initializer_wrapper(actual_initializer, *rest):
"""We ignore SIGINT.
It's up to our parent to kill us in the typical condition of this arising from ``^C`` on a
terminal. If someone is manually killing us with that signal, well... nothing will happen.
"""
signal.signal(signal.SIGINT, signal.SIG_IGN)
if actual_initializer is not None:
actual_initializer(*rest)
class CallbackWrapper(object):
def __init__(self, callback):
self.callback = callback
def __call__(self, tasks):
for task in tasks:
self.callback(task)
[docs]class MultiPool(Pool):
"""A modified version of :class:`multiprocessing.pool.Pool` that has better behavior
with regard to ``KeyboardInterrupts`` in the :func:`map` method.
(Original author: `Peter K. G. Williams <peter@newton.cx>`_)
"""
wait_timeout = 3600
[docs] def __init__(self, processes=None, initializer=None, initargs=(), **kwargs):
"""
:param processes: The number of worker processes to use; defaults to the number of CPUs.
:type processes: int, optional
:param initializer: If specified, a callable that will be invoked by each worker process when it starts.
:type initializer: callable, optional
:param initargs: Arguments for ``initializer``; it will be called as ``initializer(*initargs)``.
:type initargs: iterable, optional
:param kwargs: Extra arguments passed to the :class:`multiprocessing.pool.Pool` superclass.
"""
new_initializer = functools.partial(_initializer_wrapper, initializer)
super(MultiPool, self).__init__(processes, new_initializer, initargs, **kwargs)
self.size = self._processes
self.rank = 0
[docs] def is_master(self):
return self.rank == 0
[docs] def is_worker(self):
return self.rank != 0
[docs] @staticmethod
def enabled():
return True
[docs] def map(self, func, iterable, chunksize=None, callback=None):
"""Equivalent to the built-in ``map()`` function and
:meth:`multiprocessing.pool.Pool.map()`, without catching ``KeyboardInterrupt``.
:param func: A function or callable object that is executed on each element of
the specified ``tasks`` iterable. This object must be picklable
(i.e. it can't be a function scoped within a function or a
``lambda`` function). This should accept a single positional
argument and return a single object.
:type func: callable
:param iterable: A list or iterable of tasks. Each task can be itself an iterable
(e.g., tuple) of values or data to pass in to the worker function.
:type iterable: iterable
:param callback: An optional callback function (or callable) that is called with the
result from each worker run and is executed on the master process.
This is useful for, e.g., saving results to a file, since the
callback is only called on the master thread.
:type callback: callable, optional
:return: A list of results from the output of each ``worker()`` call.
"""
if callback is None:
callbackwrapper = None
else:
callbackwrapper = CallbackWrapper(callback)
# The key magic is that we must call r.get() with a timeout, because
# a Condition.wait() without a timeout swallows KeyboardInterrupts.
r = self.map_async(
func, iterable, chunksize=chunksize, callback=callbackwrapper
)
while True:
try:
return r.get(self.wait_timeout)
except multiprocess.TimeoutError:
pass
except KeyboardInterrupt:
self.terminate()
self.join()
raise