0981c52408a18e69cbaa0d82decc490b2f9b8d6a
[cascardo/ipsilon.git] / ipsilon / util / data.py
1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
2
3 import cherrypy
4 import datetime
5 from ipsilon.util.log import Log
6 from sqlalchemy import create_engine
7 from sqlalchemy import MetaData, Table, Column, Text
8 from sqlalchemy.pool import QueuePool, SingletonThreadPool
9 from sqlalchemy.schema import (PrimaryKeyConstraint, Index, AddConstraint,
10                                CreateIndex)
11 from sqlalchemy.sql import select, and_
12 import ConfigParser
13 import os
14 import uuid
15 import logging
16 import time
17
18
19 CURRENT_SCHEMA_VERSION = 2
20 OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
21                  'primary_key': ('name', 'option'),
22                  'indexes': [('name',)]
23                  }
24 UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
25                      'primary_key': ('uuid', 'name'),
26                      'indexes': [('uuid',)]
27                      }
28
29
30 class DatabaseError(Exception):
31     pass
32
33
34 class BaseStore(Log):
35     # Some helper functions used for upgrades
36     def add_constraint(self, table):
37         raise NotImplementedError()
38
39     def add_index(self, index):
40         raise NotImplementedError()
41
42
43 class SqlStore(BaseStore):
44     __instances = {}
45
46     @classmethod
47     def get_connection(cls, name):
48         if name not in cls.__instances:
49             if cherrypy.config.get('db.conn.log', False):
50                 logging.debug('SqlStore new: %s', name)
51             cls.__instances[name] = SqlStore(name)
52         return cls.__instances[name]
53
54     def __init__(self, name):
55         self.db_conn_log = cherrypy.config.get('db.conn.log', False)
56         self.debug('SqlStore init: %s' % name)
57         self.name = name
58         engine_name = name
59         if '://' not in engine_name:
60             engine_name = 'sqlite:///' + engine_name
61         # This pool size is per configured database. The minimum needed,
62         #  determined by binary search, is 23. We're using 25 so we have a bit
63         #  more playroom, and then the overflow should make sure things don't
64         #  break when we suddenly need more.
65         pool_args = {'poolclass': QueuePool,
66                      'pool_size': 25,
67                      'max_overflow': 50}
68         if engine_name.startswith('sqlite://'):
69             # It's not possible to share connections for SQLite between
70             #  threads, so let's use the SingletonThreadPool for them
71             pool_args = {'poolclass': SingletonThreadPool}
72         self._dbengine = create_engine(engine_name,
73                                        echo=cherrypy.config.get('db.echo',
74                                                                 False),
75                                        **pool_args)
76         self.is_readonly = False
77
78     def add_constraint(self, constraint):
79         if self._dbengine.dialect.name != 'sqlite':
80             # It is impossible to add constraints to a pre-existing table for
81             #  SQLite
82             # source: http://www.sqlite.org/omitted.html
83             create_constraint = AddConstraint(constraint, bind=self._dbengine)
84             create_constraint.execute()
85
86     def add_index(self, index):
87         add_index = CreateIndex(index, bind=self._dbengine)
88         add_index.execute()
89
90     def debug(self, fact):
91         if self.db_conn_log:
92             super(SqlStore, self).debug(fact)
93
94     def engine(self):
95         return self._dbengine
96
97     def connection(self):
98         self.debug('SqlStore connect: %s' % self.name)
99         conn = self._dbengine.connect()
100
101         def cleanup_connection():
102             self.debug('SqlStore cleanup: %s' % self.name)
103             conn.close()
104         cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
105         return conn
106
107
108 class SqlQuery(Log):
109
110     def __init__(self, db_obj, table, table_def, trans=True):
111         self._db = db_obj
112         self._con = self._db.connection()
113         self._trans = self._con.begin() if trans else None
114         self._table = self._get_table(table, table_def)
115
116     def _get_table(self, name, table_def):
117         if isinstance(table_def, list):
118             table_def = {'columns': table_def,
119                          'indexes': [],
120                          'primary_key': None}
121         table_creation = []
122         for col_name in table_def['columns']:
123             table_creation.append(Column(col_name, Text()))
124         if table_def['primary_key']:
125             table_creation.append(PrimaryKeyConstraint(
126                 *table_def['primary_key']))
127         for index in table_def['indexes']:
128             idx_name = 'idx_%s_%s' % (name, '_'.join(index))
129             table_creation.append(Index(idx_name, *index))
130         table = Table(name, MetaData(self._db.engine()), *table_creation)
131         return table
132
133     def _where(self, kvfilter):
134         where = None
135         if kvfilter is not None:
136             for k in kvfilter:
137                 w = self._table.columns[k] == kvfilter[k]
138                 if where is None:
139                     where = w
140                 else:
141                     where = where & w
142         return where
143
144     def _columns(self, columns=None):
145         cols = None
146         if columns is not None:
147             cols = []
148             for c in columns:
149                 cols.append(self._table.columns[c])
150         else:
151             cols = self._table.columns
152         return cols
153
154     def rollback(self):
155         self._trans.rollback()
156
157     def commit(self):
158         self._trans.commit()
159
160     def create(self):
161         self._table.create(checkfirst=True)
162
163     def drop(self):
164         self._table.drop(checkfirst=True)
165
166     def select(self, kvfilter=None, columns=None):
167         return self._con.execute(select(self._columns(columns),
168                                         self._where(kvfilter)))
169
170     def insert(self, values):
171         self._con.execute(self._table.insert(values))
172
173     def update(self, values, kvfilter):
174         self._con.execute(self._table.update(self._where(kvfilter), values))
175
176     def delete(self, kvfilter):
177         self._con.execute(self._table.delete(self._where(kvfilter)))
178
179
180 class FileStore(BaseStore):
181
182     def __init__(self, name):
183         self._filename = name
184         self.is_readonly = True
185         self._timestamp = None
186         self._config = None
187
188     def get_config(self):
189         try:
190             stat = os.stat(self._filename)
191         except OSError, e:
192             self.error("Unable to check config file %s: [%s]" % (
193                 self._filename, e))
194             self._config = None
195             raise
196         timestamp = stat.st_mtime
197         if self._config is None or timestamp > self._timestamp:
198             self._config = ConfigParser.RawConfigParser()
199             self._config.optionxform = str
200             self._config.read(self._filename)
201         return self._config
202
203     def add_constraint(self, table):
204         raise NotImplementedError()
205
206     def add_index(self, index):
207         raise NotImplementedError()
208
209
210 class FileQuery(Log):
211
212     def __init__(self, fstore, table, table_def, trans=True):
213         # We don't need indexes in a FileQuery, so drop that info
214         if isinstance(table_def, dict):
215             columns = table_def['columns']
216         else:
217             columns = table_def
218         self._fstore = fstore
219         self._config = fstore.get_config()
220         self._section = table
221         if len(columns) > 3 or columns[-1] != 'value':
222             raise ValueError('Unsupported configuration format')
223         self._columns = columns
224
225     def rollback(self):
226         return
227
228     def commit(self):
229         return
230
231     def create(self):
232         raise NotImplementedError
233
234     def drop(self):
235         raise NotImplementedError
236
237     def select(self, kvfilter=None, columns=None):
238         if self._section not in self._config.sections():
239             return []
240
241         opts = self._config.options(self._section)
242
243         prefix = None
244         prefix_ = ''
245         if self._columns[0] in kvfilter:
246             prefix = kvfilter[self._columns[0]]
247             prefix_ = prefix + ' '
248
249         name = None
250         if len(self._columns) == 3 and self._columns[1] in kvfilter:
251             name = kvfilter[self._columns[1]]
252
253         value = None
254         if self._columns[-1] in kvfilter:
255             value = kvfilter[self._columns[-1]]
256
257         res = []
258         for o in opts:
259             if len(self._columns) == 3:
260                 # 3 cols
261                 if prefix and not o.startswith(prefix_):
262                     continue
263
264                 col1, col2 = o.split(' ', 1)
265                 if name and col2 != name:
266                     continue
267
268                 col3 = self._config.get(self._section, o)
269                 if value and col3 != value:
270                     continue
271
272                 r = [col1, col2, col3]
273             else:
274                 # 2 cols
275                 if prefix and o != prefix:
276                     continue
277                 r = [o, self._config.get(self._section, o)]
278
279             if columns:
280                 s = []
281                 for c in columns:
282                     s.append(r[self._columns.index(c)])
283                 res.append(s)
284             else:
285                 res.append(r)
286
287         self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
288                                                  repr(kvfilter),
289                                                  repr(columns),
290                                                  repr(res)))
291         return res
292
293     def insert(self, values):
294         raise NotImplementedError
295
296     def update(self, values, kvfilter):
297         raise NotImplementedError
298
299     def delete(self, kvfilter):
300         raise NotImplementedError
301
302
303 class Store(Log):
304     # Static, Store-level variables
305     _is_upgrade = False
306     __cleanups = {}
307
308     # Static, class-level variables
309     # Either set this to False, or implement _cleanup, in child classes
310     _should_cleanup = True
311
312     def __init__(self, config_name=None, database_url=None):
313         if config_name is None and database_url is None:
314             raise ValueError('config_name or database_url must be provided')
315         if config_name:
316             if config_name not in cherrypy.config:
317                 raise NameError('Unknown database %s' % config_name)
318             name = cherrypy.config[config_name]
319         else:
320             name = database_url
321         if name.startswith('configfile://'):
322             _, filename = name.split('://')
323             self._db = FileStore(filename)
324             self._query = FileQuery
325         else:
326             self._db = SqlStore.get_connection(name)
327             self._query = SqlQuery
328
329         if not self._is_upgrade:
330             self._check_database()
331             if self._should_cleanup:
332                 self._schedule_cleanup()
333
334     def _schedule_cleanup(self):
335         store_name = self.__class__.__name__
336         if self.is_readonly:
337             # No use in cleanups on a readonly database
338             self.debug('Not scheduling cleanup for %s due to readonly' %
339                        store_name)
340             return
341         if store_name in Store.__cleanups:
342             # This class was already scheduled, skip
343             return
344         self.debug('Scheduling cleanups for %s' % store_name)
345         # Check once every minute whether we need to clean
346         task = cherrypy.process.plugins.BackgroundTask(
347             60, self._maybe_run_cleanup)
348         task.start()
349         Store.__cleanups[store_name] = task
350
351     def _maybe_run_cleanup(self):
352         # Let's see if we need to do cleanup
353         last_clean = self.load_options('dbinfo').get('%s_last_clean' %
354                                                      self.__class__.__name__,
355                                                      {})
356         time_diff = cherrypy.config.get('cleanup_interval', 30) * 60
357         next_ts = int(time.time()) - time_diff
358         self.debug('Considering cleanup for %s: %s. Next at: %s'
359                    % (self.__class__.__name__, last_clean, next_ts))
360         if ('timestamp' not in last_clean or
361                 int(last_clean['timestamp']) <= next_ts):
362             # First store the current time so that other servers don't start
363             self.save_options('dbinfo', '%s_last_clean'
364                               % self.__class__.__name__,
365                               {'timestamp': int(time.time()),
366                                'removed_entries': -1})
367
368             # Cleanup has been long enough ago, let's run
369             self.debug('Cleaning up for %s' % self.__class__.__name__)
370             removed_entries = self._cleanup()
371             self.debug('Cleaned up %i entries for %s' %
372                        (removed_entries, self.__class__.__name__))
373             self.save_options('dbinfo', '%s_last_clean'
374                               % self.__class__.__name__,
375                               {'timestamp': int(time.time()),
376                                'removed_entries': removed_entries})
377
378     def _cleanup(self):
379         # The default cleanup is to do nothing
380         # This function should return the number of rows it cleaned up.
381         # This information may be used to automatically tune the clean period.
382         self.error('Cleanup for %s not implemented' %
383                    self.__class__.__name__)
384         return 0
385
386     def _code_schema_version(self):
387         # This function makes it possible for separate plugins to have
388         #  different schema versions. We default to the global schema
389         #  version.
390         return CURRENT_SCHEMA_VERSION
391
392     def _get_schema_version(self):
393         # We are storing multiple versions: one per class
394         # That way, we can support plugins with differing schema versions from
395         #  the main codebase, and even in the same database.
396         q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
397         q.create()
398         q._con.close()  # pylint: disable=protected-access
399         cls_name = self.__class__.__name__
400         current_version = self.load_options('dbinfo').get('%s_schema'
401                                                           % cls_name, {})
402         if 'version' in current_version:
403             return int(current_version['version'])
404         else:
405             # Also try the old table name.
406             # "scheme" was a typo, but we need to retain that now for compat
407             fallback_version = self.load_options('dbinfo').get('scheme',
408                                                                {})
409             if 'version' in fallback_version:
410                 # Explanation for this is in def upgrade_database(self)
411                 return -1
412             else:
413                 return None
414
415     def _check_database(self):
416         if self.is_readonly:
417             # If the database is readonly, we cannot do anything to the
418             #  schema. Let's just return, and assume people checked the
419             #  upgrade notes
420             return
421
422         current_version = self._get_schema_version()
423         if current_version is None:
424             self.error('Database initialization required! ' +
425                        'Please run ipsilon-upgrade-database')
426             raise DatabaseError('Database initialization required for %s' %
427                                 self.__class__.__name__)
428         if current_version != self._code_schema_version():
429             self.error('Database upgrade required! ' +
430                        'Please run ipsilon-upgrade-database')
431             raise DatabaseError('Database upgrade required for %s' %
432                                 self.__class__.__name__)
433
434     def _store_new_schema_version(self, new_version):
435         cls_name = self.__class__.__name__
436         self.save_options('dbinfo', '%s_schema' % cls_name,
437                           {'version': new_version})
438
439     def _initialize_schema(self):
440         raise NotImplementedError()
441
442     def _upgrade_schema(self, old_version):
443         # Datastores need to figure out what to do with bigger old_versions
444         #  themselves.
445         # They might implement downgrading if that's feasible, or just throw
446         #  NotImplementedError
447         # Should return the new schema version
448         raise NotImplementedError()
449
450     def upgrade_database(self):
451         # Do whatever is needed to get schema to current version
452         old_schema_version = self._get_schema_version()
453         if old_schema_version is None:
454             # Just initialize a new schema
455             self._initialize_schema()
456             self._store_new_schema_version(self._code_schema_version())
457         elif old_schema_version == -1:
458             # This is a special-case from 1.0: we only created tables at the
459             # first time they were actually used, but the upgrade code assumes
460             # that the tables exist. So let's fix this.
461             self._initialize_schema()
462             # The old version was schema version 1
463             self._store_new_schema_version(1)
464             self.upgrade_database()
465         elif old_schema_version != self._code_schema_version():
466             # Upgrade from old_schema_version to code_schema_version
467             self.debug('Upgrading from schema version %i' % old_schema_version)
468             new_version = self._upgrade_schema(old_schema_version)
469             if not new_version:
470                 error = ('Schema upgrade error: %s did not provide a ' +
471                          'new schema version number!' %
472                          self.__class__.__name__)
473                 self.error(error)
474                 raise Exception(error)
475             self._store_new_schema_version(new_version)
476             # Check if we are now up-to-date
477             self.upgrade_database()
478
479     @property
480     def is_readonly(self):
481         return self._db.is_readonly
482
483     def _row_to_dict_tree(self, data, row):
484         name = row[0]
485         if len(row) > 2:
486             if name not in data:
487                 data[name] = dict()
488             d2 = data[name]
489             self._row_to_dict_tree(d2, row[1:])
490         else:
491             value = row[1]
492             if name in data:
493                 if data[name] is list:
494                     data[name].append(value)
495                 else:
496                     v = data[name]
497                     data[name] = [v, value]
498             else:
499                 data[name] = value
500
501     def _rows_to_dict_tree(self, rows):
502         data = dict()
503         for r in rows:
504             self._row_to_dict_tree(data, r)
505         return data
506
507     def _load_data(self, table, columns, kvfilter=None):
508         rows = []
509         try:
510             q = self._query(self._db, table, columns, trans=False)
511             rows = q.select(kvfilter)
512         except Exception, e:  # pylint: disable=broad-except
513             self.error("Failed to load data for table %s: [%s]" % (table, e))
514         return self._rows_to_dict_tree(rows)
515
516     def load_config(self):
517         table = 'config'
518         columns = ['name', 'value']
519         return self._load_data(table, columns)
520
521     def load_options(self, table, name=None):
522         kvfilter = dict()
523         if name:
524             kvfilter['name'] = name
525         options = self._load_data(table, OPTIONS_TABLE, kvfilter)
526         if name and name in options:
527             return options[name]
528         return options
529
530     def save_options(self, table, name, options):
531         curvals = dict()
532         q = None
533         try:
534             q = self._query(self._db, table, OPTIONS_TABLE)
535             rows = q.select({'name': name}, ['option', 'value'])
536             for row in rows:
537                 curvals[row[0]] = row[1]
538
539             for opt in options:
540                 if opt in curvals:
541                     q.update({'value': options[opt]},
542                              {'name': name, 'option': opt})
543                 else:
544                     q.insert((name, opt, options[opt]))
545
546             q.commit()
547         except Exception, e:  # pylint: disable=broad-except
548             if q:
549                 q.rollback()
550             self.error("Failed to save options: [%s]" % e)
551             raise
552
553     def delete_options(self, table, name, options=None):
554         kvfilter = {'name': name}
555         q = None
556         try:
557             q = self._query(self._db, table, OPTIONS_TABLE)
558             if options is None:
559                 q.delete(kvfilter)
560             else:
561                 for opt in options:
562                     kvfilter['option'] = opt
563                     q.delete(kvfilter)
564             q.commit()
565         except Exception, e:  # pylint: disable=broad-except
566             if q:
567                 q.rollback()
568             self.error("Failed to delete from %s: [%s]" % (table, e))
569             raise
570
571     def new_unique_data(self, table, data):
572         newid = str(uuid.uuid4())
573         q = None
574         try:
575             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
576             for name in data:
577                 q.insert((newid, name, data[name]))
578             q.commit()
579         except Exception, e:  # pylint: disable=broad-except
580             if q:
581                 q.rollback()
582             self.error("Failed to store %s data: [%s]" % (table, e))
583             raise
584         return newid
585
586     def get_unique_data(self, table, uuidval=None, name=None, value=None):
587         kvfilter = dict()
588         if uuidval:
589             kvfilter['uuid'] = uuidval
590         if name:
591             kvfilter['name'] = name
592         if value:
593             kvfilter['value'] = value
594         return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
595
596     def save_unique_data(self, table, data):
597         q = None
598         try:
599             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
600             for uid in data:
601                 curvals = dict()
602                 rows = q.select({'uuid': uid}, ['name', 'value'])
603                 for r in rows:
604                     curvals[r[0]] = r[1]
605
606                 datum = data[uid]
607                 for name in datum:
608                     if name in curvals:
609                         if datum[name] is None:
610                             q.delete({'uuid': uid, 'name': name})
611                         else:
612                             q.update({'value': datum[name]},
613                                      {'uuid': uid, 'name': name})
614                     else:
615                         if datum[name] is not None:
616                             q.insert((uid, name, datum[name]))
617
618             q.commit()
619         except Exception, e:  # pylint: disable=broad-except
620             if q:
621                 q.rollback()
622             self.error("Failed to store data in %s: [%s]" % (table, e))
623             raise
624
625     def del_unique_data(self, table, uuidval):
626         kvfilter = {'uuid': uuidval}
627         try:
628             q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
629             q.delete(kvfilter)
630         except Exception, e:  # pylint: disable=broad-except
631             self.error("Failed to delete data from %s: [%s]" % (table, e))
632
633     def _reset_data(self, table):
634         q = None
635         try:
636             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
637             q.drop()
638             q.create()
639             q.commit()
640         except Exception, e:  # pylint: disable=broad-except
641             if q:
642                 q.rollback()
643             self.error("Failed to erase all data from %s: [%s]" % (table, e))
644
645
646 class AdminStore(Store):
647     _should_cleanup = False
648
649     def __init__(self):
650         super(AdminStore, self).__init__('admin.config.db')
651
652     def get_data(self, plugin, idval=None, name=None, value=None):
653         return self.get_unique_data(plugin+"_data", idval, name, value)
654
655     def save_data(self, plugin, data):
656         return self.save_unique_data(plugin+"_data", data)
657
658     def new_datum(self, plugin, datum):
659         table = plugin+"_data"
660         return self.new_unique_data(table, datum)
661
662     def del_datum(self, plugin, idval):
663         table = plugin+"_data"
664         return self.del_unique_data(table, idval)
665
666     def wipe_data(self, plugin):
667         table = plugin+"_data"
668         self._reset_data(table)
669
670     def _initialize_schema(self):
671         for table in ['config',
672                       'info_config',
673                       'login_config',
674                       'provider_config']:
675             q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
676             q.create()
677             q._con.close()  # pylint: disable=protected-access
678
679     def _upgrade_schema(self, old_version):
680         if old_version == 1:
681             # In schema version 2, we added indexes and primary keys
682             for table in ['config',
683                           'info_config',
684                           'login_config',
685                           'provider_config']:
686                 # pylint: disable=protected-access
687                 table = self._query(self._db, table, OPTIONS_TABLE,
688                                     trans=False)._table
689                 self._db.add_constraint(table.primary_key)
690                 for index in table.indexes:
691                     self._db.add_index(index)
692             return 2
693         else:
694             raise NotImplementedError()
695
696     def create_plugin_data_table(self, plugin_name):
697         if not self.is_readonly:
698             table = plugin_name+'_data'
699             q = self._query(self._db, table, UNIQUE_DATA_TABLE,
700                             trans=False)
701             q.create()
702             q._con.close()  # pylint: disable=protected-access
703
704
705 class UserStore(Store):
706     _should_cleanup = False
707
708     def __init__(self, path=None):
709         super(UserStore, self).__init__('user.prefs.db')
710
711     def save_user_preferences(self, user, options):
712         self.save_options('users', user, options)
713
714     def load_user_preferences(self, user):
715         return self.load_options('users', user)
716
717     def save_plugin_data(self, plugin, user, options):
718         self.save_options(plugin+"_data", user, options)
719
720     def load_plugin_data(self, plugin, user):
721         return self.load_options(plugin+"_data", user)
722
723     def _initialize_schema(self):
724         q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
725         q.create()
726         q._con.close()  # pylint: disable=protected-access
727
728     def _upgrade_schema(self, old_version):
729         if old_version == 1:
730             # In schema version 2, we added indexes and primary keys
731             # pylint: disable=protected-access
732             table = self._query(self._db, 'users', OPTIONS_TABLE,
733                                 trans=False)._table
734             self._db.add_constraint(table.primary_key)
735             for index in table.indexes:
736                 self._db.add_index(index)
737             return 2
738         else:
739             raise NotImplementedError()
740
741
742 class TranStore(Store):
743
744     def __init__(self, path=None):
745         super(TranStore, self).__init__('transactions.db')
746         self.table = 'transactions'
747
748     def _initialize_schema(self):
749         q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
750                         trans=False)
751         q.create()
752         q._con.close()  # pylint: disable=protected-access
753
754     def _upgrade_schema(self, old_version):
755         if old_version == 1:
756             # In schema version 2, we added indexes and primary keys
757             # pylint: disable=protected-access
758             table = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
759                                 trans=False)._table
760             self._db.add_constraint(table.primary_key)
761             for index in table.indexes:
762                 self._db.add_index(index)
763             return 2
764         else:
765             raise NotImplementedError()
766
767     def _cleanup(self):
768         # pylint: disable=protected-access
769         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
770         in_one_hour = datetime.datetime.now() - datetime.timedelta(hours=1)
771         sel = select([table.columns.uuid]). \
772             where(and_(table.c.name == 'origintime',
773                        table.c.value <= in_one_hour))
774         # pylint: disable=no-value-for-parameter
775         d = table.delete().where(table.c.uuid.in_(sel))
776         return d.execute().rowcount
777
778
779 class SAML2SessionStore(Store):
780
781     def __init__(self, database_url):
782         super(SAML2SessionStore, self).__init__(database_url=database_url)
783         self.table = 'saml2_sessions'
784         # pylint: disable=protected-access
785         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
786         table.create(checkfirst=True)
787
788     def _get_unique_id_from_column(self, name, value):
789         """
790         The query is going to return only the column in the query.
791         Use this method to get the uuidval which can be used to fetch
792         the entire entry.
793
794         Returns None or the uuid of the first value found.
795         """
796         data = self.get_unique_data(self.table, name=name, value=value)
797         count = len(data)
798         if count == 0:
799             return None
800         elif count != 1:
801             raise ValueError("Multiple entries returned")
802         return data.keys()[0]
803
804     def _cleanup(self):
805         # pylint: disable=protected-access
806         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
807         sel = select([table.columns.uuid]). \
808             where(and_(table.c.name == 'expiration_time',
809                        table.c.value <= datetime.datetime.now()))
810         # pylint: disable=no-value-for-parameter
811         d = table.delete().where(table.c.uuid.in_(sel))
812         return d.execute().rowcount
813
814     def get_data(self, idval=None, name=None, value=None):
815         return self.get_unique_data(self.table, idval, name, value)
816
817     def new_session(self, datum):
818         if 'supported_logout_mechs' in datum:
819             datum['supported_logout_mechs'] = ','.join(
820                 datum['supported_logout_mechs']
821             )
822         return self.new_unique_data(self.table, datum)
823
824     def get_session(self, session_id=None, request_id=None):
825         if session_id:
826             uuidval = self._get_unique_id_from_column('session_id', session_id)
827         elif request_id:
828             uuidval = self._get_unique_id_from_column('request_id', request_id)
829         else:
830             raise ValueError("Unable to find session")
831         if not uuidval:
832             return None, None
833         data = self.get_unique_data(self.table, uuidval=uuidval)
834         return uuidval, data[uuidval]
835
836     def get_user_sessions(self, user):
837         """
838         Return a list of all sessions for a given user.
839         """
840         rows = self.get_unique_data(self.table, name='user', value=user)
841
842         # We have a list of sessions for this user, now get the details
843         logged_in = []
844         for r in rows:
845             data = self.get_unique_data(self.table, uuidval=r)
846             data[r]['supported_logout_mechs'] = data[r].get(
847                 'supported_logout_mechs', '').split(',')
848             logged_in.append(data)
849
850         return logged_in
851
852     def update_session(self, datum):
853         self.save_unique_data(self.table, datum)
854
855     def remove_session(self, uuidval):
856         self.del_unique_data(self.table, uuidval)
857
858     def wipe_data(self):
859         self._reset_data(self.table)
860
861     def _initialize_schema(self):
862         q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
863                         trans=False)
864         q.create()
865         q._con.close()  # pylint: disable=protected-access
866
867     def _upgrade_schema(self, old_version):
868         if old_version == 1:
869             # In schema version 2, we added indexes and primary keys
870             # pylint: disable=protected-access
871             table = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
872                                 trans=False)._table
873             self._db.add_constraint(table.primary_key)
874             for index in table.indexes:
875                 self._db.add_index(index)
876             return 2
877         else:
878             raise NotImplementedError()