Source code for lenstronomy.Sampling.Pool.multiprocessing

"""
this file is taken from schwimmbad (https://github.com/adrn/schwimmbad) and an explicit fork by Aymeric Galan
to replace the multiprocessing with the multiprocess dependence as for multi-threading, multiprocessing is
not supporting dill (only pickle) which is required.

The class also extends with a ``is_master()`` definition

"""


# Standard library

import signal
import functools
import multiprocess
from multiprocess.pool import Pool

__all__ = ["MultiPool"]


def _initializer_wrapper(actual_initializer, *rest):
    """We ignore SIGINT.

    It's up to our parent to kill us in the typical condition of this arising from ``^C`` on a
    terminal. If someone is manually killing us with that signal, well... nothing will happen.
    """
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    if actual_initializer is not None:
        actual_initializer(*rest)


class CallbackWrapper(object):
    def __init__(self, callback):
        self.callback = callback

    def __call__(self, tasks):
        for task in tasks:
            self.callback(task)


[docs] class MultiPool(Pool): """A modified version of :class:`multiprocessing.pool.Pool` that has better behavior with regard to ``KeyboardInterrupts`` in the :func:`map` method. (Original author: `Peter K. G. Williams <peter@newton.cx>`_) """ wait_timeout = 3600
[docs] def __init__(self, processes=None, initializer=None, initargs=(), **kwargs): """ :param processes: The number of worker processes to use; defaults to the number of CPUs. :type processes: int, optional :param initializer: If specified, a callable that will be invoked by each worker process when it starts. :type initializer: callable, optional :param initargs: Arguments for ``initializer``; it will be called as ``initializer(*initargs)``. :type initargs: iterable, optional :param kwargs: Extra arguments passed to the :class:`multiprocessing.pool.Pool` superclass. """ new_initializer = functools.partial(_initializer_wrapper, initializer) super(MultiPool, self).__init__(processes, new_initializer, initargs, **kwargs) self.size = self._processes self.rank = 0
[docs] def is_master(self): return self.rank == 0
[docs] def is_worker(self): return self.rank != 0
[docs] @staticmethod def enabled(): return True
[docs] def map(self, func, iterable, chunksize=None, callback=None): """Equivalent to the built-in ``map()`` function and :meth:`multiprocessing.pool.Pool.map()`, without catching ``KeyboardInterrupt``. :param func: A function or callable object that is executed on each element of the specified ``tasks`` iterable. This object must be picklable (i.e. it can't be a function scoped within a function or a ``lambda`` function). This should accept a single positional argument and return a single object. :type func: callable :param iterable: A list or iterable of tasks. Each task can be itself an iterable (e.g., tuple) of values or data to pass in to the worker function. :type iterable: iterable :param callback: An optional callback function (or callable) that is called with the result from each worker run and is executed on the master process. This is useful for, e.g., saving results to a file, since the callback is only called on the master thread. :type callback: callable, optional :return: A list of results from the output of each ``worker()`` call. """ if callback is None: callbackwrapper = None else: callbackwrapper = CallbackWrapper(callback) # The key magic is that we must call r.get() with a timeout, because # a Condition.wait() without a timeout swallows KeyboardInterrupts. r = self.map_async( func, iterable, chunksize=chunksize, callback=callbackwrapper ) while True: try: return r.get(self.wait_timeout) except multiprocess.TimeoutError: pass except KeyboardInterrupt: self.terminate() self.join() raise