import collections.abc
import itertools
import operator
import re
import typing
import dateutil.parser
class BaseError(Exception):
pass
class NoValueCapturedError(BaseError):
pass
_Sentinel = object()
class Any:
"""Matches any value."""
def __repr__(self):
return '<Any>'
def __eq__(self, other):
return True
class AnyString:
"""Matches any string."""
__testsuite_types__ = (str,)
def __repr__(self):
return '<AnyString>'
def __eq__(self, other):
if isinstance(other, str):
return True
return any(issubclass(type, str) for type in _resolve_types(other))
[docs]
class RegexString:
"""Match string with regular expression.
.. code-block:: python
assert response.json() == {
'order_id': matching.RegexString('^[0-9a-f]*$'),
...
}
"""
__testsuite_types__ = (str,)
def __init__(self, pattern):
self._pattern = re.compile(pattern)
def __repr__(self):
return f'<{self.__class__.__name__} pattern={self._pattern!r}>'
def __eq__(self, other):
if isinstance(other, str):
return self._pattern.match(other) is not None
if isinstance(other, RegexString):
return other._pattern == self._pattern
return False
class UuidString(RegexString):
"""Matches lower-case hexadecimal uuid string."""
def __init__(self):
super().__init__('^[0-9a-f]{32}$')
class ObjectIdString(RegexString):
"""Matches lower-case hexadecimal objectid string."""
def __init__(self):
super().__init__('^[0-9a-f]{24}$')
class DatetimeString:
"""Matches datetime string in any format."""
__testsuite_types__ = (str,)
def __repr__(self):
return '<DatetimeString>'
def __eq__(self, other):
if isinstance(other, str):
try:
dateutil.parser.parse(other)
return True
except ValueError:
return False
return isinstance(other, DatetimeString)
[docs]
class IsInstance:
"""Match value by its type.
Use this class when you only need to check value type.
.. code-block:: python
assert response.json() == {
# order_id must be a string
'order_id': matching.IsInstance(str),
# int or float is acceptable here
'weight': matching.IsInstance([int, float]),
...
}
"""
def __init__(self, types):
self._types = types
def __repr__(self):
if isinstance(self._types, (list, tuple)):
type_names = [t.__name__ for t in self._types]
else:
type_names = [self._types.__name__]
return f'<IsInstance {", ".join(type_names)}>'
def __eq__(self, other):
if isinstance(other, self._types):
return True
if isinstance(other, IsInstance):
return self._types == other._types
return False
[docs]
class And:
"""Logical AND on conditions.
.. code-block:: python
# match integer is in range [10, 100)
assert num == matching.And([matching.Ge(10), matching.Lt(100)])
"""
def __init__(self, *conditions):
self._conditions = conditions
def __repr__(self):
conditions = [repr(cond) for cond in self._conditions]
return f'<And {", ".join(conditions)}>'
def __eq__(self, other):
if isinstance(other, And):
return self._conditions == other._conditions
for condition in self._conditions:
if condition != other:
return False
return True
def __testsuite_visit__(self, visit):
return And(*[visit(condition) for condition in self._conditions])
[docs]
class Or:
"""Logical OR on conditions.
.. code-block:: python
# match integers abs(num) >= 10
assert num == matching.Or([matching.Ge(10), matching.Le(-10)])
"""
def __init__(self, *conditions):
self._conditions = conditions
def __repr__(self):
conditions = [repr(cond) for cond in self._conditions]
return f'<Or {", ".join(conditions)}>'
def __eq__(self, other):
if isinstance(other, Or):
return self._conditions == other._conditions
for condition in self._conditions:
if condition == other:
return True
return False
def __testsuite_visit__(self, visit):
return Or(*[visit(condition) for condition in self._conditions])
[docs]
class Not:
"""Condition inversion.
Example:
.. code-block:: python
# check value is not 1
assert value == matching.Not(1)
"""
def __init__(self, condition):
self._condition = condition
def __repr__(self):
return f'<Not {self._condition!r}>'
def __eq__(self, other):
if isinstance(other, Not):
return self._condition == other._condition
return self._condition != other
def __testsuite_visit__(self, visit):
return Not(visit(self._condition))
class Comparator:
op: typing.Callable[[typing.Any, typing.Any], bool] = operator.eq
def __init__(self, value):
self._value = value
def __repr__(self):
return f'<{self.op.__name__} {self._value}>'
def __eq__(self, other):
if isinstance(other, Comparator):
return self.op == other.op and self._value == other._value
try:
return self.op(other, self._value)
except TypeError:
return False
def __testsuite_visit__(self, visit):
return self.__class__(visit(self._value))
[docs]
class Gt(Comparator):
"""Value is greater than.
Example:
.. code-block:: python
# Value must be > 10
assert value == matching.Gt(10)
"""
op = operator.gt
[docs]
class Ge(Comparator):
"""Value is greater or equal.
Example:
.. code-block:: python
# Value must be >= 10
assert value == matching.Ge(10)
"""
op = operator.ge
[docs]
class Lt(Comparator):
"""Value is less than.
Example:
.. code-block:: python
# Value must be < 10
assert value == matching.Lt(10)
"""
op = operator.lt
[docs]
class Le(Comparator):
"""Value is less or equal.
Example:
.. code-block:: python
# Value must be <= 10
assert value == matching.Le(10)
"""
op = operator.le
[docs]
class PartialDict(collections.abc.Mapping):
"""Partial dictionary matching.
It might be useful to only check specific keys of a dictionary.
:py:class:`PartialDict` serves to solve this task.
:py:class:`PartialDict` is wrapper around regular `dict()` when instantiated
all arguments are passed as is to internal dict object.
Example:
.. code-block:: python
assert {'foo': 1, 'bar': 2} == matching.PartialDict({
# Only check for foo >= 1 ignoring other keys
'foo': matching.Ge(1),
})
"""
__testsuite_types__ = (dict,)
def __init__(self, *args, **kwargs):
self._dict = dict(*args, **kwargs)
def __contains__(self, item):
return True
def __getitem__(self, item):
return self._dict.get(item, any_value)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return f'<PartialDict {self._dict!r}>'
def __eq__(self, other):
if not isinstance(other, collections.abc.Mapping):
return False
for key in self:
if other.get(key) != self.get(key):
return False
return True
def __testsuite_visit__(self, visit):
return PartialDict(visit(self._dict))
def __testsuite_resolve_value__(self, other, report_error):
if not isinstance(other, collections.abc.Mapping):
return self
return {**other, **self._dict}
class UnorderedList:
def __init__(self, sequence, key):
self._value = sorted(sequence, key=key)
self._key = key
def __repr__(self):
return f'<UnorderedList: {self._value}>'
def __eq__(self, other):
if isinstance(other, list):
return sorted(other, key=self._key) == self._value
if isinstance(other, UnorderedList):
return self._value == other._value and self._key == other._key
return False
def __testsuite_visit__(self, visit):
return UnorderedList(visit(self._value), self._key)
def __testsuite_resolve_value__(self, other, report_error):
if not isinstance(other, list):
return self
sort_key = self._key or (lambda x: x)
other_sorted = sorted(
enumerate(other), key=lambda x: (sort_key(x[1]), x[0])
)
idx_seq = itertools.count(len(other_sorted))
it_self = iter(self._value)
def doit():
item_self = next(it_self, _Sentinel)
for idx_other, item_other in other_sorted:
if item_self is _Sentinel:
return
while sort_key(item_other) > sort_key(item_self):
yield next(idx_seq), item_self
item_self = next(it_self, _Sentinel)
if item_self is _Sentinel:
return
if sort_key(item_other) < sort_key(item_self):
continue
if item_other == item_self:
yield idx_other, item_other
else:
yield next(idx_seq), item_self
item_self = next(it_self, _Sentinel)
if item_self is not _Sentinel:
yield next(idx_seq), item_self
yield from zip(idx_seq, it_self)
return [item for _, item in sorted(doit(), key=operator.itemgetter(0))]
[docs]
class AnyList:
"""Value is a list.
Example:
.. code-block:: python
assert ['foo', 'bar'] == matching.any_dict
"""
def __repr__(self):
return '<AnyList>'
def __eq__(self, other):
return isinstance(other, (list, AnyList))
def __testsuite_resolve_value__(self, other, report_error):
if not isinstance(other, list):
return self
return other
[docs]
class ListOf:
"""Value is a list of values.
Example:
.. code-block:: python
assert ['foo', 'bar'] == matching.ListOf(matching.any_string)
assert [1, 2] != matching.ListOf(matching.any_string)
"""
def __init__(self, value=Any()):
self._value = value
def __repr__(self):
return f'<ListOf value={self._value}>'
def __eq__(self, other):
if isinstance(other, list):
for value in other:
if self._value != value:
return False
return True
if isinstance(other, ListOf):
return self._value == other._value
return False
def __testsuite_visit__(self, visit):
return ListOf(visit(self._value))
def __testsuite_resolve_value__(self, other, report_error):
if not isinstance(other, list):
return self
return [self._value] * len(other)
[docs]
class AnyDict:
"""Value is a dictionary.
Example:
.. code-block:: python
assert {'foo': 'bar'} == matching.any_dict
"""
def __repr__(self):
return '<AnyDict>'
def __eq__(self, other):
return isinstance(other, (dict, AnyDict))
def __testsuite_resolve_value__(self, other, report_error):
if not isinstance(other, list):
return self
return other
[docs]
class DictOf:
"""Value is a dictionary of (key, value) pairs.
Example:
.. code-block:: python
pred = matching.DictOf(key=matching.any_string, value=matching.any_string)
assert pred == {'foo': 'bar'}
assert pred != {'foo': 1}
assert pred != {1: 'bar'}
"""
def __init__(self, key=Any(), value=Any()):
self._key = key
self._value = value
def __repr__(self):
return f'<DictOf key={self._key} value={self._value}>'
def __eq__(self, other):
if isinstance(other, dict):
for key, value in other.items():
if self._key != key:
return False
if self._value != value:
return False
return True
if isinstance(other, DictOf):
return self._key == other._key and self._value == other._value
return False
def __testsuite_visit__(self, visit):
return DictOf(visit(self._key), visit(self._value))
def __testsuite_resolve_value__(self, other, report_error):
if not isinstance(other, collections.abc.Mapping):
return self
result = {}
for key, value in other.items():
if key != self._key:
report_error(
f'dict key must match {self._key} expression',
path=f'[{key!r}]',
)
result[key] = self._value
return result
[docs]
class Capture:
"""Capture matched value(s).
Example:
.. code-block:: python
# You can define matching rule out of pattern
capture_foo = matching.Capture(matching.any_string)
pattern = {'foo': capture_foo}
assert pattern == {'foo': 'bar'}
assert capture_foo.value == 'bar'
assert capture_foo.values_list == ['bar']
# Or do it later
capture_foo = matching.Capture()
pattern = {'foo': capture_foo(matching.any_string)}
assert pattern == {'foo': 'bar'}
assert capture_foo.value == 'bar'
assert capture_foo.values_list == ['bar']
"""
def __init__(self, value=Any(), _link_captured=None):
self._value = value
if _link_captured is None:
self._captured = []
else:
self._captured = _link_captured
@property
def value(self):
if self._captured:
return self._captured[0]
raise NoValueCapturedError(f'No value captured for value {self._value}')
@property
def values_list(self):
return self._captured
def __eq__(self, other):
if self._value != other:
return False
self._captured.append(other)
return True
def __call__(self, value):
return Capture(value, _link_captured=self._captured)
def __testsuite_visit__(self, visit):
return Capture(visit(self._value), self._captured)
def __testsuite_resolve_value__(self, other, report_error):
return _resolve_value(self._value, other, report_error)
[docs]
def unordered_list(sequence, *, key=None):
"""Unordered list comparison.
You may want to compare lists without respect to order. For instance,
when your service is serializing std::unordered_map to array.
`unordered_list` can help you with that. It sorts both array before
comparison.
:param sequence: Initial sequence
:param key: Sorting key function
Example:
.. code-block:: python
assert [3, 2, 1] == matching.unordered_list([1, 2, 3])
"""
return UnorderedList(sequence, key)
class _ObjectTransform:
def visit(self, value):
if isinstance(value, dict):
return self.visit_dict(value)
if isinstance(value, list):
return self.visit_list(value)
visit = getattr(value, '__testsuite_visit__', None)
if visit:
return visit(self.visit)
return value
def visit_dict(self, value):
return {key: self.visit(value) for key, value in value.items()}
def visit_list(self, value):
return [self.visit(item) for item in value]
[docs]
def recursive_partial_dict(*args, **kwargs):
"""Creates recursive partial dict.
Traverse input dict and create `PartialDict` for nested dicts.
Supports visiting `testsuite.matching` predicates. Skips inner
:py:class:`PartialDict` nodes in order to allow user to customize
behavior.
l Example:
.. code-block:: python
assert {
'foo': {'bar': 123, 'extra'}, 'extra'
} == matching.recursive_partial_dict({
'foo: {'bar': 123}
})
"""
class Transform(_ObjectTransform):
def visit(self, value):
if isinstance(value, PartialDict):
return value
return super().visit(value)
def visit_dict(self, value):
value = super().visit_dict(value)
return PartialDict(value)
root = dict(*args, **kwargs)
return Transform().visit(root)
def _resolve_types(value):
return getattr(value, '__testsuite_types__', ())
def _resolve_value(obj, other, report_error):
if hasattr(obj, '__testsuite_resolve_value__'):
return obj.__testsuite_resolve_value__(other, report_error)
return obj
any_value = Any()
any_float = IsInstance(float)
any_integer = IsInstance(int)
any_numeric = IsInstance((int, float))
positive_float = And(any_float, Gt(0))
positive_integer = And(any_integer, Gt(0))
positive_numeric = And(any_numeric, Gt(0))
negative_float = And(any_float, Lt(0))
negative_integer = And(any_integer, Lt(0))
negative_numeric = And(any_numeric, Lt(0))
non_negative_float = And(any_float, Ge(0))
non_negative_integer = And(any_integer, Ge(0))
non_negative_numeric = And(any_numeric, Ge(0))
any_string = AnyString()
datetime_string = DatetimeString()
objectid_string = ObjectIdString()
uuid_string = UuidString()
any_dict = AnyDict()
any_list = AnyList()