import contextlib
import itertools
import logging
import pathlib
import re
import ssl
import time
import typing
import urllib.parse
import uuid
import warnings
import yarl
import aiohttp.web
from testsuite.utils import cached_property
from testsuite.utils import callinfo
from testsuite.utils import compat
from testsuite.utils import http
from testsuite.utils import net as net_utils
from testsuite.utils import url_util
from testsuite import utils
from . import classes
from . import exceptions
from . import magicargs
DEFAULT_TRACE_ID_HEADER = 'X-YaTraceId'
DEFAULT_SPAN_ID_HEADER = 'X-YaSpanId'
_TRACE_ID_PREFIX = 'testsuite-'
REQUEST_FROM_ANOTHER_TEST_ERROR = 'Internal error: request is from other test'
_SUPPORTED_ERRORS_HEADER = 'X-Testsuite-Supported-Errors'
_ERROR_HEADER = 'X-Testsuite-Error'
_LOGGER_HEADERS = (
('X-YaTraceId', 'trace_id'),
('X-YaSpanId', 'span_id'),
('X-YaRequestId', 'link'),
)
logger = logging.getLogger(__name__)
RouteParams = typing.Dict[str, str]
class MockserverRequest(aiohttp.web.BaseRequest):
# We need original path including scheme and hostname
def __init__(self, message, *args, **kwargs):
super().__init__(message, *args, **kwargs)
self.original_path = _path_from_message(message)
class Handler:
def __init__(self, func, *, raw_request=False, json_response=False):
self.raw_request = raw_request
self.json_response = json_response
self.orig_func = func
@cached_property
def callqueue(self):
return callinfo.acallqueue(self.orig_func)
@cached_property
def handler_args(self):
return magicargs.MagicArgsHandler(
self.orig_func,
raw_request=self.raw_request,
)
def __repr__(self):
return (
f'{Handler.__module__}.{Handler.__name__}({self.orig_func!r}, '
f'raw_request={self.handler_args.raw_request}, '
f'json_response={self.json_response})'
)
async def __call__(self, request: aiohttp.web.BaseRequest, **kwargs):
args, kwargs = await self.handler_args.build_args(request, kwargs)
response = await self.callqueue(*args, **kwargs)
if not self.json_response:
return response
if isinstance(response, aiohttp.web.Response):
return response
return http.make_response(json=response)
class Session:
handlers: typing.Dict[str, Handler]
prefix_handlers: typing.List[typing.Tuple[str, Handler]]
regex_handlers: typing.List[typing.Tuple[typing.Pattern, Handler]]
def __init__(
self,
*,
tracing_enabled=True,
trace_id=None,
http_proxy_enabled=False,
mockserver_host=None,
):
if trace_id is None:
trace_id = generate_trace_id()
self.trace_id = trace_id
self.tracing_enabled = tracing_enabled
self.handlers = {}
self.prefix_handlers = []
self.regex_handlers = []
self.http_proxy_enabled = http_proxy_enabled
self.mockserver_host = mockserver_host
self._errors = []
def get_handler(self, path: str) -> typing.Tuple[Handler, RouteParams]:
handler = self.handlers.get(path)
if handler is not None:
return handler, {}
for pattern, handler in reversed(self.regex_handlers):
match = pattern.fullmatch(path)
if match:
return handler, match.groupdict()
for prefix, handler in reversed(self.prefix_handlers):
if path.startswith(prefix):
return handler, {}
__tracebackhide__ = True
raise exceptions.HandlerNotFoundError(
self._get_handler_not_found_message(path),
)
def _get_handler_not_found_message(self, path: str) -> str:
if self.tracing_enabled:
tracing_state = 'enabled'
else:
tracing_state = 'disabled'
patterns = {regex.pattern for regex, _ in self.regex_handlers}
prefixes = {prefix for prefix, _ in self.prefix_handlers}
handlers_list = '\n'.join(
itertools.chain(
(f'- {path}' for path in self.handlers),
(f'- REGEX {pattern}' for pattern in patterns),
(f'- PREFIX {prefix}' for prefix in prefixes),
),
)
return (
f'Mockserver handler is not installed for {path!r}.\n\n'
f'Perhaps you forgot to setup mockserver handler. '
f'Tracing: {tracing_state}. Installed handlers:\n'
f'{handlers_list}'
)
async def handle_request(
self,
request: aiohttp.web.BaseRequest,
nofail_404: bool,
):
__tracebackhide__ = True
try:
handler, kwargs = self._get_handler_for_request(request)
except exceptions.HandlerNotFoundError as exc:
if not nofail_404:
self._errors.append(exc)
return _internal_error(f'Internal server error: {exc!r}')
try:
response = await handler(request, **kwargs)
if isinstance(response, aiohttp.web.Response):
return response
raise exceptions.MockServerError(
'aiohttp.web.Response instance is expected '
f'{response!r} given',
)
except http.MockedError as exc:
return _mocked_error_response(request, exc.error_code)
except Exception as exc:
self._errors.append(exc)
return _internal_error(f'Internal server error: {exc!r}')
def clear_errors(self):
self._errors = []
def raise_errors_if_any(self):
__tracebackhide__ = True
for exc in self._errors:
raise exceptions.MockServerError(
f'There were {len(self._errors)} errors while processing '
f'mockserver requests, showing the last one',
) from exc
def register_handler(
self,
path: str,
func,
*,
prefix: bool = False,
regex: bool = False,
):
if regex:
if prefix:
raise RuntimeError(
'Parameter value prefix=True is not supported if regex '
'parameter is also True.',
)
pattern = re.compile(path)
self.regex_handlers.append((pattern, func))
else:
if prefix:
self.prefix_handlers.append((path, func))
else:
self.handlers[path] = func
return func
def _get_handler_for_request(
self,
request: MockserverRequest,
) -> typing.Tuple[Handler, RouteParams]:
__tracebackhide__ = True
path = request.original_path
if self.http_proxy_enabled:
host = request.headers.get('host')
if host and host != self.mockserver_host:
return self.get_handler(f'http://{host}{path}')
return self.get_handler(path)
# pylint: disable=too-many-instance-attributes
class Server:
session = None
def __init__(
self,
mockserver_info: classes.MockserverInfo,
*,
nofail=False,
mockserver_debug=False,
tracing_enabled=True,
trace_id_header=DEFAULT_TRACE_ID_HEADER,
span_id_header=DEFAULT_SPAN_ID_HEADER,
http_proxy_enabled=False,
):
self._info = mockserver_info
self._nofail = nofail
self._mockserver_debug = mockserver_debug
self._tracing_enabled = tracing_enabled
self._trace_id_header = trace_id_header
self._span_id_header = span_id_header
self._http_proxy_enabled = http_proxy_enabled
@property
def tracing_enabled(self) -> bool:
if self.session is None:
return self._tracing_enabled
return self.session.tracing_enabled
@property
def trace_id_header(self):
return self._trace_id_header
@property
def span_id_header(self):
return self._span_id_header
@property
def http_proxy_enabled(self):
return self._http_proxy_enabled
@property
def server_info(self) -> classes.MockserverInfo:
return self._info
@contextlib.contextmanager
def new_session(self, trace_id: typing.Optional[str] = None):
session = Session(
tracing_enabled=self._tracing_enabled,
trace_id=trace_id,
http_proxy_enabled=self._http_proxy_enabled,
mockserver_host=self._info.get_host_header(),
)
self.session = session
try:
yield session
finally:
self.session = None
__tracebackhide__ = True
session.raise_errors_if_any()
async def handle_request(self, request):
started = time.perf_counter()
try:
response = await self._handle_request(request)
self._log_request(started, request, response)
return response
except BaseException as exc:
self._log_request(started, request, exc=exc)
raise
def _log_request(self, started, request, response=None, exc=None):
if exc is None and not self._mockserver_debug:
return
fields = {
'_type': 'mockserver_request',
'timestamp': utils.utcnow(),
'method': request.method,
'url': request.rel_url,
}
for header, key in _LOGGER_HEADERS:
if header in request.headers:
fields[key] = request.headers[header]
delay_ms = 1000 * (time.perf_counter() - started)
fields['delay'] = f'{delay_ms:.3f}ms'
if response is not None:
log_level = logging.DEBUG
fields['meta_code'] = response.status
fields['status'] = 'DONE'
else:
log_level = logging.ERROR
fields['status'] = 'FAIL'
fields['exc_info'] = str(exc)
logger.log(log_level, 'Mockserver request', extra={'tskv': fields})
async def _handle_request(self, request: aiohttp.web.BaseRequest):
trace_id = request.headers.get(self.trace_id_header)
nofail = self._nofail
if self.tracing_enabled and not _is_from_client_fixture(trace_id):
nofail = True
if not self.session:
error_message = 'Internal error: missing mockserver fixture'
if nofail:
return _internal_error(error_message)
raise exceptions.MockServerError(error_message)
if self.tracing_enabled and _is_other_test(
trace_id,
self.session.trace_id,
):
self._report_other_test_request(request, trace_id)
return _internal_error(REQUEST_FROM_ANOTHER_TEST_ERROR)
try:
return await self.session.handle_request(request, nofail_404=nofail)
except exceptions.HandlerNotFoundError as exc:
return _internal_error(
'Internal error: mockserver handler not found',
)
def _report_other_test_request(self, request, trace_id):
logger.warning(
'Mockserver called path %s with previous test trace_id %s',
request.path,
trace_id,
)
[docs]class MockserverFixture:
"""Mockserver handler installer fixture."""
def __init__(
self,
mockserver: Server,
session: Session,
base_prefix: str = '',
) -> None:
self._server = mockserver
self._session = session
self._base_prefix = base_prefix
self._base_prefix_re = re.escape(base_prefix)
[docs] def new(self, prefix: str) -> 'MockserverFixture':
"""Create mockserver installer with given base prefix."""
return MockserverFixture(
self._server,
self._session,
self._build_fullpath(prefix),
)
@property
def base_url(self) -> str:
"""Mockserver base url."""
return self._server.server_info.base_url
@property
def host(self) -> str:
"""Mockserver hostname."""
return self._server.server_info.host
@property
def port(self) -> int:
"""Mockserver port."""
return self._server.server_info.port
@property
def trace_id_header(self) -> str:
return self._server.trace_id_header
@property
def span_id_header(self) -> str:
return self._server.span_id_header
@property
def trace_id(self) -> str:
return self._session.trace_id
[docs] def handler(
self,
path: str,
*,
prefix: bool = False,
raw_request: bool = False,
json_response: bool = False,
regex: bool = False,
) -> classes.GenericRequestDecorator:
"""Register basic http handler for ``path``.
Returns decorator that registers handler ``path``. Original function is
wrapped with :ref:`AsyncCallQueue`.
:param path: match url by prefix if ``True`` exact match otherwise
:param raw_request: pass ``aiohttp.web.Response`` to handler instead of
``testsuite.utils.http.Request``
:param regex: set True to match path as regex pattern
:param prefix: set True to match path prefix instead of whole path
:param json_response: set True to let handler return json object
instead of full response object
.. code-block:: python
@mockserver.handler('/service/path')
def handler(request: testsuite.utils.http.Request):
return mockserver.make_response('Hello, world!')
"""
if raw_request:
warnings.warn(
'raw_request=True is deprecated, '
'use aiohttp_handler() instead',
DeprecationWarning,
)
if json_response:
warnings.warn(
'json_response=True is deprecated, '
'use json_handler() instead',
DeprecationWarning,
)
return self._handler_installer(
path,
prefix=prefix,
raw_request=raw_request,
json_response=json_response,
regex=regex,
)
[docs] def json_handler(
self,
path: str,
*,
prefix: bool = False,
raw_request: bool = False,
regex: bool = False,
) -> classes.JsonRequestDecorator:
"""Register json http handler for ``path``.
Returns decorator that registers handler ``path``. Original function is
wrapped with :ref:`AsyncCallQueue`.
:param path: match url by prefix if ``True`` exact match otherwise
:param raw_request: pass ``aiohttp.web.Response`` to handler instead of
``testsuite.utils.http.Request``
:param prefix: set True to match path prefix instead of whole path
:param regex: set True to match path as regex pattern
.. code-block:: python
@mockserver.json_handler('/service/path')
def handler(request: testsuite.utils.http.Request):
# Return JSON document
return {...}
# or call to mockserver.make_response()
return mockserver.make_response(...)
"""
if raw_request:
warnings.warn(
'raw_request=True is deprecated, '
'use aiohttp_json_handler() instead',
DeprecationWarning,
)
return self._handler_installer(
path,
prefix=prefix,
raw_request=raw_request,
json_response=True,
regex=regex,
)
def aiohttp_handler(
self,
path: str,
*,
prefix: bool = False,
regex: bool = False,
) -> classes.GenericRequestDecorator:
return self._handler_installer(
path,
prefix=prefix,
raw_request=True,
json_response=False,
regex=regex,
)
def aiohttp_json_handler(
self,
path: str,
*,
prefix: bool = False,
regex: bool = False,
) -> classes.JsonRequestDecorator:
return self._handler_installer(
path,
prefix=prefix,
raw_request=True,
json_response=True,
regex=regex,
)
[docs] def url(self, path: str) -> str:
"""Builds mockserver url for ``path``"""
return url_util.join(self.base_url, path)
[docs] def url_encoded(self, path: str) -> yarl.URL:
"""Builds mockserver url for ``path``"""
return yarl.URL(url_util.join(self.base_url, path), encoded=True)
def ignore_trace_id(self) -> typing.ContextManager[None]:
return self.tracing(False)
@contextlib.contextmanager
def tracing(self, value: bool = True):
original_value = self._session.tracing_enabled
try:
self._session.tracing_enabled = value
yield
finally:
self._session.tracing_enabled = original_value
def get_callqueue_for(self, path) -> callinfo.AsyncCallQueue:
handler, _ = self._session.get_handler(path)
return handler.callqueue
make_response = staticmethod(http.make_response)
TimeoutError = http.TimeoutError
NetworkError = http.NetworkError
def _handler_installer(
self,
path: str,
*,
prefix: bool = False,
raw_request: bool = False,
json_response: bool = False,
regex: bool = False,
) -> typing.Callable:
path = self._build_fullpath(path, regex)
def decorator(func):
handler = Handler(
func,
raw_request=raw_request,
json_response=json_response,
)
self._session.register_handler(
path,
handler,
prefix=prefix,
regex=regex,
)
return handler.callqueue
return decorator
def _build_fullpath(self, path, regex: bool = False) -> str:
if regex:
return self._base_prefix_re + path
if not self._base_prefix or self._base_prefix.endswith('/'):
if self._server.http_proxy_enabled and path.startswith('http://'):
return path
return url_util.join(self._base_prefix, path)
return self._base_prefix + path
MockserverSslFixture = MockserverFixture
def _create_ssl_context(ssl_info: classes.SslCertInfo) -> ssl.SSLContext:
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(ssl_info.cert_path, ssl_info.private_key_path)
return ssl_context
def _internal_error(message: str = 'Internal error') -> aiohttp.web.Response:
return http.make_response(message, status=500)
def _mocked_error_response(request, error_code) -> aiohttp.web.Response:
if _SUPPORTED_ERRORS_HEADER not in request.headers:
raise exceptions.MockServerError(
'Service does not support mockserver errors protocol',
)
supported_errors = request.headers[_SUPPORTED_ERRORS_HEADER].split(',')
if error_code not in supported_errors:
raise exceptions.MockServerError(
f'Service does not support mockserver error of type {error_code}',
)
return http.make_response(
response='',
status=599,
headers={_ERROR_HEADER: error_code},
)
def _create_server_obj(mockserver_info, pytestconfig) -> Server:
return Server(
mockserver_info,
nofail=pytestconfig.option.mockserver_nofail,
mockserver_debug=pytestconfig.option.mockserver_debug,
tracing_enabled=pytestconfig.getini('mockserver-tracing-enabled'),
trace_id_header=pytestconfig.getini('mockserver-trace-id-header'),
span_id_header=pytestconfig.getini('mockserver-span-id-header'),
http_proxy_enabled=pytestconfig.getini('mockserver-http-proxy-enabled'),
)
def _create_web_server(server: Server, loop) -> aiohttp.web.Server:
def request_factory(*args):
return MockserverRequest(*args, loop=loop)
return aiohttp.web.Server(
server.handle_request,
request_factory=request_factory,
loop=loop,
access_log=None,
)
@compat.asynccontextmanager
async def create_server(
host: str,
port: int,
loop,
pytestconfig,
ssl_info: typing.Optional[classes.SslCertInfo],
) -> typing.AsyncGenerator[Server, None]:
ssl_context: typing.Optional[ssl.SSLContext]
if ssl_info:
ssl_context = _create_ssl_context(ssl_info)
else:
ssl_context = None
async with net_utils.create_tcp_server(
lambda: web_server(),
host=host,
port=port,
ssl=ssl_context,
) as aio_server:
mockserver_info = _create_mockserver_info(
aio_server.sockets[0],
host,
ssl_info,
)
server = _create_server_obj(mockserver_info, pytestconfig)
web_server = _create_web_server(server, loop)
yield server
@compat.asynccontextmanager
async def create_unix_server(
socket_path: pathlib.Path,
loop,
pytestconfig,
) -> typing.AsyncGenerator[Server, None]:
async with net_utils.create_unix_server(
lambda: web_server(),
path=socket_path,
):
mockserver_info = _create_unix_mockserver_info(socket_path)
server = _create_server_obj(mockserver_info, pytestconfig)
web_server = _create_web_server(server, loop)
yield server
def _create_mockserver_info(
sock,
host: str,
ssl_info: typing.Optional[classes.SslCertInfo],
) -> classes.MockserverInfo:
sock_address = sock.getsockname()
schema = 'https' if ssl_info else 'http'
port = sock_address[1]
base_url = '%s://%s:%d/' % (schema, host, port)
return classes.MockserverInfo(
host=host,
port=port,
base_url=base_url,
ssl=ssl_info,
)
def _create_unix_mockserver_info(
socket_path: pathlib.Path,
) -> classes.MockserverInfo:
return classes.MockserverInfo(
socket_path=socket_path,
# use localhost to avoid aiohttp complains on invalid url
base_url='http://localhost',
host=None,
port=None,
ssl=None,
)
def generate_trace_id() -> str:
return _TRACE_ID_PREFIX + uuid.uuid4().hex
def _is_from_client_fixture(trace_id: str) -> bool:
return trace_id is not None and trace_id.startswith(_TRACE_ID_PREFIX)
def _is_other_test(trace_id: str, current_trace_id: str) -> bool:
return trace_id != current_trace_id and _is_from_client_fixture(trace_id)
def _path_from_message(message):
"""Returns original HTTP path without query."""
path = str(message.url)
path = path.split('?')[0]
path = urllib.parse.unquote(path)
return path