1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
5 from ipsilon.util.log import Log
6 from sqlalchemy import create_engine
7 from sqlalchemy import MetaData, Table, Column, Text
8 from sqlalchemy.pool import QueuePool, SingletonThreadPool
9 from sqlalchemy.schema import PrimaryKeyConstraint, Index
10 from sqlalchemy.sql import select, and_
17 CURRENT_SCHEMA_VERSION = 2
18 OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
19 'primary_key': ('name', 'option'),
20 'indexes': [('name',)]
22 UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
23 'primary_key': ('uuid', 'name'),
24 'indexes': [('uuid',)]
28 class DatabaseError(Exception):
36 def get_connection(cls, name):
37 if name not in cls.__instances:
38 if cherrypy.config.get('db.conn.log', False):
39 logging.debug('SqlStore new: %s', name)
40 cls.__instances[name] = SqlStore(name)
41 return cls.__instances[name]
43 def __init__(self, name):
44 self.db_conn_log = cherrypy.config.get('db.conn.log', False)
45 self.debug('SqlStore init: %s' % name)
48 if '://' not in engine_name:
49 engine_name = 'sqlite:///' + engine_name
50 # This pool size is per configured database. The minimum needed,
51 # determined by binary search, is 23. We're using 25 so we have a bit
52 # more playroom, and then the overflow should make sure things don't
53 # break when we suddenly need more.
54 pool_args = {'poolclass': QueuePool,
57 if engine_name.startswith('sqlite://'):
58 # It's not possible to share connections for SQLite between
59 # threads, so let's use the SingletonThreadPool for them
60 pool_args = {'poolclass': SingletonThreadPool}
61 self._dbengine = create_engine(engine_name, **pool_args)
62 self.is_readonly = False
64 def debug(self, fact):
66 super(SqlStore, self).debug(fact)
72 self.debug('SqlStore connect: %s' % self.name)
73 conn = self._dbengine.connect()
75 def cleanup_connection():
76 self.debug('SqlStore cleanup: %s' % self.name)
78 cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
84 def __init__(self, db_obj, table, table_def, trans=True):
86 self._con = self._db.connection()
87 self._trans = self._con.begin() if trans else None
88 self._table = self._get_table(table, table_def)
90 def _get_table(self, name, table_def):
91 if isinstance(table_def, list):
92 table_def = {'columns': table_def,
96 for col_name in table_def['columns']:
97 table_creation.append(Column(col_name, Text()))
98 if table_def['primary_key']:
99 table_creation.append(PrimaryKeyConstraint(
100 *table_def['primary_key']))
101 for index in table_def['indexes']:
102 idx_name = 'idx_%s_%s' % (name, '_'.join(index))
103 table_creation.append(Index(idx_name, *index))
104 table = Table(name, MetaData(self._db.engine()), *table_creation)
107 def _where(self, kvfilter):
109 if kvfilter is not None:
111 w = self._table.columns[k] == kvfilter[k]
118 def _columns(self, columns=None):
120 if columns is not None:
123 cols.append(self._table.columns[c])
125 cols = self._table.columns
129 self._trans.rollback()
135 self._table.create(checkfirst=True)
138 self._table.drop(checkfirst=True)
140 def select(self, kvfilter=None, columns=None):
141 return self._con.execute(select(self._columns(columns),
142 self._where(kvfilter)))
144 def insert(self, values):
145 self._con.execute(self._table.insert(values))
147 def update(self, values, kvfilter):
148 self._con.execute(self._table.update(self._where(kvfilter), values))
150 def delete(self, kvfilter):
151 self._con.execute(self._table.delete(self._where(kvfilter)))
154 class FileStore(Log):
156 def __init__(self, name):
157 self._filename = name
158 self.is_readonly = True
159 self._timestamp = None
162 def get_config(self):
164 stat = os.stat(self._filename)
166 self.error("Unable to check config file %s: [%s]" % (
170 timestamp = stat.st_mtime
171 if self._config is None or timestamp > self._timestamp:
172 self._config = ConfigParser.RawConfigParser()
173 self._config.optionxform = str
174 self._config.read(self._filename)
178 class FileQuery(Log):
180 def __init__(self, fstore, table, table_def, trans=True):
181 # We don't need indexes in a FileQuery, so drop that info
182 if isinstance(table_def, dict):
183 columns = table_def['columns']
186 self._fstore = fstore
187 self._config = fstore.get_config()
188 self._section = table
189 if len(columns) > 3 or columns[-1] != 'value':
190 raise ValueError('Unsupported configuration format')
191 self._columns = columns
200 raise NotImplementedError
203 raise NotImplementedError
205 def select(self, kvfilter=None, columns=None):
206 if self._section not in self._config.sections():
209 opts = self._config.options(self._section)
213 if self._columns[0] in kvfilter:
214 prefix = kvfilter[self._columns[0]]
215 prefix_ = prefix + ' '
218 if len(self._columns) == 3 and self._columns[1] in kvfilter:
219 name = kvfilter[self._columns[1]]
222 if self._columns[-1] in kvfilter:
223 value = kvfilter[self._columns[-1]]
227 if len(self._columns) == 3:
229 if prefix and not o.startswith(prefix_):
232 col1, col2 = o.split(' ', 1)
233 if name and col2 != name:
236 col3 = self._config.get(self._section, o)
237 if value and col3 != value:
240 r = [col1, col2, col3]
243 if prefix and o != prefix:
245 r = [o, self._config.get(self._section, o)]
250 s.append(r[self._columns.index(c)])
255 self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
261 def insert(self, values):
262 raise NotImplementedError
264 def update(self, values, kvfilter):
265 raise NotImplementedError
267 def delete(self, kvfilter):
268 raise NotImplementedError
274 def __init__(self, config_name=None, database_url=None):
275 if config_name is None and database_url is None:
276 raise ValueError('config_name or database_url must be provided')
278 if config_name not in cherrypy.config:
279 raise NameError('Unknown database %s' % config_name)
280 name = cherrypy.config[config_name]
283 if name.startswith('configfile://'):
284 _, filename = name.split('://')
285 self._db = FileStore(filename)
286 self._query = FileQuery
288 self._db = SqlStore.get_connection(name)
289 self._query = SqlQuery
291 if not self._is_upgrade:
292 self._check_database()
294 def _code_schema_version(self):
295 # This function makes it possible for separate plugins to have
296 # different schema versions. We default to the global schema
298 return CURRENT_SCHEMA_VERSION
300 def _get_schema_version(self):
301 # We are storing multiple versions: one per class
302 # That way, we can support plugins with differing schema versions from
303 # the main codebase, and even in the same database.
304 q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
306 cls_name = self.__class__.__name__
307 current_version = self.load_options('dbinfo').get('%s_schema'
309 if 'version' in current_version:
310 return int(current_version['version'])
312 # Also try the old table name.
313 # "scheme" was a typo, but we need to retain that now for compat
314 fallback_version = self.load_options('dbinfo').get('scheme',
316 if 'version' in fallback_version:
317 return int(fallback_version['version'])
321 def _check_database(self):
323 # If the database is readonly, we cannot do anything to the
324 # schema. Let's just return, and assume people checked the
328 current_version = self._get_schema_version()
329 if current_version is None:
330 self.error('Database initialization required! ' +
331 'Please run ipsilon-upgrade-database')
332 raise DatabaseError('Database initialization required for %s' %
333 self.__class__.__name__)
334 if current_version != self._code_schema_version():
335 self.error('Database upgrade required! ' +
336 'Please run ipsilon-upgrade-database')
337 raise DatabaseError('Database upgrade required for %s' %
338 self.__class__.__name__)
340 def _store_new_schema_version(self, new_version):
341 cls_name = self.__class__.__name__
342 self.save_options('dbinfo', '%s_schema' % cls_name,
343 {'version': new_version})
345 def _initialize_schema(self):
346 raise NotImplementedError()
348 def _upgrade_schema(self, old_version):
349 # Datastores need to figure out what to do with bigger old_versions
351 # They might implement downgrading if that's feasible, or just throw
352 # NotImplementedError
353 raise NotImplementedError()
355 def upgrade_database(self):
356 # Do whatever is needed to get schema to current version
357 old_schema_version = self._get_schema_version()
358 if old_schema_version is None:
359 # Just initialize a new schema
360 self._initialize_schema()
361 self._store_new_schema_version(self._code_schema_version())
362 elif old_schema_version != self._code_schema_version():
363 # Upgrade from old_schema_version to code_schema_version
364 self._upgrade_schema(old_schema_version)
365 self._store_new_schema_version(self._code_schema_version())
368 def is_readonly(self):
369 return self._db.is_readonly
371 def _row_to_dict_tree(self, data, row):
377 self._row_to_dict_tree(d2, row[1:])
381 if data[name] is list:
382 data[name].append(value)
385 data[name] = [v, value]
389 def _rows_to_dict_tree(self, rows):
392 self._row_to_dict_tree(data, r)
395 def _load_data(self, table, columns, kvfilter=None):
398 q = self._query(self._db, table, columns, trans=False)
399 rows = q.select(kvfilter)
400 except Exception, e: # pylint: disable=broad-except
401 self.error("Failed to load data for table %s: [%s]" % (table, e))
402 return self._rows_to_dict_tree(rows)
404 def load_config(self):
406 columns = ['name', 'value']
407 return self._load_data(table, columns)
409 def load_options(self, table, name=None):
412 kvfilter['name'] = name
413 options = self._load_data(table, OPTIONS_TABLE, kvfilter)
414 if name and name in options:
418 def save_options(self, table, name, options):
422 q = self._query(self._db, table, OPTIONS_TABLE)
423 rows = q.select({'name': name}, ['option', 'value'])
425 curvals[row[0]] = row[1]
429 q.update({'value': options[opt]},
430 {'name': name, 'option': opt})
432 q.insert((name, opt, options[opt]))
435 except Exception, e: # pylint: disable=broad-except
438 self.error("Failed to save options: [%s]" % e)
441 def delete_options(self, table, name, options=None):
442 kvfilter = {'name': name}
445 q = self._query(self._db, table, OPTIONS_TABLE)
450 kvfilter['option'] = opt
453 except Exception, e: # pylint: disable=broad-except
456 self.error("Failed to delete from %s: [%s]" % (table, e))
459 def new_unique_data(self, table, data):
460 newid = str(uuid.uuid4())
463 q = self._query(self._db, table, UNIQUE_DATA_TABLE)
465 q.insert((newid, name, data[name]))
467 except Exception, e: # pylint: disable=broad-except
470 self.error("Failed to store %s data: [%s]" % (table, e))
474 def get_unique_data(self, table, uuidval=None, name=None, value=None):
477 kvfilter['uuid'] = uuidval
479 kvfilter['name'] = name
481 kvfilter['value'] = value
482 return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
484 def save_unique_data(self, table, data):
487 q = self._query(self._db, table, UNIQUE_DATA_TABLE)
490 rows = q.select({'uuid': uid}, ['name', 'value'])
497 if datum[name] is None:
498 q.delete({'uuid': uid, 'name': name})
500 q.update({'value': datum[name]},
501 {'uuid': uid, 'name': name})
503 if datum[name] is not None:
504 q.insert((uid, name, datum[name]))
507 except Exception, e: # pylint: disable=broad-except
510 self.error("Failed to store data in %s: [%s]" % (table, e))
513 def del_unique_data(self, table, uuidval):
514 kvfilter = {'uuid': uuidval}
516 q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
518 except Exception, e: # pylint: disable=broad-except
519 self.error("Failed to delete data from %s: [%s]" % (table, e))
521 def _reset_data(self, table):
524 q = self._query(self._db, table, UNIQUE_DATA_TABLE)
528 except Exception, e: # pylint: disable=broad-except
531 self.error("Failed to erase all data from %s: [%s]" % (table, e))
534 class AdminStore(Store):
537 super(AdminStore, self).__init__('admin.config.db')
539 def get_data(self, plugin, idval=None, name=None, value=None):
540 return self.get_unique_data(plugin+"_data", idval, name, value)
542 def save_data(self, plugin, data):
543 return self.save_unique_data(plugin+"_data", data)
545 def new_datum(self, plugin, datum):
546 table = plugin+"_data"
547 return self.new_unique_data(table, datum)
549 def del_datum(self, plugin, idval):
550 table = plugin+"_data"
551 return self.del_unique_data(table, idval)
553 def wipe_data(self, plugin):
554 table = plugin+"_data"
555 self._reset_data(table)
557 def _initialize_schema(self):
558 for table in ['config',
562 q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
565 def _upgrade_schema(self, old_version):
566 raise NotImplementedError()
569 class UserStore(Store):
571 def __init__(self, path=None):
572 super(UserStore, self).__init__('user.prefs.db')
574 def save_user_preferences(self, user, options):
575 self.save_options('users', user, options)
577 def load_user_preferences(self, user):
578 return self.load_options('users', user)
580 def save_plugin_data(self, plugin, user, options):
581 self.save_options(plugin+"_data", user, options)
583 def load_plugin_data(self, plugin, user):
584 return self.load_options(plugin+"_data", user)
586 def _initialize_schema(self):
587 q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
590 def _upgrade_schema(self, old_version):
591 raise NotImplementedError()
594 class TranStore(Store):
596 def __init__(self, path=None):
597 super(TranStore, self).__init__('transactions.db')
599 def _initialize_schema(self):
600 q = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
604 def _upgrade_schema(self, old_version):
605 raise NotImplementedError()
608 class SAML2SessionStore(Store):
610 def __init__(self, database_url):
611 super(SAML2SessionStore, self).__init__(database_url=database_url)
612 self.table = 'saml2_sessions'
613 # pylint: disable=protected-access
614 table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
615 table.create(checkfirst=True)
617 def _get_unique_id_from_column(self, name, value):
619 The query is going to return only the column in the query.
620 Use this method to get the uuidval which can be used to fetch
623 Returns None or the uuid of the first value found.
625 data = self.get_unique_data(self.table, name=name, value=value)
630 raise ValueError("Multiple entries returned")
631 return data.keys()[0]
633 def remove_expired_sessions(self):
634 # pylint: disable=protected-access
635 table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
636 sel = select([table.columns.uuid]). \
637 where(and_(table.c.name == 'expiration_time',
638 table.c.value <= datetime.datetime.now()))
639 # pylint: disable=no-value-for-parameter
640 d = table.delete().where(table.c.uuid.in_(sel))
643 def get_data(self, idval=None, name=None, value=None):
644 return self.get_unique_data(self.table, idval, name, value)
646 def new_session(self, datum):
647 if 'supported_logout_mechs' in datum:
648 datum['supported_logout_mechs'] = ','.join(
649 datum['supported_logout_mechs']
651 return self.new_unique_data(self.table, datum)
653 def get_session(self, session_id=None, request_id=None):
655 uuidval = self._get_unique_id_from_column('session_id', session_id)
657 uuidval = self._get_unique_id_from_column('request_id', request_id)
659 raise ValueError("Unable to find session")
662 data = self.get_unique_data(self.table, uuidval=uuidval)
663 return uuidval, data[uuidval]
665 def get_user_sessions(self, user):
667 Return a list of all sessions for a given user.
669 rows = self.get_unique_data(self.table, name='user', value=user)
671 # We have a list of sessions for this user, now get the details
674 data = self.get_unique_data(self.table, uuidval=r)
675 data[r]['supported_logout_mechs'] = data[r].get(
676 'supported_logout_mechs', '').split(',')
677 logged_in.append(data)
681 def update_session(self, datum):
682 self.save_unique_data(self.table, datum)
684 def remove_session(self, uuidval):
685 self.del_unique_data(self.table, uuidval)
688 self._reset_data(self.table)
690 def _initialize_schema(self):
691 q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
695 def _upgrade_schema(self, old_version):
696 raise NotImplementedError()