from sqlalchemy import create_engine
from sqlalchemy import MetaData, Table, Column, Text
from sqlalchemy.pool import QueuePool, SingletonThreadPool
+from sqlalchemy.schema import PrimaryKeyConstraint, Index
from sqlalchemy.sql import select, and_
import ConfigParser
import os
import logging
-CURRENT_SCHEMA_VERSION = 1
-OPTIONS_COLUMNS = ['name', 'option', 'value']
-UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
+CURRENT_SCHEMA_VERSION = 2
+OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
+ 'primary_key': ('name', 'option'),
+ 'indexes': [('name',)]
+ }
+UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
+ 'primary_key': ('uuid', 'name'),
+ 'indexes': [('uuid',)]
+ }
class DatabaseError(Exception):
class SqlQuery(Log):
- def __init__(self, db_obj, table, columns, trans=True):
+ def __init__(self, db_obj, table, table_def, trans=True):
self._db = db_obj
self._con = self._db.connection()
self._trans = self._con.begin() if trans else None
- self._table = self._get_table(table, columns)
-
- def _get_table(self, name, columns):
- table = Table(name, MetaData(self._db.engine()))
- for c in columns:
- table.append_column(Column(c, Text()))
+ self._table = self._get_table(table, table_def)
+
+ def _get_table(self, name, table_def):
+ if isinstance(table_def, list):
+ table_def = {'columns': table_def,
+ 'indexes': [],
+ 'primary_key': None}
+ table_creation = []
+ for col_name in table_def['columns']:
+ table_creation.append(Column(col_name, Text()))
+ if table_def['primary_key']:
+ table_creation.append(PrimaryKeyConstraint(
+ *table_def['primary_key']))
+ for index in table_def['indexes']:
+ idx_name = 'idx_%s_%s' % (name, '_'.join(index))
+ table_creation.append(Index(idx_name, *index))
+ table = Table(name, MetaData(self._db.engine()), *table_creation)
return table
def _where(self, kvfilter):
class FileQuery(Log):
- def __init__(self, fstore, table, columns, trans=True):
+ def __init__(self, fstore, table, table_def, trans=True):
+ # We don't need indexes in a FileQuery, so drop that info
+ if isinstance(table_def, dict):
+ columns = table_def['columns']
+ else:
+ columns = table_def
self._fstore = fstore
self._config = fstore.get_config()
self._section = table
# We are storing multiple versions: one per class
# That way, we can support plugins with differing schema versions from
# the main codebase, and even in the same database.
- q = self._query(self._db, 'dbinfo', OPTIONS_COLUMNS, trans=False)
+ q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
q.create()
cls_name = self.__class__.__name__
current_version = self.load_options('dbinfo').get('%s_schema'
kvfilter = dict()
if name:
kvfilter['name'] = name
- options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
+ options = self._load_data(table, OPTIONS_TABLE, kvfilter)
if name and name in options:
return options[name]
return options
curvals = dict()
q = None
try:
- q = self._query(self._db, table, OPTIONS_COLUMNS)
+ q = self._query(self._db, table, OPTIONS_TABLE)
rows = q.select({'name': name}, ['option', 'value'])
for row in rows:
curvals[row[0]] = row[1]
kvfilter = {'name': name}
q = None
try:
- q = self._query(self._db, table, OPTIONS_COLUMNS)
+ q = self._query(self._db, table, OPTIONS_TABLE)
if options is None:
q.delete(kvfilter)
else:
newid = str(uuid.uuid4())
q = None
try:
- q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+ q = self._query(self._db, table, UNIQUE_DATA_TABLE)
for name in data:
q.insert((newid, name, data[name]))
q.commit()
kvfilter['name'] = name
if value:
kvfilter['value'] = value
- return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
+ return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
def save_unique_data(self, table, data):
q = None
try:
- q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+ q = self._query(self._db, table, UNIQUE_DATA_TABLE)
for uid in data:
curvals = dict()
rows = q.select({'uuid': uid}, ['name', 'value'])
def del_unique_data(self, table, uuidval):
kvfilter = {'uuid': uuidval}
try:
- q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
+ q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
q.delete(kvfilter)
except Exception, e: # pylint: disable=broad-except
self.error("Failed to delete data from %s: [%s]" % (table, e))
def _reset_data(self, table):
q = None
try:
- q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+ q = self._query(self._db, table, UNIQUE_DATA_TABLE)
q.drop()
q.create()
q.commit()
'info_config',
'login_config',
'provider_config']:
- q = self._query(self._db, table, OPTIONS_COLUMNS, trans=False)
+ q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
q.create()
def _upgrade_schema(self, old_version):
return self.load_options(plugin+"_data", user)
def _initialize_schema(self):
- q = self._query(self._db, 'users', OPTIONS_COLUMNS, trans=False)
+ q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
q.create()
def _upgrade_schema(self, old_version):
super(TranStore, self).__init__('transactions.db')
def _initialize_schema(self):
- q = self._query(self._db, 'transactions', UNIQUE_DATA_COLUMNS,
+ q = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
trans=False)
q.create()
super(SAML2SessionStore, self).__init__(database_url=database_url)
self.table = 'saml2_sessions'
# pylint: disable=protected-access
- table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+ table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
table.create(checkfirst=True)
def _get_unique_id_from_column(self, name, value):
def remove_expired_sessions(self):
# pylint: disable=protected-access
- table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+ table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
sel = select([table.columns.uuid]). \
where(and_(table.c.name == 'expiration_time',
table.c.value <= datetime.datetime.now()))
self._reset_data(self.table)
def _initialize_schema(self):
- q = self._query(self._db, self.table, UNIQUE_DATA_COLUMNS,
+ q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
trans=False)
q.create()
import pickle
-SESSION_COLUMNS = ['id', 'data', 'expiration_time']
+SESSION_TABLE = {'columns': ['id', 'data', 'expiration_time'],
+ 'primary_key': ('id', ),
+ 'indexes': [('expiration_time',)]
+ }
class SessionStore(Store):
def _initialize_schema(self):
- q = self._query(self._db, 'sessions', SESSION_COLUMNS,
+ q = self._query(self._db, 'sessions', SESSION_TABLE,
trans=False)
q.create()
cls._db = cls._store._db
def _exists(self):
- q = SqlQuery(self._db, 'sessions', SESSION_COLUMNS)
+ q = SqlQuery(self._db, 'sessions', SESSION_TABLE)
result = q.select({'id': self.id})
return True if result.fetchone() else False
def _load(self):
- q = SqlQuery(self._db, 'sessions', SESSION_COLUMNS)
+ q = SqlQuery(self._db, 'sessions', SESSION_TABLE)
result = q.select({'id': self.id})
r = result.fetchone()
if r:
def _save(self, expiration_time):
q = None
try:
- q = SqlQuery(self._db, 'sessions', SESSION_COLUMNS, trans=True)
+ q = SqlQuery(self._db, 'sessions', SESSION_TABLE, trans=True)
q.delete({'id': self.id})
data = pickle.dumps((self._data, expiration_time), self._proto)
q.insert((self.id, base64.b64encode(data), expiration_time))
raise
def _delete(self):
- q = SqlQuery(self._db, 'sessions', SESSION_COLUMNS)
+ q = SqlQuery(self._db, 'sessions', SESSION_TABLE)
q.delete({'id': self.id})
# copy what RamSession does for now