import contextlib
import dataclasses
import multiprocessing.pool
import pathlib
import pprint
import random
import re
import typing
from bson import json_util
import pymongo
import pymongo.collection
import pymongo.errors
import pytest
from testsuite import annotations
from testsuite import utils
from . import connection
from . import ensure_db_indexes
from . import mongo_schema
from . import service
# pylint: disable=too-many-statements
DB_FILE_RE_PATTERN = re.compile(r'^db_(?P<mongo_db_alias>\w+)\.json$')
JSON_OPTIONS = json_util.JSONOptions(tz_aware=False)
MONGO_OBJECT_HOOKS = (
'$binary',
'$code',
'$date',
'$dbPointer',
'$maxKey',
'$minKey',
'$numberDecimal',
'$numberDouble',
'$numberInt',
'$numberLong',
'$oid',
'$ref',
'$regex',
'$regularExpression',
'$symbol',
'$timestamp',
'$undefined',
'$uuid',
)
class BaseError(Exception):
"""Base testsuite error"""
class UnknownCollectionError(BaseError):
pass
class CollectionWrapper:
def __init__(self, collections):
# TODO: deprecate collection as attribute
for alias, collection in collections.items():
setattr(self, alias, collection)
self._collections = collections.copy()
self._aliases = tuple(collections.keys())
def __getitem__(self, alias: str) -> pymongo.collection.Collection:
return self._collections[alias]
def __contains__(self, alias: str) -> bool:
return alias in self._collections
def get_aliases(self) -> typing.Tuple[str]:
return self._aliases
class CollectionWrapperFactory:
def __init__(self, connection_info: connection.ConnectionInfo):
self._connection_info = connection_info
@property
def connection_string(self) -> str:
return self._connection_info.get_uri()
@utils.cached_property
def client(self) -> pymongo.MongoClient:
return pymongo.MongoClient(self.connection_string)
def create_collection_wrapper(
self,
collection_names,
mongodb_settings,
) -> CollectionWrapper:
collections = {}
for name in collection_names:
if name not in mongodb_settings:
raise UnknownCollectionError(
f'Missing collection {name} in mongodb_settings fixture',
)
# pylint: disable=unsubscriptable-object
settings = mongodb_settings[name]['settings']
database = self.client[settings['database']]
collections[name] = database[settings['collection']]
return CollectionWrapper(collections)
def pytest_configure(config):
config.addinivalue_line(
'markers',
'noshuffledb: disable data set shuffle for marked test',
)
config.addinivalue_line(
'markers',
'filldb: specify mongo static file suffix',
)
config.addinivalue_line(
'markers',
'mongodb_collections: override mongo collections list',
)
def pytest_addoption(parser):
"""
:param parser: pytest's argument parser
"""
group = parser.getgroup('mongo')
group.addoption('--mongo', help='Mongo connection string.')
group.addoption(
'--no-indexes',
action='store_true',
help='Disable index creation.',
)
group.addoption(
'--no-shuffle-db',
action='store_true',
help='Disable fixture data shuffle.',
)
group.addoption(
'--no-sharding',
action='store_true',
help='Disable collections sharding.',
)
group.addoption(
'--no-mongo',
help='Disable mongo startup',
action='store_true',
)
parser.addini(
'mongo-retry-writes',
type='bool',
default=False,
help=(
'Controls value of \'retryWrites\' parameter of mongo connection '
'string.'
),
)
def pytest_service_register(register_service):
register_service('mongo', service.create_mongo_service)
def pytest_register_object_hooks():
return {key: _mongo_object_hook for key in MONGO_OBJECT_HOOKS}
[docs]@pytest.fixture
def mongodb(
mongodb_init,
_mongodb_local: CollectionWrapper,
) -> CollectionWrapper:
return _mongodb_local
@pytest.fixture
def mongo_connections(
mongodb_settings,
mongo_connection_info,
mongo_extra_connections,
_mongo_local_collections,
) -> typing.Dict[str, str]:
mongo_connection_uri = mongo_connection_info.get_uri()
return {
**{
mongodb_settings[name]['settings'][
'connection'
]: mongo_connection_uri
for name in _mongo_local_collections
},
**{
extra_conn: mongo_connection_uri
for extra_conn in mongo_extra_connections
},
}
@pytest.fixture
def mongo_extra_connections() -> typing.Tuple[str, ...]:
"""
Override this if you need to access mongo connections besides those
defined in mongo_connections fixture
"""
return ()
[docs]@pytest.fixture(scope='session')
def mongo_connection_info(
pytestconfig,
_mongo_service_settings,
) -> connection.ConnectionInfo:
# External mongo instance
if pytestconfig.option.mongo:
return connection.parse_connection_uri(pytestconfig.option.mongo)
connection_info = _mongo_service_settings.get_connection_info()
retry_writes = pytestconfig.getini('mongo-retry-writes')
return dataclasses.replace(connection_info, retry_writes=retry_writes)
@pytest.fixture
def mongodb_settings(
mongo_schema_directory,
mongo_schema_extra_directories,
_mongo_schema_cache,
) -> mongo_schema.MongoSchemas:
return mongo_schema.MongoSchemas(
_mongo_schema_cache,
(mongo_schema_directory, *mongo_schema_extra_directories),
)
@pytest.fixture
def mongodb_collections(mongodb_settings) -> typing.Tuple[str, ...]:
"""
Override this to enable access to named collections within test module
Returns all available collections by default.
"""
return tuple(mongodb_settings.keys())
@pytest.fixture(scope='session')
def mongo_schema_extra_directories() -> typing.Tuple[str, ...]:
"""
Override to use collection schemas besides those defined by
``mongo_schema_directory`` fixture
"""
return ()
@pytest.fixture(scope='session')
def _mongo_indexes_ensured() -> typing.Set[str]:
return set()
@pytest.fixture
def _mongo_service(
pytestconfig,
ensure_service_started,
_mongodb_local,
_mongo_service_settings,
) -> None:
aliases = _mongodb_local.get_aliases()
if (
aliases
and not pytestconfig.option.mongo
and not pytestconfig.option.no_mongo
):
ensure_service_started('mongo', settings=_mongo_service_settings)
@pytest.fixture
def _mongo_create_indexes(
_mongodb_local,
mongodb_settings,
pytestconfig,
_mongo_indexes_ensured,
_mongo_service,
) -> None:
aliases = _mongodb_local.get_aliases()
if not pytestconfig.option.no_indexes:
_ensure_indexes = {}
for alias in aliases:
if (
alias not in _mongo_indexes_ensured
and alias in mongodb_settings
):
_ensure_indexes[alias] = mongodb_settings[alias]
if _ensure_indexes:
sharding_enabled = not pytestconfig.option.no_sharding
ensure_db_indexes.ensure_db_indexes(
_mongodb_local,
_ensure_indexes,
sharding_enabled=sharding_enabled,
)
_mongo_indexes_ensured.update(_ensure_indexes)
@pytest.fixture(scope='session')
def _mongo_thread_pool() -> (
annotations.YieldFixture[multiprocessing.pool.ThreadPool,]
):
pool = multiprocessing.pool.ThreadPool(processes=1)
with contextlib.closing(pool):
yield pool
@pytest.fixture
def _mongo_query_loader(load_json):
def loader(filename, missing_ok=False):
data = load_json(filename, missing_ok=missing_ok)
if data is None:
return []
return data
return loader
@pytest.fixture
def mongodb_init(
request,
verify_file_paths,
static_dir: pathlib.Path,
_mongodb_local,
_mongo_thread_pool,
_mongo_create_indexes,
_mongo_query_loader,
) -> None:
"""Populate mongodb with fixture data."""
if request.node.get_closest_marker('nofilldb'):
return
# Disable shuffle to make some buggy test work
shuffle_enabled = (
not request.config.option.no_shuffle_db
and not request.node.get_closest_marker('noshuffledb')
)
aliases = {key: key for key in _mongodb_local.get_aliases()}
requested = set()
for marker in request.node.iter_markers('filldb'):
for dbname, alias in marker.kwargs.items():
if dbname not in aliases:
raise UnknownCollectionError(
f'Unknown collection {dbname} requested'
)
if alias != 'default':
aliases[dbname] = '%s_%s' % (dbname, alias)
requested.add(dbname)
def _verify_db_alias(file_path: pathlib.Path) -> bool:
if not _is_relevant_file(request, static_dir, file_path):
return True
match = DB_FILE_RE_PATTERN.search(file_path.name)
if match:
db_alias = match.group('mongo_db_alias')
if db_alias not in aliases and not any(
db_alias.startswith(alias + '_') for alias in aliases
):
return False
return True
verify_file_paths(
_verify_db_alias,
check_name='mongo_db_aliases',
text_at_fail='file has not valid mongo collection name alias '
'(probably should add to service.yaml)',
)
def load_collection(params):
dbname, alias = params
try:
col = getattr(_mongodb_local, dbname)
except AttributeError:
return
docs = _mongo_query_loader(
f'db_{alias}.json', missing_ok=dbname not in requested
)
if not docs and col.find_one({}, []) is None:
return
if shuffle_enabled:
# Make sure there is no tests that depend on order of
# documents in fixture file.
random.shuffle(docs)
try:
col.bulk_write(
[
pymongo.DeleteMany({}),
*(pymongo.InsertOne(doc) for doc in docs),
],
ordered=True,
)
except pymongo.errors.BulkWriteError as bwe:
pprint.pprint(bwe.details)
raise
pool_args = []
for dbname, alias in aliases.items():
pool_args.append((dbname, alias))
_mongo_thread_pool.map(load_collection, pool_args)
@pytest.fixture
def _mongodb_local(
mongodb_settings,
_mongo_local_collections,
_mongo_collection_wrapper_factory: CollectionWrapperFactory,
) -> CollectionWrapper:
return _mongo_collection_wrapper_factory.create_collection_wrapper(
_mongo_local_collections,
mongodb_settings,
)
@pytest.fixture(scope='session')
def _mongo_collection_wrapper_factory(
mongo_connection_info: connection.ConnectionInfo,
) -> CollectionWrapperFactory:
return CollectionWrapperFactory(mongo_connection_info)
@pytest.fixture
def _mongo_local_collections(request, mongodb_collections) -> typing.Set[str]:
result = set(mongodb_collections)
for marker in request.node.iter_markers('mongodb_collections'):
result.update(marker.args)
return result
@pytest.fixture(scope='session')
def _mongo_schema_cache() -> mongo_schema.MongoSchemaCache:
return mongo_schema.MongoSchemaCache()
@pytest.fixture(scope='session')
def _mongo_service_settings(
pytestconfig,
) -> typing.Optional[service.ServiceSettings]:
if pytestconfig.option.mongo:
return None
return service.get_service_settings()
def _is_relevant_file(
request,
static_dir: pathlib.Path,
file_path: pathlib.Path,
) -> bool:
default_static_dir = static_dir / 'default'
module_static_dir = static_dir / pathlib.Path(request.fspath).stem
return _is_nested_path(file_path, default_static_dir) or _is_nested_path(
file_path,
module_static_dir,
)
def _is_nested_path(parent: pathlib.Path, nested: pathlib.Path) -> bool:
try:
pathlib.PurePath(nested).relative_to(parent)
return True
except ValueError:
return False
def _mongo_object_hook(doc):
return json_util.object_hook(doc, JSON_OPTIONS)