import concurrent.futures
import collections
import collections.abc
import contextlib
import re
import typing
import pytest
from . import connection
from . import control
from . import discover
from . import exceptions
from . import service
from . import utils
DB_FILE_RE_PATTERN = re.compile(r'/pg_(?P<pg_db_alias>\w+)(/?\w*)\.sql$')
[docs]class ServiceLocalConfig(collections.abc.Mapping):
def __init__(
self,
databases: typing.List[discover.PgShardedDatabase],
pgsql_control: control.PgControl,
cleanup_exclude_tables: typing.FrozenSet[str],
):
self._initialized = False
self._pgsql_control = pgsql_control
self._databases = databases
self._shard_connections = {
shard.pretty_name: pgsql_control.get_connection_cached(
shard.dbname,
)
for db in self._databases
for shard in db.shards
}
self._cleanup_exclude_tables = cleanup_exclude_tables
def __len__(self) -> int:
return len(self._shard_connections)
def __iter__(self) -> typing.Iterator[str]:
return iter(self._shard_connections)
[docs] def __getitem__(self, dbname: str) -> connection.PgConnectionInfo:
"""Get
:py:class:`testsuite.databases.pgsql.connection.PgConnectionInfo`
instance by database name
"""
return self._shard_connections[dbname].conninfo
def initialize(
self, parallel_init: bool
) -> typing.Dict[str, control.ConnectionWrapper]:
if self._initialized:
return self._shard_connections
if self._databases:
self._pgsql_control.initialize()
def init_database(db):
self._pgsql_control.initialize_sharded_db(db)
for shard in db.shards:
self._shard_connections[shard.pretty_name].initialize(
self._cleanup_exclude_tables
)
if parallel_init:
with concurrent.futures.ThreadPoolExecutor() as executor:
init_db_futures = [
executor.submit(init_database, db) for db in self._databases
]
for future in init_db_futures:
future.result()
else:
for database in self._databases:
init_database(database)
self._initialized = True
return self._shard_connections
def pytest_addoption(parser):
"""
:param parser: pytest's argument parser
"""
group = parser.getgroup('postgresql')
group.addoption('--postgresql', help='PostgreSQL connection string')
group.addoption(
'--no-postgresql',
help='Disable use of PostgreSQL',
action='store_true',
)
group.addoption(
'--postgresql-keep-existing-db',
action='store_true',
help=(
'Keep existing databases with up-to-date schema. By default '
'testsuite will drop and create anew any existing database when '
'initializing databases.'
),
)
def pytest_configure(config):
config.addinivalue_line(
'markers',
'pgsql: per-test PostgreSQL initialization',
)
def pytest_service_register(register_service):
register_service('postgresql', service.create_pgsql_service)
[docs]@pytest.fixture(scope='session')
def pgsql_cleanup_exclude_tables():
return frozenset()
[docs]@pytest.fixture
def pgsql(_pgsql, pgsql_apply) -> typing.Dict[str, control.PgDatabaseWrapper]:
"""
Returns str to
:py:class:`testsuite.databases.pgsql.control.PgDatabaseWrapper` dictionary
Example usage:
.. code-block:: python
def test_pg(pgsql):
cursor = pgsql['example_db'].cursor()
cursor.execute('SELECT ... FROM ...WHERE ...')
assert list(cusror) == [...]
"""
return {
dbname: control.PgDatabaseWrapper(connection)
for dbname, connection in _pgsql.items()
}
[docs]@pytest.fixture(scope='session')
def pgsql_local_create(
_pgsql_control,
pgsql_cleanup_exclude_tables,
) -> typing.Callable[
[typing.List[discover.PgShardedDatabase]],
ServiceLocalConfig,
]:
"""Creates pgsql configuration.
:param databases: List of databases.
:returns: :py:class:`ServiceLocalConfig` instance.
"""
def _pgsql_local_create(databases):
return ServiceLocalConfig(
databases,
_pgsql_control,
pgsql_cleanup_exclude_tables,
)
return _pgsql_local_create
@pytest.fixture(scope='session')
def pgsql_disabled(pytestconfig) -> bool:
return pytestconfig.option.no_postgresql
[docs]@pytest.fixture
def pgsql_local(pgsql_local_create) -> ServiceLocalConfig:
"""Configures local pgsql instance.
:returns: :py:class:`ServiceLocalConfig` instance.
In order to use pgsql fixture you have to override pgsql_local()
in your local conftest.py file, example:
.. code-block:: python
@pytest.fixture(scope='session')
def pgsql_local(pgsql_local_create):
databases = discover.find_schemas(
'service_name', [PG_SCHEMAS_PATH])
return pgsql_local_create(list(databases.values()))
Sometimes it is desirable to have tests-only database, maybe used in one
particular test or tests group. This can be achieved by by overriding
``pgsql_local`` fixture in your test file:
.. code-block:: python
@pytest.fixture
def pgsql_local(pgsql_local_create):
databases = discover.find_schemas(
'testsuite', [pathlib.Path('custom/pgsql/schema/path')])
return pgsql_local_create(list(databases.values()))
``pgsql_local`` provides access to PostgreSQL connection parameters:
.. code-block:: python
def get_custom_connection_string(pgsql_local):
conninfo = pgsql_local['database_name']
custom_dsn: str = conninfo.replace(options='-c opt=val').get_dsn()
return custom_dsn
"""
return pgsql_local_create([])
@pytest.fixture(scope='session')
def pgsql_parallelization_enabled():
return True
@pytest.fixture
def _pgsql(
_pgsql_service,
_pgsql_control,
pgsql_local,
pgsql_disabled: bool,
pgsql_parallelization_enabled: bool,
) -> typing.Dict[str, control.ConnectionWrapper]:
if pgsql_disabled:
pgsql_local = ServiceLocalConfig(
[],
_pgsql_control,
pgsql_cleanup_exclude_tables,
)
return pgsql_local.initialize(parallel_init=pgsql_parallelization_enabled)
@pytest.fixture(scope='session')
def pgsql_background_truncate_enabled():
return True
@pytest.fixture
def _pgsql_apply_queries(
request, _pgsql: ServiceLocalConfig, _pgsql_query_loader
) -> typing.Dict[str, typing.List[control.PgQuery]]:
def pgsql_default_queries(dbname):
return [
*_pgsql_query_loader.load(
f'pg_{dbname}.sql',
'pgsql.default_queries',
missing_ok=True,
),
*_pgsql_query_loader.loaddir(
f'pg_{dbname}',
'pgsql.default_queries',
missing_ok=True,
),
]
def pgsql_mark(dbname, files=(), directories=(), queries=()):
result_queries = []
for path in files:
result_queries += _pgsql_query_loader.load(path, 'mark.pgsql.files')
for path in directories:
result_queries += _pgsql_query_loader.loaddir(
path,
'mark.pgsql.directories',
)
for query in queries:
queries_str = []
if isinstance(query, str):
queries_str = [query]
elif isinstance(query, (list, tuple)):
queries_str = query
else:
raise exceptions.PostgresqlError(
f'sql queries of type {type(query)} are not supported',
)
for query_str in queries_str:
result_queries.append(
control.PgQuery(
body=query_str,
source='mark.pgsql.queries',
path=None,
),
)
return dbname, result_queries
overrides: typing.DefaultDict[
str,
typing.List[control.PgQuery],
] = collections.defaultdict(list)
for mark in request.node.iter_markers('pgsql'):
dbname, queries = pgsql_mark(*mark.args, **mark.kwargs)
if dbname not in _pgsql:
raise exceptions.PostgresqlError('Unknown database %s' % (dbname,))
overrides[dbname].extend(queries)
queries = {}
for dbname in _pgsql.keys():
queries[dbname] = overrides.get(dbname, pgsql_default_queries(dbname))
return queries
@pytest.fixture
def pgsql_apply(
_pgsql: ServiceLocalConfig,
load,
pgsql_background_truncate_enabled: bool,
pgsql_parallelization_enabled: bool,
_pgsql_apply_queries,
) -> None:
"""Initialize PostgreSQL database with data.
By default pg_${DBNAME}.sql and pg_${DBNAME}/*.sql files are used
to fill PostgreSQL databases.
Use pytest.mark.pgsql to change this behaviour:
@pytest.mark.pgsql(
'foo@0',
files=[
'pg_foo@0_alternative.sql'
],
directories=[
'pg_foo@0_alternative_dir'
],
queries=[
'INSERT INTO foo VALUES (1, 2, 3, 4)',
]
)
"""
if pgsql_parallelization_enabled:
with concurrent.futures.ThreadPoolExecutor() as executor:
db_apply_queries_future = []
for dbname, pg_db in _pgsql.items():
db_apply_queries_future.append(
executor.submit(
pg_db.apply_queries, _pgsql_apply_queries[dbname]
)
)
for future in db_apply_queries_future:
future.result()
else:
for dbname, pg_db in _pgsql.items():
pg_db.apply_queries(_pgsql_apply_queries[dbname])
yield
if pgsql_background_truncate_enabled:
for pg_db in _pgsql.values():
pg_db.schedule_truncation()
@pytest.fixture
def _pgsql_query_loader(get_file_path, get_directory_path, mockserver_info):
def substitute_mockserver(str_val: str):
return str_val.replace(
'$mockserver',
'http://{}:{}'.format(mockserver_info.host, mockserver_info.port),
)
def load_pg_file(path, source):
query = substitute_mockserver(path.read_text())
return control.PgQuery(body=query, source=source, path=str(path))
class Loader:
@staticmethod
def load(path, source, missing_ok=False):
path = get_file_path(path, missing_ok=missing_ok)
if not path:
return []
return [load_pg_file(path, source)]
@staticmethod
def loaddir(directory, source, missing_ok=False):
result = []
directory = get_directory_path(directory, missing_ok=missing_ok)
if not directory:
return []
for path in utils.scan_sql_directory(directory):
result.append(load_pg_file(path, source))
return result
return Loader()
@pytest.fixture
def _pgsql_service(
pytestconfig,
pgsql_disabled: bool,
ensure_service_started,
pgsql_local: ServiceLocalConfig,
_pgsql_service_settings,
) -> None:
if (
not pgsql_disabled
and pgsql_local
and not pytestconfig.option.postgresql
):
ensure_service_started('postgresql', settings=_pgsql_service_settings)
@pytest.fixture(scope='session')
def _pgsql_control(pytestconfig, _pgsql_conninfo, pgsql_disabled: bool):
if pgsql_disabled:
return {}
instance = control.PgControl(
_pgsql_conninfo,
verbose=pytestconfig.option.verbose,
skip_applied_schemas=(
pytestconfig.option.postgresql_keep_existing_db
or pytestconfig.option.service_wait
),
)
with contextlib.closing(instance):
yield instance
@pytest.fixture(scope='session')
def _pgsql_service_settings() -> service.ServiceSettings:
return service.get_service_settings()
@pytest.fixture(scope='session')
def _pgsql_conninfo(
request,
_pgsql_service_settings,
) -> connection.PgConnectionInfo:
connstr = request.config.option.postgresql
if connstr:
return connection.parse_connection_string(connstr)
return _pgsql_service_settings.get_conninfo()