Source code for testsuite.databases.mongo.pytest_plugin

import contextlib
import dataclasses
import multiprocessing.pool
import pathlib
import pprint
import random
import re
import typing

import pymongo
import pymongo.collection
import pymongo.errors
import pytest

from testsuite import annotations
from testsuite import utils

from . import connection
from . import ensure_db_indexes
from . import mongo_schema
from . import service

# pylint: disable=too-many-statements

DB_FILE_RE_PATTERN = re.compile(r'^db_(?P<mongo_db_alias>\w+)\.json$')


class BaseError(Exception):
    """Base testsuite error"""


class UnknownCollectionError(BaseError):
    pass


class CollectionWrapper:
    def __init__(self, collections):
        # TODO: deprecate collection as attribute
        for alias, collection in collections.items():
            setattr(self, alias, collection)
        self._collections = collections.copy()
        self._aliases = tuple(collections.keys())

    def __getitem__(self, alias: str) -> pymongo.collection.Collection:
        return self._collections[alias]

    def __contains__(self, alias: str) -> bool:
        return alias in self._collections

    def get_aliases(self) -> typing.Tuple[str]:
        return self._aliases


class CollectionWrapperFactory:
    def __init__(self, connection_info: connection.ConnectionInfo):
        self._connection_info = connection_info

    @property
    def connection_string(self) -> str:
        return self._connection_info.get_uri()

    @utils.cached_property
    def client(self) -> pymongo.MongoClient:
        return pymongo.MongoClient(self.connection_string)

    def create_collection_wrapper(
        self,
        collection_names,
        mongodb_settings,
    ) -> CollectionWrapper:
        collections = {}
        for name in collection_names:
            if name not in mongodb_settings:
                raise UnknownCollectionError(
                    f'Missing collection {name} in mongodb_settings fixture',
                )
            # pylint: disable=unsubscriptable-object
            settings = mongodb_settings[name]['settings']
            database = self.client[settings['database']]
            collections[name] = database[settings['collection']]
        return CollectionWrapper(collections)


def pytest_configure(config):
    config.addinivalue_line(
        'markers',
        'noshuffledb: disable data set shuffle for marked test',
    )
    config.addinivalue_line(
        'markers',
        'filldb: specify mongo static file suffix',
    )
    config.addinivalue_line(
        'markers',
        'mongodb_collections: override mongo collections list',
    )


def pytest_addoption(parser):
    """
    :param parser: pytest's argument parser
    """
    group = parser.getgroup('mongo')
    group.addoption('--mongo', help='Mongo connection string.')
    group.addoption(
        '--no-indexes',
        action='store_true',
        help='Disable index creation.',
    )
    group.addoption(
        '--no-shuffle-db',
        action='store_true',
        help='Disable fixture data shuffle.',
    )
    group.addoption(
        '--no-sharding',
        action='store_true',
        help='Disable collections sharding.',
    )
    group.addoption(
        '--no-mongo',
        help='Disable mongo startup',
        action='store_true',
    )
    parser.addini(
        'mongo-retry-writes',
        type='bool',
        default=False,
        help=(
            'Controls value of \'retryWrites\' parameter of mongo connection '
            'string.'
        ),
    )


def pytest_service_register(register_service):
    register_service('mongo', service.create_mongo_service)


[docs]@pytest.fixture def mongodb( mongodb_init, _mongodb_local: CollectionWrapper, ) -> CollectionWrapper: return _mongodb_local
@pytest.fixture def mongo_connections( mongodb_settings, mongo_connection_info, mongo_extra_connections, _mongo_local_collections, ) -> typing.Dict[str, str]: mongo_connection_uri = mongo_connection_info.get_uri() return { **{ mongodb_settings[name]['settings'][ 'connection' ]: mongo_connection_uri for name in _mongo_local_collections }, **{ extra_conn: mongo_connection_uri for extra_conn in mongo_extra_connections }, } @pytest.fixture def mongo_extra_connections() -> typing.Tuple[str, ...]: """ Override this if you need to access mongo connections besides those defined in mongo_connections fixture """ return ()
[docs]@pytest.fixture(scope='session') def mongo_connection_info( pytestconfig, _mongo_service_settings, ) -> connection.ConnectionInfo: # External mongo instance if pytestconfig.option.mongo: return connection.parse_connection_uri(pytestconfig.option.mongo) connection_info = _mongo_service_settings.get_connection_info() retry_writes = pytestconfig.getini('mongo-retry-writes') return dataclasses.replace(connection_info, retry_writes=retry_writes)
@pytest.fixture def mongodb_settings( mongo_schema_directory, mongo_schema_extra_directories, _mongo_schema_cache, ) -> mongo_schema.MongoSchemas: return mongo_schema.MongoSchemas( _mongo_schema_cache, (mongo_schema_directory, *mongo_schema_extra_directories), ) @pytest.fixture def mongodb_collections(mongodb_settings) -> typing.Tuple[str, ...]: """ Override this to enable access to named collections within test module Returns all available collections by default. """ return tuple(mongodb_settings.keys()) @pytest.fixture(scope='session') def mongo_schema_extra_directories() -> typing.Tuple[str, ...]: """ Override to use collection schemas besides those defined by ``mongo_schema_directory`` fixture """ return () @pytest.fixture(scope='session') def _mongo_indexes_ensured() -> typing.Set[str]: return set() @pytest.fixture def _mongo_service( pytestconfig, ensure_service_started, _mongodb_local, _mongo_service_settings, ) -> None: aliases = _mongodb_local.get_aliases() if ( aliases and not pytestconfig.option.mongo and not pytestconfig.option.no_mongo ): ensure_service_started('mongo', settings=_mongo_service_settings) @pytest.fixture def _mongo_create_indexes( _mongodb_local, mongodb_settings, pytestconfig, _mongo_indexes_ensured, _mongo_service, ) -> None: aliases = _mongodb_local.get_aliases() if not pytestconfig.option.no_indexes: _ensure_indexes = {} for alias in aliases: if ( alias not in _mongo_indexes_ensured and alias in mongodb_settings ): _ensure_indexes[alias] = mongodb_settings[alias] if _ensure_indexes: sharding_enabled = not pytestconfig.option.no_sharding ensure_db_indexes.ensure_db_indexes( _mongodb_local, _ensure_indexes, sharding_enabled=sharding_enabled, ) _mongo_indexes_ensured.update(_ensure_indexes) @pytest.fixture(scope='session') def _mongo_thread_pool() -> ( annotations.YieldFixture[multiprocessing.pool.ThreadPool,] ): pool = multiprocessing.pool.ThreadPool(processes=20) with contextlib.closing(pool): yield pool @pytest.fixture def _mongo_query_loader(load_json): def loader(filename, missing_ok=False): data = load_json(filename, missing_ok=missing_ok) if data is None: return [] return data return loader @pytest.fixture def mongodb_init( request, verify_file_paths, static_dir: pathlib.Path, _mongodb_local, _mongo_thread_pool, _mongo_create_indexes, _mongo_query_loader, ) -> None: """Populate mongodb with fixture data.""" if request.node.get_closest_marker('nofilldb'): return # Disable shuffle to make some buggy test work shuffle_enabled = ( not request.config.option.no_shuffle_db and not request.node.get_closest_marker('noshuffledb') ) aliases = {key: key for key in _mongodb_local.get_aliases()} requested = set() for marker in request.node.iter_markers('filldb'): for dbname, alias in marker.kwargs.items(): if dbname not in aliases: raise UnknownCollectionError( f'Unknown collection {dbname} requested' ) if alias != 'default': aliases[dbname] = '%s_%s' % (dbname, alias) requested.add(dbname) def _verify_db_alias(file_path: pathlib.Path) -> bool: if not _is_relevant_file(request, static_dir, file_path): return True match = DB_FILE_RE_PATTERN.search(file_path.name) if match: db_alias = match.group('mongo_db_alias') if db_alias not in aliases and not any( db_alias.startswith(alias + '_') for alias in aliases ): return False return True verify_file_paths( _verify_db_alias, check_name='mongo_db_aliases', text_at_fail='file has not valid mongo collection name alias ' '(probably should add to service.yaml)', ) def load_collection(params): dbname, alias = params try: col = getattr(_mongodb_local, dbname) except AttributeError: return docs = _mongo_query_loader( f'db_{alias}.json', missing_ok=dbname not in requested ) if not docs and col.find_one({}, []) is None: return if shuffle_enabled: # Make sure there is no tests that depend on order of # documents in fixture file. random.shuffle(docs) try: col.bulk_write( [ pymongo.DeleteMany({}), *(pymongo.InsertOne(doc) for doc in docs), ], ordered=True, ) except pymongo.errors.BulkWriteError as bwe: pprint.pprint(bwe.details) raise pool_args = [] for dbname, alias in aliases.items(): pool_args.append((dbname, alias)) _mongo_thread_pool.map(load_collection, pool_args) @pytest.fixture def _mongodb_local( mongodb_settings, _mongo_local_collections, _mongo_collection_wrapper_factory: CollectionWrapperFactory, ) -> CollectionWrapper: return _mongo_collection_wrapper_factory.create_collection_wrapper( _mongo_local_collections, mongodb_settings, ) @pytest.fixture(scope='session') def _mongo_collection_wrapper_factory( mongo_connection_info: connection.ConnectionInfo, ) -> CollectionWrapperFactory: return CollectionWrapperFactory(mongo_connection_info) @pytest.fixture def _mongo_local_collections(request, mongodb_collections) -> typing.Set[str]: result = set(mongodb_collections) for marker in request.node.iter_markers('mongodb_collections'): result.update(marker.args) return result @pytest.fixture(scope='session') def _mongo_schema_cache() -> mongo_schema.MongoSchemaCache: return mongo_schema.MongoSchemaCache() @pytest.fixture(scope='session') def _mongo_service_settings( pytestconfig, ) -> typing.Optional[service.ServiceSettings]: if pytestconfig.option.mongo: return None return service.get_service_settings() def _is_relevant_file( request, static_dir: pathlib.Path, file_path: pathlib.Path, ) -> bool: default_static_dir = static_dir / 'default' module_static_dir = static_dir / pathlib.Path(request.fspath).stem return _is_nested_path(file_path, default_static_dir) or _is_nested_path( file_path, module_static_dir, ) def _is_nested_path(parent: pathlib.Path, nested: pathlib.Path) -> bool: try: pathlib.PurePath(nested).relative_to(parent) return True except ValueError: return False