Source code for testsuite.databases.pgsql.discover

import collections
import dataclasses
import hashlib
import itertools
import logging
import pathlib
import typing
from typing import DefaultDict, Dict, Iterable, List, Optional

from . import exceptions, utils

logger = logging.getLogger(__name__)

SINGLE_SHARD = -1
DB_NAME_MAX = 31


@dataclasses.dataclass(frozen=True)
class ShardName:
    db_name: str
    shard: int


@dataclasses.dataclass
class ShardFiles:
    name: ShardName
    files: Optional[List[pathlib.Path]] = None
    pg_migrations: Optional[List[pathlib.Path]] = None


@dataclasses.dataclass
class ShardFileInfo:
    files: List[pathlib.Path]
    pg_migrations: List[pathlib.Path]

    def extend(self, other: ShardFiles) -> None:
        if other.files:
            self.files.extend(other.files)
        if other.pg_migrations:
            self.pg_migrations.extend(other.pg_migrations)


ShardPathesDict = Dict[int, ShardFileInfo]


@dataclasses.dataclass(frozen=True)
class PgShard:
    shard_id: int
    pretty_name: str
    dbname: str
    files: List[pathlib.Path]
    migrations: List[pathlib.Path]

    def get_schema_hash(self) -> str:
        return utils.get_files_hash(
            itertools.chain(self.files, self.migrations),
        )


@dataclasses.dataclass(frozen=True)
class PgShardedDatabase:
    service_name: typing.Optional[str]
    dbname: str
    shards: List[PgShard]


[docs]def find_schemas( service_name: Optional[str], schema_dirs: List[pathlib.Path], ) -> Dict[str, PgShardedDatabase]: """Read database schemas from directories ``schema_dirs``. :: |- schema_path/ |- database1.sql |- database2.sql :param service_name: service name used as prefix for database name if not empty, e.g. "servicename_dbname". :param schema_dirs: list of pathes to scan for schemas :returns: :py:class:`Dict[str, PgShardedDatabase]` where key is database name as stored in :py:attr:`PgShard.dbname` """ result: Dict[str, PgShardedDatabase] = {} for path in schema_dirs: if not path.is_dir(): continue schemas = _find_databases_schemas(service_name, path) for dbname in schemas.keys() & result.keys(): raise exceptions.PostgresqlError( f'Database {dbname} is declared twice', ) result.update(schemas) return result
def _find_databases_schemas( service_name: Optional[str], schema_path: pathlib.Path, ) -> Dict[str, PgShardedDatabase]: logger.debug('Looking up for PostgreSQL schemas at %s', schema_path) shard_files_map = _build_shard_files_map(schema_path) result = {} for dbname, shards in shard_files_map.items(): _raise_if_invalid_shards(dbname, shards) pg_shards = [] for shard_id, shard_files in sorted(shards.items()): pg_shards.append( _create_pgshard( dbname, service_name=service_name, shard_id=shard_id, files=sorted(shard_files.files), migrations=sorted(shard_files.pg_migrations), ), ) result[dbname] = PgShardedDatabase( service_name=service_name, dbname=dbname, shards=pg_shards, ) return result def _build_shard_files_map( root_path: pathlib.Path, ) -> DefaultDict[str, ShardPathesDict]: result: DefaultDict[str, ShardPathesDict] result = collections.defaultdict( lambda: collections.defaultdict(lambda: ShardFileInfo([], [])), ) for shard in _find_shard_files(root_path): result[shard.name.db_name][shard.name.shard].extend(shard) return result def _find_shard_files(schema_path: pathlib.Path) -> Iterable[ShardFiles]: for entry in schema_path.iterdir(): shard_files = _get_shard_schema_files(entry) if shard_files is not None: yield shard_files def _get_shard_schema_files(path: pathlib.Path) -> Optional[ShardFiles]: shard_name = _parse_shard_name(path.stem) if path.is_file(): if path.suffix == '.sql': return ShardFiles(shard_name, files=[path]) elif path.is_dir(): if path.joinpath('migrations').is_dir(): return ShardFiles(shard_name, pg_migrations=[path]) return ShardFiles(shard_name, files=utils.scan_sql_directory(path)) return None def _raise_if_invalid_shards(dbname: str, shards: ShardPathesDict) -> None: if SINGLE_SHARD in shards: if len(shards) != 1: raise exceptions.PostgresqlError( 'Postgresql database %s has single shard configuration ' 'while defined as multishard' % (dbname,), ) else: if set(shards.keys()) != set(range(len(shards))): raise exceptions.PostgresqlError( 'Postgresql database %s is missing fixtures ' 'for some shards' % (dbname,), ) def _create_pgshard( dbname: str, service_name: Optional[str] = None, shard_id: int = SINGLE_SHARD, files: Optional[List[pathlib.Path]] = None, migrations: Optional[List[pathlib.Path]] = None, ) -> PgShard: if files is None: files = [] if migrations is None: migrations = [] if shard_id == SINGLE_SHARD: actual_shard_id = 0 pretty_name = dbname else: actual_shard_id = shard_id pretty_name = '%s@%d' % (dbname, shard_id) sharded_dbname = _database_name(service_name, dbname, shard_id) return PgShard( shard_id=actual_shard_id, pretty_name=pretty_name, dbname=sharded_dbname, files=files, migrations=migrations, ) _names_used = {} def _database_name(service_name: Optional[str], dbname: str, shard_id: int): dbkey = (service_name, dbname) suffix = '' if shard_id != SINGLE_SHARD: suffix = f'_{shard_id}' prefix = '' if service_name is not None: prefix = f'{service_name}_' name = _normalize_name(prefix + dbname) dbname = _normalize_name(name + suffix) if len(dbname) > DB_NAME_MAX: dbname = _shortened(name, suffix) if dbname not in _names_used: _names_used[dbname] = dbkey elif _names_used[dbname] != dbkey: raise exceptions.NameCannotBeShortend( f'Database name conflict for {dbkey} and {_names_used[dbname]}' ) return dbname def _shortened(name: str, suffix: str): short_name = ''.join([part[:1] for part in name.split('_')]) name_hash = hashlib.sha1(name.encode('utf-8')).hexdigest() hash_len = DB_NAME_MAX - len(short_name) - len(suffix) - 1 name_hash = name_hash[:hash_len] dbname = f'{short_name}_{name_hash}{suffix}' if len(dbname) > DB_NAME_MAX: raise exceptions.NameCannotBeShortend( f'Dbname cannot be shortened {name}{suffix}' ) return dbname def _parse_shard_name(name) -> ShardName: parts = name.rsplit('@', 1) if len(parts) == 2: try: shard_id = int(parts[1]) except (ValueError, TypeError): pass else: return ShardName(parts[0], shard_id) return ShardName(name, SINGLE_SHARD) def _normalize_name(name): return name.replace('.', '_').replace('-', '_')