import logging
import warnings
import weakref
from concurrent import futures
from contextlib import ExitStack
from contextlib import contextmanager
from pprint import pformat
from typing import Any
from typing import Optional
from typing import Tuple
from typing import Type
from ..client import SlurmPyConnRestClient
from ..client import defaults
from ..client import errors
from ..client.job_io.local import ExecuteContextReturnType
from ..client.job_io.local import ExecutorShutdown
from ..client.job_io.local import FileConnection
from ..client.job_io.local import RemoteExecutor
from ..client.job_io.local import RemoteWorkerProxy
from ..client.job_io.local import TcpConnection
from ..client.rest.api import slurm_response
logger = logging.getLogger(__name__)
[docs]
class SlurmRestFuture(futures.Future):
def __init__(self) -> None:
self._job_id = None
self._slurm_client = None
self._delayed_cancel_job = False
super().__init__()
[docs]
def job_submitted(self, job_id: int, slurm_client: SlurmPyConnRestClient) -> None:
"""The SLURM job was submitted. Beware that the Slurm job may be running other tasks as well."""
self._job_id = job_id
self._slurm_client = weakref.proxy(slurm_client)
if self._delayed_cancel_job:
slurm_client.cancel_job(job_id)
@property
def job_id(self) -> Optional[int]:
return self._job_id
@property
def slurm_client(self) -> Optional[SlurmPyConnRestClient]:
return self._slurm_client
[docs]
def cancel_job(self) -> None:
warnings.warn(
"`cancel_job()` is deprecated and will be removed in a future release. Use `abort()` instead.",
DeprecationWarning,
stacklevel=2,
)
[docs]
def abort(self) -> bool:
"""Cancel the Slurm job, even when it is already running. Beware that the Slurm job may be running other tasks as well."""
slurm_client = self.slurm_client
# The SLURM job was asked to be cancelled but didn't start yet: use `_delayed_cancel_job` to cancel it after it started.
if slurm_client is None:
self._delayed_cancel_job = True
else:
slurm_client.cancel_job(self.job_id)
return self.aborted()
[docs]
def aborted(self) -> bool:
slurm_client = self.slurm_client
if slurm_client is None:
return False
status = slurm_client.get_status(self.job_id)
return status == "CANCELLED"
[docs]
class SlurmRestExecutor(RemoteExecutor):
_FUTURE_CLASS = SlurmRestFuture
def __init__(
self,
url: str = "",
user_name: str = "",
token: str = "",
api_version: str = "",
renewal_url: str = "",
parameters: Optional[dict] = None,
log_directory: Optional[str] = None,
std_split: Optional[bool] = False,
request_options: Optional[dict] = None,
pre_script: Optional[str] = None,
post_script: Optional[str] = None,
python_cmd: Optional[str] = None,
initializer: Optional[callable] = None,
initargs: Optional[tuple] = None,
initkwargs: Optional[tuple] = None,
data_directory: Optional[str] = None,
max_workers: Optional[int] = None,
max_tasks_per_worker: Optional[int] = 1,
lazy_scheduling: bool = True,
conservative_scheduling: bool = False,
cleanup_job_artifacts: bool = False,
use_os_environment: bool = True,
):
"""
:param url: SLURM REST API URL (fallback to SLURM_URL env)
:param user_name: SLURM username (fallback to SLURM_USER or system user)
:param token: SLURM JWT token (fallback to SLURM_TOKEN env)
:param api_version: SLURM API version (e.g. 'v0.0.42')
:param renewal_url: Url for SLURM JWT token renewal (fallback to SLURM_RENEWAL_URL env)
:param parameters: SLURM job parameters
:param log_directory: SLURM log directory
:param std_split: Split standard output and standard error
:param request_options: GET, POST and DELETE options
:param pre_script: Shell script to execute at the start of a job
:param post_script: Shell script to execute at the end of a job
:param python_cmd: Python command
:param initializer: execute when starting a job
:param initargs: parameters for `initializer`
:param initkwargs: parameters for `initializer`
:param data_directory: communicate with the Slum job through files when specified
:param max_workers: maximum number of Slum jobs that can run at any given time. `None` means unlimited.
:param max_tasks_per_worker: maximum number of tasks each Slum job can receive before exiting. `None` means unlimited.
:param lazy_scheduling: schedule SLURM jobs only when needed. Can only be disabled when `max_workers` is specified.
:param conservative_scheduling: schedule the least amount of workers at the expense of tasks staying longer in the queue.
:param cleanup_job_artifacts: cleanup job artifacts like logs.
:param use_os_environment: Use ``SLURM_*`` environment variables
"""
self._slurm_client = SlurmPyConnRestClient(
url=url,
user_name=user_name,
token=token,
api_version=api_version,
renewal_url=renewal_url,
log_directory=log_directory,
parameters=parameters,
std_split=std_split,
request_options=request_options,
pre_script=pre_script,
post_script=post_script,
python_cmd=python_cmd,
use_os_environment=use_os_environment,
)
self._proxy_kwargs = {
"max_tasks": max_tasks_per_worker,
"initializer": initializer,
"initargs": initargs,
"initkwargs": initkwargs,
}
if data_directory:
user_name = self._slurm_client._auth.user_name
data_directory = str(data_directory).format(user_name=user_name)
self._file_connection_kwargs = {
"directory": data_directory,
"basename": defaults.JOB_NAME,
}
else:
self._file_connection_kwargs = None
self._cleanup_job_artifacts = cleanup_job_artifacts
super().__init__(
max_workers=max_workers,
max_tasks_per_worker=max_tasks_per_worker,
lazy_scheduling=lazy_scheduling,
conservative_scheduling=conservative_scheduling,
)
[docs]
@contextmanager
def execute_context(self) -> ExecuteContextReturnType:
with ExitStack() as stack:
job_id = None
first_submit_kw = None
def initialize(submit_kw):
"""
Initialize SLURM worker: submit the SLURM job and initialize
the communication with the worker.
"""
nonlocal job_id, first_submit_kw
first_submit_kw = submit_kw
if submit_kw is None:
submit_kw = dict()
job_id = self._slurm_client.submit_script(worker_proxy, **submit_kw)
log_ctx = self._slurm_client.redirect_stdout_stderr(job_id)
_ = stack.enter_context(log_ctx)
if self._cleanup_job_artifacts:
cleanup_ctx = self._slurm_client.clean_job_artifacts_context(job_id)
_ = stack.enter_context(cleanup_ctx)
worker_proxy.initialize()
def execute(
task: callable, args: tuple, kwargs: dict, future: SlurmRestFuture
) -> Any:
"""
Send a task to the SLURM worker. Start the worker when not already running.
"""
submit_kw = kwargs.pop(defaults.SLURM_ARGUMENTS_NAME, None)
if job_id is None:
initialize(submit_kw)
elif submit_kw != first_submit_kw:
logger.warning(
"SLURM submit arguments\n %s\n are ignored in favor of the arguments of the first task\n %s",
pformat(submit_kw),
pformat(first_submit_kw),
)
future.job_submitted(job_id, self._slurm_client)
try:
result, exc_info = worker_proxy.execute_without_reraise(
task, args=args, kwargs=kwargs
)
except (ExecutorShutdown, errors.RemoteExit):
raise
except Exception as ex:
exc_type, error_msg = status_error()
if exc_type:
raise exc_type(error_msg) from ex
raise
if exc_info is not None:
errors.reraise_remote_exception_from_tb(exc_info)
return result
def status_error() -> Tuple[Optional[Type[Exception]], Optional[str]]:
"""Returns status exception class and message in case of a status error:
remote Slurm job exited or local executor is shutting down.
"""
if job_id is None:
return None, None
try:
status = self._slurm_client.get_status(job_id)
except Exception as e:
if self._shutdown_flag:
return ExecutorShutdown, "Slurm REST executor is shutting down"
logger.warning("failed getting the job state: %s", e, exc_info=True)
status = None
if status in slurm_response.FINISHING_STATES:
return errors.RemoteExit, f"SLURM job {job_id} {status}"
return None, None
def raise_on_status_error() -> None:
"""Raise exception when there is a status error."""
exc_type, error_message = status_error()
if exc_type:
raise exc_type(error_message) from None
def worker_exit_msg() -> Optional[str]:
"""Return exit message in case the executor worker must exit, i.e.
when there is a status error."""
_, exit_msg = status_error()
return exit_msg
if self._file_connection_kwargs:
conn_ctx = FileConnection(
**self._file_connection_kwargs,
raise_on_status_error=raise_on_status_error,
)
else:
conn_ctx = TcpConnection(raise_on_status_error=raise_on_status_error)
connection = stack.enter_context(conn_ctx)
proxy_ctx = RemoteWorkerProxy(connection, **self._proxy_kwargs)
worker_proxy = stack.enter_context(proxy_ctx)
if not self._lazy_scheduling:
try:
initialize(None)
except Exception as e:
logger.warning(
"SLURM worker initialization failed: %s", e, exc_info=True
)
yield (execute, worker_exit_msg)