Source code for testsuite.databases.pgsql.discover

import collections
import dataclasses
import hashlib
import itertools
import logging
import pathlib
from typing import DefaultDict
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional

from . import exceptions
from . import utils

logger = logging.getLogger(__name__)

SINGLE_SHARD = -1
DB_NAME_MAX = 31


@dataclasses.dataclass(frozen=True)
class ShardName:
    db_name: str
    shard: int


@dataclasses.dataclass
class ShardFiles:
    name: ShardName
    files: Optional[List[pathlib.Path]] = None
    pg_migrations: Optional[List[pathlib.Path]] = None


@dataclasses.dataclass
class ShardFileInfo:
    files: List[pathlib.Path]
    pg_migrations: List[pathlib.Path]

    def extend(self, other: ShardFiles) -> None:
        if other.files:
            self.files.extend(other.files)
        if other.pg_migrations:
            self.pg_migrations.extend(other.pg_migrations)


ShardPathesDict = Dict[int, ShardFileInfo]


@dataclasses.dataclass(frozen=True)
class PgShard:
    shard_id: int
    pretty_name: str
    dbname: str
    files: List[pathlib.Path]
    migrations: List[pathlib.Path]

    def get_schema_hash(self) -> str:
        return utils.get_files_hash(
            itertools.chain(self.files, self.migrations),
        )


@dataclasses.dataclass(frozen=True)
class PgShardedDatabase:
    service_name: str
    dbname: str
    shards: List[PgShard]


[docs]def find_schemas( service_name: Optional[str], schema_dirs: List[pathlib.Path], ) -> Dict[str, PgShardedDatabase]: """Read database schemas from directories ``schema_dirs``. :: |- schema_path/ |- database1.sql |- database2.sql :param service_name: service name used as prefix for database name if not empty, e.g. "servicename_dbname". :param schema_dirs: list of pathes to scan for schemas :returns: :py:class:`Dict[str, PgShardedDatabase]` where key is database name as stored in :py:attr:`PgShard.dbname` """ result: Dict[str, PgShardedDatabase] = {} for path in schema_dirs: if not path.is_dir(): continue schemas = _find_databases_schemas(service_name, path) for dbname in schemas.keys() & result.keys(): raise exceptions.PostgresqlError( f'Database {dbname} is declared twice', ) result.update(schemas) return result
def _find_databases_schemas( service_name: Optional[str], schema_path: pathlib.Path, ) -> Dict[str, PgShardedDatabase]: logger.debug('Looking up for PostgreSQL schemas at %s', schema_path) shard_files_map = _build_shard_files_map(schema_path) result = {} for dbname, shards in shard_files_map.items(): _raise_if_invalid_shards(dbname, shards) pg_shards = [] for shard_id, shard_files in sorted(shards.items()): pg_shards.append( _create_pgshard( dbname, service_name=service_name, shard_id=shard_id, files=sorted(shard_files.files), migrations=sorted(shard_files.pg_migrations), ), ) result[dbname] = PgShardedDatabase( service_name=service_name, dbname=dbname, shards=pg_shards, ) return result def _build_shard_files_map( root_path: pathlib.Path, ) -> DefaultDict[str, ShardPathesDict]: result: DefaultDict[str, ShardPathesDict] result = collections.defaultdict( lambda: collections.defaultdict(lambda: ShardFileInfo([], [])), ) for shard in _find_shard_files(root_path): result[shard.name.db_name][shard.name.shard].extend(shard) return result def _find_shard_files(schema_path: pathlib.Path) -> Iterable[ShardFiles]: for entry in schema_path.iterdir(): shard_files = _get_shard_schema_files(entry) if shard_files is not None: yield shard_files def _get_shard_schema_files(path: pathlib.Path) -> Optional[ShardFiles]: shard_name = _parse_shard_name(path.stem) if path.is_file(): if path.suffix == '.sql': return ShardFiles(shard_name, files=[path]) elif path.is_dir(): if path.joinpath('migrations').is_dir(): return ShardFiles(shard_name, pg_migrations=[path]) return ShardFiles(shard_name, files=utils.scan_sql_directory(path)) return None def _raise_if_invalid_shards(dbname: str, shards: ShardPathesDict) -> None: if SINGLE_SHARD in shards: if len(shards) != 1: raise exceptions.PostgresqlError( 'Postgresql database %s has single shard configuration ' 'while defined as multishard' % (dbname,), ) else: if set(shards.keys()) != set(range(len(shards))): raise exceptions.PostgresqlError( 'Postgresql database %s is missing fixtures ' 'for some shards' % (dbname,), ) def _create_pgshard( dbname: str, service_name: Optional[str] = None, shard_id: int = SINGLE_SHARD, files: Optional[List[pathlib.Path]] = None, migrations: Optional[List[pathlib.Path]] = None, ) -> PgShard: if files is None: files = [] if migrations is None: migrations = [] if shard_id == SINGLE_SHARD: actual_shard_id = 0 pretty_name = dbname else: actual_shard_id = shard_id pretty_name = '%s@%d' % (dbname, shard_id) sharded_dbname = _database_name(service_name, dbname, shard_id) return PgShard( shard_id=actual_shard_id, pretty_name=pretty_name, dbname=sharded_dbname, files=files, migrations=migrations, ) _names_used = {} def _database_name(service_name: Optional[str], dbname: str, shard_id: int): dbkey = (service_name, dbname) suffix = '' if shard_id != SINGLE_SHARD: suffix = f'_{shard_id}' prefix = '' if service_name is not None: prefix = f'{service_name}_' name = _normalize_name(prefix + dbname) dbname = _normalize_name(name + suffix) if len(dbname) > DB_NAME_MAX: dbname = _shortened(name, suffix) if dbname not in _names_used: _names_used[dbname] = dbkey elif _names_used[dbname] != dbkey: raise exceptions.NameCannotBeShortend( f'Database name conflict for {dbkey} and {_names_used[dbname]}' ) return dbname def _shortened(name: str, suffix: str): short_name = ''.join([part[:1] for part in name.split('_')]) name_hash = hashlib.sha1(name.encode('utf-8')).hexdigest() hash_len = DB_NAME_MAX - len(short_name) - len(suffix) - 1 name_hash = name_hash[:hash_len] dbname = f'{short_name}_{name_hash}{suffix}' if len(dbname) > DB_NAME_MAX: raise exceptions.NameCannotBeShortend( f'Dbname cannot be shortened {name}{suffix}' ) return dbname def _parse_shard_name(name) -> ShardName: parts = name.rsplit('@', 1) if len(parts) == 2: try: shard_id = int(parts[1]) except (ValueError, TypeError): pass else: return ShardName(parts[0], shard_id) return ShardName(name, SINGLE_SHARD) def _normalize_name(name): return name.replace('.', '_').replace('-', '_')