"""ZMQServer Driver."""
from typing import Any, Dict, Optional
from schema import Use, Or
import zmq
from testplan.common.config import ConfigOption
from testplan.common.utils.documentation_helper import emphasized
from testplan.common.utils.timing import retry_until_timeout
from ..base import (
Driver,
DriverConfig,
)
from ..connection import (
Direction,
Protocol,
ConnectionExtractor,
)
[docs]
class ZMQServerConfig(DriverConfig):
"""
Configuration object for
:py:class:`~testplan.testing.multitest.driver.zmq.server.ZMQServer` driver.
"""
[docs]
@classmethod
def get_options(cls) -> Dict[Any, Any]:
"""
Schema for options validation and assignment of default values.
"""
return {
ConfigOption("host", default="localhost"): str,
ConfigOption("port", default=0): Use(int),
ConfigOption("message_pattern", default=zmq.PAIR): Or(
zmq.PAIR, zmq.REP, zmq.PUB, zmq.PUSH
),
}
[docs]
class ZMQServer(Driver):
"""
The ZMQServer can receive multiple connections from different ZMQClients.
The socket can be of type:
* zmq.PAIR
* zmq.REP
* zmq.PUB
* zmq.PUSH
{emphasized_members_docs}
:param name: Name of ZMQServer.
:type name: ``str``
:param host: Host name to bind to. Default: 'localhost'
:type host: ``str``
:param port: Port number to bind to. Default: 0 (Random port)
:type port: ``int``
:param message_pattern: Message pattern. Default: ``zmq.PAIR``
:type message_pattern: ``int``
"""
CONFIG = ZMQServerConfig
EXTRACTORS = [ConnectionExtractor(Protocol.TCP, Direction.LISTENING)]
def __init__(
self,
name: str,
host: str = "localhost",
port: int = 0,
message_pattern: int = zmq.PAIR,
**options: Any,
) -> None:
options.update(self.filter_locals(locals()))
super(ZMQServer, self).__init__(**options)
self._host: Optional[str] = None
self._port: Optional[int] = None
self._zmq_context: Optional[zmq.Context[Any]] = None
self._socket: Optional[zmq.Socket[Any]] = None
@emphasized # type: ignore[prop-decorator]
@property
def host(self) -> Optional[str]:
"""Target host name."""
return self._host
@emphasized # type: ignore[prop-decorator]
@property
def port(self) -> Optional[int]:
"""Port number assigned."""
return self._port
@property
def socket(self) -> Optional[zmq.Socket[Any]]:
"""
Returns the underlying ``zmq.sugar.socket.Socket`` object.
"""
return self._socket
@property
def connection_identifier(self) -> Optional[int]:
return self.port
@property
def local_port(self) -> Optional[int]:
return self.port
@property
def local_host(self) -> Optional[str]:
return self.host
[docs]
def send(self, data: Any, timeout: int = 30) -> Any:
"""
Try to send the message until it either sends or hits timeout.
:param timeout: Timeout to retry sending the message
:type timeout: ``int``
"""
if self._socket is None:
raise RuntimeError("self._socket must not be None")
return retry_until_timeout(
exception=zmq.ZMQError,
item=self._socket.send,
kwargs={"data": data, "flags": zmq.NOBLOCK},
timeout=timeout,
raise_on_timeout=True,
)
[docs]
def receive(self, timeout: int = 30) -> Any:
"""
Try to send the message until it either has been received or
hits timeout.
:param timeout: Timeout to retry receiving the message
:type timeout: ``int``
:return: The received message
:rtype: ``object`` or ``str`` or ``zmq.sugar.frame.Frame``
"""
if self._socket is None:
raise RuntimeError("self._socket must not be None")
return retry_until_timeout(
exception=zmq.ZMQError,
item=self._socket.recv,
kwargs={"flags": zmq.NOBLOCK},
timeout=timeout,
raise_on_timeout=True,
)
[docs]
def starting(self) -> None:
"""
Start the ZMQServer.
"""
super(ZMQServer, self).starting()
# pylint: disable=abstract-class-instantiated
self._zmq_context = zmq.Context()
self._socket = self._zmq_context.socket(self.cfg.message_pattern)
if self.cfg.port == 0:
port = self._socket.bind_to_random_port(
"tcp://{host}".format(host=self.cfg.host)
)
else:
self._socket.bind(
"tcp://{host}:{port}".format(
host=self.cfg.host, port=self.cfg.port
)
)
port = self.cfg.port
self._host = self.cfg.host
self._port = port
[docs]
def stopping(self) -> None:
"""
Stop the ZMQServer.
"""
super(ZMQServer, self).stopping()
if self._socket is not None and not self._socket.closed:
self._socket.close()
if self._zmq_context is not None and not self._zmq_context.closed:
self._zmq_context.term()
[docs]
def aborting(self) -> None:
"""Abort logic that stops the server."""
if self._socket is not None and not self._socket.closed:
self._socket.close()
if self._zmq_context is not None and not self._zmq_context.closed:
self._zmq_context.term()