# Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
import cherrypy
+import datetime
from ipsilon.util.log import Log
from sqlalchemy import create_engine
from sqlalchemy import MetaData, Table, Column, Text
from sqlalchemy.pool import QueuePool, SingletonThreadPool
-from sqlalchemy.sql import select
+from sqlalchemy.sql import select, and_
import ConfigParser
import os
import uuid
UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
+class DatabaseError(Exception):
+ pass
+
+
class SqlStore(Log):
__instances = {}
@classmethod
def get_connection(cls, name):
- if name not in cls.__instances.keys():
+ if name not in cls.__instances:
if cherrypy.config.get('db.conn.log', False):
logging.debug('SqlStore new: %s', name)
cls.__instances[name] = SqlStore(name)
return conn
-def SqlAutotable(f):
- def at(self, *args, **kwargs):
- self.create()
- return f(self, *args, **kwargs)
- return at
-
-
class SqlQuery(Log):
def __init__(self, db_obj, table, columns, trans=True):
def drop(self):
self._table.drop(checkfirst=True)
- @SqlAutotable
def select(self, kvfilter=None, columns=None):
return self._con.execute(select(self._columns(columns),
self._where(kvfilter)))
- @SqlAutotable
def insert(self, values):
self._con.execute(self._table.insert(values))
- @SqlAutotable
def update(self, values, kvfilter):
self._con.execute(self._table.update(self._where(kvfilter), values))
- @SqlAutotable
def delete(self, kvfilter):
self._con.execute(self._table.delete(self._where(kvfilter)))
class Store(Log):
+ _is_upgrade = False
+
def __init__(self, config_name=None, database_url=None):
if config_name is None and database_url is None:
raise ValueError('config_name or database_url must be provided')
else:
self._db = SqlStore.get_connection(name)
self._query = SqlQuery
- self._upgrade_database()
- def _upgrade_database(self):
+ if not self._is_upgrade:
+ self._check_database()
+
+ def _code_schema_version(self):
+ # This function makes it possible for separate plugins to have
+ # different schema versions. We default to the global schema
+ # version.
+ return CURRENT_SCHEMA_VERSION
+
+ def _get_schema_version(self):
+ # We are storing multiple versions: one per class
+ # That way, we can support plugins with differing schema versions from
+ # the main codebase, and even in the same database.
+ q = self._query(self._db, 'dbinfo', OPTIONS_COLUMNS, trans=False)
+ q.create()
+ cls_name = self.__class__.__name__
+ current_version = self.load_options('dbinfo').get('%s_schema'
+ % cls_name, {})
+ if 'version' in current_version:
+ return int(current_version['version'])
+ else:
+ # Also try the old table name.
+ # "scheme" was a typo, but we need to retain that now for compat
+ fallback_version = self.load_options('dbinfo').get('scheme',
+ {})
+ if 'version' in fallback_version:
+ return int(fallback_version['version'])
+ else:
+ return None
+
+ def _check_database(self):
if self.is_readonly:
# If the database is readonly, we cannot do anything to the
# schema. Let's just return, and assume people checked the
# upgrade notes
return
- current_version = self.load_options('dbinfo').get('scheme', None)
- if current_version is None or 'version' not in current_version:
- # No version stored, storing current version
- self.save_options('dbinfo', 'scheme',
- {'version': CURRENT_SCHEMA_VERSION})
- current_version = CURRENT_SCHEMA_VERSION
- else:
- current_version = int(current_version['version'])
- if current_version != CURRENT_SCHEMA_VERSION:
- self.debug('Upgrading database schema from %i to %i' % (
- current_version, CURRENT_SCHEMA_VERSION))
- self._upgrade_database_from(current_version)
-
- def _upgrade_database_from(self, old_schema_version):
- # Insert code here to upgrade from old_schema_version to
- # CURRENT_SCHEMA_VERSION
- raise Exception('Unable to upgrade database to current schema'
- ' version: version %i is unknown!' %
- old_schema_version)
+
+ current_version = self._get_schema_version()
+ if current_version is None:
+ self.error('Database initialization required! ' +
+ 'Please run ipsilon-upgrade-database')
+ raise DatabaseError('Database initialization required for %s' %
+ self.__class__.__name__)
+ if current_version != self._code_schema_version():
+ self.error('Database upgrade required! ' +
+ 'Please run ipsilon-upgrade-database')
+ raise DatabaseError('Database upgrade required for %s' %
+ self.__class__.__name__)
+
+ def _store_new_schema_version(self, new_version):
+ cls_name = self.__class__.__name__
+ self.save_options('dbinfo', '%s_schema' % cls_name,
+ {'version': new_version})
+
+ def _initialize_schema(self):
+ raise NotImplementedError()
+
+ def _upgrade_schema(self, old_version):
+ # Datastores need to figure out what to do with bigger old_versions
+ # themselves.
+ # They might implement downgrading if that's feasible, or just throw
+ # NotImplementedError
+ raise NotImplementedError()
+
+ def upgrade_database(self):
+ # Do whatever is needed to get schema to current version
+ old_schema_version = self._get_schema_version()
+ if old_schema_version is None:
+ # Just initialize a new schema
+ self._initialize_schema()
+ self._store_new_schema_version(self._code_schema_version())
+ elif old_schema_version != self._code_schema_version():
+ # Upgrade from old_schema_version to code_schema_version
+ self._upgrade_schema(old_schema_version)
+ self._store_new_schema_version(self._code_schema_version())
@property
def is_readonly(self):
table = plugin+"_data"
self._reset_data(table)
+ def _initialize_schema(self):
+ for table in ['config',
+ 'info_config',
+ 'login_config',
+ 'provider_config']:
+ q = self._query(self._db, table, OPTIONS_COLUMNS, trans=False)
+ q.create()
+
+ def _upgrade_schema(self, old_version):
+ raise NotImplementedError()
+
class UserStore(Store):
def load_plugin_data(self, plugin, user):
return self.load_options(plugin+"_data", user)
+ def _initialize_schema(self):
+ q = self._query(self._db, 'users', OPTIONS_COLUMNS, trans=False)
+ q.create()
+
+ def _upgrade_schema(self, old_version):
+ raise NotImplementedError()
+
class TranStore(Store):
def __init__(self, path=None):
super(TranStore, self).__init__('transactions.db')
+ def _initialize_schema(self):
+ q = self._query(self._db, 'transactions', UNIQUE_DATA_COLUMNS,
+ trans=False)
+ q.create()
+
+ def _upgrade_schema(self, old_version):
+ raise NotImplementedError()
+
class SAML2SessionStore(Store):
- def __init__(self, path=None):
- super(SAML2SessionStore, self).__init__('saml2.sessions.db')
+ def __init__(self, database_url):
+ super(SAML2SessionStore, self).__init__(database_url=database_url)
self.table = 'sessions'
+ # pylint: disable=protected-access
+ table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+ table.create(checkfirst=True)
def _get_unique_id_from_column(self, name, value):
"""
raise ValueError("Multiple entries returned")
return data.keys()[0]
+ def remove_expired_sessions(self):
+ # pylint: disable=protected-access
+ table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+ sel = select([table.columns.uuid]). \
+ where(and_(table.c.name == 'expiration_time',
+ table.c.value <= datetime.datetime.now()))
+ # pylint: disable=no-value-for-parameter
+ d = table.delete().where(table.c.uuid.in_(sel))
+ d.execute()
+
def get_data(self, idval=None, name=None, value=None):
return self.get_unique_data(self.table, idval, name, value)
def new_session(self, datum):
+ if 'supported_logout_mechs' in datum:
+ datum['supported_logout_mechs'] = ','.join(
+ datum['supported_logout_mechs']
+ )
return self.new_unique_data(self.table, datum)
def get_session(self, session_id=None, request_id=None):
def get_user_sessions(self, user):
"""
- Retrun a list of all sessions for a given user.
+ Return a list of all sessions for a given user.
"""
rows = self.get_unique_data(self.table, name='user', value=user)
logged_in = []
for r in rows:
data = self.get_unique_data(self.table, uuidval=r)
+ data[r]['supported_logout_mechs'] = data[r].get(
+ 'supported_logout_mechs', '').split(',')
logged_in.append(data)
return logged_in
def wipe_data(self):
self._reset_data(self.table)
+
+ def _initialize_schema(self):
+ q = self._query(self._db, self.table, UNIQUE_DATA_COLUMNS,
+ trans=False)
+ q.create()
+
+ def _upgrade_schema(self, old_version):
+ raise NotImplementedError()