Source code for testplan.testing.multitest.driver.sqlite

"""Small wrapper driver around sqlite3 library."""

import os
import sqlite3
import functools
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

from contextlib import contextmanager

from testplan.common.config import ConfigOption
from testplan.common.utils.documentation_helper import emphasized

from .base import Driver, DriverConfig


[docs] class Sqlite3Config(DriverConfig): """ Configuration object for :py:class:`~testplan.testing.multitest.driver.sqlite.Sqlite3` resource. """
[docs] @classmethod def get_options(cls) -> Dict[Any, Any]: """ Schema for options validation and assignment of default values. """ return { "db_path": str, ConfigOption("connect_at_start", default=True): bool, }
def _rollback_on_error(func: Callable[..., Any]) -> Callable[..., Any]: """Rollback the databse if db operation raises.""" @functools.wraps(func) def wrap(self: Any, *args: Any) -> Any: try: return func(self, *args) except Exception as exc: self.logger.error( "Exception while executing: %s%s%s", args, os.sep, exc ) self.db.rollback() raise return wrap
[docs] class Sqlite3(Driver): """ Basic sqlite3 driver to add to a MultiTest environment, connect to a database and perform sql queries etc. {emphasized_members_docs} :param db_path: Path to the database file to connect to. In case a relative path is provided it will be appended to the runpath. :type db_path: ``str`` :param connect_at_start: Connect to the database when driver starts. Default: True :type connect_at_start: ``bool`` """ CONFIG = Sqlite3Config def __init__( self, name: str, db_path: str, connect_at_start: bool = True, **options: Any, ) -> None: options.update(self.filter_locals(locals())) super(Sqlite3, self).__init__(**options) self.db: Optional[sqlite3.Connection] = None self.cursor: Optional[sqlite3.Cursor] = None @emphasized # type: ignore[prop-decorator] @property def db_path(self) -> str: """Database file path.""" # if self.cfg.db_path is an absolute path it will return self.cfg.db_path return os.path.join(self.runpath, self.cfg.db_path)
[docs] def connect(self) -> None: """Connect to the database and set the internal db cursor.""" self.db = sqlite3.connect(self.db_path) self.cursor = self.db.cursor()
[docs] def starting(self) -> None: """ Start the driver. """ super(Sqlite3, self).starting() if self.cfg.connect_at_start: self.connect()
[docs] def stopping(self) -> None: """ Stop the driver. """ super(Sqlite3, self).stopping() if self.db: self.db.close()
[docs] def aborting(self, *args: Any, **kwargs: Any) -> None: """ Abort the driver. """ if self.db: self.db.close()
[docs] @contextmanager def commit_at_exit(self) -> Generator[None, None, None]: """ Context manager to perform operations and .commit() at exit. """ yield if self.db is None: raise RuntimeError("self.db must not be None") self.db.commit()
[docs] def commit(self) -> None: """Commit db changes.""" if self.db is None: raise RuntimeError("self.db must not be None") self.db.commit()
[docs] @_rollback_on_error def execute(self, *args: Any, **kwargs: Any) -> None: """Invoke cursor execute.""" if self.cursor is None: raise RuntimeError("self.cursor must not be None") self.cursor.execute(*args, **kwargs)
[docs] @_rollback_on_error def executemany(self, *args: Any) -> None: """Invoke cursor executemany.""" if self.cursor is None: raise RuntimeError("self.cursor must not be None") self.cursor.executemany(*args)
[docs] def fetchone(self) -> Optional[Any]: """Invoke cursor fetchone.""" if self.cursor is None: raise RuntimeError("self.cursor must not be None") return self.cursor.fetchone()
[docs] def fetchall(self) -> List[Any]: """Invoke cursor fetchall.""" if self.cursor is None: raise RuntimeError("self.cursor must not be None") return self.cursor.fetchall()
[docs] def fetch_table( self, table: str, columns: Optional[List[str]] = None ) -> List[List[Any]]: """ Fetch a table from the db. The first row will be the column names and the following rows will be the table rows. Returns a table like: .. code-block:: bash [ ['symbol', 'amount'], ['AAPL', 12], ['GOOG', 21], ['FB', 32], ['AMZN', 5], ['MSFT', 42] ] :param table: Table name in the db. :type table: ``str`` :param columns: Names of columns to be fetched. :type columns: ``list`` of ``str`` :return: The table contents. :rtype: ``list`` of ``list`` of values. """ if columns is None: self.execute("PRAGMA table_info({})".format(table)) if self.cursor is None: raise RuntimeError("self.cursor must not be None") columns = [str(col[1]) for col in self.cursor.fetchall()] self.execute("SELECT {} FROM {}".format(", ".join(columns), table)) if self.cursor is None: raise RuntimeError("self.cursor must not be None") result: List[List[Any]] = [columns] for row in self.cursor.fetchall(): result.append([item for item in row]) return result