Source code for testplan.common.remote.ssh_client

"""Module containing SSH client functionality for remote operations."""

import os
import time
import shlex
import logging
import paramiko
import getpass
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union

from testplan.common.utils.timing import retry_until_timeout
from testplan.common.utils.logger import TESTPLAN_LOGGER

logger = logging.getLogger(__name__)

DEFAULT_PARAMIKO_CONFIG = {
    "username": getpass.getuser(),
}


[docs] class SSHClient: """ SSH client for remote operations including file transfers and command execution. Wraps paramiko functionality in a convenient interface. """ def __init__( self, host: str, port: int = 22, logger: Optional[logging.Logger] = None, **args: Any, ) -> None: """ Initialize the SSH client. :param host: Host to connect to :type host: ``str`` :param port: Port to connect to :type port: ``int`` :param user: Username to connect with :type user: ``str`` :param password: Password to use for authentication :type password: ``str`` or ``NoneType`` :param key_path: Path to private key file to use for authentication :type key_path: ``str`` or ``NoneType`` """ self.host = host self.port = port self.logger = logger or TESTPLAN_LOGGER self.paramiko_args = { **DEFAULT_PARAMIKO_CONFIG, "hostname": self.host, "port": self.port, **args, } self._ssh_client: Optional[paramiko.SSHClient] = None self._sftp_client: Optional[paramiko.SFTPClient] = None @property def ssh_client(self) -> paramiko.SSHClient: """ Get the underlying paramiko SSH client. :return: Paramiko SSH client instance :rtype: ``paramiko.SSHClient`` """ if not self._ssh_client: self.connect() return self._ssh_client # type: ignore[return-value] @property def sftp_client(self) -> paramiko.SFTPClient: """ Get the underlying paramiko SFTP client. :return: Paramiko SFTP client instance :rtype: ``paramiko.sftp_client.SFTPClient`` """ if not self._sftp_client: self.open_sftp() return self._sftp_client # type: ignore[return-value]
[docs] def connect(self) -> paramiko.SSHClient: """ Establish an SSH connection. :return: Self for method chaining :rtype: ``SSHClient`` """ if self._ssh_client is not None: self.logger.warning("SSH connection already established") return self._ssh_client ssh_client = paramiko.SSHClient() ssh_client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy()) ssh_client.connect(**self.paramiko_args) self.logger.debug( "Connected to %s@%s:%s", self.paramiko_args["username"], self.paramiko_args["hostname"], self.paramiko_args["port"], ) self._ssh_client = ssh_client return self._ssh_client
[docs] def exec_command( self, cmd: Union[str, List[str]], label: Optional[str] = None, check: bool = True, env: Optional[Dict[str, Any]] = None, timeout: int = 30, ) -> Tuple[int, str, str]: """ Run a command on the remote host. :param cmd: Command to execute (either a string or list of arguments) :type cmd: ``str`` or ``List[str]`` :param label: Label for identifying the command in logs (defaults to hash of command) :type label: ``str`` or ``NoneType`` :param check: If True, raises exception when command fails :type check: ``bool`` :param env: Environment variables to set for the command. :type env: ``Dict`` or ``NoneType`` :param timeout: Timeout for command execution in seconds :type timeout: ``int`` :return: Tuple of (exit_code, stdout_str, stderr_str) :rtype: ``tuple`` of (``int``, ``str``, ``str``) :raises: ``RuntimeError`` if command fails and check is True """ if isinstance(cmd, list): cmd = [str(a) for a in cmd] # for logging, easy to copy and execute cmd_string = shlex.join(cmd) else: cmd_string = cmd if not label: label = str(hash(cmd_string) % 1000) if env: # Warning: paramiko exec_command may silently ignore some env var # thus we prepend env to cmd string env_str = " ".join( f"{k}={shlex.quote(str(v))}" for k, v in env.items() ) cmd_string = f"{env_str} {cmd_string}" self.logger.debug( "ssh_client executing command [%s]: '%s'", label, cmd_string, ) start_time = time.time() _, stdout, stderr = self.ssh_client.exec_command( command=cmd_string, timeout=timeout, environment=env, ) elapsed = time.time() - start_time exit_code = stdout.channel.recv_exit_status() stdout_str = stdout.read().decode("utf-8").strip() stderr_str = stderr.read().decode("utf-8").strip() if exit_code != 0: self.logger.warning( "Failed executing command [%s] after %.2f sec.", label, elapsed ) if stdout_str: self.logger.warning("Stdout:\n%s", stdout_str) if stderr_str: self.logger.warning("Stderr:\n%s", stderr_str) if check: raise RuntimeError( f"Command '{cmd_string}' failed with exit code {exit_code}.\n" f"Stdout:\n{stdout_str}\nStderr:\n{stderr_str}" ) else: self.logger.debug( "Command [%s] executed successfully in %.2f sec.", label, elapsed, ) return exit_code, stdout_str, stderr_str
[docs] def open_sftp(self) -> paramiko.SFTPClient: """ Open an SFTP session. :return: SFTP client object :rtype: ``paramiko.sftp_client.SFTPClient`` """ if self._sftp_client is not None: self.logger.warning("SFTP session already open") return self._sftp_client self._sftp_client = self.ssh_client.open_sftp() self.logger.debug("Opened SFTP session to %s:%s", self.host, self.port) return self._sftp_client
[docs] def listdir_iter(self, path: str) -> Any: """ List files in a directory on the remote host. :param path: Path to the directory to list :type path: ``str`` :return: Generator yielding file names in the directory :rtype: ``generator`` of ``str`` """ self.logger.debug("Listing directory: %s", path) return self.sftp_client.listdir_iter(path)
[docs] def open_file(self, path: str, mode: str) -> paramiko.SFTPFile: """ Open a file on the remote host using SFTP. :param path: Path to the file to open :type path: ``str`` :param mode: Mode in which to open the file (e.g., 'r', 'w', 'rb') :type mode: ``str`` :return: File object for the remote file :rtype: ``paramiko.sftp_file.SFTPFile`` """ self.logger.debug("Opening remote file: %s in mode %s", path, mode) return self.sftp_client.open(path, mode)
[docs] def close(self) -> None: """ Close the SSH and SFTP connections. :return: None """ if self._sftp_client: self._sftp_client.close() self._sftp_client = None self.logger.debug("Closed SFTP session") if self._ssh_client: self._ssh_client.close() self._ssh_client = None self.logger.debug("Closed SSH connection")