"""
Custom marshmallow fields.
"""
import abc
import pprint
from datetime import timezone, datetime
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from lxml import etree
from marshmallow import fields
from marshmallow import class_registry
from marshmallow.schema import Schema
from marshmallow.utils import missing as missing_
from testplan.common.utils import comparison
# We explicitly enumerate types that are known to be safe to serialize by
# pickle. All other types will be converted to strings before pickling.
# types.NoneType is gone in python3 so we inspect the type of None directly.
COMPATIBLE_TYPES = (bool, float, type(None), str, bytes, int)
# pylint: disable=unused-argument
[docs]
class Serializable(metaclass=abc.ABCMeta):
[docs]
@abc.abstractmethod
def serialize(self) -> Any:
pass
[docs]
class LogLink(Serializable):
"""
Save an HTML link in WebUI
"""
def __init__(
self,
link: str,
title: Optional[str] = None,
new_window: bool = True,
inner: bool = False,
) -> None:
"""
:param link: The URL of the page the link goes to.
:type link: ``str``
:param title: The name of the link.
:type title: ``str``
:param new_window: Open the link on new window or not. Default is True.
:type new_window: ``bool``
:param inner: Link to internal page (report uid as the prefix). Default is False.
:type new_window: ``bool``
"""
self.link = link
self.title = title or link
self.new_window = new_window
self.inner = inner
[docs]
def serialize(self) -> Dict[str, Any]:
return {
"link": self.link,
"title": self.title,
"new_window": self.new_window,
"inner": self.inner,
"type": "link",
}
def _repr_obj(obj: object) -> str:
# copypasta from unittest code
try:
return repr(obj)
except Exception:
return object.__repr__(obj)
[docs]
class Unicode(fields.Field):
"""
Field that tries to convert value into a unicode
object with the given codecs. Marshmallow internally decodes to
utf-8 encoding, however it fails on Python 2 for str values
like ``@t\xe9\xa7t\xfel\xe5\xf1``.
So we have this field with explicit codecs instead.
"""
codecs = ["utf-8", "latin-1"] # Ideally we will let users override this
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> Optional[str]:
if isinstance(value, str) or value is None:
return value
elif isinstance(value, bytes):
for codec in self.codecs:
try:
return str(value, codec)
except UnicodeDecodeError:
pass
raise ValueError(
"Could not decode {value!r} to unicode"
" with the given codecs: {codecs}".format(
value=value, codecs=self.codecs
)
)
else:
return str(value)
[docs]
class NativeOrPretty(fields.Field):
"""
Uses serialization compatible native values
or pretty formatted str representation.
"""
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> Any:
if isinstance(value, Serializable):
return value.serialize()
else:
return native_or_pformat(value)
[docs]
class NativeOrPrettyDict(fields.Field):
"""
Dictionary serialization with native or pretty formatted values.
Keys should be JSON serializable (str type),
should be used for flat dicts only.
"""
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> Dict[str, Any]:
if not isinstance(value, dict):
raise TypeError(
"`value` ({value}) should be"
" `dict` type, it was: {type}".format(
value=value, type=type(value)
)
)
for k in value:
if not isinstance(k, str):
raise TypeError(
"`key` ({key}) should be of"
" `str` type, it was: {type}".format(key=k, type=type(k))
)
return native_or_pformat_dict(value)
# TODO: Move to entries
[docs]
class RowComparisonField(fields.Field):
"""Serialization logic for RowComparison"""
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> Tuple[Any, List[Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
idx, row, diff, errors, extra = value
return (
idx,
native_or_pformat_list(row),
native_or_pformat_dict(diff),
native_or_pformat_dict(errors),
native_or_pformat_dict(extra),
)
[docs]
class SliceComparisonField(fields.Field):
"""Serialization logic for SliceComparison"""
# TODO: strip actual & expected to save more space, as these value could
# TODO: be retrieved from context
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> Tuple[str, Any, Any, Any, Any]:
def str_or_iterable(val: Any) -> Any:
return val if isinstance(val, str) else native_or_pformat_list(val)
slice_obj, comp_indices, mismatch_indices, actual, expected = value
return (
repr(slice_obj),
comp_indices,
mismatch_indices,
str_or_iterable(actual),
str_or_iterable(expected),
)
[docs]
class ColumnContainComparisonField(fields.Field):
"""Serialization logic for ColumnContainComparison"""
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> Tuple[Any, Any, Any]:
return (value.idx, native_or_pformat(value.value), value.passed)
[docs]
class XMLElementField(fields.Field):
"""Custom field for `lxml.etree.Element serialization`."""
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> str:
return cast(
str, etree.tostring(value, pretty_print=True).decode("utf-8")
)
[docs]
class ClassName(fields.Field):
"""Return the class name of the `obj`."""
_CHECK_ATTRIBUTE = False
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> str:
return cast(str, obj.__class__.__name__)
[docs]
class DictMatch(fields.Field):
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> Dict[str, Any]:
keys = ("value", "ignore", "only")
return {key: getattr(value, key) for key in keys}
[docs]
class GenericNested(fields.Field):
"""
Marshmallow does not support multiple schemas
for a single `Nested` field.
There is a project (marshmallow-oneofschema)
that has similar functionality but it doesn't support
self-referencing schemas, which is
needed for serializing tree structures.
This field should be used along with `ClassNameField`
to return the type (class name) of the objects,
so it can choose the correct schema during deserialization.
"""
def __init__(
self,
schema_context: Dict[Any, Any],
type_field: str = "type",
default: Any = missing_,
**kwargs: Any,
) -> None:
self.schema_context = schema_context
self.type_field = type_field
self.many = kwargs.get("many", False)
super(GenericNested, self).__init__(default=default, metadata=kwargs)
def _get_schema_obj(self, schema_value: Any) -> Schema:
parent_ctx = getattr(self.parent, "context", {})
if callable(schema_value) and not isinstance(schema_value, type):
schema_value = schema_value()
if isinstance(schema_value, Schema):
schema_value.context.update(parent_ctx)
return schema_value
elif isinstance(schema_value, type) and issubclass(
schema_value, Schema
):
return schema_value(many=self.many, context=parent_ctx)
elif isinstance(schema_value, str):
if schema_value == "self":
if self.parent is None:
raise ValueError(
"Cannot use 'self' schema without a parent"
)
return self.parent.__class__( # type: ignore[return-value]
many=self.many, context=parent_ctx
)
else:
schema_class = class_registry.get_class(schema_value)
return schema_class(many=self.many, context=parent_ctx)
raise ValueError(
"Invalid value for schema: {}, {}".format(
schema_value, type(schema_value)
)
)
@property
def schemas(self) -> Dict[str, Schema]:
"""Return schema mapping in `<CLASS_NAME>: <SCHEMA_OBJECT>` format."""
result: Dict[str, Schema] = {}
for object_type, schema_value in self.schema_context.items():
if isinstance(object_type, str):
key = object_type
elif isinstance(object_type, type):
key = object_type.__name__
else:
raise ValueError(
"Invalid value for object type ({}), strings"
" and class objects are allowed.".format(object_type)
)
result[key] = self._get_schema_obj(schema_value)
return result
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> Any:
if value is None:
return None
schemas = self.schemas
if isinstance(value, (list, tuple)):
return [self._serialize(nobj, attr, obj) for nobj in value]
class_name = value.__class__.__name__
if class_name not in schemas:
raise KeyError(
"No schema declaration found in"
" `schema_context` for : {}".format(class_name)
)
schema_obj = schemas[class_name]
return schema_obj.dump(value, many=False)
[docs]
class UTCDateTime(fields.DateTime):
"""
A formatted datetime string that represents UTC time. Naive datetime
will be thought as in UTC timezone.
Example: 2014-12-22T03:12:58.019077+00:00 (always ends with '+00:00')
"""
def _serialize(
self, value: Optional[datetime], attr: Any, obj: Any, **kwargs: Any
) -> Optional[str]:
if value is None:
return None
if value.tzname() != "UTC":
# note: below doesn't work when value is a rpyc netref
# if value.tzinfo != timezone.utc:
raise RuntimeError("Field is expected to have utc timezone info")
return value.isoformat()
def _deserialize( # type: ignore[override]
self, value: Any, attr: Any, data: Any, **kwargs: Any
) -> Optional[datetime]:
if value is None:
return None
dt = datetime.fromisoformat(value)
return (
dt.replace(tzinfo=timezone.utc)
if dt.tzinfo is None
else dt.astimezone(tz=timezone.utc)
)
[docs]
class LocalDateTime(fields.DateTime):
"""
A formatted datetime string that represents machine time. Naive datetime
will be thought as in local timezone.
Example: 2014-12-22T11:12:58.019077+08:00
Note: Since Python 3.6 `datetime.datetime.astimezone` method can be called
on naive instances that are presumed to represent system local time.
"""
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> Optional[str]:
return None if value is None else value.astimezone().isoformat()
def _deserialize( # type: ignore[override]
self, value: Any, attr: Any, data: Any, **kwargs: Any
) -> Optional[datetime]:
return (
None
if value is None
else datetime.fromisoformat(value).astimezone()
)
[docs]
class ExceptionField(fields.Field):
"""
Serialize exceptions type and message.
"""
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> Tuple[str, str]:
return (str(type(value)), str(value))