Also add the store name when reporting data load error
[cascardo/ipsilon.git] / ipsilon / util / data.py
1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
2
3 import cherrypy
4 import datetime
5 from ipsilon.util.log import Log
6 from sqlalchemy import create_engine
7 from sqlalchemy import MetaData, Table, Column, Text
8 from sqlalchemy.pool import QueuePool, SingletonThreadPool
9 from sqlalchemy.schema import (PrimaryKeyConstraint, Index, AddConstraint,
10                                CreateIndex)
11 from sqlalchemy.sql import select, and_
12 import ConfigParser
13 import os
14 import uuid
15 import logging
16 import time
17
18
19 CURRENT_SCHEMA_VERSION = 2
20 OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
21                  'primary_key': ('name', 'option'),
22                  'indexes': [('name',)]
23                  }
24 UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
25                      'primary_key': ('uuid', 'name'),
26                      'indexes': [('uuid',)]
27                      }
28
29
30 class DatabaseError(Exception):
31     pass
32
33
34 class BaseStore(Log):
35     # Some helper functions used for upgrades
36     def add_constraint(self, table):
37         raise NotImplementedError()
38
39     def add_index(self, index):
40         raise NotImplementedError()
41
42
43 class SqlStore(BaseStore):
44     __instances = {}
45
46     @classmethod
47     def get_connection(cls, name):
48         if name not in cls.__instances:
49             if cherrypy.config.get('db.conn.log', False):
50                 logging.debug('SqlStore new: %s', name)
51             cls.__instances[name] = SqlStore(name)
52         return cls.__instances[name]
53
54     def __init__(self, name):
55         self.db_conn_log = cherrypy.config.get('db.conn.log', False)
56         self.debug('SqlStore init: %s' % name)
57         self.name = name
58         engine_name = name
59         if '://' not in engine_name:
60             engine_name = 'sqlite:///' + engine_name
61         # This pool size is per configured database. The minimum needed,
62         #  determined by binary search, is 23. We're using 25 so we have a bit
63         #  more playroom, and then the overflow should make sure things don't
64         #  break when we suddenly need more.
65         pool_args = {'poolclass': QueuePool,
66                      'pool_size': 25,
67                      'max_overflow': 50}
68         if engine_name.startswith('sqlite://'):
69             # It's not possible to share connections for SQLite between
70             #  threads, so let's use the SingletonThreadPool for them
71             pool_args = {'poolclass': SingletonThreadPool}
72         self._dbengine = create_engine(engine_name,
73                                        echo=cherrypy.config.get('db.echo',
74                                                                 False),
75                                        **pool_args)
76         self.is_readonly = False
77
78     def add_constraint(self, constraint):
79         if self._dbengine.dialect.name != 'sqlite':
80             # It is impossible to add constraints to a pre-existing table for
81             #  SQLite
82             # source: http://www.sqlite.org/omitted.html
83             create_constraint = AddConstraint(constraint, bind=self._dbengine)
84             create_constraint.execute()
85
86     def add_index(self, index):
87         add_index = CreateIndex(index, bind=self._dbengine)
88         add_index.execute()
89
90     def debug(self, fact):
91         if self.db_conn_log:
92             super(SqlStore, self).debug(fact)
93
94     def engine(self):
95         return self._dbengine
96
97     def connection(self):
98         self.debug('SqlStore connect: %s' % self.name)
99         conn = self._dbengine.connect()
100
101         def cleanup_connection():
102             self.debug('SqlStore cleanup: %s' % self.name)
103             conn.close()
104         cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
105         return conn
106
107
108 class SqlQuery(Log):
109
110     def __init__(self, db_obj, table, table_def, trans=True):
111         self._db = db_obj
112         self._con = self._db.connection()
113         self._trans = self._con.begin() if trans else None
114         self._table = self._get_table(table, table_def)
115
116     def _get_table(self, name, table_def):
117         if isinstance(table_def, list):
118             table_def = {'columns': table_def,
119                          'indexes': [],
120                          'primary_key': None}
121         table_creation = []
122         for col_name in table_def['columns']:
123             table_creation.append(Column(col_name, Text()))
124         if table_def['primary_key']:
125             table_creation.append(PrimaryKeyConstraint(
126                 *table_def['primary_key']))
127         for index in table_def['indexes']:
128             idx_name = 'idx_%s_%s' % (name, '_'.join(index))
129             table_creation.append(Index(idx_name, *index))
130         table = Table(name, MetaData(self._db.engine()), *table_creation)
131         return table
132
133     def _where(self, kvfilter):
134         where = None
135         if kvfilter is not None:
136             for k in kvfilter:
137                 w = self._table.columns[k] == kvfilter[k]
138                 if where is None:
139                     where = w
140                 else:
141                     where = where & w
142         return where
143
144     def _columns(self, columns=None):
145         cols = None
146         if columns is not None:
147             cols = []
148             for c in columns:
149                 cols.append(self._table.columns[c])
150         else:
151             cols = self._table.columns
152         return cols
153
154     def rollback(self):
155         self._trans.rollback()
156
157     def commit(self):
158         self._trans.commit()
159
160     def create(self):
161         self._table.create(checkfirst=True)
162
163     def drop(self):
164         self._table.drop(checkfirst=True)
165
166     def select(self, kvfilter=None, columns=None):
167         return self._con.execute(select(self._columns(columns),
168                                         self._where(kvfilter)))
169
170     def insert(self, values):
171         self._con.execute(self._table.insert(values))
172
173     def update(self, values, kvfilter):
174         self._con.execute(self._table.update(self._where(kvfilter), values))
175
176     def delete(self, kvfilter):
177         self._con.execute(self._table.delete(self._where(kvfilter)))
178
179
180 class FileStore(BaseStore):
181
182     def __init__(self, name):
183         self._filename = name
184         self.is_readonly = True
185         self._timestamp = None
186         self._config = None
187
188     def get_config(self):
189         try:
190             stat = os.stat(self._filename)
191         except OSError, e:
192             self.error("Unable to check config file %s: [%s]" % (
193                 self._filename, e))
194             self._config = None
195             raise
196         timestamp = stat.st_mtime
197         if self._config is None or timestamp > self._timestamp:
198             self._config = ConfigParser.RawConfigParser()
199             self._config.optionxform = str
200             self._config.read(self._filename)
201         return self._config
202
203     def add_constraint(self, table):
204         raise NotImplementedError()
205
206     def add_index(self, index):
207         raise NotImplementedError()
208
209
210 class FileQuery(Log):
211
212     def __init__(self, fstore, table, table_def, trans=True):
213         # We don't need indexes in a FileQuery, so drop that info
214         if isinstance(table_def, dict):
215             columns = table_def['columns']
216         else:
217             columns = table_def
218         self._fstore = fstore
219         self._config = fstore.get_config()
220         self._section = table
221         if len(columns) > 3 or columns[-1] != 'value':
222             raise ValueError('Unsupported configuration format')
223         self._columns = columns
224
225     def rollback(self):
226         return
227
228     def commit(self):
229         return
230
231     def create(self):
232         raise NotImplementedError
233
234     def drop(self):
235         raise NotImplementedError
236
237     def select(self, kvfilter=None, columns=None):
238         if self._section not in self._config.sections():
239             return []
240
241         opts = self._config.options(self._section)
242
243         prefix = None
244         prefix_ = ''
245         if self._columns[0] in kvfilter:
246             prefix = kvfilter[self._columns[0]]
247             prefix_ = prefix + ' '
248
249         name = None
250         if len(self._columns) == 3 and self._columns[1] in kvfilter:
251             name = kvfilter[self._columns[1]]
252
253         value = None
254         if self._columns[-1] in kvfilter:
255             value = kvfilter[self._columns[-1]]
256
257         res = []
258         for o in opts:
259             if len(self._columns) == 3:
260                 # 3 cols
261                 if prefix and not o.startswith(prefix_):
262                     continue
263
264                 col1, col2 = o.split(' ', 1)
265                 if name and col2 != name:
266                     continue
267
268                 col3 = self._config.get(self._section, o)
269                 if value and col3 != value:
270                     continue
271
272                 r = [col1, col2, col3]
273             else:
274                 # 2 cols
275                 if prefix and o != prefix:
276                     continue
277                 r = [o, self._config.get(self._section, o)]
278
279             if columns:
280                 s = []
281                 for c in columns:
282                     s.append(r[self._columns.index(c)])
283                 res.append(s)
284             else:
285                 res.append(r)
286
287         self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
288                                                  repr(kvfilter),
289                                                  repr(columns),
290                                                  repr(res)))
291         return res
292
293     def insert(self, values):
294         raise NotImplementedError
295
296     def update(self, values, kvfilter):
297         raise NotImplementedError
298
299     def delete(self, kvfilter):
300         raise NotImplementedError
301
302
303 class Store(Log):
304     # Static, Store-level variables
305     _is_upgrade = False
306     __cleanups = {}
307
308     # Static, class-level variables
309     # Either set this to False, or implement _cleanup, in child classes
310     _should_cleanup = True
311
312     def __init__(self, config_name=None, database_url=None):
313         if config_name is None and database_url is None:
314             raise ValueError('config_name or database_url must be provided')
315         if config_name:
316             if config_name not in cherrypy.config:
317                 raise NameError('Unknown database %s' % config_name)
318             name = cherrypy.config[config_name]
319         else:
320             name = database_url
321         if name.startswith('configfile://'):
322             _, filename = name.split('://')
323             self._db = FileStore(filename)
324             self._query = FileQuery
325         else:
326             self._db = SqlStore.get_connection(name)
327             self._query = SqlQuery
328
329         if not self._is_upgrade:
330             self._check_database()
331             if self._should_cleanup:
332                 self._schedule_cleanup()
333
334     def _schedule_cleanup(self):
335         store_name = self.__class__.__name__
336         if self.is_readonly:
337             # No use in cleanups on a readonly database
338             self.debug('Not scheduling cleanup for %s due to readonly' %
339                        store_name)
340             return
341         if store_name in Store.__cleanups:
342             # This class was already scheduled, skip
343             return
344         self.debug('Scheduling cleanups for %s' % store_name)
345         # Check once every minute whether we need to clean
346         task = cherrypy.process.plugins.BackgroundTask(
347             60, self._maybe_run_cleanup)
348         task.start()
349         Store.__cleanups[store_name] = task
350
351     def _maybe_run_cleanup(self):
352         # Let's see if we need to do cleanup
353         last_clean = self.load_options('dbinfo').get('%s_last_clean' %
354                                                      self.__class__.__name__,
355                                                      {})
356         time_diff = cherrypy.config.get('cleanup_interval', 30) * 60
357         next_ts = int(time.time()) - time_diff
358         self.debug('Considering cleanup for %s: %s. Next at: %s'
359                    % (self.__class__.__name__, last_clean, next_ts))
360         if ('timestamp' not in last_clean or
361                 int(last_clean['timestamp']) <= next_ts):
362             # First store the current time so that other servers don't start
363             self.save_options('dbinfo', '%s_last_clean'
364                               % self.__class__.__name__,
365                               {'timestamp': int(time.time()),
366                                'removed_entries': -1})
367
368             # Cleanup has been long enough ago, let's run
369             self.debug('Cleaning up for %s' % self.__class__.__name__)
370             removed_entries = self._cleanup()
371             self.debug('Cleaned up %i entries for %s' %
372                        (removed_entries, self.__class__.__name__))
373             self.save_options('dbinfo', '%s_last_clean'
374                               % self.__class__.__name__,
375                               {'timestamp': int(time.time()),
376                                'removed_entries': removed_entries})
377
378     def _cleanup(self):
379         # The default cleanup is to do nothing
380         # This function should return the number of rows it cleaned up.
381         # This information may be used to automatically tune the clean period.
382         self.error('Cleanup for %s not implemented' %
383                    self.__class__.__name__)
384         return 0
385
386     def _code_schema_version(self):
387         # This function makes it possible for separate plugins to have
388         #  different schema versions. We default to the global schema
389         #  version.
390         return CURRENT_SCHEMA_VERSION
391
392     def _get_schema_version(self):
393         # We are storing multiple versions: one per class
394         # That way, we can support plugins with differing schema versions from
395         #  the main codebase, and even in the same database.
396         q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
397         q.create()
398         q._con.close()  # pylint: disable=protected-access
399         cls_name = self.__class__.__name__
400         current_version = self.load_options('dbinfo').get('%s_schema'
401                                                           % cls_name, {})
402         if 'version' in current_version:
403             return int(current_version['version'])
404         else:
405             # Also try the old table name.
406             # "scheme" was a typo, but we need to retain that now for compat
407             fallback_version = self.load_options('dbinfo').get('scheme',
408                                                                {})
409             if 'version' in fallback_version:
410                 # Explanation for this is in def upgrade_database(self)
411                 return -1
412             else:
413                 return None
414
415     def _check_database(self):
416         if self.is_readonly:
417             # If the database is readonly, we cannot do anything to the
418             #  schema. Let's just return, and assume people checked the
419             #  upgrade notes
420             return
421
422         current_version = self._get_schema_version()
423         if current_version is None:
424             self.error('Database initialization required! ' +
425                        'Please run ipsilon-upgrade-database')
426             raise DatabaseError('Database initialization required for %s' %
427                                 self.__class__.__name__)
428         if current_version != self._code_schema_version():
429             self.error('Database upgrade required! ' +
430                        'Please run ipsilon-upgrade-database')
431             raise DatabaseError('Database upgrade required for %s' %
432                                 self.__class__.__name__)
433
434     def _store_new_schema_version(self, new_version):
435         cls_name = self.__class__.__name__
436         self.save_options('dbinfo', '%s_schema' % cls_name,
437                           {'version': new_version})
438
439     def _initialize_schema(self):
440         raise NotImplementedError()
441
442     def _upgrade_schema(self, old_version):
443         # Datastores need to figure out what to do with bigger old_versions
444         #  themselves.
445         # They might implement downgrading if that's feasible, or just throw
446         #  NotImplementedError
447         # Should return the new schema version
448         raise NotImplementedError()
449
450     def upgrade_database(self):
451         # Do whatever is needed to get schema to current version
452         old_schema_version = self._get_schema_version()
453         if old_schema_version is None:
454             # Just initialize a new schema
455             self._initialize_schema()
456             self._store_new_schema_version(self._code_schema_version())
457         elif old_schema_version == -1:
458             # This is a special-case from 1.0: we only created tables at the
459             # first time they were actually used, but the upgrade code assumes
460             # that the tables exist. So let's fix this.
461             self._initialize_schema()
462             # The old version was schema version 1
463             self._store_new_schema_version(1)
464             self.upgrade_database()
465         elif old_schema_version != self._code_schema_version():
466             # Upgrade from old_schema_version to code_schema_version
467             self.debug('Upgrading from schema version %i' % old_schema_version)
468             new_version = self._upgrade_schema(old_schema_version)
469             if not new_version:
470                 error = ('Schema upgrade error: %s did not provide a ' +
471                          'new schema version number!' %
472                          self.__class__.__name__)
473                 self.error(error)
474                 raise Exception(error)
475             self._store_new_schema_version(new_version)
476             # Check if we are now up-to-date
477             self.upgrade_database()
478
479     @property
480     def is_readonly(self):
481         return self._db.is_readonly
482
483     def _row_to_dict_tree(self, data, row):
484         name = row[0]
485         if len(row) > 2:
486             if name not in data:
487                 data[name] = dict()
488             d2 = data[name]
489             self._row_to_dict_tree(d2, row[1:])
490         else:
491             value = row[1]
492             if name in data:
493                 if data[name] is list:
494                     data[name].append(value)
495                 else:
496                     v = data[name]
497                     data[name] = [v, value]
498             else:
499                 data[name] = value
500
501     def _rows_to_dict_tree(self, rows):
502         data = dict()
503         for r in rows:
504             self._row_to_dict_tree(data, r)
505         return data
506
507     def _load_data(self, table, columns, kvfilter=None):
508         rows = []
509         try:
510             q = self._query(self._db, table, columns, trans=False)
511             rows = q.select(kvfilter)
512         except Exception, e:  # pylint: disable=broad-except
513             self.error("Failed to load data for table %s for store %s: [%s]"
514                        % (table, self.__class__.__name__, e))
515         return self._rows_to_dict_tree(rows)
516
517     def load_config(self):
518         table = 'config'
519         columns = ['name', 'value']
520         return self._load_data(table, columns)
521
522     def load_options(self, table, name=None):
523         kvfilter = dict()
524         if name:
525             kvfilter['name'] = name
526         options = self._load_data(table, OPTIONS_TABLE, kvfilter)
527         if name and name in options:
528             return options[name]
529         return options
530
531     def save_options(self, table, name, options):
532         curvals = dict()
533         q = None
534         try:
535             q = self._query(self._db, table, OPTIONS_TABLE)
536             rows = q.select({'name': name}, ['option', 'value'])
537             for row in rows:
538                 curvals[row[0]] = row[1]
539
540             for opt in options:
541                 if opt in curvals:
542                     q.update({'value': options[opt]},
543                              {'name': name, 'option': opt})
544                 else:
545                     q.insert((name, opt, options[opt]))
546
547             q.commit()
548         except Exception, e:  # pylint: disable=broad-except
549             if q:
550                 q.rollback()
551             self.error("Failed to save options: [%s]" % e)
552             raise
553
554     def delete_options(self, table, name, options=None):
555         kvfilter = {'name': name}
556         q = None
557         try:
558             q = self._query(self._db, table, OPTIONS_TABLE)
559             if options is None:
560                 q.delete(kvfilter)
561             else:
562                 for opt in options:
563                     kvfilter['option'] = opt
564                     q.delete(kvfilter)
565             q.commit()
566         except Exception, e:  # pylint: disable=broad-except
567             if q:
568                 q.rollback()
569             self.error("Failed to delete from %s: [%s]" % (table, e))
570             raise
571
572     def new_unique_data(self, table, data):
573         newid = str(uuid.uuid4())
574         q = None
575         try:
576             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
577             for name in data:
578                 q.insert((newid, name, data[name]))
579             q.commit()
580         except Exception, e:  # pylint: disable=broad-except
581             if q:
582                 q.rollback()
583             self.error("Failed to store %s data: [%s]" % (table, e))
584             raise
585         return newid
586
587     def get_unique_data(self, table, uuidval=None, name=None, value=None):
588         kvfilter = dict()
589         if uuidval:
590             kvfilter['uuid'] = uuidval
591         if name:
592             kvfilter['name'] = name
593         if value:
594             kvfilter['value'] = value
595         return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
596
597     def save_unique_data(self, table, data):
598         q = None
599         try:
600             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
601             for uid in data:
602                 curvals = dict()
603                 rows = q.select({'uuid': uid}, ['name', 'value'])
604                 for r in rows:
605                     curvals[r[0]] = r[1]
606
607                 datum = data[uid]
608                 for name in datum:
609                     if name in curvals:
610                         if datum[name] is None:
611                             q.delete({'uuid': uid, 'name': name})
612                         else:
613                             q.update({'value': datum[name]},
614                                      {'uuid': uid, 'name': name})
615                     else:
616                         if datum[name] is not None:
617                             q.insert((uid, name, datum[name]))
618
619             q.commit()
620         except Exception, e:  # pylint: disable=broad-except
621             if q:
622                 q.rollback()
623             self.error("Failed to store data in %s: [%s]" % (table, e))
624             raise
625
626     def del_unique_data(self, table, uuidval):
627         kvfilter = {'uuid': uuidval}
628         try:
629             q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
630             q.delete(kvfilter)
631         except Exception, e:  # pylint: disable=broad-except
632             self.error("Failed to delete data from %s: [%s]" % (table, e))
633
634     def _reset_data(self, table):
635         q = None
636         try:
637             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
638             q.drop()
639             q.create()
640             q.commit()
641         except Exception, e:  # pylint: disable=broad-except
642             if q:
643                 q.rollback()
644             self.error("Failed to erase all data from %s: [%s]" % (table, e))
645
646
647 class AdminStore(Store):
648     _should_cleanup = False
649
650     def __init__(self):
651         super(AdminStore, self).__init__('admin.config.db')
652
653     def get_data(self, plugin, idval=None, name=None, value=None):
654         return self.get_unique_data(plugin+"_data", idval, name, value)
655
656     def save_data(self, plugin, data):
657         return self.save_unique_data(plugin+"_data", data)
658
659     def new_datum(self, plugin, datum):
660         table = plugin+"_data"
661         return self.new_unique_data(table, datum)
662
663     def del_datum(self, plugin, idval):
664         table = plugin+"_data"
665         return self.del_unique_data(table, idval)
666
667     def wipe_data(self, plugin):
668         table = plugin+"_data"
669         self._reset_data(table)
670
671     def _initialize_schema(self):
672         for table in ['config',
673                       'info_config',
674                       'login_config',
675                       'provider_config']:
676             q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
677             q.create()
678             q._con.close()  # pylint: disable=protected-access
679
680     def _upgrade_schema(self, old_version):
681         if old_version == 1:
682             # In schema version 2, we added indexes and primary keys
683             for table in ['config',
684                           'info_config',
685                           'login_config',
686                           'provider_config']:
687                 # pylint: disable=protected-access
688                 table = self._query(self._db, table, OPTIONS_TABLE,
689                                     trans=False)._table
690                 self._db.add_constraint(table.primary_key)
691                 for index in table.indexes:
692                     self._db.add_index(index)
693             return 2
694         else:
695             raise NotImplementedError()
696
697     def create_plugin_data_table(self, plugin_name):
698         if not self.is_readonly:
699             table = plugin_name+'_data'
700             q = self._query(self._db, table, UNIQUE_DATA_TABLE,
701                             trans=False)
702             q.create()
703             q._con.close()  # pylint: disable=protected-access
704
705
706 class UserStore(Store):
707     _should_cleanup = False
708
709     def __init__(self, path=None):
710         super(UserStore, self).__init__('user.prefs.db')
711
712     def save_user_preferences(self, user, options):
713         self.save_options('users', user, options)
714
715     def load_user_preferences(self, user):
716         return self.load_options('users', user)
717
718     def save_plugin_data(self, plugin, user, options):
719         self.save_options(plugin+"_data", user, options)
720
721     def load_plugin_data(self, plugin, user):
722         return self.load_options(plugin+"_data", user)
723
724     def _initialize_schema(self):
725         q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
726         q.create()
727         q._con.close()  # pylint: disable=protected-access
728
729     def _upgrade_schema(self, old_version):
730         if old_version == 1:
731             # In schema version 2, we added indexes and primary keys
732             # pylint: disable=protected-access
733             table = self._query(self._db, 'users', OPTIONS_TABLE,
734                                 trans=False)._table
735             self._db.add_constraint(table.primary_key)
736             for index in table.indexes:
737                 self._db.add_index(index)
738             return 2
739         else:
740             raise NotImplementedError()
741
742
743 class TranStore(Store):
744
745     def __init__(self, path=None):
746         super(TranStore, self).__init__('transactions.db')
747         self.table = 'transactions'
748
749     def _initialize_schema(self):
750         q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
751                         trans=False)
752         q.create()
753         q._con.close()  # pylint: disable=protected-access
754
755     def _upgrade_schema(self, old_version):
756         if old_version == 1:
757             # In schema version 2, we added indexes and primary keys
758             # pylint: disable=protected-access
759             table = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
760                                 trans=False)._table
761             self._db.add_constraint(table.primary_key)
762             for index in table.indexes:
763                 self._db.add_index(index)
764             return 2
765         else:
766             raise NotImplementedError()
767
768     def _cleanup(self):
769         # pylint: disable=protected-access
770         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
771         in_one_hour = datetime.datetime.now() - datetime.timedelta(hours=1)
772         sel = select([table.columns.uuid]). \
773             where(and_(table.c.name == 'origintime',
774                        table.c.value <= in_one_hour))
775         # pylint: disable=no-value-for-parameter
776         d = table.delete().where(table.c.uuid.in_(sel))
777         return d.execute().rowcount
778
779
780 class SAML2SessionStore(Store):
781
782     def __init__(self, database_url):
783         super(SAML2SessionStore, self).__init__(database_url=database_url)
784         self.table = 'saml2_sessions'
785         # pylint: disable=protected-access
786         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
787         table.create(checkfirst=True)
788
789     def _get_unique_id_from_column(self, name, value):
790         """
791         The query is going to return only the column in the query.
792         Use this method to get the uuidval which can be used to fetch
793         the entire entry.
794
795         Returns None or the uuid of the first value found.
796         """
797         data = self.get_unique_data(self.table, name=name, value=value)
798         count = len(data)
799         if count == 0:
800             return None
801         elif count != 1:
802             raise ValueError("Multiple entries returned")
803         return data.keys()[0]
804
805     def _cleanup(self):
806         # pylint: disable=protected-access
807         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
808         sel = select([table.columns.uuid]). \
809             where(and_(table.c.name == 'expiration_time',
810                        table.c.value <= datetime.datetime.now()))
811         # pylint: disable=no-value-for-parameter
812         d = table.delete().where(table.c.uuid.in_(sel))
813         return d.execute().rowcount
814
815     def get_data(self, idval=None, name=None, value=None):
816         return self.get_unique_data(self.table, idval, name, value)
817
818     def new_session(self, datum):
819         if 'supported_logout_mechs' in datum:
820             datum['supported_logout_mechs'] = ','.join(
821                 datum['supported_logout_mechs']
822             )
823         return self.new_unique_data(self.table, datum)
824
825     def get_session(self, session_id=None, request_id=None):
826         if session_id:
827             uuidval = self._get_unique_id_from_column('session_id', session_id)
828         elif request_id:
829             uuidval = self._get_unique_id_from_column('request_id', request_id)
830         else:
831             raise ValueError("Unable to find session")
832         if not uuidval:
833             return None, None
834         data = self.get_unique_data(self.table, uuidval=uuidval)
835         return uuidval, data[uuidval]
836
837     def get_user_sessions(self, user):
838         """
839         Return a list of all sessions for a given user.
840         """
841         rows = self.get_unique_data(self.table, name='user', value=user)
842
843         # We have a list of sessions for this user, now get the details
844         logged_in = []
845         for r in rows:
846             data = self.get_unique_data(self.table, uuidval=r)
847             data[r]['supported_logout_mechs'] = data[r].get(
848                 'supported_logout_mechs', '').split(',')
849             logged_in.append(data)
850
851         return logged_in
852
853     def update_session(self, datum):
854         self.save_unique_data(self.table, datum)
855
856     def remove_session(self, uuidval):
857         self.del_unique_data(self.table, uuidval)
858
859     def wipe_data(self):
860         self._reset_data(self.table)
861
862     def _initialize_schema(self):
863         q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
864                         trans=False)
865         q.create()
866         q._con.close()  # pylint: disable=protected-access
867
868     def _upgrade_schema(self, old_version):
869         if old_version == 1:
870             # In schema version 2, we added indexes and primary keys
871             # pylint: disable=protected-access
872             table = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
873                                 trans=False)._table
874             self._db.add_constraint(table.primary_key)
875             for index in table.indexes:
876                 self._db.add_index(index)
877             return 2
878         else:
879             raise NotImplementedError()