Source code for pyslurmutils.client.rest.api.slurm_auth
import datetime
import getpass
import logging
import os
import shlex
import subprocess
from typing import List
from typing import Optional
from urllib.parse import ParseResult
from urllib.parse import urlparse
import jwt
from ... import url_utils
from ...errors import SlurmInvalidUrlError
from ...errors import SlurmMissingParameterError
from ...errors import SlurmTokenInvalidError
from ...errors import SlurmTokenRenewalError
logger = logging.getLogger(__name__)
[docs]
class SlurmAuthManager:
def __init__(
self,
user_name: Optional[str] = None,
token: Optional[str] = None,
renewal_url: Optional[str] = None,
renewal_margin_seconds: int = 60,
renewal_lifespan: int = 3600,
renewal_timeout: int = 6,
use_env_credentials: bool = True,
export_slurm_env_vars: bool = True,
):
"""
Initializes the SLURM authentication manager.
:param user_name: SLURM username (fallback to SLURM_USER or system user)
:param token: SLURM JWT token (fallback to SLURM_TOKEN env)
:param renewal_url: Url for SLURM JWT token renewal (fallback to SLURM_RENEWAL_URL env)
:param renewal_margin_seconds: Time before expiry to trigger renewal
:param renewal_lifespan: Requested lifespan in seconds for a new token
:param renewal_timeout: Token renewal timeout in seconds
:param use_env_credentials: use local environment variables ``SLURM_URL``, ``SLURM_USER``, ``SLURM_TOKEN`` for authentication.
:param export_slurm_env_vars: export local ``SLURM_ENV_*`` environment variables to the job environment, stripping the prefix (e.g. ``SLURM_ENV_FOO`` -> ``FOO``).
:raises SlurmTokenInvalidError:
:raises SlurmInvalidUrlError:
"""
self._use_env_credentials = use_env_credentials
if not user_name and use_env_credentials:
user_name = os.environ.get("SLURM_USER")
if not user_name:
user_name = getpass.getuser()
self._user_name = user_name.strip()
if not token and use_env_credentials:
token = os.environ.get("SLURM_TOKEN")
self._token = token
if not renewal_url and use_env_credentials:
renewal_url = os.environ.get("SLURM_RENEWAL_URL")
if renewal_url:
self._renewal_url = _parse_renewal_url(renewal_url)
else:
self._renewal_url = None
self._renewal_lifespan = renewal_lifespan
self._renewal_margin_seconds = datetime.timedelta(
seconds=renewal_margin_seconds
)
self._renewal_timeout = renewal_timeout
@property
def user_name(self) -> str:
return self._user_name
@property
def _token(self) -> Optional[str]:
return self.__token
@_token.setter
def _token(self, value: Optional[str]) -> None:
"""
:raises SlurmTokenInvalidError:
"""
if value is not None and not isinstance(value, str):
raise TypeError(f"Token must be str or None, got {type(value).__name__}")
if value:
value = _parse_token(value)
self.__token = value
if self._use_env_credentials:
os.environ["SLURM_TOKEN"] = value
else:
self.__token = None
if self._use_env_credentials:
_ = os.environ.pop("SLURM_TOKEN", None)
def _token_renewal_required(self, use_offset: bool = False) -> bool:
"""
Determines whether the current token should be renewed.
:param use_offset: Use configured renewal_margin_seconds for early renewal
:return: True if token should be renewed
"""
if not self._token:
return True
try:
payload = jwt.decode(self._token, options={"verify_signature": False})
except jwt.DecodeError:
return True
exp = payload.get("exp")
if not exp:
return True
expiry_time = datetime.datetime.fromtimestamp(exp).astimezone(
datetime.timezone.utc
)
margin = self._renewal_margin_seconds if use_offset else datetime.timedelta(0)
time_left = expiry_time - datetime.datetime.now(datetime.timezone.utc)
expired = time_left <= margin
if expired:
logger.warning("SLURM token expired on %s", expiry_time)
return expired
def _create_new_token(self) -> str:
"""
Attempts to renew the SLURM token (local first, then remote).
:return: New token as a string
:raises SlurmTokenRenewalError:
:raises SlurmTokenInvalidError:
:raises SlurmMissingParameterError:
"""
new_token = self._create_new_token_local()
if new_token:
return new_token
return self._create_new_token_ssh()
def _create_new_token_local(self) -> Optional[str]:
"""
Attempts to renew the token using local `scontrol`.
:return: New token if successful, else None
:raises SlurmTokenInvalidError:
"""
cmd = self._renewal_command()
return self._execute_renewal(cmd)
def _create_new_token_ssh(self) -> str:
"""
Attempts to renew the token via SSH to the configured renewal host.
:return: New token if successful
:raises SlurmTokenRenewalError:
:raises SlurmTokenInvalidError:
:raises SlurmMissingParameterError:
"""
if not self._renewal_url:
self._raise_on_missing_token()
raise SlurmMissingParameterError(
"SLURM token renewal failed: 'renewal_url' most be provided or set environment variable SLURM_RENEWAL_URL."
)
cmd = self._renewal_command()
user = self._renewal_url.username or self._user_name
destination = f"{user}@{self._renewal_url.hostname}"
ssh_cmd = ["ssh", destination, "-p", str(self._renewal_url.port), " ".join(cmd)]
return self._execute_renewal(
ssh_cmd, timeout=self._renewal_timeout, raise_on_error=True
)
def _renewal_command(self) -> List[str]:
"""
Constructs the `scontrol token` command.
:return: List of command components
"""
return [
"scontrol",
"token",
f"username={self._user_name}",
f"lifespan={self._renewal_lifespan}",
]
def _execute_renewal(
self, cmd: List[str], raise_on_error: bool = False, **kwargs
) -> Optional[str]:
"""
Executes a command to retrieve a new SLURM token.
:param cmd: Command as list of strings
:param raise_on_error: Raise when failed
:param kwargs: Additional subprocess parameters (e.g., timeout)
:return: Parsed token if successful, else None
:raises SlurmTokenInvalidError: If the renewal failed
"""
s_cmd = shlex.join(cmd)
try:
output = subprocess.check_output(
cmd, stderr=subprocess.STDOUT, universal_newlines=True, **kwargs
)
if output:
logger.info("Successfully renewed SLURM token via: %s", s_cmd)
return output
except Exception as ex:
if raise_on_error:
self._raise_on_missing_token()
raise SlurmTokenRenewalError(
f"SLURM token renewal failed: {s_cmd}"
) from ex
return None
def _raise_on_missing_token(self):
"""
:raises SlurmMissingParameterError:
"""
if not self._token:
raise SlurmMissingParameterError(
"SLURM authentication failed: 'token' must be provided or set environment variable SLURM_TOKEN or SLURM_RENEWAL_URL."
)
def _parse_token(token: str) -> str:
"""
Cleans and validates a SLURM JWT token.
:param token: Raw token string
:return: Cleaned token string
:raises SlurmTokenInvalidError:
"""
result = token.strip().split("\n")[-1]
result = result.replace("SLURM_JWT=", "")
if not result:
raise SlurmTokenInvalidError("Slurm token is empty.")
try:
jwt.decode(result, options={"verify_signature": False})
except jwt.DecodeError:
raise SlurmTokenInvalidError("SLURM token is invalid.") from None
return result
def _parse_renewal_url(renewal_url: str, default_port: int = 22) -> ParseResult:
"""
:raises SlurmInvalidUrlError:
"""
parsed = urlparse(renewal_url)
if parsed.scheme != "ssh":
raise SlurmInvalidUrlError(
f"Invalid SLURM renewal URL scheme: {parsed.scheme}. Expected 'ssh'."
)
if not parsed.hostname:
raise SlurmInvalidUrlError("Missing hostname.")
if parsed.port is None:
return url_utils.set_url_port(parsed, default_port)
return parsed