Also create plugin UserStore data tables
[cascardo/ipsilon.git] / ipsilon / util / data.py
index c0fe4ab..65bf4b5 100644 (file)
@@ -6,23 +6,41 @@ from ipsilon.util.log import Log
 from sqlalchemy import create_engine
 from sqlalchemy import MetaData, Table, Column, Text
 from sqlalchemy.pool import QueuePool, SingletonThreadPool
+from sqlalchemy.schema import (PrimaryKeyConstraint, Index, AddConstraint,
+                               CreateIndex)
 from sqlalchemy.sql import select, and_
 import ConfigParser
 import os
 import uuid
 import logging
+import time
 
 
-CURRENT_SCHEMA_VERSION = 1
-OPTIONS_COLUMNS = ['name', 'option', 'value']
-UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
+CURRENT_SCHEMA_VERSION = 2
+OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
+                 'primary_key': ('name', 'option'),
+                 'indexes': [('name',)]
+                 }
+UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
+                     'primary_key': ('uuid', 'name'),
+                     'indexes': [('uuid',)]
+                     }
 
 
 class DatabaseError(Exception):
     pass
 
 
-class SqlStore(Log):
+class BaseStore(Log):
+    # Some helper functions used for upgrades
+    def add_constraint(self, table):
+        raise NotImplementedError()
+
+    def add_index(self, index):
+        raise NotImplementedError()
+
+
+class SqlStore(BaseStore):
     __instances = {}
 
     @classmethod
@@ -51,9 +69,24 @@ class SqlStore(Log):
             # It's not possible to share connections for SQLite between
             #  threads, so let's use the SingletonThreadPool for them
             pool_args = {'poolclass': SingletonThreadPool}
-        self._dbengine = create_engine(engine_name, **pool_args)
+        self._dbengine = create_engine(engine_name,
+                                       echo=cherrypy.config.get('db.echo',
+                                                                False),
+                                       **pool_args)
         self.is_readonly = False
 
+    def add_constraint(self, constraint):
+        if self._dbengine.dialect.name != 'sqlite':
+            # It is impossible to add constraints to a pre-existing table for
+            #  SQLite
+            # source: http://www.sqlite.org/omitted.html
+            create_constraint = AddConstraint(constraint, bind=self._dbengine)
+            create_constraint.execute()
+
+    def add_index(self, index):
+        add_index = CreateIndex(index, bind=self._dbengine)
+        add_index.execute()
+
     def debug(self, fact):
         if self.db_conn_log:
             super(SqlStore, self).debug(fact)
@@ -74,16 +107,27 @@ class SqlStore(Log):
 
 class SqlQuery(Log):
 
-    def __init__(self, db_obj, table, columns, trans=True):
+    def __init__(self, db_obj, table, table_def, trans=True):
         self._db = db_obj
         self._con = self._db.connection()
         self._trans = self._con.begin() if trans else None
-        self._table = self._get_table(table, columns)
-
-    def _get_table(self, name, columns):
-        table = Table(name, MetaData(self._db.engine()))
-        for c in columns:
-            table.append_column(Column(c, Text()))
+        self._table = self._get_table(table, table_def)
+
+    def _get_table(self, name, table_def):
+        if isinstance(table_def, list):
+            table_def = {'columns': table_def,
+                         'indexes': [],
+                         'primary_key': None}
+        table_creation = []
+        for col_name in table_def['columns']:
+            table_creation.append(Column(col_name, Text()))
+        if table_def['primary_key']:
+            table_creation.append(PrimaryKeyConstraint(
+                *table_def['primary_key']))
+        for index in table_def['indexes']:
+            idx_name = 'idx_%s_%s' % (name, '_'.join(index))
+            table_creation.append(Index(idx_name, *index))
+        table = Table(name, MetaData(self._db.engine()), *table_creation)
         return table
 
     def _where(self, kvfilter):
@@ -133,7 +177,7 @@ class SqlQuery(Log):
         self._con.execute(self._table.delete(self._where(kvfilter)))
 
 
-class FileStore(Log):
+class FileStore(BaseStore):
 
     def __init__(self, name):
         self._filename = name
@@ -156,10 +200,21 @@ class FileStore(Log):
             self._config.read(self._filename)
         return self._config
 
+    def add_constraint(self, table):
+        raise NotImplementedError()
+
+    def add_index(self, index):
+        raise NotImplementedError()
+
 
 class FileQuery(Log):
 
-    def __init__(self, fstore, table, columns, trans=True):
+    def __init__(self, fstore, table, table_def, trans=True):
+        # We don't need indexes in a FileQuery, so drop that info
+        if isinstance(table_def, dict):
+            columns = table_def['columns']
+        else:
+            columns = table_def
         self._fstore = fstore
         self._config = fstore.get_config()
         self._section = table
@@ -246,7 +301,13 @@ class FileQuery(Log):
 
 
 class Store(Log):
+    # Static, Store-level variables
     _is_upgrade = False
+    __cleanups = {}
+
+    # Static, class-level variables
+    # Either set this to False, or implement _cleanup, in child classes
+    _should_cleanup = True
 
     def __init__(self, config_name=None, database_url=None):
         if config_name is None and database_url is None:
@@ -267,6 +328,60 @@ class Store(Log):
 
         if not self._is_upgrade:
             self._check_database()
+            if self._should_cleanup:
+                self._schedule_cleanup()
+
+    def _schedule_cleanup(self):
+        store_name = self.__class__.__name__
+        if self.is_readonly:
+            # No use in cleanups on a readonly database
+            self.debug('Not scheduling cleanup for %s due to readonly' %
+                       store_name)
+            return
+        if store_name in Store.__cleanups:
+            # This class was already scheduled, skip
+            return
+        self.debug('Scheduling cleanups for %s' % store_name)
+        # Check once every minute whether we need to clean
+        task = cherrypy.process.plugins.BackgroundTask(
+            60, self._maybe_run_cleanup)
+        task.start()
+        Store.__cleanups[store_name] = task
+
+    def _maybe_run_cleanup(self):
+        # Let's see if we need to do cleanup
+        last_clean = self.load_options('dbinfo').get('%s_last_clean' %
+                                                     self.__class__.__name__,
+                                                     {})
+        time_diff = cherrypy.config.get('cleanup_interval', 30) * 60
+        next_ts = int(time.time()) - time_diff
+        self.debug('Considering cleanup for %s: %s. Next at: %s'
+                   % (self.__class__.__name__, last_clean, next_ts))
+        if ('timestamp' not in last_clean or
+                int(last_clean['timestamp']) <= next_ts):
+            # First store the current time so that other servers don't start
+            self.save_options('dbinfo', '%s_last_clean'
+                              % self.__class__.__name__,
+                              {'timestamp': int(time.time()),
+                               'removed_entries': -1})
+
+            # Cleanup has been long enough ago, let's run
+            self.debug('Cleaning up for %s' % self.__class__.__name__)
+            removed_entries = self._cleanup()
+            self.debug('Cleaned up %i entries for %s' %
+                       (removed_entries, self.__class__.__name__))
+            self.save_options('dbinfo', '%s_last_clean'
+                              % self.__class__.__name__,
+                              {'timestamp': int(time.time()),
+                               'removed_entries': removed_entries})
+
+    def _cleanup(self):
+        # The default cleanup is to do nothing
+        # This function should return the number of rows it cleaned up.
+        # This information may be used to automatically tune the clean period.
+        self.error('Cleanup for %s not implemented' %
+                   self.__class__.__name__)
+        return 0
 
     def _code_schema_version(self):
         # This function makes it possible for separate plugins to have
@@ -278,8 +393,9 @@ class Store(Log):
         # We are storing multiple versions: one per class
         # That way, we can support plugins with differing schema versions from
         #  the main codebase, and even in the same database.
-        q = self._query(self._db, 'dbinfo', OPTIONS_COLUMNS, trans=False)
+        q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
         q.create()
+        q._con.close()  # pylint: disable=protected-access
         cls_name = self.__class__.__name__
         current_version = self.load_options('dbinfo').get('%s_schema'
                                                           % cls_name, {})
@@ -291,7 +407,8 @@ class Store(Log):
             fallback_version = self.load_options('dbinfo').get('scheme',
                                                                {})
             if 'version' in fallback_version:
-                return int(fallback_version['version'])
+                # Explanation for this is in def upgrade_database(self)
+                return -1
             else:
                 return None
 
@@ -327,6 +444,7 @@ class Store(Log):
         #  themselves.
         # They might implement downgrading if that's feasible, or just throw
         #  NotImplementedError
+        # Should return the new schema version
         raise NotImplementedError()
 
     def upgrade_database(self):
@@ -336,10 +454,27 @@ class Store(Log):
             # Just initialize a new schema
             self._initialize_schema()
             self._store_new_schema_version(self._code_schema_version())
+        elif old_schema_version == -1:
+            # This is a special-case from 1.0: we only created tables at the
+            # first time they were actually used, but the upgrade code assumes
+            # that the tables exist. So let's fix this.
+            self._initialize_schema()
+            # The old version was schema version 1
+            self._store_new_schema_version(1)
+            self.upgrade_database()
         elif old_schema_version != self._code_schema_version():
             # Upgrade from old_schema_version to code_schema_version
-            self._upgrade_schema(old_schema_version)
-            self._store_new_schema_version(self._code_schema_version())
+            self.debug('Upgrading from schema version %i' % old_schema_version)
+            new_version = self._upgrade_schema(old_schema_version)
+            if not new_version:
+                error = ('Schema upgrade error: %s did not provide a ' +
+                         'new schema version number!' %
+                         self.__class__.__name__)
+                self.error(error)
+                raise Exception(error)
+            self._store_new_schema_version(new_version)
+            # Check if we are now up-to-date
+            self.upgrade_database()
 
     @property
     def is_readonly(self):
@@ -375,7 +510,8 @@ class Store(Log):
             q = self._query(self._db, table, columns, trans=False)
             rows = q.select(kvfilter)
         except Exception, e:  # pylint: disable=broad-except
-            self.error("Failed to load data for table %s: [%s]" % (table, e))
+            self.error("Failed to load data for table %s for store %s: [%s]"
+                       % (table, self.__class__.__name__, e))
         return self._rows_to_dict_tree(rows)
 
     def load_config(self):
@@ -387,7 +523,7 @@ class Store(Log):
         kvfilter = dict()
         if name:
             kvfilter['name'] = name
-        options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
+        options = self._load_data(table, OPTIONS_TABLE, kvfilter)
         if name and name in options:
             return options[name]
         return options
@@ -396,7 +532,7 @@ class Store(Log):
         curvals = dict()
         q = None
         try:
-            q = self._query(self._db, table, OPTIONS_COLUMNS)
+            q = self._query(self._db, table, OPTIONS_TABLE)
             rows = q.select({'name': name}, ['option', 'value'])
             for row in rows:
                 curvals[row[0]] = row[1]
@@ -419,7 +555,7 @@ class Store(Log):
         kvfilter = {'name': name}
         q = None
         try:
-            q = self._query(self._db, table, OPTIONS_COLUMNS)
+            q = self._query(self._db, table, OPTIONS_TABLE)
             if options is None:
                 q.delete(kvfilter)
             else:
@@ -437,7 +573,7 @@ class Store(Log):
         newid = str(uuid.uuid4())
         q = None
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE)
             for name in data:
                 q.insert((newid, name, data[name]))
             q.commit()
@@ -456,12 +592,12 @@ class Store(Log):
             kvfilter['name'] = name
         if value:
             kvfilter['value'] = value
-        return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
+        return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
 
     def save_unique_data(self, table, data):
         q = None
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE)
             for uid in data:
                 curvals = dict()
                 rows = q.select({'uuid': uid}, ['name', 'value'])
@@ -490,7 +626,7 @@ class Store(Log):
     def del_unique_data(self, table, uuidval):
         kvfilter = {'uuid': uuidval}
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
             q.delete(kvfilter)
         except Exception, e:  # pylint: disable=broad-except
             self.error("Failed to delete data from %s: [%s]" % (table, e))
@@ -498,7 +634,7 @@ class Store(Log):
     def _reset_data(self, table):
         q = None
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE)
             q.drop()
             q.create()
             q.commit()
@@ -509,6 +645,7 @@ class Store(Log):
 
 
 class AdminStore(Store):
+    _should_cleanup = False
 
     def __init__(self):
         super(AdminStore, self).__init__('admin.config.db')
@@ -536,14 +673,38 @@ class AdminStore(Store):
                       'info_config',
                       'login_config',
                       'provider_config']:
-            q = self._query(self._db, table, OPTIONS_COLUMNS, trans=False)
+            q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
             q.create()
+            q._con.close()  # pylint: disable=protected-access
 
     def _upgrade_schema(self, old_version):
-        raise NotImplementedError()
+        if old_version == 1:
+            # In schema version 2, we added indexes and primary keys
+            for table in ['config',
+                          'info_config',
+                          'login_config',
+                          'provider_config']:
+                # pylint: disable=protected-access
+                table = self._query(self._db, table, OPTIONS_TABLE,
+                                    trans=False)._table
+                self._db.add_constraint(table.primary_key)
+                for index in table.indexes:
+                    self._db.add_index(index)
+            return 2
+        else:
+            raise NotImplementedError()
+
+    def create_plugin_data_table(self, plugin_name):
+        if not self.is_readonly:
+            table = plugin_name+'_data'
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE,
+                            trans=False)
+            q.create()
+            q._con.close()  # pylint: disable=protected-access
 
 
 class UserStore(Store):
+    _should_cleanup = False
 
     def __init__(self, path=None):
         super(UserStore, self).__init__('user.prefs.db')
@@ -561,25 +722,67 @@ class UserStore(Store):
         return self.load_options(plugin+"_data", user)
 
     def _initialize_schema(self):
-        q = self._query(self._db, 'users', OPTIONS_COLUMNS, trans=False)
+        q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
         q.create()
+        q._con.close()  # pylint: disable=protected-access
 
     def _upgrade_schema(self, old_version):
-        raise NotImplementedError()
+        if old_version == 1:
+            # In schema version 2, we added indexes and primary keys
+            # pylint: disable=protected-access
+            table = self._query(self._db, 'users', OPTIONS_TABLE,
+                                trans=False)._table
+            self._db.add_constraint(table.primary_key)
+            for index in table.indexes:
+                self._db.add_index(index)
+            return 2
+        else:
+            raise NotImplementedError()
+
+    def create_plugin_data_table(self, plugin_name):
+        if not self.is_readonly:
+            table = plugin_name+'_data'
+            q = self._query(self._db, table, OPTIONS_TABLE,
+                            trans=False)
+            q.create()
+            q._con.close()  # pylint: disable=protected-access
 
 
 class TranStore(Store):
 
     def __init__(self, path=None):
         super(TranStore, self).__init__('transactions.db')
+        self.table = 'transactions'
 
     def _initialize_schema(self):
-        q = self._query(self._db, 'transactions', UNIQUE_DATA_COLUMNS,
+        q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
                         trans=False)
         q.create()
+        q._con.close()  # pylint: disable=protected-access
 
     def _upgrade_schema(self, old_version):
-        raise NotImplementedError()
+        if old_version == 1:
+            # In schema version 2, we added indexes and primary keys
+            # pylint: disable=protected-access
+            table = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
+                                trans=False)._table
+            self._db.add_constraint(table.primary_key)
+            for index in table.indexes:
+                self._db.add_index(index)
+            return 2
+        else:
+            raise NotImplementedError()
+
+    def _cleanup(self):
+        # pylint: disable=protected-access
+        table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
+        in_one_hour = datetime.datetime.now() - datetime.timedelta(hours=1)
+        sel = select([table.columns.uuid]). \
+            where(and_(table.c.name == 'origintime',
+                       table.c.value <= in_one_hour))
+        # pylint: disable=no-value-for-parameter
+        d = table.delete().where(table.c.uuid.in_(sel))
+        return d.execute().rowcount
 
 
 class SAML2SessionStore(Store):
@@ -588,7 +791,7 @@ class SAML2SessionStore(Store):
         super(SAML2SessionStore, self).__init__(database_url=database_url)
         self.table = 'saml2_sessions'
         # pylint: disable=protected-access
-        table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+        table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
         table.create(checkfirst=True)
 
     def _get_unique_id_from_column(self, name, value):
@@ -607,15 +810,15 @@ class SAML2SessionStore(Store):
             raise ValueError("Multiple entries returned")
         return data.keys()[0]
 
-    def remove_expired_sessions(self):
+    def _cleanup(self):
         # pylint: disable=protected-access
-        table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+        table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
         sel = select([table.columns.uuid]). \
             where(and_(table.c.name == 'expiration_time',
                        table.c.value <= datetime.datetime.now()))
         # pylint: disable=no-value-for-parameter
         d = table.delete().where(table.c.uuid.in_(sel))
-        d.execute()
+        return d.execute().rowcount
 
     def get_data(self, idval=None, name=None, value=None):
         return self.get_unique_data(self.table, idval, name, value)
@@ -665,9 +868,20 @@ class SAML2SessionStore(Store):
         self._reset_data(self.table)
 
     def _initialize_schema(self):
-        q = self._query(self._db, self.table, UNIQUE_DATA_COLUMNS,
+        q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
                         trans=False)
         q.create()
+        q._con.close()  # pylint: disable=protected-access
 
     def _upgrade_schema(self, old_version):
-        raise NotImplementedError()
+        if old_version == 1:
+            # In schema version 2, we added indexes and primary keys
+            # pylint: disable=protected-access
+            table = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
+                                trans=False)._table
+            self._db.add_constraint(table.primary_key)
+            for index in table.indexes:
+                self._db.add_index(index)
+            return 2
+        else:
+            raise NotImplementedError()