Source code for migration_runner.controller

# -*- coding: utf-8 -*-
import logging

import mysql.connector

from migration_runner.database_tools import DatabaseTools
from migration_runner.helpers import Helpers


[docs]class Controller: def __init__(self, logger=None): if logger is None: self.logger = logging.getLogger(__name__) else: self.logger = logger self.helpers = Helpers(logger) self.database = DatabaseTools(logger)
[docs] def process_single_file(self, db_params, single_file): self.logger.warning( "Use of this option means DB version will be out of sync!") self.database.apply_migration(db_params, single_file) self.logger.info( "Successfully executed SQL in file: '{}'".format(single_file) )
[docs] def update_current_version(self, db_params, new_version): current_db_version = 0 try: db_connection = self.database.connect_database(db_params) cursor = db_connection.cursor() cursor.execute("UPDATE versionTable SET version = \'{}\'" .format(new_version)) cursor.execute("SELECT version FROM versionTable LIMIT 1") db_version_row = cursor.fetchone() if db_version_row is not None: current_db_version = db_version_row[0] db_connection.close() except mysql.connector.Error as error: self.logger.error( "{} while attempting to update current database version, " "assuming version 0: {}".format(type(error).__name__, error) ) return current_db_version
[docs] def process_migrations(self, db_params, db_version, unprocessed_migrations): total_processed = 0 for version_code, sql_filename in unprocessed_migrations: self.logger.debug( "Applying migration: {version} with filename: '{file}'".format (version=version_code, file=sql_filename) ) try: self.database.apply_migration(db_params, sql_filename) self.logger.info( "Upgraded DB version from {old} to {new} by executing file" ": '{file}'".format( old=db_version, new=version_code, file=sql_filename) ) db_version = self.update_current_version(db_params, version_code) total_processed += 1 except mysql.connector.Error as error: self.logger.error( "{type} while processing migration in file: '{file}': " "{error}".format(type=type(error).__name__, file=sql_filename, error=error)) break return db_version, total_processed
[docs] def process_migrations_in_directory(self, db_params, sql_directory): self.logger.debug( "Looking for migrations in dir: {}".format(sql_directory)) migrations = self.helpers.populate_migrations(sql_directory) self.logger.debug("Migrations found: {}".format(len(migrations))) db_version = self.database.fetch_current_version(db_params) self.logger.info( "Starting with database version: {}".format(db_version)) unprocessed = self.helpers.get_unprocessed_migrations(db_version, migrations) self.logger.info( "Migrations yet to be processed: {unprocessed} (out of {total} " "in dir)".format( unprocessed=len(unprocessed), total=len(migrations) ) ) db_version, total_processed = self.process_migrations( db_params, db_version, unprocessed ) self.logger.info( "Database version now {version} after processing {processed}" " migrations. Remaining: {unprocessed}.".format (version=db_version, processed=total_processed, unprocessed=(len(unprocessed) - total_processed)))