Source code for pyslurmutils.tests.test_remote_executor
import threading
import time
from concurrent.futures import as_completed
from ..client.job_io.local._executor import RemoteExecutor
[docs]
def test_remote_executor_nomax():
# All tasks are executed sequentially (time = ntasks * sleep)
with RemoteExecutor(conservative_scheduling=True) as executor:
nworkers = _assert_executor(executor, ntasks=10)
assert nworkers == 1
# All tasks are executed in parallel (time = sleep)
with RemoteExecutor() as executor:
nworkers = _assert_executor(executor, ntasks=10)
assert nworkers == 10
# Sequential blocks of `max_tasks_per_worker tasks` (time = max_tasks_per_worker * sleep)
with RemoteExecutor(
max_tasks_per_worker=3, conservative_scheduling=True
) as executor:
nworkers = _assert_executor(executor, ntasks=10)
assert nworkers == 4
# All tasks are executed in parallel (time = sleep)
with RemoteExecutor(max_tasks_per_worker=1) as executor:
nworkers = _assert_executor(executor, ntasks=10)
assert nworkers == 10
[docs]
def test_remote_executor_max_workers():
# Tasks executed in parallel blocks of `max_workers` (time = ceil(ntasks/max_workers) * sleep)
with RemoteExecutor(max_workers=3) as executor:
nworkers = _assert_executor(executor, ntasks=10)
assert nworkers == 3
# Tasks executed in parallel blocks of `max_workers` (time = ceil(ntasks/max_workers) * sleep)
with RemoteExecutor(max_workers=3, max_tasks_per_worker=2) as executor:
nworkers = _assert_executor(executor, ntasks=10)
assert nworkers > 3
# Tasks executed in parallel blocks of `max_workers` (time = ceil(ntasks/max_workers) * sleep)
with RemoteExecutor(
max_workers=3, max_tasks_per_worker=2, lazy_scheduling=False
) as executor:
nworkers = _assert_executor(executor, ntasks=10)
assert nworkers > 3
def _assert_executor(executor, ntasks) -> int:
futures_list = []
data = list(range(ntasks))
for i in data:
future = executor.submit(_example_task, i)
futures_list.append(future)
results = list()
thread_ids = set()
for future in as_completed(futures_list):
i, thread_id = future.result()
results.append(i)
thread_ids.add(thread_id)
assert data == sorted(results)
return len(thread_ids)
def _example_task(i):
time.sleep(0.3)
return i, id(threading.current_thread())