import os
import threading
from contextlib import contextmanager
from typing import Generator
from typing import List
import pytest
from ..client.job_io.local import Connection
from ..client.job_io.local import FileConnection
from ..client.job_io.local import LocalWorkerProxy
from ..client.job_io.local import RemoteWorkerProxy
from ..client.job_io.local import TcpConnection
from ..client.job_io.remote import file_main
from ..client.job_io.remote import tcp_main
[docs]
def test_single_task(communication):
connection, _ = communication
if connection is None:
proxy_ctx = LocalWorkerProxy()
else:
proxy_ctx = RemoteWorkerProxy(connection)
with proxy_ctx as proxy:
assert proxy.execute(sum, ([1, 2],), None) == 3
[docs]
def test_single_task_exception(communication):
connection, _ = communication
if connection is None:
proxy_ctx = LocalWorkerProxy()
else:
proxy_ctx = RemoteWorkerProxy(connection)
with proxy_ctx as proxy:
with pytest.raises(TypeError):
proxy.execute(sum, ([1, "2"],), None)
[docs]
def test_max_tasks(communication):
connection, _ = communication
if connection is None:
proxy_ctx = LocalWorkerProxy(max_tasks=2)
else:
proxy_ctx = RemoteWorkerProxy(connection, max_tasks=2)
with proxy_ctx as proxy:
assert proxy.execute(sum, ([1, 2],), None) == 3
assert proxy.execute(sum, ([3, 4],), None) == 7
[docs]
def test_max_tasks_exception(communication):
connection, _ = communication
if connection is None:
proxy_ctx = LocalWorkerProxy(max_tasks=2)
else:
proxy_ctx = RemoteWorkerProxy(connection, max_tasks=2)
with proxy_ctx as proxy:
assert proxy.execute(sum, ([1, 2],), None) == 3
with pytest.raises(TypeError):
proxy.execute(sum, ([1, "2"],), None)
with pytest.raises(RuntimeError, match="cannot send data after stopped"):
proxy.execute(sum, ([3, 4],), None)
[docs]
def test_no_max_tasks(communication):
connection, remote = communication
if connection is None:
pytest.skip("test needs a remote")
proxy_ctx = RemoteWorkerProxy(connection, max_tasks=None)
with proxy_ctx as proxy:
assert proxy.execute(sum, ([1, 2],), None) == 3
remote.join(timeout=1)
assert remote.is_alive()
[docs]
def test_initializer(communication):
global _INITIALIZED
_INITIALIZED = None
connection, _ = communication
if connection is None:
proxy_ctx = LocalWorkerProxy(initializer=_initializer, initargs=([10],))
else:
proxy_ctx = RemoteWorkerProxy(
connection, initializer=_initializer, initargs=([10],)
)
with proxy_ctx as proxy:
assert proxy.execute(_sum_to_initialized, ([1, 2],), None) == 13
def _initializer(value: List[int]) -> None:
global _INITIALIZED
try:
if isinstance(_INITIALIZED, Exception):
raise _INITIALIZED
except NameError:
pass
_INITIALIZED = value
def _sum_to_initialized(value: List[int]) -> int:
global _INITIALIZED # noqa F824
return sum(value + _INITIALIZED)
[docs]
def test_initializer_exception(communication):
global _INITIALIZED
_INITIALIZED = RuntimeError("intentional for testing")
connection, _ = communication
if connection is None:
proxy_ctx = LocalWorkerProxy(initializer=_failing_initializer)
else:
proxy_ctx = RemoteWorkerProxy(connection, initializer=_failing_initializer)
with proxy_ctx as proxy:
with pytest.raises(RuntimeError, match="intentional for testing"):
proxy.initialize()
with pytest.raises(RuntimeError, match="cannot send data after stopped"):
proxy.execute(sum, ([1, 2],), None)
def _failing_initializer() -> None:
raise RuntimeError("intentional for testing")
[docs]
@pytest.fixture(params=["tcp", "file", "local"])
def communication(request, tmp_path):
if request.param == "tcp":
with _tcp_connection() as conn:
with _remote_tcp_env(conn):
with _remote_job(tcp_main) as remote:
yield conn, remote
elif request.param == "file":
with _file_connection(tmp_path) as conn:
with remote_file_env(conn):
with _remote_job(file_main) as remote:
yield conn, remote
else:
yield None, None
@contextmanager
def _tcp_connection() -> Generator[Connection, None, None]:
"""Start the client-side TCP-based connection."""
with TcpConnection() as conn:
yield conn
@contextmanager
def _file_connection(tmp_path) -> Generator[Connection, None, None]:
"""Start the client-side file-based connection."""
with FileConnection(str(tmp_path), "test") as conn:
yield conn
@contextmanager
def _remote_tcp_env(connection: TcpConnection) -> Generator[None, None, None]:
"""Setup environment for the remote job."""
os.environ["_PYSLURMUTILS_HOST"] = connection.host
os.environ["_PYSLURMUTILS_PORT"] = str(connection.port)
try:
yield
finally:
del os.environ["_PYSLURMUTILS_HOST"]
del os.environ["_PYSLURMUTILS_PORT"]
[docs]
@contextmanager
def remote_file_env(connection: TcpConnection) -> Generator[None, None, None]:
"""Setup environment for the remote job."""
os.environ["_PYSLURMUTILS_INFILE"] = connection.input_filename
os.environ["_PYSLURMUTILS_OUTFILE"] = connection.output_filename
try:
yield
finally:
del os.environ["_PYSLURMUTILS_INFILE"]
del os.environ["_PYSLURMUTILS_OUTFILE"]
@contextmanager
def _remote_job(remote_main: callable) -> Generator[threading.Thread, None, None]:
"""Run the remote main function in a local thread."""
thread = threading.Thread(target=remote_main, daemon=True)
thread.start()
try:
yield thread
finally:
thread.join(timeout=60)
assert not thread.is_alive()