"""Module containing SSH client functionality for remote operations."""
import os
import time
import shlex
import logging
import paramiko
import getpass
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from testplan.common.utils.timing import retry_until_timeout
from testplan.common.utils.logger import TESTPLAN_LOGGER
logger = logging.getLogger(__name__)
DEFAULT_PARAMIKO_CONFIG = {
"username": getpass.getuser(),
}
[docs]
class SSHClient:
"""
SSH client for remote operations including file transfers and command execution.
Wraps paramiko functionality in a convenient interface.
"""
def __init__(
self,
host: str,
port: int = 22,
logger: Optional[logging.Logger] = None,
**args: Any,
) -> None:
"""
Initialize the SSH client.
:param host: Host to connect to
:type host: ``str``
:param port: Port to connect to
:type port: ``int``
:param user: Username to connect with
:type user: ``str``
:param password: Password to use for authentication
:type password: ``str`` or ``NoneType``
:param key_path: Path to private key file to use for authentication
:type key_path: ``str`` or ``NoneType``
"""
self.host = host
self.port = port
self.logger = logger or TESTPLAN_LOGGER
self.paramiko_args = {
**DEFAULT_PARAMIKO_CONFIG,
"hostname": self.host,
"port": self.port,
**args,
}
self._ssh_client: Optional[paramiko.SSHClient] = None
self._sftp_client: Optional[paramiko.SFTPClient] = None
@property
def ssh_client(self) -> paramiko.SSHClient:
"""
Get the underlying paramiko SSH client.
:return: Paramiko SSH client instance
:rtype: ``paramiko.SSHClient``
"""
if not self._ssh_client:
self.connect()
return self._ssh_client # type: ignore[return-value]
@property
def sftp_client(self) -> paramiko.SFTPClient:
"""
Get the underlying paramiko SFTP client.
:return: Paramiko SFTP client instance
:rtype: ``paramiko.sftp_client.SFTPClient``
"""
if not self._sftp_client:
self.open_sftp()
return self._sftp_client # type: ignore[return-value]
[docs]
def connect(self) -> paramiko.SSHClient:
"""
Establish an SSH connection.
:return: Self for method chaining
:rtype: ``SSHClient``
"""
if self._ssh_client is not None:
self.logger.warning("SSH connection already established")
return self._ssh_client
ssh_client = paramiko.SSHClient()
ssh_client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy())
ssh_client.connect(**self.paramiko_args)
self.logger.debug(
"Connected to %s@%s:%s",
self.paramiko_args["username"],
self.paramiko_args["hostname"],
self.paramiko_args["port"],
)
self._ssh_client = ssh_client
return self._ssh_client
[docs]
def exec_command(
self,
cmd: Union[str, List[str]],
label: Optional[str] = None,
check: bool = True,
env: Optional[Dict[str, Any]] = None,
timeout: int = 30,
) -> Tuple[int, str, str]:
"""
Run a command on the remote host.
:param cmd: Command to execute (either a string or list of arguments)
:type cmd: ``str`` or ``List[str]``
:param label: Label for identifying the command in logs (defaults to hash of command)
:type label: ``str`` or ``NoneType``
:param check: If True, raises exception when command fails
:type check: ``bool``
:param env: Environment variables to set for the command.
:type env: ``Dict`` or ``NoneType``
:param timeout: Timeout for command execution in seconds
:type timeout: ``int``
:return: Tuple of (exit_code, stdout_str, stderr_str)
:rtype: ``tuple`` of (``int``, ``str``, ``str``)
:raises: ``RuntimeError`` if command fails and check is True
"""
if isinstance(cmd, list):
cmd = [str(a) for a in cmd]
# for logging, easy to copy and execute
cmd_string = shlex.join(cmd)
else:
cmd_string = cmd
if not label:
label = str(hash(cmd_string) % 1000)
if env:
# Warning: paramiko exec_command may silently ignore some env var
# thus we prepend env to cmd string
env_str = " ".join(
f"{k}={shlex.quote(str(v))}" for k, v in env.items()
)
cmd_string = f"{env_str} {cmd_string}"
self.logger.debug(
"ssh_client executing command [%s]: '%s'",
label,
cmd_string,
)
start_time = time.time()
_, stdout, stderr = self.ssh_client.exec_command(
command=cmd_string,
timeout=timeout,
environment=env,
)
elapsed = time.time() - start_time
exit_code = stdout.channel.recv_exit_status()
stdout_str = stdout.read().decode("utf-8").strip()
stderr_str = stderr.read().decode("utf-8").strip()
if exit_code != 0:
self.logger.warning(
"Failed executing command [%s] after %.2f sec.", label, elapsed
)
if stdout_str:
self.logger.warning("Stdout:\n%s", stdout_str)
if stderr_str:
self.logger.warning("Stderr:\n%s", stderr_str)
if check:
raise RuntimeError(
f"Command '{cmd_string}' failed with exit code {exit_code}.\n"
f"Stdout:\n{stdout_str}\nStderr:\n{stderr_str}"
)
else:
self.logger.debug(
"Command [%s] executed successfully in %.2f sec.",
label,
elapsed,
)
return exit_code, stdout_str, stderr_str
[docs]
def open_sftp(self) -> paramiko.SFTPClient:
"""
Open an SFTP session.
:return: SFTP client object
:rtype: ``paramiko.sftp_client.SFTPClient``
"""
if self._sftp_client is not None:
self.logger.warning("SFTP session already open")
return self._sftp_client
self._sftp_client = self.ssh_client.open_sftp()
self.logger.debug("Opened SFTP session to %s:%s", self.host, self.port)
return self._sftp_client
[docs]
def listdir_iter(self, path: str) -> Any:
"""
List files in a directory on the remote host.
:param path: Path to the directory to list
:type path: ``str``
:return: Generator yielding file names in the directory
:rtype: ``generator`` of ``str``
"""
self.logger.debug("Listing directory: %s", path)
return self.sftp_client.listdir_iter(path)
[docs]
def open_file(self, path: str, mode: str) -> paramiko.SFTPFile:
"""
Open a file on the remote host using SFTP.
:param path: Path to the file to open
:type path: ``str``
:param mode: Mode in which to open the file (e.g., 'r', 'w', 'rb')
:type mode: ``str``
:return: File object for the remote file
:rtype: ``paramiko.sftp_file.SFTPFile``
"""
self.logger.debug("Opening remote file: %s in mode %s", path, mode)
return self.sftp_client.open(path, mode)
[docs]
def close(self) -> None:
"""
Close the SSH and SFTP connections.
:return: None
"""
if self._sftp_client:
self._sftp_client.close()
self._sftp_client = None
self.logger.debug("Closed SFTP session")
if self._ssh_client:
self._ssh_client.close()
self._ssh_client = None
self.logger.debug("Closed SSH connection")