Create database upgrade framework
[cascardo/ipsilon.git] / ipsilon / util / data.py
index f90519d..200feb8 100644 (file)
@@ -1,11 +1,12 @@
 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
 
 import cherrypy
+import datetime
 from ipsilon.util.log import Log
 from sqlalchemy import create_engine
 from sqlalchemy import MetaData, Table, Column, Text
 from sqlalchemy.pool import QueuePool, SingletonThreadPool
-from sqlalchemy.sql import select
+from sqlalchemy.sql import select, and_
 import ConfigParser
 import os
 import uuid
@@ -17,12 +18,16 @@ OPTIONS_COLUMNS = ['name', 'option', 'value']
 UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
 
 
+class DatabaseError(Exception):
+    pass
+
+
 class SqlStore(Log):
     __instances = {}
 
     @classmethod
     def get_connection(cls, name):
-        if name not in cls.__instances.keys():
+        if name not in cls.__instances:
             if cherrypy.config.get('db.conn.log', False):
                 logging.debug('SqlStore new: %s', name)
             cls.__instances[name] = SqlStore(name)
@@ -67,13 +72,6 @@ class SqlStore(Log):
         return conn
 
 
-def SqlAutotable(f):
-    def at(self, *args, **kwargs):
-        self.create()
-        return f(self, *args, **kwargs)
-    return at
-
-
 class SqlQuery(Log):
 
     def __init__(self, db_obj, table, columns, trans=True):
@@ -121,20 +119,16 @@ class SqlQuery(Log):
     def drop(self):
         self._table.drop(checkfirst=True)
 
-    @SqlAutotable
     def select(self, kvfilter=None, columns=None):
         return self._con.execute(select(self._columns(columns),
                                         self._where(kvfilter)))
 
-    @SqlAutotable
     def insert(self, values):
         self._con.execute(self._table.insert(values))
 
-    @SqlAutotable
     def update(self, values, kvfilter):
         self._con.execute(self._table.update(self._where(kvfilter), values))
 
-    @SqlAutotable
     def delete(self, kvfilter):
         self._con.execute(self._table.delete(self._where(kvfilter)))
 
@@ -252,6 +246,8 @@ class FileQuery(Log):
 
 
 class Store(Log):
+    _is_upgrade = False
+
     def __init__(self, config_name=None, database_url=None):
         if config_name is None and database_url is None:
             raise ValueError('config_name or database_url must be provided')
@@ -268,33 +264,82 @@ class Store(Log):
         else:
             self._db = SqlStore.get_connection(name)
             self._query = SqlQuery
-        self._upgrade_database()
 
-    def _upgrade_database(self):
+        if not self._is_upgrade:
+            self._check_database()
+
+    def _code_schema_version(self):
+        # This function makes it possible for separate plugins to have
+        #  different schema versions. We default to the global schema
+        #  version.
+        return CURRENT_SCHEMA_VERSION
+
+    def _get_schema_version(self):
+        # We are storing multiple versions: one per class
+        # That way, we can support plugins with differing schema versions from
+        #  the main codebase, and even in the same database.
+        q = self._query(self._db, 'dbinfo', OPTIONS_COLUMNS, trans=False)
+        q.create()
+        cls_name = self.__class__.__name__
+        current_version = self.load_options('dbinfo').get('%s_schema'
+                                                          % cls_name, {})
+        if 'version' in current_version:
+            return int(current_version['version'])
+        else:
+            # Also try the old table name.
+            # "scheme" was a typo, but we need to retain that now for compat
+            fallback_version = self.load_options('dbinfo').get('scheme',
+                                                               {})
+            if 'version' in fallback_version:
+                return int(fallback_version['version'])
+            else:
+                return None
+
+    def _check_database(self):
         if self.is_readonly:
             # If the database is readonly, we cannot do anything to the
             #  schema. Let's just return, and assume people checked the
             #  upgrade notes
             return
-        current_version = self.load_options('dbinfo').get('scheme', None)
-        if current_version is None or 'version' not in current_version:
-            # No version stored, storing current version
-            self.save_options('dbinfo', 'scheme',
-                              {'version': CURRENT_SCHEMA_VERSION})
-            current_version = CURRENT_SCHEMA_VERSION
-        else:
-            current_version = int(current_version['version'])
-        if current_version != CURRENT_SCHEMA_VERSION:
-            self.debug('Upgrading database schema from %i to %i' % (
-                       current_version, CURRENT_SCHEMA_VERSION))
-            self._upgrade_database_from(current_version)
-
-    def _upgrade_database_from(self, old_schema_version):
-        # Insert code here to upgrade from old_schema_version to
-        #  CURRENT_SCHEMA_VERSION
-        raise Exception('Unable to upgrade database to current schema'
-                        ' version: version %i is unknown!' %
-                        old_schema_version)
+
+        current_version = self._get_schema_version()
+        if current_version is None:
+            self.error('Database initialization required! ' +
+                       'Please run ipsilon-upgrade-database')
+            raise DatabaseError('Database initialization required for %s' %
+                                self.__class__.__name__)
+        if current_version != self._code_schema_version():
+            self.error('Database upgrade required! ' +
+                       'Please run ipsilon-upgrade-database')
+            raise DatabaseError('Database upgrade required for %s' %
+                                self.__class__.__name__)
+
+    def _store_new_schema_version(self, new_version):
+        cls_name = self.__class__.__name__
+        self.save_options('dbinfo', '%s_schema' % cls_name,
+                          {'version': new_version})
+
+    def _initialize_schema(self):
+        raise NotImplementedError()
+
+    def _upgrade_schema(self, old_version):
+        # Datastores need to figure out what to do with bigger old_versions
+        #  themselves.
+        # They might implement downgrading if that's feasible, or just throw
+        #  NotImplementedError
+        raise NotImplementedError()
+
+    def upgrade_database(self):
+        # Do whatever is needed to get schema to current version
+        old_schema_version = self._get_schema_version()
+        if old_schema_version is None:
+            # Just initialize a new schema
+            self._initialize_schema()
+            self._store_new_schema_version(self._code_schema_version())
+        elif old_schema_version != self._code_schema_version():
+            # Upgrade from old_schema_version to code_schema_version
+            self._upgrade_schema(old_schema_version)
+            self._store_new_schema_version(self._code_schema_version())
 
     @property
     def is_readonly(self):
@@ -486,6 +531,17 @@ class AdminStore(Store):
         table = plugin+"_data"
         self._reset_data(table)
 
+    def _initialize_schema(self):
+        for table in ['config',
+                      'info_config',
+                      'login_config',
+                      'provider_config']:
+            q = self._query(self._db, table, OPTIONS_COLUMNS, trans=False)
+            q.create()
+
+    def _upgrade_schema(self, old_version):
+        raise NotImplementedError()
+
 
 class UserStore(Store):
 
@@ -504,18 +560,36 @@ class UserStore(Store):
     def load_plugin_data(self, plugin, user):
         return self.load_options(plugin+"_data", user)
 
+    def _initialize_schema(self):
+        q = self._query(self._db, 'users', OPTIONS_COLUMNS, trans=False)
+        q.create()
+
+    def _upgrade_schema(self, old_version):
+        raise NotImplementedError()
+
 
 class TranStore(Store):
 
     def __init__(self, path=None):
         super(TranStore, self).__init__('transactions.db')
 
+    def _initialize_schema(self):
+        q = self._query(self._db, 'transactions', UNIQUE_DATA_COLUMNS,
+                        trans=False)
+        q.create()
+
+    def _upgrade_schema(self, old_version):
+        raise NotImplementedError()
+
 
 class SAML2SessionStore(Store):
 
-    def __init__(self, path=None):
-        super(SAML2SessionStore, self).__init__('saml2.sessions.db')
+    def __init__(self, database_url):
+        super(SAML2SessionStore, self).__init__(database_url=database_url)
         self.table = 'sessions'
+        # pylint: disable=protected-access
+        table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+        table.create(checkfirst=True)
 
     def _get_unique_id_from_column(self, name, value):
         """
@@ -533,10 +607,24 @@ class SAML2SessionStore(Store):
             raise ValueError("Multiple entries returned")
         return data.keys()[0]
 
+    def remove_expired_sessions(self):
+        # pylint: disable=protected-access
+        table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+        sel = select([table.columns.uuid]). \
+            where(and_(table.c.name == 'expiration_time',
+                       table.c.value <= datetime.datetime.now()))
+        # pylint: disable=no-value-for-parameter
+        d = table.delete().where(table.c.uuid.in_(sel))
+        d.execute()
+
     def get_data(self, idval=None, name=None, value=None):
         return self.get_unique_data(self.table, idval, name, value)
 
     def new_session(self, datum):
+        if 'supported_logout_mechs' in datum:
+            datum['supported_logout_mechs'] = ','.join(
+                datum['supported_logout_mechs']
+            )
         return self.new_unique_data(self.table, datum)
 
     def get_session(self, session_id=None, request_id=None):
@@ -553,7 +641,7 @@ class SAML2SessionStore(Store):
 
     def get_user_sessions(self, user):
         """
-        Retrun a list of all sessions for a given user.
+        Return a list of all sessions for a given user.
         """
         rows = self.get_unique_data(self.table, name='user', value=user)
 
@@ -561,6 +649,8 @@ class SAML2SessionStore(Store):
         logged_in = []
         for r in rows:
             data = self.get_unique_data(self.table, uuidval=r)
+            data[r]['supported_logout_mechs'] = data[r].get(
+                'supported_logout_mechs', '').split(',')
             logged_in.append(data)
 
         return logged_in
@@ -573,3 +663,11 @@ class SAML2SessionStore(Store):
 
     def wipe_data(self):
         self._reset_data(self.table)
+
+    def _initialize_schema(self):
+        q = self._query(self._db, self.table, UNIQUE_DATA_COLUMNS,
+                        trans=False)
+        q.create()
+
+    def _upgrade_schema(self, old_version):
+        raise NotImplementedError()