Source code for pyslurmutils.client.job_io.local._executor

import logging
import queue
import threading
import weakref
from concurrent import futures
from contextlib import contextmanager
from typing import Any
from typing import Callable
from typing import Generator
from typing import Optional
from typing import Set
from typing import Tuple

from ... import errors

logger = logging.getLogger(__name__)

ExecuteType = Callable[[callable, tuple, dict, futures.Future], Any]
WorkerExitMessageType = Callable[[], str]
ExecuteContextReturnType = Generator[
    Tuple[ExecuteType, WorkerExitMessageType], None, None
]


class ExecutorShutdown(Exception):
    pass


class RemoteExecutor(futures.Executor):
    """Asynchronous executor which delegates task execution to :math:`WorkerProxy` instances."""

    _FUTURE_CLASS = futures.Future

    def __init__(
        self,
        max_workers: Optional[int] = None,
        max_tasks_per_worker: Optional[int] = None,
        lazy_scheduling: bool = True,
        conservative_scheduling: bool = False,
    ):
        """
        :param max_workers: maximum number of workers that can run at any given time. `None` means unlimited.
        :param max_tasks_per_worker: maximum number of tasks each worker can receive before exiting. `None` means unlimited.
        :param lazy_scheduling: schedule workers only when needed. Can only be disabled when `max_workers` is specified.
        :param conservative_scheduling: schedule the least amount of workers at the expense of tasks staying longer in the queue.
        """
        if max_workers is None:
            if not lazy_scheduling:
                raise ValueError(
                    "Cannot disable lazy scheduling when there is no maximum number of workers"
                )
        else:
            if not isinstance(max_workers, int):
                raise TypeError("max_workers must be an integer")
            elif max_workers <= 0:
                raise ValueError("max_workers must be greater than 0")
        self._max_workers = max_workers

        if max_tasks_per_worker is not None:
            if not isinstance(max_tasks_per_worker, int):
                raise TypeError("max_tasks_per_worker must be an integer")
            elif max_tasks_per_worker <= 0:
                raise ValueError("max_tasks_per_worker must be >= 1")
        self._max_tasks_per_worker = max_tasks_per_worker

        self._lazy_scheduling = lazy_scheduling
        self._task_queue = queue.Queue()
        self._shutdown_lock = threading.Lock()
        self._workers: Set[threading.Thread] = set()
        self._futures = weakref.WeakSet()
        self._shutdown_flag = False
        self._conservative_scheduling = conservative_scheduling

        if not lazy_scheduling:
            for _ in range(self._max_workers):
                self._start_worker()

    def submit(self, task: Callable, *args, **kwargs) -> futures.Future:
        """
        :param task: function to be executed remotely
        :param args: positional arguments of the function to be executed remotely
        :param kwargs: named arguments of the function to be executed remotely
        :returns: future object to retrieve the result
        """
        with self._shutdown_lock:
            if self._shutdown_flag:
                raise RuntimeError("ThreadPool has been shut down")

            future = self._FUTURE_CLASS()
            # self._futures.add(future)
            self._task_queue.put((task, args, kwargs, future))
            logger.debug(
                "%s: submitted (task capacity = %s)",
                type(self).__name__,
                self._task_capacity(),
            )

            if self._lazy_scheduling and self._require_more_workers():
                self._start_worker()
        return future

    def shutdown(self, wait: bool = True, cancel: bool = False):
        """
        :param wait: wait for all workers to exit
        :param cancel: cancel all pending tasks (running once cannot be cancelled)
        """
        with self._shutdown_lock:
            self._shutdown_flag = True
            workers = list(self._workers)

            if cancel:
                for future in list(self._futures):
                    future.cancel()

            # Sentinel to shut down workers
            for _ in range(len(workers)):
                self._task_queue.put((None, None, None, None))

        if wait:
            for worker in workers:
                worker.join()
        self._workers.clear()

    def _task_capacity(self) -> int:
        """The total number of tasks that the running workers can execute."""
        return sum(worker.task_capacity for worker in self._workers)

    def _idle_workers(self) -> int:
        """The total number of idle workers"""
        return sum(worker.idle for worker in self._workers)

    def _max_task_queue_size(self) -> int:
        if self._conservative_scheduling:
            return self._task_capacity()
        else:
            return self._idle_workers()

    def _require_more_workers(self) -> bool:
        if self._shutdown_flag:
            # Do not start workers when shutting down
            return False
        if self._max_workers and len(self._workers) >= self._max_workers:
            # Already the maximum amount of workers
            return False
        if self._task_queue.qsize() <= self._max_task_queue_size():
            # Workers have enough capacity to drain the task queue
            return False
        return True

    def _start_worker(self) -> None:
        if self._shutdown_flag:
            # Do not start workers when shutting down
            return False

        used_indices = {w.idx for w in self._workers}
        idx = 0
        while idx in used_indices:
            idx += 1

        thread_name = f"{type(self).__name__}-{idx}"
        task_capacity = self._max_tasks_per_worker or float("inf")
        worker = threading.Thread(
            target=self._worker_main,
            name=thread_name,
            daemon=True,
        )
        worker.task_capacity = task_capacity
        worker.idle = False
        worker.idx = idx
        worker.start()
        self._workers.add(worker)
        return True

    def _worker_main(self):
        worker = threading.current_thread()
        logger.info(
            "%s: started (task capacity = %s)", worker.name, worker.task_capacity
        )
        exit_reason = "cancelled"
        try:
            with self.execute_context() as (execute, worker_exit_msg):
                while not self._shutdown_flag:
                    try:
                        worker.idle = True
                        task, args, kwargs, future = self._task_queue.get(timeout=1)
                    except queue.Empty:
                        exit_msg = worker_exit_msg()
                        if exit_msg:
                            exit_reason = exit_msg
                            break
                        continue

                    try:
                        worker.idle = False

                        if task is None:  # Sentinel to shut down the worker
                            exit_reason = "requested to stop"
                            break

                        worker.task_capacity -= 1
                        if future.set_running_or_notify_cancel():
                            logger.info(
                                "%s: start executing %s (task capacity = %s)",
                                worker.name,
                                task,
                                worker.task_capacity,
                            )
                            try:
                                result = execute(task, args, kwargs, future)
                            except Exception as exc:
                                future.set_exception(exc)
                                # self = None  # TODO: break a reference cycle with the exception 'exc' but we need self later
                                if isinstance(
                                    exc, (ExecutorShutdown, errors.RemoteExit)
                                ):
                                    exit_reason = str(exc)
                                    logger.info(
                                        "%s: failed executing %s%s (task capacity = %s)",
                                        worker.name,
                                        task,
                                        exit_reason,
                                        worker.task_capacity,
                                    )
                                    break
                                else:
                                    logger.error(
                                        "%s: failed executing %s (task capacity = %s)",
                                        worker.name,
                                        task,
                                        worker.task_capacity,
                                        exc_info=True,
                                    )
                            else:
                                future.set_result(result)
                                logger.info(
                                    "%s: succeeded executing %s (task capacity = %s)",
                                    worker.name,
                                    task,
                                    worker.task_capacity,
                                )
                        else:
                            logger.warning(
                                "%s: future cancelled (task capacity = %s)",
                                worker.name,
                                worker.task_capacity,
                            )
                    finally:
                        self._task_queue.task_done()

                    if worker.task_capacity <= 0:
                        exit_reason = "task capacity reached"
                        break
        except BaseException:
            logger.critical("Exception in worker", exc_info=True)
            raise

        try:
            logger.debug("%s: %s", worker.name, exit_reason)
            with self._shutdown_lock:
                self._workers.discard(worker)
                if not self._shutdown_flag:
                    if self._lazy_scheduling:
                        while self._require_more_workers():
                            self._start_worker()
                    else:
                        self._start_worker()
            logger.debug("%s: exiting", worker.name)
        except BaseException:
            logger.critical("Exception in worker exit", exc_info=True)
            raise

    @contextmanager
    def execute_context(
        self,
    ) -> ExecuteContextReturnType:
        def execute(
            task: callable, args: tuple, kwargs: dict, future: futures.Future
        ) -> Any:
            return task(*args, **kwargs)

        def worker_exit_msg() -> Optional[str]:
            """Return exit message in case the executor worker must exit."""
            return

        yield (execute, worker_exit_msg)