"""
Class based assertions, these will be serialized into native dicts via
marshmallow schemas.
An assertion object will call ``evaluate`` on instantiation and will
use the result of that call to set its ``passed`` attribute.
"""
import cmath
import collections
import decimal
import lxml
import numbers
import operator
import os
import pprint
import re
import subprocess
import sys
import tempfile
from dataclasses import dataclass
from typing import Dict, Hashable, List, Optional
from testplan.common.utils.convert import make_tuple, flatten_dict_comparison
from testplan.common.utils import comparison, difflib
from testplan.common.utils.process import subprocess_popen
from testplan.common.utils.strings import map_to_str
from testplan.common.utils.table import TableEntry
from .base import BaseEntry
__all__ = [
"Assertion",
"RawAssertion",
"IsTrue",
"IsFalse",
"Fail",
"FuncAssertion",
"Equal",
"NotEqual",
"Less",
"LessEqual",
"Greater",
"GreaterEqual",
"IsClose",
"Contain",
"NotContain",
"RegexAssertion",
"RegexMatch",
"RegexMatchNotExists",
"RegexSearch",
"RegexSearchNotExists",
"RegexFindIter",
"RegexMatchLine",
"ExceptionRaised",
"EqualSlices",
"EqualExcludeSlices",
"LineDiff",
"ColumnContain",
"TableMatch",
"TableDiff",
"XMLCheck",
"DictCheck",
"DictMatch",
"DictMatchAll",
"FixCheck",
"FixMatch",
"FixMatchAll",
"LogfileMatch",
]
[docs]class Assertion(BaseEntry):
meta_type = "assertion"
def __init__(self, description=None, category=None, flag=None):
super(Assertion, self).__init__(
description=description, category=category, flag=flag
)
self.passed = bool(self.evaluate())
[docs] def evaluate(self):
raise NotImplementedError
def __bool__(self):
return self.passed
[docs]class RawAssertion(Assertion):
"""
This class is used for creating explicit pass/fail entries
with custom content.
Its content will be displayed preformatted, so it's useful for
integration with 3rd party testing libraries (unittest, qunit etc).
"""
def __init__(self, passed, content, description=None, category=None):
self._passed_override = passed
self.content = content
super(RawAssertion, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
return self._passed_override
[docs]class IsTrue(Assertion):
def __init__(self, expr, description=None, category=None):
self.expr = expr
super(IsTrue, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
return bool(self.expr)
[docs]class IsFalse(IsTrue):
[docs] def evaluate(self):
return not bool(self.expr)
[docs]class Fail(Assertion):
def __init__(
self, description=None, category=None, flag=None, message=None
):
if isinstance(message, str):
self.message = message
elif isinstance(message, bytes):
self.message = message.decode()
else:
self.message = pprint.pformat(message)
if not description:
description = next((l for l in self.message.split("\n") if l), "")
if len(description) > 80:
description = description[0:80] + "..."
super(Fail, self).__init__(
description=description, category=category, flag=flag
)
[docs] def evaluate(self):
return False
[docs]class FuncAssertion(Assertion):
func = None
def __init__(self, first, second, description=None, category=None):
self.first = first
self.second = second
if not description:
description = "{} {} {}".format(
(str(self.first)[0:30] + "...")
if len(str(self.first)) > 30
else self.first,
self.label,
(str(self.second)[0:30] + "...")
if len(str(self.second)) > 30
else self.second,
)
super(FuncAssertion, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
# pylint: disable=not-callable
return self.func(self.first, self.second)
# pylint: enable=not-callable
[docs]class Equal(FuncAssertion):
label = "=="
func = operator.eq
def __init__(self, first, second, description=None, category=None):
self.type_actual = type(first).__name__
self.type_expected = type(second).__name__
super(Equal, self).__init__(
first=first,
second=second,
description=description,
category=category,
)
[docs]class NotEqual(FuncAssertion):
label = "!="
func = operator.ne
[docs]class Less(FuncAssertion):
label = "<"
func = operator.lt
[docs]class LessEqual(FuncAssertion):
label = "<="
func = operator.le
[docs]class Greater(FuncAssertion):
label = ">"
func = operator.gt
[docs]class GreaterEqual(FuncAssertion):
label = ">="
func = operator.ge
[docs]class IsClose(Assertion):
label = "~="
def __init__(
self,
first,
second,
rel_tol=1e-09,
abs_tol=0.0,
description=None,
category=None,
):
if not isinstance(first, numbers.Number) or not isinstance(
second, numbers.Number
):
raise ValueError("`first` and `second` must be numbers.")
if (
not isinstance(rel_tol, (numbers.Real, decimal.Decimal))
or not isinstance(abs_tol, (numbers.Real, decimal.Decimal))
or rel_tol < 0
or abs_tol < 0
):
raise ValueError("`rel_tol` and `abs_tol` must be non-negative.")
self.first = first
self.second = second
self.rel_tol = rel_tol
self.abs_tol = abs_tol
if not description:
description = "{} {} {}".format(
self.first, self.label, self.second
)
super(IsClose, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
if self.first == self.second:
return True
if cmath.isinf(self.first) or cmath.isinf(self.second):
return False
diff = abs(self.second - self.first)
return (
(diff <= abs(self.rel_tol * self.first))
or (diff <= abs(self.rel_tol * self.second))
) or (diff <= self.abs_tol)
[docs]class Contain(Assertion):
def __init__(self, member, container, description=None, category=None):
self.member = member
self.container = container
super(Contain, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
return self.member in self.container
[docs]class NotContain(Contain):
[docs] def evaluate(self):
return self.member not in self.container
[docs]class RegexAssertion(Assertion):
def __init__(
self, regexp, string, flags=0, description=None, category=None
):
if isinstance(regexp, re.Pattern):
if flags != 0:
raise ValueError(
"`flags` argument is redundant if"
" `regexp` is of type `re.Pattern`"
)
self.pattern = regexp.pattern
self.regexp = regexp
else:
self.pattern = regexp
self.regexp = re.compile(regexp, flags=flags)
self.string = string
self.match_indexes = []
super(RegexAssertion, self).__init__(
description=description, category=category
)
# after evaluate(), convert string & pattern to str if they are bytes
self.string = map_to_str(self.string)
self.pattern = map_to_str(self.pattern)
[docs] def get_regex_result(self):
raise NotImplementedError
[docs] def evaluate(self):
result = self.get_regex_result()
if result:
self.match_indexes.append((result.start(), result.end()))
return bool(result)
[docs]class RegexMatch(RegexAssertion):
[docs] def get_regex_result(self):
return self.regexp.match(self.string)
[docs]class RegexMatchNotExists(RegexMatch):
[docs] def evaluate(self):
return not super(RegexMatchNotExists, self).evaluate()
[docs]class RegexSearch(RegexAssertion):
[docs] def get_regex_result(self):
return self.regexp.search(self.string)
[docs]class RegexSearchNotExists(RegexSearch):
[docs] def evaluate(self):
return not super(RegexSearchNotExists, self).evaluate()
[docs]class RegexFindIter(RegexAssertion):
def __init__(
self,
regexp,
string,
flags=0,
condition=None,
description=None,
category=None,
):
self.condition = condition
self.condition_match = None # may be set by self.evaluate
super(RegexFindIter, self).__init__(
regexp, string, flags, description=description, category=category
)
[docs] def evaluate(self):
result = list(self.regexp.finditer(self.string))
for match in result:
self.match_indexes.append((match.start(), match.end()))
if self.condition:
self.condition_match = self.condition(len(result))
return bool(self.condition_match)
return bool(result)
[docs]class RegexMatchLine(RegexAssertion):
"""
Match indexes are a little bit different than other
assertions for this one: (line_no, begin, end)
"""
def __init__(
self, regexp, string, flags=0, description=None, category=None
):
self.lines = None
super(RegexMatchLine, self).__init__(
regexp,
string,
flags=flags,
description=description,
category=category,
)
self.lines = map(map_to_str, self.lines)
[docs] def evaluate(self):
if isinstance(self.string, bytes):
self.lines = self.string.split(os.linesep.encode())
else:
self.lines = self.string.split(os.linesep)
for line_num, line in enumerate(self.lines):
match = self.regexp.match(line)
if match:
self.match_indexes.append(
(line_num, match.start(), match.end())
)
return self.match_indexes
[docs]class ExceptionRaised(Assertion):
"""TODO"""
def __init__(
self,
raised_exception,
expected_exceptions,
pattern=None,
func=None,
description=None,
category=None,
):
expected_exceptions = make_tuple(expected_exceptions)
assert expected_exceptions, "`expected_exceptions` cannot be empty."
assert [
issubclass(exc, Exception) for exc in expected_exceptions
], "items in `expected_exceptions` must be subclass of `Exception` ."
if func:
assert callable(func), "`func` must be a callable."
if pattern:
assert isinstance(
pattern, str
), "`pattern` must be of string type, it was: {}".format(
type(pattern)
)
self.raised_exception = raised_exception
self.expected_exceptions = expected_exceptions
self.pattern = pattern
self.func = func
# These will be set by `evaluate`
self.exception_match = None
self.pattern_match = None
self.func_match = None
super(ExceptionRaised, self).__init__(
description=description, category=category
)
[docs] def get_match_context(self):
exception_match = isinstance(
self.raised_exception, self.expected_exceptions
)
pattern_match = self.pattern is None or re.search(
self.pattern, str(self.raised_exception)
)
func_match = self.func is None or self.func(self.raised_exception)
return exception_match, pattern_match, func_match
[docs] def evaluate(self):
match_ctx = self.get_match_context()
self.exception_match, self.pattern_match, self.func_match = match_ctx
return all(match_ctx)
class ExceptionNotRaised(ExceptionRaised):
def evaluate(self):
return not super(ExceptionNotRaised, self).evaluate()
_SliceComparison = collections.namedtuple(
"_SliceComparison",
"slice comparison_indices mismatch_indices actual expected",
)
class SliceComparison(_SliceComparison):
"""
Simple data container that will be generated
as a result of EqualSlice / EqualExcludeSlices assertions
Attributes:
slice: Slice object
comparison_indices: List of integers, may correspond to the indices
that sit inside or outside the given slice,
depending on the assertion type.
mismatch_indices: List of integers that correspond to
items that fail equality check.
actual: Original iterable, may be made up of items that sit inside or
outside the slice range depending on assertion type.
expected: Expected iterable, may be made up of items that sit inside or
outside the slice range depending on assertion type.
"""
@property
def passed(self):
return not self.mismatch_indices
[docs]class EqualSlices(Assertion):
"""
Assertion that checks if the given slices of two iterables match.
Generates a list of SliceComparison objects as data.
"""
def __init__(
self, actual, expected, slices, description=None, category=None
):
assert slices and isinstance(slices, (list, tuple))
self.actual = actual
self.expected = expected
self.slices = slices
self.data = [] # will be populated via self.evaluate
self.included_indices = set()
super(EqualSlices, self).__init__(
description=description, category=category
)
[docs] def get_comparison_indices(self, slice_obj, iterable):
"""
Generate a list of indices to be used
for comparison for the given slice and iterable.
"""
return range(*slice_obj.indices(len(iterable)))
[docs] def get_iterable(self, iterable, comparison_indices):
"""
Generate the iterable that is being used
for the current slice comparison
"""
items = [
i for idx, i in enumerate(iterable) if idx in comparison_indices
]
if isinstance(iterable, str):
return "".join(items)
return type(iterable)(items)
[docs] def generate_data(self, slices, actual, expected):
"""Build a list of ``SliceComparison`` objects, for each slice."""
result = []
for slice_ in slices:
indices = self.get_comparison_indices(slice_, expected)
mismatch_indices = [
idx for idx in indices if actual[idx] != expected[idx]
]
result.append(
SliceComparison(
slice=slice_,
comparison_indices=sorted(indices),
mismatch_indices=sorted(mismatch_indices),
actual=self.get_iterable(actual, indices),
expected=self.get_iterable(expected, indices),
)
)
return result
[docs] def evaluate(self):
"""Equal slices assertion passes if all slice comparisons pass."""
actual, expected = self.actual, self.expected
if len(actual) != len(expected):
return False
self.data = self.generate_data(self.slices, actual, expected)
return all(comp.passed for comp in self.data)
[docs]class EqualExcludeSlices(EqualSlices):
"""
Assertion that checks if the items that are outside
slices of two iterables match.
Generates a list of SliceComparison objects as data.
"""
[docs] def get_comparison_indices(self, slice_obj, iterable):
indices = super(EqualExcludeSlices, self).get_comparison_indices(
slice_obj, iterable
)
return set(range(len(iterable))) - set(indices)
[docs] def evaluate(self):
"""
Slice exclusion evaluation generates SliceComparison data and
explicitly checks if items in the merged exclusion indices match or not.
"""
actual, expected = self.actual, self.expected
if len(actual) != len(expected):
return False
self.data = self.generate_data(self.slices, actual, expected)
# Slice exclusion check is a little bit more tricky,
# as a slice comparison in this assertion's context means comparing
# all items that sit outside the slice range, meaning we can have
# failing SliceComparisons, but the overall assertion can still pass
# if all items at merged comparison indices match.
# Example:
# slices = [slice(0, 2), slice(5, 7)]
# actual = [0, 1, 2, 3, 4, 5, 6]
# expected = ['a', 'b', 2, 3, 4, 'c', 'd']
# This would produce 2 SliceComparisons that fail:
# slice(0, 2) ==> [5, 6] != ['c', 'd']
# slice(5, 7) ==> [0, 1] != ['a', 'b']
# However the merged comparison indices of these
# two slices are [2, 3, 4], which correspond to the same iterable:
# [2, 3, 4] == [2, 3, 4], so the overall assertion passes.
ranges = [
range(*slice_.indices(len(self.expected))) # could just use method
for slice_ in self.slices
]
excluded_indices = {idx for range_ in ranges for idx in range_}
self.included_indices = set(range(len(expected))) - excluded_indices
return all(
[actual[idx] == expected[idx] for idx in self.included_indices]
)
[docs]class LineDiff(Assertion):
"""
Assertion that checks if 2 blocks of textual content have difference.
If difference found, generates a list of strings as data.
"""
def __init__(
self,
first,
second,
ignore_space_change=False,
ignore_whitespaces=False,
ignore_blank_lines=False,
unified=False,
context=False,
description=None,
category=None,
):
if (not isinstance(first, (str, list))) or (
not isinstance(second, (str, list))
):
raise ValueError("`first` and `second` must be string or list.")
if isinstance(unified, int) and unified < 0:
raise ValueError("`unified` cannot be negative integer.")
if isinstance(context, int) and context < 0:
raise ValueError("`context` cannot be negative integer.")
self.first = (
first.splitlines(True) if isinstance(first, str) else first
)
self.second = (
second.splitlines(True) if isinstance(second, str) else second
)
self.ignore_space_change = ignore_space_change
self.ignore_whitespaces = ignore_whitespaces
self.ignore_blank_lines = ignore_blank_lines
self.unified = unified
self.context = context
self.delta = [] # will be populated via self.evaluate
super(LineDiff, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
if sys.platform != "win32":
self.delta = self._diff_process().splitlines(True)
else:
self.delta = list(self._diff_difflib())
return self.delta == []
def _diff_difflib(self):
out = difflib.diff(
self.first,
self.second,
ignore_space_change=self.ignore_space_change,
ignore_whitespaces=self.ignore_whitespaces,
ignore_blank_lines=self.ignore_blank_lines,
unified=self.unified,
context=self.context,
)
return out
def _diff_process(self):
first = "".join(self.first)
second = "".join(self.second)
with tempfile.NamedTemporaryFile(
delete=False,
mode="w",
) as first_file:
first_file.write(first)
with tempfile.NamedTemporaryFile(
delete=False,
mode="w",
) as second_file:
second_file.write(second)
cmd = ["diff"]
if self.ignore_space_change:
cmd.append("-b")
if self.ignore_whitespaces:
cmd.append("-w")
if self.ignore_blank_lines:
cmd.append("-B")
if self.unified:
cmd.append("-u")
if self.context:
cmd.append("-c")
cmd.extend([first_file.name, second_file.name])
handler = subprocess_popen(
cmd,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
universal_newlines=True, # otherwise we get a byte stream
)
out, err = handler.communicate()
if handler.returncode == 2:
raise RuntimeError(err)
os.unlink(first_file.name)
os.unlink(second_file.name)
return out
ColumnContainComparison = collections.namedtuple(
"ColumnContainComparison", "idx value passed"
)
[docs]class ColumnContain(Assertion):
"""
Checks if the any of the ``value`` in ``values``
exists in the ``column`` of ``table``.
"""
def __init__(
self,
table,
values,
column,
limit=0,
report_fails_only=False,
description=None,
category=None,
):
self.table = TableEntry(table).as_list_of_dict()
self.values = values
self.column = column
self.limit = limit
self.report_fails_only = report_fails_only
self.data = [] # will be set by evaluate
super(ColumnContain, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
passed = True
for idx, row in enumerate(self.table):
comp_obj = ColumnContainComparison(
idx=idx,
value=row[self.column],
passed=row[self.column] in self.values,
)
if not comp_obj.passed:
passed = False
if not self.report_fails_only or (
self.report_fails_only and not comp_obj.passed
):
self.data.append(comp_obj)
if self.limit and len(self.data) >= self.limit:
break
return passed
_RowComparison = collections.namedtuple(
"_RowComparison", "idx data diff errors extra"
)
class RowComparison(_RowComparison):
"""
Named tuple that stores the data and comparison results of two tables.
The column values from the first table is stored in the `data` attribute.
If there are diffs, errors or custom comparators on the second
table's row, diff / errors / extra dicts will be populated accordingly.
We can then use this information to render two tables completely.
idx: Index of the row on the table
data (list): Column values of the original table.
diff (dict): Diff context of the second table's row
(key: column name, value: second table value
OR comparator representation)
errors (dict): Errors raised during the comparison.
(key: column name, value: error stack trace text)
extra (dict): Comparator representations of the second table's row,
if there is any. This field will be populated if we use a
custom comparator that returns True OR the other column has
different value but is included only as a display column.
"""
@property
def passed(self):
"""Row comparison passes if there are no diffs or errors."""
return not (self.diff or self.errors)
def get_comparison_value(self, column, column_idx):
"""
Return the comparison value (e.g. other
side of the match) and match status.
"""
if column in self.diff:
return self.diff[column], False
elif column in self.errors:
return self.errors[column], False
elif column in self.extra:
return self.extra[column], None
return self.data[column_idx], True
def get_comparison_columns(
columns_1, columns_2, include_columns, exclude_columns
):
"""
Given two tables and inclusion / exclusion rules, return a
list of columns that will be used for comparison.
Inclusion/exclusion rules apply to both tables, the resulting sub-tables
must have matching columns.
:param columns_1: Columns of the first table
:type columns_1: ``list```
:param columns_2: Columns of the second table
:type columns_2: ``list``
:param include_columns: Inclusion rules for columns.
:type include_columns: ``list`` of ``str``
:param exclude_columns: Exclusion rules for columns.
:type exclude_columns: ``list`` of ``str``
"""
def check_missing_columns(columns, lookup):
"""Check if ``columns`` have any missing elements from ``lookup``."""
diff = set(lookup) - set(columns)
if diff:
raise ValueError("Missing columns: {}".format(", ".join(diff)))
if include_columns and exclude_columns:
raise ValueError(
"Either use `include_columns` or `exclude_columns`, not"
" both. (include_columns: {}, exclude_columns: {})".format(
include_columns, exclude_columns
)
)
comparison_columns = columns_1
if include_columns:
check_missing_columns(columns_1, lookup=include_columns)
check_missing_columns(columns_2, lookup=include_columns)
comparison_columns = [c for c in columns_1 if c in include_columns]
elif exclude_columns:
columns_1 = [c for c in columns_1 if c not in exclude_columns]
columns_2 = [c for c in columns_2 if c not in exclude_columns]
if set(columns_1) != set(columns_2):
raise ValueError(
'Table columns ("{}", "{}") do not match after'
' applying exclusion rules: "{}"'.format(
", ".join(sorted(columns_1)),
", ".join(sorted(columns_2)),
", ".join(exclude_columns),
)
)
comparison_columns = columns_1
elif set(columns_1) != set(columns_2):
raise ValueError(
'Table columns ("{}", "{}") do not match,'
" consider using `include_columns` or"
" `exclude_columns` arguments.".format(
", ".join(sorted(columns_1)), ", ".join(sorted(columns_2))
)
)
return comparison_columns
def compare_rows(
table,
expected_table,
comparison_columns,
display_columns,
strict=True,
fail_limit=0,
report_fails_only=False,
):
"""
Apply row by row comparison of two tables,
creating a ``RowComparison`` for each row couple.
:param table: Original table.
:type table: ``list`` of ``dict``
:param expected_table: Comparison table, it can contain
custom comparators as column values.
:type expected_table: ``list`` of ``dict``
:param comparison_columns: Columns to be used for comparison.
:type comparison_columns: ``list`` of ``str``
:param display_columns: Columns to be used
for populating ``RowComparison`` data.
:type display_columns: ``list`` of ``str``
:param strict: Custom comparator strictness flag, currently will
auto-convert non-str values to
``str`` for pattern if ``False``.
:type strict: ``bool``
:param fail_limit: Max number of failures before aborting
the comparison run. Useful for large
tables, when we want to stop after we have N rows
that fail the comparison.
:type fail_limit: ``int``
:param report_fails_only: If ``True``, only repoty the failures (used
for diff typically)
:type report_fails_only: ``bool``
:returns: overall passed status and RowComparison data.
"""
# We always want to display a superset of comparison columns
# otherwise we can have a failing comparison but the
# resulting data will not include the mismatch context.
if not set(comparison_columns).issubset(display_columns):
raise ValueError(
"comparison_columns ({}) must be "
"subset of display_columns ({})".format(
", ".join(sorted(comparison_columns)),
", ".join(sorted(display_columns)),
)
)
data = []
num_failures = 0
display_only = [
col for col in display_columns if col not in comparison_columns
]
for idx, (row_1, row_2) in enumerate(zip(table, expected_table)):
diff, errors, extra = {}, {}, {}
for column_name in comparison_columns:
if column_name not in row_1 and column_name not in row_2:
continue
elif (
column_name in row_1
and column_name not in row_2
or column_name not in row_1
and column_name in row_2
):
diff[column_name] = row_2.get(column_name, None)
continue
first, second = row_1[column_name], row_2[column_name]
passed, error = comparison.basic_compare(
first=first, second=second, strict=strict
)
if error:
errors[column_name] = error
elif not passed:
diff[column_name] = second
# Populate extra if values differ (we don't check for equality as
# that may have raised an error for incompatible types as well)
if first is not second and (error or passed):
extra[column_name] = second
row_data = [row_1.get(col, None) for col in display_columns]
# Need to populate `extra` with values from the second table
# they are not being used for comparison but for display.
extra.update({col: row_2.get(col, None) for col in display_only})
row_comparison = RowComparison(idx, row_data, diff, errors, extra)
if not (report_fails_only and row_comparison.passed):
data.append(row_comparison)
if not row_comparison.passed:
num_failures += 1
if fail_limit > 0 and num_failures >= fail_limit:
break
return num_failures == 0, data
[docs]class TableMatch(Assertion):
"""
Match two tables using ``compare_rows``, may generate
custom message if tables cannot be compared for certain reasons.
"""
def __init__(
self,
table,
expected_table,
include_columns=None,
exclude_columns=None,
report_all=True,
fail_limit=0,
report_fail_only=False,
strict=False,
description=None,
category=None,
):
table_entry = TableEntry(table)
expected_table_entry = TableEntry(expected_table)
self.table = table_entry.as_list_of_dict()
self.table_columns = table_entry.columns
self.expected_table = expected_table_entry.as_list_of_dict()
self.expected_table_columns = expected_table_entry.columns
self.include_columns = include_columns
self.exclude_columns = exclude_columns
self.strict = strict
self.report_all = report_all
self.fail_limit = fail_limit
self.report_fails_only = report_fail_only
# these will populated by self.evaluate
self.display_columns = []
self.message = None
self.data = []
super(TableMatch, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
len_table, len_expected = len(self.table), len(self.expected_table)
if len_table != len_expected:
self.message = (
"Cannot run comparison on tables with different number "
"of rows ({} vs {}), make sure tables have the same size."
).format(len_table, len_expected)
return False
if not (self.table or self.expected_table):
self.message = "Both tables are empty."
return True
try:
comparison_columns = get_comparison_columns(
columns_1=self.table_columns,
columns_2=self.expected_table_columns,
include_columns=self.include_columns,
exclude_columns=self.exclude_columns,
)
except ValueError as exc:
self.message = str(exc)
return False # Fail on invalid tables
self.display_columns = (
self.table_columns if self.report_all else comparison_columns
)
passed, self.data = compare_rows(
table=self.table,
expected_table=self.expected_table,
comparison_columns=comparison_columns,
display_columns=self.display_columns,
strict=self.strict,
fail_limit=self.fail_limit,
report_fails_only=self.report_fails_only,
)
return passed
[docs]class TableDiff(TableMatch):
"""
Match two tables using ``compare_rows`` but only keep
failing comparisons, may generate custom message if tables
cannot be compared for certain reasons.
"""
pass
_XMLTagComparison = collections.namedtuple(
"_XMLTagComparison", "tag diff error extra"
)
class XMLTagComparison(_XMLTagComparison):
"""
Named tuple that stores the data and comparison results XML tags.
"""
@property
def passed(self):
"""Tag comparison passes if there are no diff or error."""
return not (self.diff or self.error)
@property
def comparison_value(self):
result = self.error or self.diff or self.extra or self.tag
if comparison.is_regex(result):
result = "REGEX('{}')".format(result.pattern)
return result
[docs]class XMLCheck(Assertion):
"""
Validate XML tag texts or existence in a given xpath,
supports regex patterns as tag values as well.
"""
def __init__(
self,
element,
xpath,
tags=None,
namespaces=None,
description=None,
category=None,
):
self.xpath = xpath
self.tags = tags
if isinstance(element, str):
element = lxml.etree.fromstring(element)
# pylint: disable=protected-access
elif not isinstance(element, lxml.etree._Element):
raise ValueError(
"`element` must be either an XML"
" string or `lxml.etree.Element`."
" It was of type: {}".format(type(element))
)
self.element = element
self.namespaces = namespaces
self.data = [] # will be populated by evaluate
self.message = None # will be populated by evaluate
super(XMLCheck, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
element, namespaces = self.element, self.namespaces
xpath, tags = self.xpath, self.tags
# This may raise XPathEvalError for incorrect namespacing
results = element.xpath(xpath, namespaces=namespaces)
# xpath does not exist in XML
if not results:
self.message = "xpath: `{}` does not" " exist in the XML.".format(
xpath
)
return False
# xpath exists, no tag lookup -> Pass
if not tags:
self.message = "xpath: `{}` exists in the XML.".format(xpath)
return True
data = []
# Tag lookup in xpath
for idx, tag in enumerate(tags):
try:
text = results[idx].text
if not text:
xml_comp = XMLTagComparison(
tag=tag,
diff=None,
error="No value is found,"
" although the path exists.",
extra=None,
)
elif isinstance(tag, str) and re.match(tag, text):
extra = tag if tag != text else None
xml_comp = XMLTagComparison(
tag=text, diff=None, error=None, extra=extra
)
else:
passed, error = comparison.basic_compare(
first=text, second=tag
)
if error:
xml_comp = XMLTagComparison(
tag=text, diff=None, error=error, extra=tag
)
elif not passed:
xml_comp = XMLTagComparison(
tag=text, diff=tag, error=None, extra=None
)
else:
xml_comp = XMLTagComparison(
tag=text, diff=None, error=None, extra=tag
)
except IndexError:
xml_comp = XMLTagComparison(
tag=None,
diff=tag,
error="No tags found for the index: {}".format(idx),
extra=None,
)
data.append(xml_comp)
self.data = data
return all([comp.passed for comp in self.data])
[docs]class DictCheck(Assertion):
"""
Assertion that checks if a given ``dict`` contains
(or does not contain) given keys.
"""
def __init__(
self,
dictionary,
has_keys=None,
absent_keys=None,
description=None,
category=None,
):
self.dictionary = dictionary
self.has_keys = has_keys
self.absent_keys = absent_keys
self.has_keys_diff = None # will be set by evaluate
self.absent_keys_diff = None # will be set by evaluate
super(DictCheck, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
result = comparison.check_dict_keys(
data=self.dictionary,
has_keys=self.has_keys,
absent_keys=self.absent_keys,
)
self.has_keys_diff, self.absent_keys_diff = result
return not (self.has_keys_diff or self.absent_keys_diff)
[docs]class FixCheck(DictCheck):
"""
Similar to DictCheck, however dict keys
will have fix tag info popups on web UI
"""
def __init__(
self,
msg,
has_tags=None,
absent_tags=None,
description=None,
category=None,
):
super(FixCheck, self).__init__(
dictionary=msg,
has_keys=has_tags,
absent_keys=absent_tags,
description=description,
category=category,
)
[docs]class DictMatch(Assertion):
"""
Match two dictionaries by comparing values under
each key recursively.
"""
def __init__(
self,
value: Dict,
expected: Dict,
include_only_expected: bool = False,
include_keys: List[Hashable] = None,
exclude_keys: List[Hashable] = None,
report_mode=comparison.ReportOptions.ALL,
description: str = None,
category: str = None,
actual_description: str = None,
expected_description: str = None,
value_cmp_func=comparison.COMPARE_FUNCTIONS["native_equality"],
):
self.value = value
self.expected = expected
self.include_only_expected = include_only_expected
self.include_keys = include_keys
self.exclude_keys = exclude_keys
self.actual_description = actual_description
self.expected_description = expected_description
self._report_mode = report_mode
self._value_cmp_func = value_cmp_func
self.comparison = None # will be set by evaluate
super(DictMatch, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
"""Evaluate the dict match."""
passed, cmp_result = comparison.compare(
lhs=self.value,
rhs=self.expected,
ignore=self.exclude_keys,
include=self.include_keys,
report_mode=self._report_mode,
value_cmp_func=self._value_cmp_func,
include_only_rhs=self.include_only_expected,
)
self.comparison = flatten_dict_comparison(cmp_result)
return passed
[docs]class FixMatch(DictMatch):
"""
Similar to DictMatch, however dict keys
will have fix tag info popups on web UI
"""
def __init__(
self,
value: Dict,
expected: Dict,
include_only_expected: bool = False,
include_tags: List[Hashable] = None,
exclude_tags: List[Hashable] = None,
report_mode=comparison.ReportOptions.ALL,
description: str = None,
category: str = None,
actual_description: str = None,
expected_description: str = None,
):
"""
If both FIX messages are typed, we enable strict type checking.
Otherwise, if either side is untyped we will compare the values as
strings.
"""
typed_value = getattr(value, "typed_values", False)
typed_expected = getattr(expected, "typed_values", False)
if typed_value and typed_expected:
value_cmp_func = comparison.COMPARE_FUNCTIONS["check_types"]
else:
value_cmp_func = comparison.COMPARE_FUNCTIONS["untyped_fixtag"]
super(FixMatch, self).__init__(
value=value,
expected=expected,
include_only_expected=include_only_expected,
include_keys=include_tags,
exclude_keys=exclude_tags,
report_mode=report_mode,
description=description,
category=category,
actual_description=actual_description,
expected_description=expected_description,
value_cmp_func=value_cmp_func,
)
[docs]class DictMatchAll(Assertion):
def __init__(
self,
values,
comparisons,
key_weightings=None,
description=None,
category=None,
value_cmp_func=comparison.COMPARE_FUNCTIONS["native_equality"],
):
self.comparisons = comparisons
self.values = values
self.key_weightings = key_weightings
self.value_cmp_func = value_cmp_func
self.matches = None
self.result = None # will be set by evaluate
super(DictMatchAll, self).__init__(
description=description, category=category
)
[docs] def evaluate(self):
self.matches, self.result = comparison.dictmatch_all_compat(
match_name=self.__class__.__name__,
comparisons=self.comparisons,
values=self.values,
key_weightings=self.key_weightings,
description=self.description,
value_cmp_func=self.value_cmp_func,
)
for match in self.matches:
match["comparison"] = flatten_dict_comparison(match["comparison"])
return self.result.passed
[docs]class FixMatchAll(DictMatchAll):
"""
Similar to DictMatchAll, however dict keys
will have fix tag info popups on web UI
"""
def __init__(
self,
values,
comparisons,
tag_weightings=None,
description=None,
category=None,
):
"""
If all input FIX messages are typed, we enable strict type checking.
Otherwise, if any entry of either side is untyped we will compare the
values as strings.
"""
typed_value = all(
[getattr(value, "typed_values", False) for value in values]
)
typed_expected = all(
[
getattr(expected.value, "typed_values", False)
for expected in comparisons
]
)
if typed_value and typed_expected:
value_cmp_func = comparison.COMPARE_FUNCTIONS["native_equality"]
else:
value_cmp_func = comparison.COMPARE_FUNCTIONS["untyped_fixtag"]
super(FixMatchAll, self).__init__(
values=values,
comparisons=comparisons,
key_weightings=tag_weightings,
description=description,
category=category,
value_cmp_func=value_cmp_func,
)
[docs]class LogfileMatch(Assertion):
"""
NOTE: this structure was designed for multiple regexes matching,
NOTE: which will be implemented in the future
"""
[docs] @dataclass
class Result:
matched: Optional[str]
pattern: str
start_pos: str
end_pos: str
def __init__(
self,
timeout: float,
results: List[tuple],
failure: Optional[tuple],
description: Optional[str] = None,
category: Optional[str] = None,
):
self.timeout = timeout
self.results = [LogfileMatch._handle_quadruple(r) for r in results]
self.failure = (
[LogfileMatch._handle_quadruple(failure)]
if failure is not None
else []
)
super().__init__(description=description, category=category)
@staticmethod
def _truncate_str(s):
if len(s) <= 50:
return s
return f"{s[:50]} ... ({len(s) - 50} chars omitted)"
@classmethod
def _handle_quadruple(cls, tup):
m, r, s, e = tup
if s is None:
s = "<BOF>"
r = re.compile(r).pattern
return cls.Result(
cls._truncate_str(m.group()) if m is not None else None,
cls._truncate_str(r),
str(s),
str(e),
)
[docs] def evaluate(self):
return not self.failure