Source code for pyslurmutils.tests.test_rest_executor
import os
import sys
import time
import pytest
from ..client import os_utils
from ..client.errors import RemoteExit
from ..concurrent.futures import SlurmRestExecutor
[docs]
@pytest.fixture(params=["tcp", "file"])
def slurm_executor_kwargs(
request, slurm_data_directory, slurm_client_kwargs, slurm_parameters
):
if request.param == "tcp":
slurm_data_directory = None
elif not slurm_parameters["mock"]:
request.node.add_marker(
pytest.mark.xfail(reason="Distributed filesystems I/O is not reliable")
)
return {"data_directory": slurm_data_directory, **slurm_client_kwargs}
[docs]
def test_rest_executor_submit(slurm_executor_kwargs):
with SlurmRestExecutor(**slurm_executor_kwargs, max_workers=8) as slurm_executor:
future1 = slurm_executor.submit(time.sleep, 0)
future2 = slurm_executor.submit(sum, [1, 1])
future3 = slurm_executor.submit(sum, [1, "a"])
assert future1.result(timeout=60) is None
assert future2.result(timeout=60) == 2
with pytest.raises(TypeError):
future3.result(timeout=60)
[docs]
def test_rest_executor_map(slurm_executor_kwargs):
with SlurmRestExecutor(**slurm_executor_kwargs, max_workers=8) as slurm_executor:
results = [
result
for result in slurm_executor.map(sum, [[1, 1], [2, 2], [3, 3]], timeout=60)
]
assert results == [2, 4, 6], str(results)
[docs]
def test_rest_executor_initializer(slurm_executor_kwargs):
with SlurmRestExecutor(
**slurm_executor_kwargs,
initializer=_initializer,
max_workers=1,
max_tasks_per_worker=5,
) as slurm_executor:
ftls = [slurm_executor.submit(_increment_global_value) for _ in range(5)]
results = {future.result(timeout=60) for future in ftls}
assert results == set(range(1, 6)), str(results)
def _initializer():
global GLOBAL_VALUE
GLOBAL_VALUE = 0
def _increment_global_value():
global GLOBAL_VALUE
GLOBAL_VALUE += 1
return GLOBAL_VALUE
[docs]
def test_rest_executor_max_tasks_per_worker(slurm_executor_kwargs):
with SlurmRestExecutor(
**slurm_executor_kwargs, max_workers=8, max_tasks_per_worker=2
) as slurm_executor:
# Note: the sleep time is needed to ensure jobs don't finish
# before the submit for-loop is finished. In production this
# does not matter but here we want to test `max_tasks_per_worker`.
# Each worker executes one job.
ftls = [slurm_executor.submit(_job_ident) for _ in range(8)]
job_idents = {future.result(timeout=60) for future in ftls}
assert len(job_idents) == 8, str(len(job_idents))
# Each worker executes another job.
ftls = [slurm_executor.submit(_job_ident) for _ in range(8)]
job_idents |= {future.result(timeout=60) for future in ftls}
assert len(job_idents) == 8, str(len(job_idents))
# Each worker needs to be restarted because it has reached
# the `max_tasks_per_worker` limit.
ftls = [slurm_executor.submit(_job_ident) for _ in range(8)]
job_idents |= {future.result(timeout=60) for future in ftls}
assert len(job_idents) == 16, str(len(job_idents))
def _job_ident():
time.sleep(0.3)
return os.environ["SLURM_JOB_ID"]
[docs]
@pytest.mark.parametrize("abort_delay", [0, 0.1, 1])
@pytest.mark.skipif(
sys.platform == "win32", reason="Signal propagation not reliable on Windows"
)
def test_rest_executor_cancel(slurm_executor_kwargs, abort_delay):
with SlurmRestExecutor(**slurm_executor_kwargs, max_workers=8) as slurm_executor:
future = slurm_executor.submit(time.sleep, 10)
if abort_delay:
time.sleep(abort_delay)
future.abort()
with pytest.raises(RemoteExit, match=r"SLURM job \d+ CANCELLED"):
future.result(timeout=60)
[docs]
def test_slurm_tmp_path(slurm_executor_kwargs, slurm_tmp_path):
print(slurm_tmp_path)
filename = slurm_tmp_path / "test.txt"
with SlurmRestExecutor(**slurm_executor_kwargs) as slurm_executor:
future = slurm_executor.submit(_touch, str(filename))
_ = future.result(timeout=60)
os_utils.nfs_cache_refresh(slurm_tmp_path)
assert filename.exists()
def _touch(filename):
with open(filename, "w"):
pass