Source code for testsuite.logcapture

"""
Logcapture allows to intercepts service logs on demand with context manager.

It starts tcp server and read logs sent by server.
"""

from __future__ import annotations

import asyncio
import collections
import contextlib
import enum
import logging
import typing

from testsuite.utils import callinfo, traceback

logger = logging.getLogger(__name__)


class BaseError(Exception):
    pass


class IncorrectUsageError(BaseError):
    """Incorrect usage error."""


class ClientConnectTimeoutError(BaseError):
    pass


class TimeoutError(BaseError):
    pass


[docs] class LogLevel(enum.IntEnum): """ Represents log level as IntEnum, which supports comparison. Available levels are: TRACE, DEBUG, INFO, WARNING, ERROR, CRITICAL, NONE """ TRACE = 0 DEBUG = 1 INFO = 2 WARNING = 3 ERROR = 4 CRITICAL = 5 NONE = 6
[docs] @classmethod def from_string(cls, level: str) -> 'LogLevel': """Parse log level from the string.""" return cls[level.upper()]
class CapturedLogs: def __init__(self, *, log_level: LogLevel) -> None: self._log_level = log_level self._logs: list[dict] = [] self._subscribers = [] self._closed = False @property def log_level(self): return self._log_level def is_closed(self): return self._closed def close(self): self._closed = True async def publish(self, row: dict) -> None: self._logs.append(row) for query, callback in self._subscribers: if _match_entry(row, query): await callback(**row) def subscribe(self, query: dict, decorated): self._subscribers.append((query, decorated)) def __iter__(self) -> typing.Iterator[dict]: return iter(self._logs)
[docs] class Capture: def __init__(self, logs: CapturedLogs): self._logs = logs
[docs] def select(self, **query) -> list[dict]: """Select logs matching query. Could only be used after capture contextmanager block. .. code-block:: python async with logcapture_server.capture() as capture: ... records = capture.select(text='Message to capture') """ if not self._logs.is_closed(): raise IncorrectUsageError( 'select() is only supported for closed captures\n' 'Please move select() after context manager body', ) level = query.get('level') if level: log_level = LogLevel[level] if log_level.value < self._logs.log_level.value: raise IncorrectUsageError( f'Requested log level={log_level.name} is lower than service log level {self._logs.log_level.name}', ) result = [] for row in self._logs: if _match_entry(row, query): result.append(row) return result
[docs] def subscribe(self, **query): """Subscribe to records matching `query`. Returns decorator function. `subscribe()` may only be used within `capture()` block. Callqueue is returned. .. code-block:: python async with logcapture_server.capture() as capture: @capture.subscribe(text='Message to capture') def log_event(link, **other): ... ... assert log_event.wait_call() """ if self._logs.is_closed(): raise IncorrectUsageError( 'subscribe() is not supported for closed captures\nPlease move subscribe() into context manager body', ) def decorator(func): decorated = callinfo.acallqueue(func) self._logs.subscribe(query, decorated) return decorated return decorator
[docs] class CaptureServer: _capture: CapturedLogs | None
[docs] def __init__( self, *, log_level: LogLevel, parse_line: collections.abc.Callable[[bytes], dict], ): """Capture server.""" self._log_level = log_level self._client_cond = asyncio.Condition() self._capture = None self._tasks = [] self._parse_line = parse_line self._started = False self._socknames = []
@property def default_log_level(self) -> LogLevel: """Returns default log level specified on object creation.""" return self._log_level
[docs] def getsocknames(self) -> list[tuple]: """Return list of server socket names.""" return self._socknames
[docs] @contextlib.asynccontextmanager async def start( self, *args, **kwargs ) -> typing.AsyncIterator['CaptureServer']: """Starts capture logs asyncio server. Arguments are directly passed to `asyncio.start_server`. Server could be started only once. Capture server is returned. Server is closed when contextmanager is finished. """ if self._started: raise IncorrectUsageError('Service was already started') server = await asyncio.start_server( self._handle_client, *args, **kwargs ) self._started = True self._socknames = [sock.getsockname() for sock in server.sockets] try: yield self finally: server.close() await server.wait_closed()
[docs] async def wait_for_client(self, timeout: float = 10.0): """Waits for logserver client to connect.""" async def waiter(): async with self._client_cond: await self._client_cond.wait_for(lambda: self._tasks) logger.debug('Waiting for logcapture client to connect...') try: await asyncio.wait_for(waiter(), timeout=timeout) except TimeoutError: raise ClientConnectTimeoutError( 'Timedout while waiting for logcapture client to connect', )
async def _handle_client(self, reader, writer): logger.debug('logcapture client connected') async def log_reader(capture: CapturedLogs): with contextlib.closing(writer): try: async for line in reader: row = self._parse_line(line) await capture.publish(row) except Exception: async for line in reader: # wait for data transfer to finish pass raise await writer.wait_closed() if not self._capture: writer.close() await writer.wait_closed() else: self._tasks.append(asyncio.create_task(log_reader(self._capture))) async with self._client_cond: self._client_cond.notify_all()
[docs] @contextlib.asynccontextmanager async def capture( self, *, log_level: LogLevel | None = None, timeout: float = 10.0, ) -> typing.AsyncIterator[Capture]: """ Starts logs capture. Returns `Capture` object. """ if self._capture: yield self._capture return self._capture = CapturedLogs(log_level=log_level or self._log_level) try: yield Capture(self._capture) finally: self._capture.close() self._capture = None if self._tasks: _, pending = await asyncio.wait(self._tasks, timeout=timeout) self._tasks = [] if pending: raise TimeoutError( 'Timeout while waiting for capture task to finish', )
def _match_entry(row: dict, query: dict) -> bool: for key, value in query.items(): if row.get(key) != value: return False return True __tracebackhide__ = traceback.hide(BaseError, FileNotFoundError)