Rename the SAML2 sessions database to saml2_sessions
[cascardo/ipsilon.git] / ipsilon / util / data.py
1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
2
3 import cherrypy
4 import datetime
5 from ipsilon.util.log import Log
6 from sqlalchemy import create_engine
7 from sqlalchemy import MetaData, Table, Column, Text
8 from sqlalchemy.pool import QueuePool, SingletonThreadPool
9 from sqlalchemy.sql import select, and_
10 import ConfigParser
11 import os
12 import uuid
13 import logging
14
15
16 CURRENT_SCHEMA_VERSION = 1
17 OPTIONS_COLUMNS = ['name', 'option', 'value']
18 UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
19
20
21 class DatabaseError(Exception):
22     pass
23
24
25 class SqlStore(Log):
26     __instances = {}
27
28     @classmethod
29     def get_connection(cls, name):
30         if name not in cls.__instances:
31             if cherrypy.config.get('db.conn.log', False):
32                 logging.debug('SqlStore new: %s', name)
33             cls.__instances[name] = SqlStore(name)
34         return cls.__instances[name]
35
36     def __init__(self, name):
37         self.db_conn_log = cherrypy.config.get('db.conn.log', False)
38         self.debug('SqlStore init: %s' % name)
39         self.name = name
40         engine_name = name
41         if '://' not in engine_name:
42             engine_name = 'sqlite:///' + engine_name
43         # This pool size is per configured database. The minimum needed,
44         #  determined by binary search, is 23. We're using 25 so we have a bit
45         #  more playroom, and then the overflow should make sure things don't
46         #  break when we suddenly need more.
47         pool_args = {'poolclass': QueuePool,
48                      'pool_size': 25,
49                      'max_overflow': 50}
50         if engine_name.startswith('sqlite://'):
51             # It's not possible to share connections for SQLite between
52             #  threads, so let's use the SingletonThreadPool for them
53             pool_args = {'poolclass': SingletonThreadPool}
54         self._dbengine = create_engine(engine_name, **pool_args)
55         self.is_readonly = False
56
57     def debug(self, fact):
58         if self.db_conn_log:
59             super(SqlStore, self).debug(fact)
60
61     def engine(self):
62         return self._dbengine
63
64     def connection(self):
65         self.debug('SqlStore connect: %s' % self.name)
66         conn = self._dbengine.connect()
67
68         def cleanup_connection():
69             self.debug('SqlStore cleanup: %s' % self.name)
70             conn.close()
71         cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
72         return conn
73
74
75 class SqlQuery(Log):
76
77     def __init__(self, db_obj, table, columns, trans=True):
78         self._db = db_obj
79         self._con = self._db.connection()
80         self._trans = self._con.begin() if trans else None
81         self._table = self._get_table(table, columns)
82
83     def _get_table(self, name, columns):
84         table = Table(name, MetaData(self._db.engine()))
85         for c in columns:
86             table.append_column(Column(c, Text()))
87         return table
88
89     def _where(self, kvfilter):
90         where = None
91         if kvfilter is not None:
92             for k in kvfilter:
93                 w = self._table.columns[k] == kvfilter[k]
94                 if where is None:
95                     where = w
96                 else:
97                     where = where & w
98         return where
99
100     def _columns(self, columns=None):
101         cols = None
102         if columns is not None:
103             cols = []
104             for c in columns:
105                 cols.append(self._table.columns[c])
106         else:
107             cols = self._table.columns
108         return cols
109
110     def rollback(self):
111         self._trans.rollback()
112
113     def commit(self):
114         self._trans.commit()
115
116     def create(self):
117         self._table.create(checkfirst=True)
118
119     def drop(self):
120         self._table.drop(checkfirst=True)
121
122     def select(self, kvfilter=None, columns=None):
123         return self._con.execute(select(self._columns(columns),
124                                         self._where(kvfilter)))
125
126     def insert(self, values):
127         self._con.execute(self._table.insert(values))
128
129     def update(self, values, kvfilter):
130         self._con.execute(self._table.update(self._where(kvfilter), values))
131
132     def delete(self, kvfilter):
133         self._con.execute(self._table.delete(self._where(kvfilter)))
134
135
136 class FileStore(Log):
137
138     def __init__(self, name):
139         self._filename = name
140         self.is_readonly = True
141         self._timestamp = None
142         self._config = None
143
144     def get_config(self):
145         try:
146             stat = os.stat(self._filename)
147         except OSError, e:
148             self.error("Unable to check config file %s: [%s]" % (
149                 self._filename, e))
150             self._config = None
151             raise
152         timestamp = stat.st_mtime
153         if self._config is None or timestamp > self._timestamp:
154             self._config = ConfigParser.RawConfigParser()
155             self._config.optionxform = str
156             self._config.read(self._filename)
157         return self._config
158
159
160 class FileQuery(Log):
161
162     def __init__(self, fstore, table, columns, trans=True):
163         self._fstore = fstore
164         self._config = fstore.get_config()
165         self._section = table
166         if len(columns) > 3 or columns[-1] != 'value':
167             raise ValueError('Unsupported configuration format')
168         self._columns = columns
169
170     def rollback(self):
171         return
172
173     def commit(self):
174         return
175
176     def create(self):
177         raise NotImplementedError
178
179     def drop(self):
180         raise NotImplementedError
181
182     def select(self, kvfilter=None, columns=None):
183         if self._section not in self._config.sections():
184             return []
185
186         opts = self._config.options(self._section)
187
188         prefix = None
189         prefix_ = ''
190         if self._columns[0] in kvfilter:
191             prefix = kvfilter[self._columns[0]]
192             prefix_ = prefix + ' '
193
194         name = None
195         if len(self._columns) == 3 and self._columns[1] in kvfilter:
196             name = kvfilter[self._columns[1]]
197
198         value = None
199         if self._columns[-1] in kvfilter:
200             value = kvfilter[self._columns[-1]]
201
202         res = []
203         for o in opts:
204             if len(self._columns) == 3:
205                 # 3 cols
206                 if prefix and not o.startswith(prefix_):
207                     continue
208
209                 col1, col2 = o.split(' ', 1)
210                 if name and col2 != name:
211                     continue
212
213                 col3 = self._config.get(self._section, o)
214                 if value and col3 != value:
215                     continue
216
217                 r = [col1, col2, col3]
218             else:
219                 # 2 cols
220                 if prefix and o != prefix:
221                     continue
222                 r = [o, self._config.get(self._section, o)]
223
224             if columns:
225                 s = []
226                 for c in columns:
227                     s.append(r[self._columns.index(c)])
228                 res.append(s)
229             else:
230                 res.append(r)
231
232         self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
233                                                  repr(kvfilter),
234                                                  repr(columns),
235                                                  repr(res)))
236         return res
237
238     def insert(self, values):
239         raise NotImplementedError
240
241     def update(self, values, kvfilter):
242         raise NotImplementedError
243
244     def delete(self, kvfilter):
245         raise NotImplementedError
246
247
248 class Store(Log):
249     _is_upgrade = False
250
251     def __init__(self, config_name=None, database_url=None):
252         if config_name is None and database_url is None:
253             raise ValueError('config_name or database_url must be provided')
254         if config_name:
255             if config_name not in cherrypy.config:
256                 raise NameError('Unknown database %s' % config_name)
257             name = cherrypy.config[config_name]
258         else:
259             name = database_url
260         if name.startswith('configfile://'):
261             _, filename = name.split('://')
262             self._db = FileStore(filename)
263             self._query = FileQuery
264         else:
265             self._db = SqlStore.get_connection(name)
266             self._query = SqlQuery
267
268         if not self._is_upgrade:
269             self._check_database()
270
271     def _code_schema_version(self):
272         # This function makes it possible for separate plugins to have
273         #  different schema versions. We default to the global schema
274         #  version.
275         return CURRENT_SCHEMA_VERSION
276
277     def _get_schema_version(self):
278         # We are storing multiple versions: one per class
279         # That way, we can support plugins with differing schema versions from
280         #  the main codebase, and even in the same database.
281         q = self._query(self._db, 'dbinfo', OPTIONS_COLUMNS, trans=False)
282         q.create()
283         cls_name = self.__class__.__name__
284         current_version = self.load_options('dbinfo').get('%s_schema'
285                                                           % cls_name, {})
286         if 'version' in current_version:
287             return int(current_version['version'])
288         else:
289             # Also try the old table name.
290             # "scheme" was a typo, but we need to retain that now for compat
291             fallback_version = self.load_options('dbinfo').get('scheme',
292                                                                {})
293             if 'version' in fallback_version:
294                 return int(fallback_version['version'])
295             else:
296                 return None
297
298     def _check_database(self):
299         if self.is_readonly:
300             # If the database is readonly, we cannot do anything to the
301             #  schema. Let's just return, and assume people checked the
302             #  upgrade notes
303             return
304
305         current_version = self._get_schema_version()
306         if current_version is None:
307             self.error('Database initialization required! ' +
308                        'Please run ipsilon-upgrade-database')
309             raise DatabaseError('Database initialization required for %s' %
310                                 self.__class__.__name__)
311         if current_version != self._code_schema_version():
312             self.error('Database upgrade required! ' +
313                        'Please run ipsilon-upgrade-database')
314             raise DatabaseError('Database upgrade required for %s' %
315                                 self.__class__.__name__)
316
317     def _store_new_schema_version(self, new_version):
318         cls_name = self.__class__.__name__
319         self.save_options('dbinfo', '%s_schema' % cls_name,
320                           {'version': new_version})
321
322     def _initialize_schema(self):
323         raise NotImplementedError()
324
325     def _upgrade_schema(self, old_version):
326         # Datastores need to figure out what to do with bigger old_versions
327         #  themselves.
328         # They might implement downgrading if that's feasible, or just throw
329         #  NotImplementedError
330         raise NotImplementedError()
331
332     def upgrade_database(self):
333         # Do whatever is needed to get schema to current version
334         old_schema_version = self._get_schema_version()
335         if old_schema_version is None:
336             # Just initialize a new schema
337             self._initialize_schema()
338             self._store_new_schema_version(self._code_schema_version())
339         elif old_schema_version != self._code_schema_version():
340             # Upgrade from old_schema_version to code_schema_version
341             self._upgrade_schema(old_schema_version)
342             self._store_new_schema_version(self._code_schema_version())
343
344     @property
345     def is_readonly(self):
346         return self._db.is_readonly
347
348     def _row_to_dict_tree(self, data, row):
349         name = row[0]
350         if len(row) > 2:
351             if name not in data:
352                 data[name] = dict()
353             d2 = data[name]
354             self._row_to_dict_tree(d2, row[1:])
355         else:
356             value = row[1]
357             if name in data:
358                 if data[name] is list:
359                     data[name].append(value)
360                 else:
361                     v = data[name]
362                     data[name] = [v, value]
363             else:
364                 data[name] = value
365
366     def _rows_to_dict_tree(self, rows):
367         data = dict()
368         for r in rows:
369             self._row_to_dict_tree(data, r)
370         return data
371
372     def _load_data(self, table, columns, kvfilter=None):
373         rows = []
374         try:
375             q = self._query(self._db, table, columns, trans=False)
376             rows = q.select(kvfilter)
377         except Exception, e:  # pylint: disable=broad-except
378             self.error("Failed to load data for table %s: [%s]" % (table, e))
379         return self._rows_to_dict_tree(rows)
380
381     def load_config(self):
382         table = 'config'
383         columns = ['name', 'value']
384         return self._load_data(table, columns)
385
386     def load_options(self, table, name=None):
387         kvfilter = dict()
388         if name:
389             kvfilter['name'] = name
390         options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
391         if name and name in options:
392             return options[name]
393         return options
394
395     def save_options(self, table, name, options):
396         curvals = dict()
397         q = None
398         try:
399             q = self._query(self._db, table, OPTIONS_COLUMNS)
400             rows = q.select({'name': name}, ['option', 'value'])
401             for row in rows:
402                 curvals[row[0]] = row[1]
403
404             for opt in options:
405                 if opt in curvals:
406                     q.update({'value': options[opt]},
407                              {'name': name, 'option': opt})
408                 else:
409                     q.insert((name, opt, options[opt]))
410
411             q.commit()
412         except Exception, e:  # pylint: disable=broad-except
413             if q:
414                 q.rollback()
415             self.error("Failed to save options: [%s]" % e)
416             raise
417
418     def delete_options(self, table, name, options=None):
419         kvfilter = {'name': name}
420         q = None
421         try:
422             q = self._query(self._db, table, OPTIONS_COLUMNS)
423             if options is None:
424                 q.delete(kvfilter)
425             else:
426                 for opt in options:
427                     kvfilter['option'] = opt
428                     q.delete(kvfilter)
429             q.commit()
430         except Exception, e:  # pylint: disable=broad-except
431             if q:
432                 q.rollback()
433             self.error("Failed to delete from %s: [%s]" % (table, e))
434             raise
435
436     def new_unique_data(self, table, data):
437         newid = str(uuid.uuid4())
438         q = None
439         try:
440             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
441             for name in data:
442                 q.insert((newid, name, data[name]))
443             q.commit()
444         except Exception, e:  # pylint: disable=broad-except
445             if q:
446                 q.rollback()
447             self.error("Failed to store %s data: [%s]" % (table, e))
448             raise
449         return newid
450
451     def get_unique_data(self, table, uuidval=None, name=None, value=None):
452         kvfilter = dict()
453         if uuidval:
454             kvfilter['uuid'] = uuidval
455         if name:
456             kvfilter['name'] = name
457         if value:
458             kvfilter['value'] = value
459         return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
460
461     def save_unique_data(self, table, data):
462         q = None
463         try:
464             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
465             for uid in data:
466                 curvals = dict()
467                 rows = q.select({'uuid': uid}, ['name', 'value'])
468                 for r in rows:
469                     curvals[r[0]] = r[1]
470
471                 datum = data[uid]
472                 for name in datum:
473                     if name in curvals:
474                         if datum[name] is None:
475                             q.delete({'uuid': uid, 'name': name})
476                         else:
477                             q.update({'value': datum[name]},
478                                      {'uuid': uid, 'name': name})
479                     else:
480                         if datum[name] is not None:
481                             q.insert((uid, name, datum[name]))
482
483             q.commit()
484         except Exception, e:  # pylint: disable=broad-except
485             if q:
486                 q.rollback()
487             self.error("Failed to store data in %s: [%s]" % (table, e))
488             raise
489
490     def del_unique_data(self, table, uuidval):
491         kvfilter = {'uuid': uuidval}
492         try:
493             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
494             q.delete(kvfilter)
495         except Exception, e:  # pylint: disable=broad-except
496             self.error("Failed to delete data from %s: [%s]" % (table, e))
497
498     def _reset_data(self, table):
499         q = None
500         try:
501             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
502             q.drop()
503             q.create()
504             q.commit()
505         except Exception, e:  # pylint: disable=broad-except
506             if q:
507                 q.rollback()
508             self.error("Failed to erase all data from %s: [%s]" % (table, e))
509
510
511 class AdminStore(Store):
512
513     def __init__(self):
514         super(AdminStore, self).__init__('admin.config.db')
515
516     def get_data(self, plugin, idval=None, name=None, value=None):
517         return self.get_unique_data(plugin+"_data", idval, name, value)
518
519     def save_data(self, plugin, data):
520         return self.save_unique_data(plugin+"_data", data)
521
522     def new_datum(self, plugin, datum):
523         table = plugin+"_data"
524         return self.new_unique_data(table, datum)
525
526     def del_datum(self, plugin, idval):
527         table = plugin+"_data"
528         return self.del_unique_data(table, idval)
529
530     def wipe_data(self, plugin):
531         table = plugin+"_data"
532         self._reset_data(table)
533
534     def _initialize_schema(self):
535         for table in ['config',
536                       'info_config',
537                       'login_config',
538                       'provider_config']:
539             q = self._query(self._db, table, OPTIONS_COLUMNS, trans=False)
540             q.create()
541
542     def _upgrade_schema(self, old_version):
543         raise NotImplementedError()
544
545
546 class UserStore(Store):
547
548     def __init__(self, path=None):
549         super(UserStore, self).__init__('user.prefs.db')
550
551     def save_user_preferences(self, user, options):
552         self.save_options('users', user, options)
553
554     def load_user_preferences(self, user):
555         return self.load_options('users', user)
556
557     def save_plugin_data(self, plugin, user, options):
558         self.save_options(plugin+"_data", user, options)
559
560     def load_plugin_data(self, plugin, user):
561         return self.load_options(plugin+"_data", user)
562
563     def _initialize_schema(self):
564         q = self._query(self._db, 'users', OPTIONS_COLUMNS, trans=False)
565         q.create()
566
567     def _upgrade_schema(self, old_version):
568         raise NotImplementedError()
569
570
571 class TranStore(Store):
572
573     def __init__(self, path=None):
574         super(TranStore, self).__init__('transactions.db')
575
576     def _initialize_schema(self):
577         q = self._query(self._db, 'transactions', UNIQUE_DATA_COLUMNS,
578                         trans=False)
579         q.create()
580
581     def _upgrade_schema(self, old_version):
582         raise NotImplementedError()
583
584
585 class SAML2SessionStore(Store):
586
587     def __init__(self, database_url):
588         super(SAML2SessionStore, self).__init__(database_url=database_url)
589         self.table = 'saml2_sessions'
590         # pylint: disable=protected-access
591         table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
592         table.create(checkfirst=True)
593
594     def _get_unique_id_from_column(self, name, value):
595         """
596         The query is going to return only the column in the query.
597         Use this method to get the uuidval which can be used to fetch
598         the entire entry.
599
600         Returns None or the uuid of the first value found.
601         """
602         data = self.get_unique_data(self.table, name=name, value=value)
603         count = len(data)
604         if count == 0:
605             return None
606         elif count != 1:
607             raise ValueError("Multiple entries returned")
608         return data.keys()[0]
609
610     def remove_expired_sessions(self):
611         # pylint: disable=protected-access
612         table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
613         sel = select([table.columns.uuid]). \
614             where(and_(table.c.name == 'expiration_time',
615                        table.c.value <= datetime.datetime.now()))
616         # pylint: disable=no-value-for-parameter
617         d = table.delete().where(table.c.uuid.in_(sel))
618         d.execute()
619
620     def get_data(self, idval=None, name=None, value=None):
621         return self.get_unique_data(self.table, idval, name, value)
622
623     def new_session(self, datum):
624         if 'supported_logout_mechs' in datum:
625             datum['supported_logout_mechs'] = ','.join(
626                 datum['supported_logout_mechs']
627             )
628         return self.new_unique_data(self.table, datum)
629
630     def get_session(self, session_id=None, request_id=None):
631         if session_id:
632             uuidval = self._get_unique_id_from_column('session_id', session_id)
633         elif request_id:
634             uuidval = self._get_unique_id_from_column('request_id', request_id)
635         else:
636             raise ValueError("Unable to find session")
637         if not uuidval:
638             return None, None
639         data = self.get_unique_data(self.table, uuidval=uuidval)
640         return uuidval, data[uuidval]
641
642     def get_user_sessions(self, user):
643         """
644         Return a list of all sessions for a given user.
645         """
646         rows = self.get_unique_data(self.table, name='user', value=user)
647
648         # We have a list of sessions for this user, now get the details
649         logged_in = []
650         for r in rows:
651             data = self.get_unique_data(self.table, uuidval=r)
652             data[r]['supported_logout_mechs'] = data[r].get(
653                 'supported_logout_mechs', '').split(',')
654             logged_in.append(data)
655
656         return logged_in
657
658     def update_session(self, datum):
659         self.save_unique_data(self.table, datum)
660
661     def remove_session(self, uuidval):
662         self.del_unique_data(self.table, uuidval)
663
664     def wipe_data(self):
665         self._reset_data(self.table)
666
667     def _initialize_schema(self):
668         q = self._query(self._db, self.table, UNIQUE_DATA_COLUMNS,
669                         trans=False)
670         q.create()
671
672     def _upgrade_schema(self, old_version):
673         raise NotImplementedError()