import contextlib
import dataclasses
import logging
import pathlib
import typing
import pymysql
import pymysql.constants
from testsuite.environment import shell
from testsuite.utils import cached_property
from . import classes
from . import exceptions
logger = logging.getLogger(__name__)
MYSQL_HELPER = pathlib.Path(__file__).parent.joinpath('scripts/mysql-helper')
@dataclasses.dataclass(frozen=True)
class MysqlQuery:
body: str
source: str
path: typing.Optional[str]
[docs]class ConnectionWrapper:
"""MySQL database connection wrapper."""
def __init__(self, connection, conninfo, tables):
self._connection = connection
self._conninfo = conninfo
self._tables: typing.List[str] = tables
@property
def conninfo(self) -> classes.ConnectionInfo:
""":py:class:`classes.ConnectionInfo` instance."""
return self._conninfo
[docs] def cursor(self, **kwargs) -> pymysql.cursors.Cursor:
"""Returns cursor instance."""
return self._connection.cursor(**kwargs)
def dict_cursor(self, **kwargs) -> pymysql.cursors.Cursor:
"""Return dictionary cursor, pymysql.cursors.DictCursor."""
kwargs['cursor'] = pymysql.cursors.DictCursor
return self.cursor(**kwargs)
def commit(self) -> None:
self._connection.commit()
def _truncate_non_empty_tables(self) -> typing.Optional[typing.List[str]]:
cursor = self.cursor()
if self._tables:
with contextlib.closing(cursor):
queries = []
for table in self._tables:
queries.append(
f'select \'{table}\' as name, count(*) as c from {table}'
)
subquery = ' union '.join(queries)
query = f'select name from ({subquery}) tables where c>0;'
cursor.execute(query)
tables = cursor.fetchall()
return [table for (table,) in tables]
def apply_queries(
self,
queries: typing.List[MysqlQuery],
keep_tables: typing.List[str] = None,
truncate_non_empty: bool = False,
) -> None:
if not keep_tables:
keep_tables = []
with self.cursor() as cursor:
if truncate_non_empty:
tables = self._truncate_non_empty_tables()
else:
tables = self._tables
if tables:
truncate_sql = []
for table in tables:
if table not in keep_tables:
truncate_sql.append(f'truncate table {table};')
truncate_sql = ' '.join(truncate_sql)
cursor.execute(
'set foreign_key_checks=0;'
f'{truncate_sql}'
'set foreign_key_checks=1;',
)
for query in queries:
try:
cursor.execute(query.body, args=[])
except pymysql.Error as exc:
error_message = (
f'MySQL apply query error\n'
f'Query from: {query.source}\n'
)
if query.path:
error_message += f'File path: {query.path}\n'
error_message += '\n' + str(exc)
raise exceptions.MysqlError(error_message)
self.commit()
class ConnectionCache:
def __init__(self, conninfo, verbose: bool = False):
self._conninfo = conninfo
self._cache: dict = {}
self._master_connection = None
def get_master_connection(self):
if self._master_connection is None:
self._master_connection = self._connect(self._conninfo)
return self._master_connection
def get_conninfo(self, dbname: str) -> classes.ConnectionInfo:
return self._conninfo.replace(dbname=dbname)
def get_connection(self, dbname):
if dbname not in self._cache:
self._cache[dbname] = self._create_connection(dbname)
return self._cache[dbname]
def _create_connection(self, dbname):
return self._connect(self.get_conninfo(dbname))
def _connect(self, conninfo: classes.ConnectionInfo):
return pymysql.connect(
host=conninfo.hostname,
port=conninfo.port,
user=conninfo.user,
password=conninfo.password,
database=conninfo.dbname,
client_flag=pymysql.constants.CLIENT.MULTI_STATEMENTS,
)
class DatabasesState:
_migrations_run: typing.Set[typing.Tuple[str, str]]
_initialized: typing.Set[str]
def __init__(self, connections: ConnectionCache, verbose: bool = False):
self._need_save_tables = True
self._connections = connections
self._verbose = verbose
self._migrations_run = set()
self._initialized = set()
self._tables = dict()
def get_connection(self, dbname: str, create_db: bool = True):
if dbname not in self._initialized:
if create_db:
self._initdb(dbname)
self._initialized.add(dbname)
return self._connections.get_connection(dbname)
def wrapper_for(self, dbname: str):
return ConnectionWrapper(
self._connections.get_connection(dbname),
self._connections.get_conninfo(dbname),
self._tables.get(dbname),
)
def run_migration(self, dbname: str, path: str):
key = dbname, path
if key in self._migrations_run:
return
logger.debug(
'Running mysql script %s against database %s',
path,
dbname,
)
conninfo = self._connections.get_conninfo(dbname)
_run_script(conninfo, ['-e', f'source {path}'], verbose=self._verbose)
self._migrations_run.add(key)
self._need_save_tables = True
@cached_property
def known_databases(self):
connection = self._connections.get_master_connection()
cursor = connection.cursor()
cursor.execute('show databases')
return {row[0] for row in cursor.fetchall()}
def _initdb(self, dbname: str):
connection = self._connections.get_master_connection()
with connection.cursor() as cursor:
if dbname in self.known_databases:
cursor.execute(f'DROP DATABASE IF EXISTS `{dbname}`')
cursor.execute(f'CREATE DATABASE `{dbname}`')
connection.commit()
self._initialized.add(dbname)
def save_tables(self, dbname: str) -> None:
if not self._need_save_tables:
return
connection = self._connections.get_connection(dbname)
cursor = connection.cursor()
with contextlib.closing(cursor):
cursor.execute('show tables')
self._tables[dbname] = [table for (table,) in cursor.fetchall()]
self._need_save_tables = False
class Control:
def __init__(
self,
databases: classes.DatabasesDict,
state: DatabasesState,
):
self._databases = databases
self._state = state
def get_wrappers(self):
return {
alias: self._state.wrapper_for(dbconfig.dbname)
for alias, dbconfig in self._databases.items()
}
def run_migrations(self):
for dbconfig in self._databases.values():
self._run_database_migrations(dbconfig)
def _run_database_migrations(self, dbconfig):
self._state.get_connection(dbconfig.dbname, create_db=dbconfig.create)
for path in dbconfig.migrations:
self._state.run_migration(dbconfig.dbname, path)
self._state.save_tables(dbconfig.dbname)
def _build_mysql_args(conninfo: classes.ConnectionInfo) -> typing.List[str]:
result = ['--protocol=tcp']
if conninfo.hostname:
result.append(f'--host={conninfo.hostname}')
if conninfo.port:
result.append(f'--port={conninfo.port}')
if conninfo.user:
result.append(f'--user={conninfo.user}')
if conninfo.password:
result.append(f'--password={conninfo.password}')
if conninfo.dbname:
result.append(f'--database={conninfo.dbname}')
return result
def _run_script(
conninfo: classes.ConnectionInfo,
args: typing.List[str],
verbose: bool,
):
command = [str(MYSQL_HELPER), *_build_mysql_args(conninfo), *args]
shell.execute(command, verbose=verbose, command_alias='mysql/script')