Skip to content

Instantly share code, notes, and snippets.

@skwerlman
Created January 24, 2018 04:32
Show Gist options
  • Save skwerlman/5610695e49ca605c1bdd6957fabdfa25 to your computer and use it in GitHub Desktop.
Save skwerlman/5610695e49ca605c1bdd6957fabdfa25 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Run migrations on the servatrice database.
Reads migrations from `servatrice/migrations/` and runs them
on the database in order.
Only runs migrations if they are nedded. This is determined using
the cockatrice_schema_version table.
Stops running migrations if any fail for any reason.
"""
import os
from argparse import ArgumentParser
import pymysql
import pymysql.cursors
SQL_CONTROLLER = None
def run_sql_command(sql: str) -> str:
"""Run a SQL command."""
SQL_CONTROLLER.execute(sql)
result = SQL_CONTROLLER.fetchall()
return result
def get_all_migrations(args) -> list:
"""Get a list of all available migrations."""
files = os.listdir(args.migration_directory)
migrations = [f'{args.migration_directory}/{x}' for x in files if x.endswith('.sql')]
migrations.sort()
return migrations
def get_schema_version() -> int:
"""Get the schema version of the servatrice database."""
command = 'SELECT version FROM cockatrice_schema_version;'
result = run_sql_command(command)
result = result[0]['version']
return int(result)
def valid_migrations(migrations: list, schema_version: int) -> list:
"""Retrun a list of valid migrations."""
valid = []
for migration in migrations:
parts = migration.split('_')
if schema_version <= int(parts[1]):
valid.append(migration)
valid.sort()
return valid
def run_migration(migration: str) -> dict:
"""Load a migration from disk and run it."""
with open(migration, 'r') as f:
sql = f.read()
# print(SQL_CONTROLLER.mogrify(sql))
# exit()
try:
result = run_sql_command(sql)
status = {
'success': True,
'result': result
}
except pymysql.err.MySQLError as exception:
status = {
'success': False,
'error': exception
}
return status
def main() -> None:
"""Run the migrations."""
global SQL_CONTROLLER
parser = ArgumentParser(
description='Run migrations on a servatrice database.',
epilog='Be sure to manually verify migrations _before_ running them!'
)
mysql_group = parser.add_argument_group('MySql Server Args')
mysql_group.add_argument('-u', '--user', required=True)
mysql_group.add_argument('-p', '--password', '--pass', required=True)
mysql_group.add_argument('-H', '--host', default='127.0.0.1')
mysql_group.add_argument('-d', '--database', default='servatrice')
mysql_group.add_argument('-P', '--port', type=int, default=3306)
script_group = parser.add_argument_group('Script Args')
script_group.add_argument('-D', '--migration-directory', default='./migrations')
script_group.add_argument('--safe-mode', type=bool, default=True)
args = parser.parse_args()
connection = pymysql.connect(
host=args.host, user=args.user, password=args.password,
db=args.database, charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor)
SQL_CONTROLLER = connection.cursor()
migrations = get_all_migrations(args)
schema_version = get_schema_version()
migrations = valid_migrations(migrations, schema_version)
for migration in migrations:
status = run_migration(migration)
if not status['success']:
exc = status['error']
print(exc)
break
# TODO handle ctlaltca's concerns from #2969
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment