Source code for migration_runner.database_tools

# -*- coding: utf-8 -*-
import io
import logging
import sys

import mysql.connector


[docs]class DatabaseTools: def __init__(self, logger=None): if logger is None: self.logger = logging.getLogger(__name__) else: self.logger = logger
[docs] def connect_database(self, db_params): try: host, user, password, name = db_params self.logger.debug( "Connecting to database with details: " "user={user}, password={password}, host={host}, db={db}".format (user=user, password=password, host=host, db=name) ) db_connection = mysql.connector.connect(user=user, password=password, host=host, database=name) db_connection.autocommit = True return db_connection except mysql.connector.Error as error: self.logger.error( "{} while connecting to database: {}".format( type(error).__name__, error)) sys.exit(1)
[docs] def fetch_current_version(self, db_params): current_db_version = 0 try: db_connection = self.connect_database(db_params) cursor = db_connection.cursor() cursor.execute("SELECT version FROM versionTable LIMIT 1") current_db_version = int(cursor.fetchone()[0]) db_connection.close() except mysql.connector.Error as error: self.logger.error( "{} while attempting to fetch database version, assuming" " version 0: {}".format(type(error).__name__, error) ) return current_db_version
[docs] def apply_migration(self, db_params, sql_filename): with io.open(sql_filename) as sql_file: db_connection = self.connect_database(db_params) cursor = db_connection.cursor() cursor.execute(sql_file.read(), multi=True) db_connection.close()