1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
5 from ipsilon.util.log import Log
6 from sqlalchemy import create_engine
7 from sqlalchemy import MetaData, Table, Column, Text
8 from sqlalchemy.pool import QueuePool, SingletonThreadPool
9 from sqlalchemy.sql import select, and_
16 CURRENT_SCHEMA_VERSION = 1
17 OPTIONS_COLUMNS = ['name', 'option', 'value']
18 UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
21 class DatabaseError(Exception):
29 def get_connection(cls, name):
30 if name not in cls.__instances:
31 if cherrypy.config.get('db.conn.log', False):
32 logging.debug('SqlStore new: %s', name)
33 cls.__instances[name] = SqlStore(name)
34 return cls.__instances[name]
36 def __init__(self, name):
37 self.db_conn_log = cherrypy.config.get('db.conn.log', False)
38 self.debug('SqlStore init: %s' % name)
41 if '://' not in engine_name:
42 engine_name = 'sqlite:///' + engine_name
43 # This pool size is per configured database. The minimum needed,
44 # determined by binary search, is 23. We're using 25 so we have a bit
45 # more playroom, and then the overflow should make sure things don't
46 # break when we suddenly need more.
47 pool_args = {'poolclass': QueuePool,
50 if engine_name.startswith('sqlite://'):
51 # It's not possible to share connections for SQLite between
52 # threads, so let's use the SingletonThreadPool for them
53 pool_args = {'poolclass': SingletonThreadPool}
54 self._dbengine = create_engine(engine_name, **pool_args)
55 self.is_readonly = False
57 def debug(self, fact):
59 super(SqlStore, self).debug(fact)
65 self.debug('SqlStore connect: %s' % self.name)
66 conn = self._dbengine.connect()
68 def cleanup_connection():
69 self.debug('SqlStore cleanup: %s' % self.name)
71 cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
77 def __init__(self, db_obj, table, columns, trans=True):
79 self._con = self._db.connection()
80 self._trans = self._con.begin() if trans else None
81 self._table = self._get_table(table, columns)
83 def _get_table(self, name, columns):
84 table = Table(name, MetaData(self._db.engine()))
86 table.append_column(Column(c, Text()))
89 def _where(self, kvfilter):
91 if kvfilter is not None:
93 w = self._table.columns[k] == kvfilter[k]
100 def _columns(self, columns=None):
102 if columns is not None:
105 cols.append(self._table.columns[c])
107 cols = self._table.columns
111 self._trans.rollback()
117 self._table.create(checkfirst=True)
120 self._table.drop(checkfirst=True)
122 def select(self, kvfilter=None, columns=None):
123 return self._con.execute(select(self._columns(columns),
124 self._where(kvfilter)))
126 def insert(self, values):
127 self._con.execute(self._table.insert(values))
129 def update(self, values, kvfilter):
130 self._con.execute(self._table.update(self._where(kvfilter), values))
132 def delete(self, kvfilter):
133 self._con.execute(self._table.delete(self._where(kvfilter)))
136 class FileStore(Log):
138 def __init__(self, name):
139 self._filename = name
140 self.is_readonly = True
141 self._timestamp = None
144 def get_config(self):
146 stat = os.stat(self._filename)
148 self.error("Unable to check config file %s: [%s]" % (
152 timestamp = stat.st_mtime
153 if self._config is None or timestamp > self._timestamp:
154 self._config = ConfigParser.RawConfigParser()
155 self._config.optionxform = str
156 self._config.read(self._filename)
160 class FileQuery(Log):
162 def __init__(self, fstore, table, columns, trans=True):
163 self._fstore = fstore
164 self._config = fstore.get_config()
165 self._section = table
166 if len(columns) > 3 or columns[-1] != 'value':
167 raise ValueError('Unsupported configuration format')
168 self._columns = columns
177 raise NotImplementedError
180 raise NotImplementedError
182 def select(self, kvfilter=None, columns=None):
183 if self._section not in self._config.sections():
186 opts = self._config.options(self._section)
190 if self._columns[0] in kvfilter:
191 prefix = kvfilter[self._columns[0]]
192 prefix_ = prefix + ' '
195 if len(self._columns) == 3 and self._columns[1] in kvfilter:
196 name = kvfilter[self._columns[1]]
199 if self._columns[-1] in kvfilter:
200 value = kvfilter[self._columns[-1]]
204 if len(self._columns) == 3:
206 if prefix and not o.startswith(prefix_):
209 col1, col2 = o.split(' ', 1)
210 if name and col2 != name:
213 col3 = self._config.get(self._section, o)
214 if value and col3 != value:
217 r = [col1, col2, col3]
220 if prefix and o != prefix:
222 r = [o, self._config.get(self._section, o)]
227 s.append(r[self._columns.index(c)])
232 self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
238 def insert(self, values):
239 raise NotImplementedError
241 def update(self, values, kvfilter):
242 raise NotImplementedError
244 def delete(self, kvfilter):
245 raise NotImplementedError
251 def __init__(self, config_name=None, database_url=None):
252 if config_name is None and database_url is None:
253 raise ValueError('config_name or database_url must be provided')
255 if config_name not in cherrypy.config:
256 raise NameError('Unknown database %s' % config_name)
257 name = cherrypy.config[config_name]
260 if name.startswith('configfile://'):
261 _, filename = name.split('://')
262 self._db = FileStore(filename)
263 self._query = FileQuery
265 self._db = SqlStore.get_connection(name)
266 self._query = SqlQuery
268 if not self._is_upgrade:
269 self._check_database()
271 def _code_schema_version(self):
272 # This function makes it possible for separate plugins to have
273 # different schema versions. We default to the global schema
275 return CURRENT_SCHEMA_VERSION
277 def _get_schema_version(self):
278 # We are storing multiple versions: one per class
279 # That way, we can support plugins with differing schema versions from
280 # the main codebase, and even in the same database.
281 q = self._query(self._db, 'dbinfo', OPTIONS_COLUMNS, trans=False)
283 cls_name = self.__class__.__name__
284 current_version = self.load_options('dbinfo').get('%s_schema'
286 if 'version' in current_version:
287 return int(current_version['version'])
289 # Also try the old table name.
290 # "scheme" was a typo, but we need to retain that now for compat
291 fallback_version = self.load_options('dbinfo').get('scheme',
293 if 'version' in fallback_version:
294 return int(fallback_version['version'])
298 def _check_database(self):
300 # If the database is readonly, we cannot do anything to the
301 # schema. Let's just return, and assume people checked the
305 current_version = self._get_schema_version()
306 if current_version is None:
307 self.error('Database initialization required! ' +
308 'Please run ipsilon-upgrade-database')
309 raise DatabaseError('Database initialization required for %s' %
310 self.__class__.__name__)
311 if current_version != self._code_schema_version():
312 self.error('Database upgrade required! ' +
313 'Please run ipsilon-upgrade-database')
314 raise DatabaseError('Database upgrade required for %s' %
315 self.__class__.__name__)
317 def _store_new_schema_version(self, new_version):
318 cls_name = self.__class__.__name__
319 self.save_options('dbinfo', '%s_schema' % cls_name,
320 {'version': new_version})
322 def _initialize_schema(self):
323 raise NotImplementedError()
325 def _upgrade_schema(self, old_version):
326 # Datastores need to figure out what to do with bigger old_versions
328 # They might implement downgrading if that's feasible, or just throw
329 # NotImplementedError
330 raise NotImplementedError()
332 def upgrade_database(self):
333 # Do whatever is needed to get schema to current version
334 old_schema_version = self._get_schema_version()
335 if old_schema_version is None:
336 # Just initialize a new schema
337 self._initialize_schema()
338 self._store_new_schema_version(self._code_schema_version())
339 elif old_schema_version != self._code_schema_version():
340 # Upgrade from old_schema_version to code_schema_version
341 self._upgrade_schema(old_schema_version)
342 self._store_new_schema_version(self._code_schema_version())
345 def is_readonly(self):
346 return self._db.is_readonly
348 def _row_to_dict_tree(self, data, row):
354 self._row_to_dict_tree(d2, row[1:])
358 if data[name] is list:
359 data[name].append(value)
362 data[name] = [v, value]
366 def _rows_to_dict_tree(self, rows):
369 self._row_to_dict_tree(data, r)
372 def _load_data(self, table, columns, kvfilter=None):
375 q = self._query(self._db, table, columns, trans=False)
376 rows = q.select(kvfilter)
377 except Exception, e: # pylint: disable=broad-except
378 self.error("Failed to load data for table %s: [%s]" % (table, e))
379 return self._rows_to_dict_tree(rows)
381 def load_config(self):
383 columns = ['name', 'value']
384 return self._load_data(table, columns)
386 def load_options(self, table, name=None):
389 kvfilter['name'] = name
390 options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
391 if name and name in options:
395 def save_options(self, table, name, options):
399 q = self._query(self._db, table, OPTIONS_COLUMNS)
400 rows = q.select({'name': name}, ['option', 'value'])
402 curvals[row[0]] = row[1]
406 q.update({'value': options[opt]},
407 {'name': name, 'option': opt})
409 q.insert((name, opt, options[opt]))
412 except Exception, e: # pylint: disable=broad-except
415 self.error("Failed to save options: [%s]" % e)
418 def delete_options(self, table, name, options=None):
419 kvfilter = {'name': name}
422 q = self._query(self._db, table, OPTIONS_COLUMNS)
427 kvfilter['option'] = opt
430 except Exception, e: # pylint: disable=broad-except
433 self.error("Failed to delete from %s: [%s]" % (table, e))
436 def new_unique_data(self, table, data):
437 newid = str(uuid.uuid4())
440 q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
442 q.insert((newid, name, data[name]))
444 except Exception, e: # pylint: disable=broad-except
447 self.error("Failed to store %s data: [%s]" % (table, e))
451 def get_unique_data(self, table, uuidval=None, name=None, value=None):
454 kvfilter['uuid'] = uuidval
456 kvfilter['name'] = name
458 kvfilter['value'] = value
459 return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
461 def save_unique_data(self, table, data):
464 q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
467 rows = q.select({'uuid': uid}, ['name', 'value'])
474 if datum[name] is None:
475 q.delete({'uuid': uid, 'name': name})
477 q.update({'value': datum[name]},
478 {'uuid': uid, 'name': name})
480 if datum[name] is not None:
481 q.insert((uid, name, datum[name]))
484 except Exception, e: # pylint: disable=broad-except
487 self.error("Failed to store data in %s: [%s]" % (table, e))
490 def del_unique_data(self, table, uuidval):
491 kvfilter = {'uuid': uuidval}
493 q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
495 except Exception, e: # pylint: disable=broad-except
496 self.error("Failed to delete data from %s: [%s]" % (table, e))
498 def _reset_data(self, table):
501 q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
505 except Exception, e: # pylint: disable=broad-except
508 self.error("Failed to erase all data from %s: [%s]" % (table, e))
511 class AdminStore(Store):
514 super(AdminStore, self).__init__('admin.config.db')
516 def get_data(self, plugin, idval=None, name=None, value=None):
517 return self.get_unique_data(plugin+"_data", idval, name, value)
519 def save_data(self, plugin, data):
520 return self.save_unique_data(plugin+"_data", data)
522 def new_datum(self, plugin, datum):
523 table = plugin+"_data"
524 return self.new_unique_data(table, datum)
526 def del_datum(self, plugin, idval):
527 table = plugin+"_data"
528 return self.del_unique_data(table, idval)
530 def wipe_data(self, plugin):
531 table = plugin+"_data"
532 self._reset_data(table)
534 def _initialize_schema(self):
535 for table in ['config',
539 q = self._query(self._db, table, OPTIONS_COLUMNS, trans=False)
542 def _upgrade_schema(self, old_version):
543 raise NotImplementedError()
546 class UserStore(Store):
548 def __init__(self, path=None):
549 super(UserStore, self).__init__('user.prefs.db')
551 def save_user_preferences(self, user, options):
552 self.save_options('users', user, options)
554 def load_user_preferences(self, user):
555 return self.load_options('users', user)
557 def save_plugin_data(self, plugin, user, options):
558 self.save_options(plugin+"_data", user, options)
560 def load_plugin_data(self, plugin, user):
561 return self.load_options(plugin+"_data", user)
563 def _initialize_schema(self):
564 q = self._query(self._db, 'users', OPTIONS_COLUMNS, trans=False)
567 def _upgrade_schema(self, old_version):
568 raise NotImplementedError()
571 class TranStore(Store):
573 def __init__(self, path=None):
574 super(TranStore, self).__init__('transactions.db')
576 def _initialize_schema(self):
577 q = self._query(self._db, 'transactions', UNIQUE_DATA_COLUMNS,
581 def _upgrade_schema(self, old_version):
582 raise NotImplementedError()
585 class SAML2SessionStore(Store):
587 def __init__(self, database_url):
588 super(SAML2SessionStore, self).__init__(database_url=database_url)
589 self.table = 'sessions'
590 # pylint: disable=protected-access
591 table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
592 table.create(checkfirst=True)
594 def _get_unique_id_from_column(self, name, value):
596 The query is going to return only the column in the query.
597 Use this method to get the uuidval which can be used to fetch
600 Returns None or the uuid of the first value found.
602 data = self.get_unique_data(self.table, name=name, value=value)
607 raise ValueError("Multiple entries returned")
608 return data.keys()[0]
610 def remove_expired_sessions(self):
611 # pylint: disable=protected-access
612 table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
613 sel = select([table.columns.uuid]). \
614 where(and_(table.c.name == 'expiration_time',
615 table.c.value <= datetime.datetime.now()))
616 # pylint: disable=no-value-for-parameter
617 d = table.delete().where(table.c.uuid.in_(sel))
620 def get_data(self, idval=None, name=None, value=None):
621 return self.get_unique_data(self.table, idval, name, value)
623 def new_session(self, datum):
624 if 'supported_logout_mechs' in datum:
625 datum['supported_logout_mechs'] = ','.join(
626 datum['supported_logout_mechs']
628 return self.new_unique_data(self.table, datum)
630 def get_session(self, session_id=None, request_id=None):
632 uuidval = self._get_unique_id_from_column('session_id', session_id)
634 uuidval = self._get_unique_id_from_column('request_id', request_id)
636 raise ValueError("Unable to find session")
639 data = self.get_unique_data(self.table, uuidval=uuidval)
640 return uuidval, data[uuidval]
642 def get_user_sessions(self, user):
644 Return a list of all sessions for a given user.
646 rows = self.get_unique_data(self.table, name='user', value=user)
648 # We have a list of sessions for this user, now get the details
651 data = self.get_unique_data(self.table, uuidval=r)
652 data[r]['supported_logout_mechs'] = data[r].get(
653 'supported_logout_mechs', '').split(',')
654 logged_in.append(data)
658 def update_session(self, datum):
659 self.save_unique_data(self.table, datum)
661 def remove_session(self, uuidval):
662 self.del_unique_data(self.table, uuidval)
665 self._reset_data(self.table)
667 def _initialize_schema(self):
668 q = self._query(self._db, self.table, UNIQUE_DATA_COLUMNS,
672 def _upgrade_schema(self, old_version):
673 raise NotImplementedError()