"""Low-level SLURM API to submit, cancel and monitor jobs"""
import datetime
import json
import logging
import os
import time
from contextlib import contextmanager
from pprint import pformat
from typing import Any
from typing import Dict
from typing import Generator
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from urllib.parse import urlparse
from urllib.parse import urlunparse
from pydantic import BaseModel
from .. import defaults
from .. import os_utils
from .. import url_utils
from .. import utils
from ..errors import SlurmInvalidUrlError
from ..errors import SlurmMissingParameterError
from ..log_utils import log_file_monitor_context
from .api import slurm_access
from .api import slurm_auth
from .api import slurm_logs
from .api import slurm_params
from .api import slurm_response
logger = logging.getLogger(__name__)
[docs]
class SlurmBaseRestClient:
"""Low-level SLURM API to submit, cancel and monitor jobs.
This class does not contain any job-related state."""
def __init__(
self,
url: str = "",
user_name: str = "",
token: str = "",
api_version: str = "",
renewal_url: str = "",
parameters: Optional[dict] = None,
log_directory: Optional[str] = None,
std_split: Optional[bool] = False,
request_options: Optional[dict] = None,
use_os_environment: bool = True,
):
"""
:param url: SLURM REST API URL (fallback to SLURM_URL env)
:param user_name: SLURM username (fallback to SLURM_USER or system user)
:param token: SLURM JWT token (fallback to SLURM_TOKEN env)
:param api_version: SLURM API version (e.g. 'v0.0.42')
:param renewal_url: Url for SLURM JWT token renewal (fallback to SLURM_RENEWAL_URL env)
:param parameters: SLURM job parameters
:param log_directory: SLURM log directory
:param std_split: Split standard output and standard error
:param request_options: GET, POST and DELETE options
:param use_os_environment: Use ``SLURM_*`` environment variables
:raises SlurmMissingParameterError:
:raises SlurmTokenInvalidError:
:raises SlurmInvalidUrlError:
"""
if not url and use_os_environment:
url = os.environ.get("SLURM_URL")
if not url:
raise SlurmMissingParameterError(
"'url' must be provided or set environment variable SLURM_URL"
)
self._base_url = _parse_slurm_rest_url(url)
self._auth = slurm_auth.SlurmAuthManager(
user_name=user_name,
token=token,
renewal_url=renewal_url,
use_os_environment=use_os_environment,
)
if not api_version and use_os_environment:
api_version = os.environ.get("SLURM_API_VERSION")
if not api_version:
api_version = defaults.DEFAULT_API_VERSION
version_tuple = slurm_access.parse_version(api_version)
self._api_version_str: str = slurm_access.create_version(version_tuple)
self._api_version_tuple: Tuple[int, int, int] = version_tuple
self._parameters = parameters
if log_directory:
log_directory = str(log_directory).format(user_name=self._auth.user_name)
self._log_directory = log_directory
self._std_split = std_split
self._request_options = request_options
[docs]
def submit_job(
self,
script: str,
parameters: Optional[dict] = None,
metadata: Optional[Union[str, dict]] = None,
request_options: Optional[dict] = None,
) -> int:
"""Returns the SLURM job ID"""
# Job parameters
parameters = utils.merge_mappings(self._parameters, parameters)
self._submit_ensure_name(parameters)
self._submit_ensure_std(parameters)
self._submit_ensure_environment(parameters)
self._submit_ensure_wd(parameters)
if metadata:
if not isinstance(metadata, str):
metadata = json.dumps(metadata)
parameters["comment"] = metadata
slurm_params.coerce_parameters(self._api_version_tuple, parameters)
# Request body
body = {"job": parameters}
if request_options:
request_options["json"] = body
else:
request_options = {"json": body}
# Job script
if self._api_version_tuple < (0, 0, 39):
body["script"] = script
else:
body["job"]["script"] = script
# Submit job
response = self.post(
f"/slurm/{self._api_version_str}/job/submit",
request_options=request_options,
)
# Log submission
job_name = parameters["name"]
job_id = response.job_id
user_msg = response.job_submit_user_msg
if user_msg:
logger.info(
"SLURM job '%s' submitted (Job ID: %s): %s", job_name, job_id, user_msg
)
else:
logger.info("SLURM job '%s' submitted (Job ID: %s)", job_name, job_id)
return job_id
def _submit_ensure_name(self, parameters: dict) -> None:
if not parameters.get("name"):
parameters["name"] = defaults.JOB_NAME
def _submit_ensure_std(self, parameters: dict) -> None:
parameters["standard_input"] = "/dev/null"
if self._log_directory:
os_utils.makedirs(self._log_directory, log_level=logging.DEBUG)
job_name = parameters["name"]
filetemplate = f"{self._log_directory}/{job_name}.%j"
if self._std_split:
parameters["standard_output"] = f"{filetemplate}.out"
parameters["standard_error"] = f"{filetemplate}.err"
else:
parameters["standard_output"] = f"{filetemplate}.outerr"
else:
parameters["standard_output"] = "/dev/null"
parameters["standard_error"] = "/dev/null"
def _submit_ensure_environment(self, parameters: dict) -> None:
if parameters.get("environment"):
return
parameters["environment"] = {"_DUMMY_VAR": "dummy_value"}
def _submit_ensure_wd(self, parameters: dict) -> None:
"""The slurm user does not necessarily have a home directory
and if it has, "~" does not work.
"""
if not parameters.get("current_working_directory"):
parameters["current_working_directory"] = defaults.SLURM_TEMPDIR
[docs]
def cancel_job(self, job_id: int, request_options: Optional[dict] = None) -> None:
request_options = utils.merge_mappings(
request_options, {"path_params": {"job_id": job_id}}
)
_ = self.delete(
f"/slurm/{self._api_version_str}/job/{{job_id}}",
request_options=request_options,
)
[docs]
def clean_job_artifacts(
self,
job_id: int,
job_properties: Optional[BaseModel] = None,
request_options: Optional[dict] = None,
) -> None:
if job_properties is None:
job_properties = self.get_job_properties(
job_id, raise_on_error=False, request_options=request_options
)
if job_properties is None:
return
slurm_logs.remove_slurm_log(job_properties, stderr=False)
slurm_logs.remove_slurm_log(job_properties, stderr=True)
[docs]
@contextmanager
def clean_job_artifacts_context(
self,
job_id: int,
job_properties: Optional[BaseModel] = None,
request_options: Optional[dict] = None,
) -> Generator[None, None, None]:
try:
yield
finally:
self.clean_job_artifacts(
job_id,
job_properties=job_properties,
request_options=request_options,
)
[docs]
def get_job_properties(
self, job_id: int, raise_on_error=True, request_options: Optional[dict] = None
) -> Optional[BaseModel]:
request_options = utils.merge_mappings(
request_options, {"path_params": {"job_id": job_id}}
)
response = self.get(
f"/slurm/{self._api_version_str}/job/{{job_id}}",
request_options=request_options,
raise_on_error=raise_on_error,
)
if response is None:
return
if not response.jobs:
return
return response.jobs[0]
[docs]
def get_all_job_properties(
self,
raise_on_error=True,
request_options: Optional[dict] = None,
filter: Optional[dict] = None,
update_time: Optional[datetime.datetime] = None,
all_users: bool = False,
) -> List[BaseModel]:
# Server-side filter
if update_time:
request_options = utils.merge_mappings(
request_options,
{"query_params": {"update_time": update_time.timestamp()}},
)
response = self.get(
f"/slurm/{self._api_version_str}/jobs",
request_options=request_options,
raise_on_error=raise_on_error,
)
if response is None or not response.jobs:
return list()
# Client-side filter
if not filter:
filter = dict()
if not all_users:
filter.setdefault("user_name", self._auth.user_name)
return [
job_properties
for job_properties in response.jobs
if self._is_part_of(job_properties, filter)
]
def _is_part_of(self, obj: Any, subobj: Any) -> bool:
"""Check that `subobj` is a part of `obj`."""
if isinstance(subobj, dict):
# subobj must be a sub-tree of obj
if not isinstance(obj, (dict, BaseModel)):
return False
for subobj_key, subobj_value in subobj.items():
if isinstance(obj, BaseModel):
if not hasattr(obj, subobj_key):
return False
obj_value = getattr(obj, subobj_key)
else:
if subobj_key not in obj:
return False
obj_value = obj[subobj_key]
if not self._is_part_of(obj_value, subobj_value):
return False
return True
elif isinstance(subobj, list):
# subobj must be equal to obj[:len(subobj)]
if not isinstance(obj, list):
return False
if len(subobj) > len(obj):
return False
for obj_item, subobj_item in zip(obj, subobj):
if not self._is_part_of(obj_item, subobj_item):
return False
return True
else:
return obj == subobj
[docs]
def get_status(self, job_id: int, request_options: Optional[dict] = None) -> str:
job_properties = self.get_job_properties(
job_id, request_options=request_options
)
return slurm_response.slurm_job_state(job_properties)
[docs]
def get_full_status(
self, job_id: int, request_options: Optional[dict] = None
) -> dict:
job_properties = self.get_job_properties(
job_id, request_options=request_options
)
return slurm_response.slurm_job_full_state(job_properties)
[docs]
def is_finished(self, job_id: int) -> bool:
"""Returns `True` when the job is completely finished."""
return self.get_status(job_id) in slurm_response.FINISHED_STATES
[docs]
def is_finishing(self, job_id: int, request_options: Optional[dict] = None) -> bool:
"""Returns `True` when the job is finished by Slurm is still finalizing."""
return self.get_status(job_id) in slurm_response.FINISHING_STATES
[docs]
def wait_finished(self, job_id: int, **kw) -> str:
"""Wait until a job is finished and return the state."""
return self.wait_states(job_id, slurm_response.FINISHED_STATES, **kw)
[docs]
def wait_finishing(self, job_id: int, **kw) -> str:
"""Wait until a job is finished or in the process of finishing and return the state."""
return self.wait_states(job_id, slurm_response.FINISHING_STATES, **kw)
[docs]
def wait_states(
self,
job_id: int,
states: Tuple[str],
progress: bool = False,
timeout: Optional[float] = None,
period: float = 0.5,
) -> str:
"""Wait until a job state is reached and return the state."""
job_state = None
t0 = time.time()
while job_state not in states:
job_state = self.get_status(job_id)
time.sleep(period)
if progress:
print(".", end="", flush=True)
if timeout is not None and (time.time() - t0) > timeout:
raise TimeoutError
status = self.get_full_status(job_id)
logger.info("Job '%s' finished: %s", job_id, pformat(status))
return job_state
[docs]
def print_stdout_stderr(
self, job_id: int, request_options: Optional[dict] = None
) -> None:
job_properties = self.get_job_properties(
job_id, request_options=request_options, raise_on_error=False
)
if job_properties is None:
return
stdout, stderr = slurm_logs.get_stdout_stderr(job_properties)
print()
print("".join(stdout.lines()))
if stderr is not None:
print("".join(stderr.lines()))
[docs]
def log_stdout_stderr(
self,
job_id: int,
logger: Optional[logging.Logger] = None,
level: int = logging.INFO,
request_options: Optional[dict] = None,
**log_options,
) -> None:
job_properties = self.get_job_properties(
job_id, request_options=request_options, raise_on_error=False
)
if job_properties is None:
return
stdout, stderr = slurm_logs.get_stdout_stderr(job_properties)
stdout = "".join(stdout.lines())
if logger is None:
logger = logger
if stderr is None:
logger.log(level, "\n%s", stdout, **log_options)
else:
stderr = "".join(stderr.lines())
logger.log(level, "\n%s\n%s", stdout, stderr, **log_options)
[docs]
@contextmanager
def redirect_stdout_stderr(
self, job_id: int, request_options: Optional[dict] = None
) -> Generator[None, None, None]:
"""Redirect logs files to the local root logger within this context."""
job_properties = self.get_job_properties(
job_id, raise_on_error=False, request_options=request_options
)
if job_properties is None:
yield
return
ofilename = slurm_logs.slurm_log_filename(
job_properties, job_properties.standard_output
)
efilename = slurm_logs.slurm_log_filename(
job_properties, job_properties.standard_error
)
with log_file_monitor_context([ofilename, efilename]):
yield
[docs]
def get(self, path: str, **kwargs) -> Optional[BaseModel]:
return self.request("GET", path, **kwargs)
[docs]
def delete(self, path: str, **kwargs) -> Optional[BaseModel]:
return self.request("DELETE", path, **kwargs)
[docs]
def post(self, path: str, **kwargs) -> Optional[BaseModel]:
return self.request("POST", path, **kwargs)
[docs]
def request(
self,
method: str,
path: str,
request_options: Optional[dict] = None,
raise_on_error: bool = True,
) -> Optional[BaseModel]:
request_options = self._merged_request_options(request_options)
try:
return slurm_access.validated_slurm_request(
method, self._base_url, path, **request_options
)
except Exception as e:
headers = self._auth.retry_headers()
if not headers:
if raise_on_error:
raise
logger.debug(str(e))
return None
request_options["headers"] = headers
try:
return slurm_access.validated_slurm_request(
method, self._base_url, path, **request_options
)
except Exception as e:
if raise_on_error:
raise
logger.debug(str(e))
return None
[docs]
def server_has_api(
self, request_options: Optional[dict] = None
) -> Union[bool, str, List[str]]:
request_options = self._merged_request_options(request_options)
openapi_spec = None
try:
openapi_spec = slurm_access.slurm_request(
"GET", self._base_url, "/openapi", **request_options
)
except Exception:
headers = self._auth.retry_headers()
if not headers:
raise
request_options["headers"] = headers
if openapi_spec is None:
openapi_spec = slurm_access.slurm_request(
"GET", self._base_url, "/openapi", **request_options
)
versions = slurm_response.extract_slurm_api_versions(openapi_spec)
has_api = self._api_version_tuple in versions
versions = [slurm_access.create_version(v) for v in versions]
return has_api, self._api_version_str, versions
def _merged_request_options(
self,
request_options: Optional[dict] = None,
) -> Dict[str, Any]:
# self._auth > request_options > self._request_options
headers = {"headers": self._auth.headers()}
request_options = utils.merge_mappings(request_options, headers)
request_options = utils.merge_mappings(self._request_options, request_options)
return request_options
def _parse_slurm_rest_url(slurm_url: str) -> str:
"""
:raises SlurmInvalidUrlError:
"""
parsed = urlparse(slurm_url)
if not parsed.scheme:
raise SlurmInvalidUrlError(
f"Empty SLURM REST URL scheme: {parsed.scheme}. Expected 'http' or 'https'."
)
if not parsed.hostname:
raise SlurmInvalidUrlError("Missing hostname.")
if parsed.port is None:
if parsed.scheme == "https":
default_port = 443
elif parsed.scheme == "http":
default_port = 6820
else:
default_port = None
if default_port:
parsed = url_utils.set_url_port(parsed, default_port)
return urlunparse(parsed)