"""This module contains utilites for testing Testplan itself."""

import collections
import functools
import io
import logging
import os
import pprint
import sys
import warnings
from contextlib import contextmanager

from lxml import objectify

from import Report, BaseReportGroup
from ..utils.comparison import is_regex

null_handler = logging.NullHandler()

[docs]def context_wrapper(ctx_manager, *ctx_args, **ctx_kwargs): """ Higher order function that returns a decorator that runs the wrapped func within the context of the given `ctx_manager` initialized by `ctx_args` and `ctx_kwargs` """ def _wrapper(func): @functools.wraps(func) def _inner(*args, **kwargs): with ctx_manager(*ctx_args, **ctx_kwargs): return func(*args, **kwargs) return _inner return _wrapper
[docs]@contextmanager def argv_overridden(*override_ctx): """ Override sys.argv for the given context. This is not a thread safe operation, may cause issues if you run tests in parallel threads. """ argv_backup = list(sys.argv) sys.argv = [argv_backup[0]] + list(override_ctx) yield sys.argv = argv_backup
[docs]def override_argv(*override_ctx): """Override sys.argv for the wrapped function.""" return context_wrapper(argv_overridden, *override_ctx)
[docs]@contextmanager def log_propagation_disabled(logger): """ Disables log propagation for the given logger. WARNING: Use this logic sparingly as it will hide actual exceptions from getting displayed in the console. A use case would be testing out exception logging to a custom target without showing these messages in the console itself, leaving us with a clean console output. """ old_prop = logger.propagate logger.propagate = False logger.addHandler(null_handler) # This prevents No handler found warning. yield logger.propagate = old_prop logger.removeHandler(null_handler)
[docs]def disable_log_propagation(logger): """Disables log propagation for the given logger.""" return context_wrapper(log_propagation_disabled, logger)
[docs]@contextmanager def log_level_changed(logger, level): """ Change log level for the given logger. """ old_level = logger.level logger.setLevel(level) yield logger.setLevel(old_level)
[docs]@contextmanager def captured_logging(logger, level=logging.INFO): """ Utility for capturing a logger object's output at a specific level, with a default level of INFO. Useful for command line output testing. """ class LogWrapper: def __init__(self): self.buffer = io.StringIO() self.stream_handler = logging.StreamHandler(self.buffer) self._output = None @property def output(self): if self._output is None: self.stream_handler.flush() # Standardize line endings self._output = self.buffer.getvalue().replace("\r\n", "\n") return self._output log_wrapper = LogWrapper() log_wrapper.stream_handler.setLevel(level) logger.addHandler(log_wrapper.stream_handler) yield log_wrapper logger.removeHandler(log_wrapper.stream_handler)
[docs]def to_stdout(*items): """ Utility function that can be used for testing logging output along with `captured_logging`. """ return "\n".join(items) + "\n"
[docs]@contextmanager def warnings_suppressed(): """Suppress warnings within a block""" with warnings.catch_warnings(): warnings.simplefilter("ignore") yield
[docs]def suppress_warnings(func): """Suppress warnings within a function""" return context_wrapper(warnings_suppressed)
[docs]def check_iterable( expected, actual, curr_path="ROOT", _orig_exp=None, _orig_act=None ): """ Utility for checking an iterable, supports custom func assertions along with normal value matches. """ _orig_act = _orig_act or actual _orig_exp = _orig_exp or expected def render_mismatch(act, exp, full_path): return ( "{linesep}" 'Mismatch: "{full_path}",{linesep}' "Expected value: {expected}{linesep}" "Actual value: {actual}{linesep}" "Expected Data:{linesep}{exp_data}{linesep}" "Actual Data:{linesep}{act_data}" ).format( full_path=full_path, linesep=os.linesep, expected=pprint.pformat(exp, indent=2), actual=pprint.pformat(act, indent=2), exp_data=pprint.pformat(_orig_exp, indent=2), act_data=pprint.pformat(_orig_act, indent=2), ) msg = render_mismatch(act=actual, exp=expected, full_path=curr_path) if isinstance(expected, (list, tuple)): for idx, (exp_item, act_item) in enumerate(zip(expected, actual)): check_iterable( expected=exp_item, actual=act_item, curr_path="{} | [{}]".format(curr_path, idx), _orig_exp=_orig_exp, _orig_act=_orig_act, ) elif isinstance(expected, dict): for key, value in expected.items(): check_iterable( actual=actual[key], expected=value, curr_path="{} | {}".format(curr_path, key), _orig_exp=_orig_exp, _orig_act=_orig_act, ) elif callable(expected): assert bool(expected(actual)), msg elif is_regex(expected): msg = render_mismatch( act=actual, exp=expected.pattern, full_path=curr_path ) assert, msg else: assert expected == actual, msg
[docs]def check_entry(expected, actual): """Utility function for comparing serialized entries.""" if expected["type"] == "Group": assert len(expected["entries"]) == len(actual["entries"]) for expected_child, actual_child in zip( expected["entries"], actual["entries"] ): check_entry(expected_child, actual_child) else: check_iterable(expected, actual, _orig_act=actual, _orig_exp=expected)
[docs]def check_report(expected, actual, skip=None): """ Utility function for comparing report objects. Skip uid attribute, entries will be checked recursively via `check_entry`. """ skip = skip or [] attrs = [ attr for attr in expected._get_comparison_attrs() if attr not in ["entries", "uid", "timer", "logs"] + skip ] for attr in attrs: exp_value = getattr(expected, attr) act_value = getattr(actual, attr) if isinstance(act_value, (list, dict, tuple)): try: check_iterable(exp_value, act_value) except AssertionError as err: msg = ( "Report name: {report_name}{linesep}" "Attribute: {attr}{linesep}" ).format( linesep=os.linesep, attr=attr, ) raise AssertionError( "{linesep}{report_msg}{error_msg}".format( linesep=os.linesep, report_msg=msg, error_msg=str(err) ) ) else: msg = "{}Expected: {} {}Actual: {}".format( os.linesep, expected, os.linesep, actual ) msg += '{}Mismatch: "{}", `{}` != `{}`'.format( os.linesep, attr, exp_value, act_value ) assert exp_value == act_value, msg if isinstance(expected, BaseReportGroup): msg = "{} {} {}".format( pprint.pformat(expected, indent=2), os.linesep, pprint.pformat(actual, indent=2), ) assert len(expected) == len(actual), msg for expected_child, actual_child in zip( expected.entries, actual.entries ): check_report(expected_child, actual_child, skip=skip) elif isinstance(expected, Report): msg = "{} {} {}".format( pprint.pformat(expected.entries, indent=2), os.linesep, pprint.pformat(actual.entries, indent=2), ) assert len(expected) == len(actual), msg for expected_entry, actual_entry in zip( expected.entries, actual.entries ): check_entry(expected_entry, actual_entry)
[docs]def check_report_context(report, ctx): """ Utility function for checking filtered/ordered test results, we are not interested in report contents, just the existence of reports with matching names, with the correct order. """ assert len(report) == len(ctx) for mt_report, (multitest_name, suite_ctx) in zip(report, ctx): assert == multitest_name assert len(mt_report) == len(suite_ctx) for suite_report, (suite_name, testcases) in zip(mt_report, suite_ctx): assert == suite_name assert len(suite_report) == len(testcases) for testcase_report, testcase_info in zip(suite_report, testcases): if isinstance(testcase_info, tuple): param_group_name, param_testcases = testcase_info assert == param_group_name assert [ for entry in testcase_report.entries ] == param_testcases else: assert == testcase_info
[docs]class XMLComparison: r""" Testing utility for generated XML file contents. Recursively compares children as well, supports simple string or regex matching. Usage: my_file.xml .. code-block:: xml <root foo="bar"> <parent id="0"/> <parent id="1"> <child hello="world" time="12:00"/> </parent> </root> .. code-block:: python comparison = XMLComparison( tag='root', foo='bar', children=[ XMLComparison(tag='parent', id='0'), XMLComparison( tag='parent', id='1' children=[ XMLComparison( tag='child', hello='world', time=re.compile('\d{2}:\d{2}') ) ] ), ] ) with open('my_file.xml') as xml_file: """ def __init__(self, tag, children=None, **kwargs): self.tag = tag self.children = children or [] self.attrib = kwargs def _compare_value(self, first, second, key): msg = ( "Attrib mismatch key: `{key}`, " "expected: `{expected}`, actual: `{actual}`" ) if is_regex(first): assert, msg.format( key=key, expected=first.pattern, actual=second ) else: assert first == second, msg.format( key=key, expected=first, actual=second ) def _compare_obj(self, xml_obj): self._compare_value(self.tag, xml_obj.tag, key="tag") for key, value in self.attrib.items(): self._compare_value(value, xml_obj.attrib[key], key) xml_children = xml_obj.getchildren() num_children = len(self.children) num_xml_children = len(xml_children) assert ( num_children == num_xml_children ), "Mismatching number of children ({} vs {})".format( num_children, num_xml_children ) for curr_child, xml_child in zip(self.children, xml_children): curr_child._compare_obj(xml_child)
[docs] def compare(self, xml_str, encoding="utf-8"): """Compare with xml string input.""" # fromstring complains when we have UTF declaration like # `<?xml version="1.0" encoding="utf-8" ?>`, so we use bytes instead # xml_bytes = bytes(bytearray(xml_str, encoding=encoding)) xml_obj = objectify.fromstring(xml_bytes) self._compare_obj(xml_obj)
[docs]class FixMessage(collections.OrderedDict): """ Basic FIX message for testing. A FIX message may be either typed or untyped. In other respects is acts like a plain dict. """ def __init__(self, *args, **kwargs): # Default to untyped if no typed_values kwarg was provided. self.typed_values = kwargs.pop("typed_values", False) super(FixMessage, self).__init__(*args, **kwargs)
[docs] def copy(self): """ Override copy() to return another FixMessage, preserving the typed_values attribute. """ return FixMessage(self.items(), typed_values=self.typed_values)