Source code for testplan.runners.pools.connection

"""Connections module."""

import abc
import queue
import time
import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import zmq

from testplan.common import entity
from testplan.common.serialization import deserialize, serialize
from testplan.common.utils import logger
from testplan.runners.pools.communication import Message

if TYPE_CHECKING:
    from testplan.runners.pools.base import Worker


[docs]class Client(logger.Loggable, metaclass=abc.ABCMeta): """ Workers are Client in Pool/Worker communication. Abstract base class for workers to communicate with its pool.""" def __init__(self) -> None: super(Client, self).__init__() self.active = False
[docs] @abc.abstractmethod def connect(self, server) -> None: """Connect client to server""" self.active = True
[docs] @abc.abstractmethod def disconnect(self) -> None: """Disconnect client from server""" self.active = False
[docs] @abc.abstractmethod def send(self, message: Message) -> None: """ Sends a message to server. :param message: Message to be sent. """ pass
[docs] @abc.abstractmethod def receive(self) -> Optional[Message]: """Receives response to the message sent""" pass
[docs] def send_and_receive( self, message: Message, expect: Union[None, Tuple, List, Message] = None, ) -> Optional[Message]: """ Send and receive shortcut. Optionally assert that the response is of the type expected. I.e For a TaskSending message, an Ack is expected. :param message: Message sent. :param expect: Expected command of message received. :return: Message received. """ if not self.active: return None try: self.send(message) except Exception as exc: self.logger.exception("Exception on transport send: %s.", exc) raise RuntimeError(f"On transport send - {exc}.") try: received = self.receive() except Exception as exc: self.logger.exception("Exception on transport receive: %s.", exc) raise RuntimeError(f"On transport receive - {exc}.") if expect is not None: if received is None: raise RuntimeError( f"Received None when {expect} was expected." ) if isinstance(expect, (tuple, list)): assert ( received.cmd in expect ), f"{received.cmd} not in {expect}" else: assert received.cmd == expect, f"{received.cmd} != {expect}" return received
[docs]class QueueClient(Client): """ Queue based client implementation, for thread pool workers to communicate with its pool. """ def __init__(self, recv_sleep: float = 0.05) -> None: super(QueueClient, self).__init__() self._recv_sleep = recv_sleep self.requests: Optional[queue.Queue] = None # single-producer(pool) single-consumer(worker) FIFO queue self.responses = []
[docs] def connect(self, requests: queue.Queue) -> None: """ Connect to the request queue of Pool :param requests: request queue of pool that worker should write to. :type requests: Queue """ self.requests = requests self.active = True
[docs] def disconnect(self) -> None: """Disconnect worker from pool""" self.active = False self.requests = None
[docs] def send(self, message: Message) -> None: """ Worker sends a message :param message: Message to be sent. """ if self.active: self.requests.put(message)
[docs] def receive(self) -> Message: """ Worker receives response to the message sent, this method blocks. :return: Response to the message sent. """ while self.active: try: return self.responses.pop() except IndexError: time.sleep(self._recv_sleep)
[docs] def respond(self, message: Message) -> None: """ Used by :py:class:`~testplan.runners.pools.base.Pool` to respond to worker request. :param message: Respond message. """ if self.active: self.responses.append(message) else: raise RuntimeError("Responding to inactive worker")
[docs]class ZMQClient(Client): """ ZMQ based client implementation for process worker to communicate with its pool. :param address: Pool server address to connect to. :param recv_sleep: Sleep duration in msg receive loop. """ def __init__( self, address: str, recv_sleep: float = 0.05, recv_timeout: float = 5, ) -> None: super(ZMQClient, self).__init__() self._address = address self._recv_sleep = recv_sleep self._recv_timeout = recv_timeout self._context = None self._sock = None self.connect() # auto connect
[docs] def connect(self) -> None: """Connect to a ZMQ Server""" # pylint: disable=abstract-class-instantiated self._context = zmq.Context() self._sock = self._context.socket(zmq.REQ) self._sock.connect("tcp://{}".format(self._address)) self.active = True
[docs] def disconnect(self) -> None: """Disconnect from Server""" self.active = False self._sock.close() self._sock = None self._context.destroy() self._context = None self._address = None
[docs] def send(self, message: Message) -> None: """ Worker sends a message. :param message: Message to be sent. """ if self.active: self._sock.send(serialize(message))
[docs] def receive(self) -> Optional[Message]: """ Worker tries to receive the response to the message sent until timeout. :return: Response to the message sent. """ start_time = time.time() while self.active: try: received = self._sock.recv(flags=zmq.NOBLOCK) try: loaded = deserialize(received) except Exception as exc: print(f"Deserialization error. - {exc}") raise else: return loaded except zmq.Again: if time.time() - start_time > self._recv_timeout: print( f"Transport receive timeout {self._recv_timeout}s" f" reached!" ) return None time.sleep(self._recv_sleep) return None
[docs]class ZMQClientProxy: """ Representative of a process worker's transport in local worker object. """ def __init__(self) -> None: self.active = False self.connection = None self.address = None
[docs] def connect(self, server) -> None: self.connection = server.sock self.address = server.address self.active = True
[docs] def disconnect(self) -> None: self.active = False self.connection = None self.address = None
[docs] def respond(self, message: Message) -> None: """ Used by :py:class:`~testplan.runners.pools.base.Pool` to respond to worker request. :param message: Respond message. """ if self.active: self.connection.send(serialize(message)) else: raise RuntimeError("Responding to inactive worker")
[docs]class Server(entity.Resource, metaclass=abc.ABCMeta): """ Abstract base class for pools to communicate to its workers. """ def __init__(self) -> None: super(Server, self).__init__()
[docs] def starting(self) -> None: """Server starting logic.""" self.status.change(self.status.STARTED) # Start is async
[docs] def stopping(self) -> None: """Server stopping logic.""" self.status.change(self.status.STOPPED) # Stop is async
[docs] def aborting(self) -> None: """Abort policy - no abort actions are required in the base class.""" pass
[docs] @abc.abstractmethod def register(self, worker: "Worker") -> None: """ Register a new worker. Workers should be registered after the connection manager is started and will be automatically unregistered when it is stopped. """ if self.status != self.status.STARTED: raise RuntimeError( "Can only register workers when started." f" Current state is {self.status.tag}." )
[docs] @abc.abstractmethod def accept(self) -> Optional[Message]: """ Accepts a new message from worker. This method should not block - if no message is queued for receiving it should return None. :return: Message received from worker transport, or None. """ pass
[docs]class QueueServer(Server): """ Queue based server implementation, for thread pool to get requests from workers. """ def __init__(self) -> None: super(QueueServer, self).__init__() # multi-producer(workers) single-consumer(pool) FIFO queue self.requests = None
[docs] def starting(self) -> None: self.requests = queue.Queue() super(QueueServer, self).starting()
[docs] def register(self, worker) -> None: super(QueueServer, self).register(worker) worker.transport.connect(self.requests)
[docs] def accept(self) -> Optional[Message]: """ Accepts the next request in the request queue. :return: Message received from worker transport, or None. """ try: return self.requests.get_nowait() except queue.Empty: return None
[docs]class ZMQServer(Server): """ ZMQ based server implementation, for process/remote/treadmill pool to get request from workers. """ def __init__(self) -> None: super(ZMQServer, self).__init__() # Here, context is a factory class provided by ZMQ that creates # sockets. Context and other attributes below are set when starting # and cleaned up when stopping. self._zmq_context = None self._sock = None self._address = None @property def sock(self): return self._sock @property def address(self): return self._address
[docs] def starting(self): """Create a ZMQ context and socket to handle TCP communication.""" if self.parent is None: raise RuntimeError("Parent pool was not set - cannot start.") # pylint: disable=abstract-class-instantiated self._zmq_context = zmq.Context() self._sock = self._zmq_context.socket(zmq.REP) if self.parent.cfg.port == 0: port_selected = self._sock.bind_to_random_port( "tcp://{}".format(self.parent.cfg.host) ) else: self._sock.bind( "tcp://{}:{}".format( self.parent.cfg.host, self.parent.cfg.port ) ) port_selected = self.parent.cfg.port self._address = "{}:{}".format(self.parent.cfg.host, port_selected) super(ZMQServer, self).starting()
def _close(self) -> None: """Closes TCP connections managed by this object..""" self.logger.debug("Closing TCP connections for %s", self.parent) if self._sock is not None: self._sock.close() self._sock = None if self._zmq_context is not None: self._zmq_context.destroy() self._zmq_context = None self._address = None
[docs] def stopping(self) -> None: """ Terminate the ZMQ context and socket when stopping. We require that all workers are stopped before stopping the connection manager, so that we can safely remove references to connection sockets from the worker. """ self._close() super(ZMQServer, self).stopping()
[docs] def aborting(self) -> None: """Terminate the ZMQ context and socket when aborting.""" if self._sock is not None: self._close() super(ZMQServer, self).aborting()
[docs] def register(self, worker) -> None: """Register a new worker.""" super(ZMQServer, self).register(worker) worker.transport.connect(self)
[docs] def accept(self) -> Optional[Message]: """ Accepts a new message from worker. Doesn't block if no message is queued for receiving. :return: Message received from worker transport, or None. """ try: return deserialize(self._sock.recv(flags=zmq.NOBLOCK)) except zmq.Again: return None
def __del__(self) -> None: """ Check that ZMQ sockets are properly closed when this manager is garbage-collected. If not we close them now as a fallback. """ # Use getattr() with a default here - there is no guarantee that # __init__() has completed successfully when __del__() is called. if (getattr(self, "_sock", None) is not None) or ( getattr(self, "_zmq_context", None) is not None ): warnings.warn("Pool TCP connections were not closed.") self._close()