Fix initialization of plugin_data table in AdminStore
[cascardo/ipsilon.git] / ipsilon / util / data.py
1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
2
3 import cherrypy
4 import datetime
5 from ipsilon.util.log import Log
6 from sqlalchemy import create_engine
7 from sqlalchemy import MetaData, Table, Column, Text
8 from sqlalchemy.pool import QueuePool, SingletonThreadPool
9 from sqlalchemy.schema import (PrimaryKeyConstraint, Index, AddConstraint,
10                                CreateIndex)
11 from sqlalchemy.sql import select, and_
12 import ConfigParser
13 import os
14 import uuid
15 import logging
16
17
18 CURRENT_SCHEMA_VERSION = 2
19 OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
20                  'primary_key': ('name', 'option'),
21                  'indexes': [('name',)]
22                  }
23 UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
24                      'primary_key': ('uuid', 'name'),
25                      'indexes': [('uuid',)]
26                      }
27
28
29 class DatabaseError(Exception):
30     pass
31
32
33 class BaseStore(Log):
34     # Some helper functions used for upgrades
35     def add_constraint(self, table):
36         raise NotImplementedError()
37
38     def add_index(self, index):
39         raise NotImplementedError()
40
41
42 class SqlStore(BaseStore):
43     __instances = {}
44
45     @classmethod
46     def get_connection(cls, name):
47         if name not in cls.__instances:
48             if cherrypy.config.get('db.conn.log', False):
49                 logging.debug('SqlStore new: %s', name)
50             cls.__instances[name] = SqlStore(name)
51         return cls.__instances[name]
52
53     def __init__(self, name):
54         self.db_conn_log = cherrypy.config.get('db.conn.log', False)
55         self.debug('SqlStore init: %s' % name)
56         self.name = name
57         engine_name = name
58         if '://' not in engine_name:
59             engine_name = 'sqlite:///' + engine_name
60         # This pool size is per configured database. The minimum needed,
61         #  determined by binary search, is 23. We're using 25 so we have a bit
62         #  more playroom, and then the overflow should make sure things don't
63         #  break when we suddenly need more.
64         pool_args = {'poolclass': QueuePool,
65                      'pool_size': 25,
66                      'max_overflow': 50}
67         if engine_name.startswith('sqlite://'):
68             # It's not possible to share connections for SQLite between
69             #  threads, so let's use the SingletonThreadPool for them
70             pool_args = {'poolclass': SingletonThreadPool}
71         self._dbengine = create_engine(engine_name, **pool_args)
72         self.is_readonly = False
73
74     def add_constraint(self, constraint):
75         if self._dbengine.dialect.name != 'sqlite':
76             # It is impossible to add constraints to a pre-existing table for
77             #  SQLite
78             # source: http://www.sqlite.org/omitted.html
79             create_constraint = AddConstraint(constraint, bind=self._dbengine)
80             create_constraint.execute()
81
82     def add_index(self, index):
83         add_index = CreateIndex(index, bind=self._dbengine)
84         add_index.execute()
85
86     def debug(self, fact):
87         if self.db_conn_log:
88             super(SqlStore, self).debug(fact)
89
90     def engine(self):
91         return self._dbengine
92
93     def connection(self):
94         self.debug('SqlStore connect: %s' % self.name)
95         conn = self._dbengine.connect()
96
97         def cleanup_connection():
98             self.debug('SqlStore cleanup: %s' % self.name)
99             conn.close()
100         cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
101         return conn
102
103
104 class SqlQuery(Log):
105
106     def __init__(self, db_obj, table, table_def, trans=True):
107         self._db = db_obj
108         self._con = self._db.connection()
109         self._trans = self._con.begin() if trans else None
110         self._table = self._get_table(table, table_def)
111
112     def _get_table(self, name, table_def):
113         if isinstance(table_def, list):
114             table_def = {'columns': table_def,
115                          'indexes': [],
116                          'primary_key': None}
117         table_creation = []
118         for col_name in table_def['columns']:
119             table_creation.append(Column(col_name, Text()))
120         if table_def['primary_key']:
121             table_creation.append(PrimaryKeyConstraint(
122                 *table_def['primary_key']))
123         for index in table_def['indexes']:
124             idx_name = 'idx_%s_%s' % (name, '_'.join(index))
125             table_creation.append(Index(idx_name, *index))
126         table = Table(name, MetaData(self._db.engine()), *table_creation)
127         return table
128
129     def _where(self, kvfilter):
130         where = None
131         if kvfilter is not None:
132             for k in kvfilter:
133                 w = self._table.columns[k] == kvfilter[k]
134                 if where is None:
135                     where = w
136                 else:
137                     where = where & w
138         return where
139
140     def _columns(self, columns=None):
141         cols = None
142         if columns is not None:
143             cols = []
144             for c in columns:
145                 cols.append(self._table.columns[c])
146         else:
147             cols = self._table.columns
148         return cols
149
150     def rollback(self):
151         self._trans.rollback()
152
153     def commit(self):
154         self._trans.commit()
155
156     def create(self):
157         self._table.create(checkfirst=True)
158
159     def drop(self):
160         self._table.drop(checkfirst=True)
161
162     def select(self, kvfilter=None, columns=None):
163         return self._con.execute(select(self._columns(columns),
164                                         self._where(kvfilter)))
165
166     def insert(self, values):
167         self._con.execute(self._table.insert(values))
168
169     def update(self, values, kvfilter):
170         self._con.execute(self._table.update(self._where(kvfilter), values))
171
172     def delete(self, kvfilter):
173         self._con.execute(self._table.delete(self._where(kvfilter)))
174
175
176 class FileStore(BaseStore):
177
178     def __init__(self, name):
179         self._filename = name
180         self.is_readonly = True
181         self._timestamp = None
182         self._config = None
183
184     def get_config(self):
185         try:
186             stat = os.stat(self._filename)
187         except OSError, e:
188             self.error("Unable to check config file %s: [%s]" % (
189                 self._filename, e))
190             self._config = None
191             raise
192         timestamp = stat.st_mtime
193         if self._config is None or timestamp > self._timestamp:
194             self._config = ConfigParser.RawConfigParser()
195             self._config.optionxform = str
196             self._config.read(self._filename)
197         return self._config
198
199     def add_constraint(self, table):
200         raise NotImplementedError()
201
202     def add_index(self, index):
203         raise NotImplementedError()
204
205
206 class FileQuery(Log):
207
208     def __init__(self, fstore, table, table_def, trans=True):
209         # We don't need indexes in a FileQuery, so drop that info
210         if isinstance(table_def, dict):
211             columns = table_def['columns']
212         else:
213             columns = table_def
214         self._fstore = fstore
215         self._config = fstore.get_config()
216         self._section = table
217         if len(columns) > 3 or columns[-1] != 'value':
218             raise ValueError('Unsupported configuration format')
219         self._columns = columns
220
221     def rollback(self):
222         return
223
224     def commit(self):
225         return
226
227     def create(self):
228         raise NotImplementedError
229
230     def drop(self):
231         raise NotImplementedError
232
233     def select(self, kvfilter=None, columns=None):
234         if self._section not in self._config.sections():
235             return []
236
237         opts = self._config.options(self._section)
238
239         prefix = None
240         prefix_ = ''
241         if self._columns[0] in kvfilter:
242             prefix = kvfilter[self._columns[0]]
243             prefix_ = prefix + ' '
244
245         name = None
246         if len(self._columns) == 3 and self._columns[1] in kvfilter:
247             name = kvfilter[self._columns[1]]
248
249         value = None
250         if self._columns[-1] in kvfilter:
251             value = kvfilter[self._columns[-1]]
252
253         res = []
254         for o in opts:
255             if len(self._columns) == 3:
256                 # 3 cols
257                 if prefix and not o.startswith(prefix_):
258                     continue
259
260                 col1, col2 = o.split(' ', 1)
261                 if name and col2 != name:
262                     continue
263
264                 col3 = self._config.get(self._section, o)
265                 if value and col3 != value:
266                     continue
267
268                 r = [col1, col2, col3]
269             else:
270                 # 2 cols
271                 if prefix and o != prefix:
272                     continue
273                 r = [o, self._config.get(self._section, o)]
274
275             if columns:
276                 s = []
277                 for c in columns:
278                     s.append(r[self._columns.index(c)])
279                 res.append(s)
280             else:
281                 res.append(r)
282
283         self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
284                                                  repr(kvfilter),
285                                                  repr(columns),
286                                                  repr(res)))
287         return res
288
289     def insert(self, values):
290         raise NotImplementedError
291
292     def update(self, values, kvfilter):
293         raise NotImplementedError
294
295     def delete(self, kvfilter):
296         raise NotImplementedError
297
298
299 class Store(Log):
300     _is_upgrade = False
301
302     def __init__(self, config_name=None, database_url=None):
303         if config_name is None and database_url is None:
304             raise ValueError('config_name or database_url must be provided')
305         if config_name:
306             if config_name not in cherrypy.config:
307                 raise NameError('Unknown database %s' % config_name)
308             name = cherrypy.config[config_name]
309         else:
310             name = database_url
311         if name.startswith('configfile://'):
312             _, filename = name.split('://')
313             self._db = FileStore(filename)
314             self._query = FileQuery
315         else:
316             self._db = SqlStore.get_connection(name)
317             self._query = SqlQuery
318
319         if not self._is_upgrade:
320             self._check_database()
321
322     def _code_schema_version(self):
323         # This function makes it possible for separate plugins to have
324         #  different schema versions. We default to the global schema
325         #  version.
326         return CURRENT_SCHEMA_VERSION
327
328     def _get_schema_version(self):
329         # We are storing multiple versions: one per class
330         # That way, we can support plugins with differing schema versions from
331         #  the main codebase, and even in the same database.
332         q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
333         q.create()
334         cls_name = self.__class__.__name__
335         current_version = self.load_options('dbinfo').get('%s_schema'
336                                                           % cls_name, {})
337         if 'version' in current_version:
338             return int(current_version['version'])
339         else:
340             # Also try the old table name.
341             # "scheme" was a typo, but we need to retain that now for compat
342             fallback_version = self.load_options('dbinfo').get('scheme',
343                                                                {})
344             if 'version' in fallback_version:
345                 return int(fallback_version['version'])
346             else:
347                 return None
348
349     def _check_database(self):
350         if self.is_readonly:
351             # If the database is readonly, we cannot do anything to the
352             #  schema. Let's just return, and assume people checked the
353             #  upgrade notes
354             return
355
356         current_version = self._get_schema_version()
357         if current_version is None:
358             self.error('Database initialization required! ' +
359                        'Please run ipsilon-upgrade-database')
360             raise DatabaseError('Database initialization required for %s' %
361                                 self.__class__.__name__)
362         if current_version != self._code_schema_version():
363             self.error('Database upgrade required! ' +
364                        'Please run ipsilon-upgrade-database')
365             raise DatabaseError('Database upgrade required for %s' %
366                                 self.__class__.__name__)
367
368     def _store_new_schema_version(self, new_version):
369         cls_name = self.__class__.__name__
370         self.save_options('dbinfo', '%s_schema' % cls_name,
371                           {'version': new_version})
372
373     def _initialize_schema(self):
374         raise NotImplementedError()
375
376     def _upgrade_schema(self, old_version):
377         # Datastores need to figure out what to do with bigger old_versions
378         #  themselves.
379         # They might implement downgrading if that's feasible, or just throw
380         #  NotImplementedError
381         # Should return the new schema version
382         raise NotImplementedError()
383
384     def upgrade_database(self):
385         # Do whatever is needed to get schema to current version
386         old_schema_version = self._get_schema_version()
387         if old_schema_version is None:
388             # Just initialize a new schema
389             self._initialize_schema()
390             self._store_new_schema_version(self._code_schema_version())
391         elif old_schema_version != self._code_schema_version():
392             # Upgrade from old_schema_version to code_schema_version
393             self.debug('Upgrading from schema version %i' % old_schema_version)
394             new_version = self._upgrade_schema(old_schema_version)
395             if not new_version:
396                 error = ('Schema upgrade error: %s did not provide a ' +
397                          'new schema version number!' %
398                          self.__class__.__name__)
399                 self.error(error)
400                 raise Exception(error)
401             self._store_new_schema_version(new_version)
402             # Check if we are now up-to-date
403             self.upgrade_database()
404
405     @property
406     def is_readonly(self):
407         return self._db.is_readonly
408
409     def _row_to_dict_tree(self, data, row):
410         name = row[0]
411         if len(row) > 2:
412             if name not in data:
413                 data[name] = dict()
414             d2 = data[name]
415             self._row_to_dict_tree(d2, row[1:])
416         else:
417             value = row[1]
418             if name in data:
419                 if data[name] is list:
420                     data[name].append(value)
421                 else:
422                     v = data[name]
423                     data[name] = [v, value]
424             else:
425                 data[name] = value
426
427     def _rows_to_dict_tree(self, rows):
428         data = dict()
429         for r in rows:
430             self._row_to_dict_tree(data, r)
431         return data
432
433     def _load_data(self, table, columns, kvfilter=None):
434         rows = []
435         try:
436             q = self._query(self._db, table, columns, trans=False)
437             rows = q.select(kvfilter)
438         except Exception, e:  # pylint: disable=broad-except
439             self.error("Failed to load data for table %s: [%s]" % (table, e))
440         return self._rows_to_dict_tree(rows)
441
442     def load_config(self):
443         table = 'config'
444         columns = ['name', 'value']
445         return self._load_data(table, columns)
446
447     def load_options(self, table, name=None):
448         kvfilter = dict()
449         if name:
450             kvfilter['name'] = name
451         options = self._load_data(table, OPTIONS_TABLE, kvfilter)
452         if name and name in options:
453             return options[name]
454         return options
455
456     def save_options(self, table, name, options):
457         curvals = dict()
458         q = None
459         try:
460             q = self._query(self._db, table, OPTIONS_TABLE)
461             rows = q.select({'name': name}, ['option', 'value'])
462             for row in rows:
463                 curvals[row[0]] = row[1]
464
465             for opt in options:
466                 if opt in curvals:
467                     q.update({'value': options[opt]},
468                              {'name': name, 'option': opt})
469                 else:
470                     q.insert((name, opt, options[opt]))
471
472             q.commit()
473         except Exception, e:  # pylint: disable=broad-except
474             if q:
475                 q.rollback()
476             self.error("Failed to save options: [%s]" % e)
477             raise
478
479     def delete_options(self, table, name, options=None):
480         kvfilter = {'name': name}
481         q = None
482         try:
483             q = self._query(self._db, table, OPTIONS_TABLE)
484             if options is None:
485                 q.delete(kvfilter)
486             else:
487                 for opt in options:
488                     kvfilter['option'] = opt
489                     q.delete(kvfilter)
490             q.commit()
491         except Exception, e:  # pylint: disable=broad-except
492             if q:
493                 q.rollback()
494             self.error("Failed to delete from %s: [%s]" % (table, e))
495             raise
496
497     def new_unique_data(self, table, data):
498         newid = str(uuid.uuid4())
499         q = None
500         try:
501             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
502             for name in data:
503                 q.insert((newid, name, data[name]))
504             q.commit()
505         except Exception, e:  # pylint: disable=broad-except
506             if q:
507                 q.rollback()
508             self.error("Failed to store %s data: [%s]" % (table, e))
509             raise
510         return newid
511
512     def get_unique_data(self, table, uuidval=None, name=None, value=None):
513         kvfilter = dict()
514         if uuidval:
515             kvfilter['uuid'] = uuidval
516         if name:
517             kvfilter['name'] = name
518         if value:
519             kvfilter['value'] = value
520         return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
521
522     def save_unique_data(self, table, data):
523         q = None
524         try:
525             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
526             for uid in data:
527                 curvals = dict()
528                 rows = q.select({'uuid': uid}, ['name', 'value'])
529                 for r in rows:
530                     curvals[r[0]] = r[1]
531
532                 datum = data[uid]
533                 for name in datum:
534                     if name in curvals:
535                         if datum[name] is None:
536                             q.delete({'uuid': uid, 'name': name})
537                         else:
538                             q.update({'value': datum[name]},
539                                      {'uuid': uid, 'name': name})
540                     else:
541                         if datum[name] is not None:
542                             q.insert((uid, name, datum[name]))
543
544             q.commit()
545         except Exception, e:  # pylint: disable=broad-except
546             if q:
547                 q.rollback()
548             self.error("Failed to store data in %s: [%s]" % (table, e))
549             raise
550
551     def del_unique_data(self, table, uuidval):
552         kvfilter = {'uuid': uuidval}
553         try:
554             q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
555             q.delete(kvfilter)
556         except Exception, e:  # pylint: disable=broad-except
557             self.error("Failed to delete data from %s: [%s]" % (table, e))
558
559     def _reset_data(self, table):
560         q = None
561         try:
562             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
563             q.drop()
564             q.create()
565             q.commit()
566         except Exception, e:  # pylint: disable=broad-except
567             if q:
568                 q.rollback()
569             self.error("Failed to erase all data from %s: [%s]" % (table, e))
570
571
572 class AdminStore(Store):
573
574     def __init__(self):
575         super(AdminStore, self).__init__('admin.config.db')
576
577     def get_data(self, plugin, idval=None, name=None, value=None):
578         return self.get_unique_data(plugin+"_data", idval, name, value)
579
580     def save_data(self, plugin, data):
581         return self.save_unique_data(plugin+"_data", data)
582
583     def new_datum(self, plugin, datum):
584         table = plugin+"_data"
585         return self.new_unique_data(table, datum)
586
587     def del_datum(self, plugin, idval):
588         table = plugin+"_data"
589         return self.del_unique_data(table, idval)
590
591     def wipe_data(self, plugin):
592         table = plugin+"_data"
593         self._reset_data(table)
594
595     def _initialize_schema(self):
596         for table in ['config',
597                       'info_config',
598                       'login_config',
599                       'provider_config']:
600             q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
601             q.create()
602
603     def _upgrade_schema(self, old_version):
604         if old_version == 1:
605             # In schema version 2, we added indexes and primary keys
606             for table in ['config',
607                           'info_config',
608                           'login_config',
609                           'provider_config']:
610                 # pylint: disable=protected-access
611                 table = self._query(self._db, table, OPTIONS_TABLE,
612                                     trans=False)._table
613                 self._db.add_constraint(table.primary_key)
614                 for index in table.indexes:
615                     self._db.add_index(index)
616             return 2
617         else:
618             raise NotImplementedError()
619
620     def create_plugin_data_table(self, plugin_name):
621         table = plugin_name+'_data'
622         q = self._query(self._db, table, UNIQUE_DATA_TABLE,
623                         trans=False)
624         q.create()
625
626
627 class UserStore(Store):
628
629     def __init__(self, path=None):
630         super(UserStore, self).__init__('user.prefs.db')
631
632     def save_user_preferences(self, user, options):
633         self.save_options('users', user, options)
634
635     def load_user_preferences(self, user):
636         return self.load_options('users', user)
637
638     def save_plugin_data(self, plugin, user, options):
639         self.save_options(plugin+"_data", user, options)
640
641     def load_plugin_data(self, plugin, user):
642         return self.load_options(plugin+"_data", user)
643
644     def _initialize_schema(self):
645         q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
646         q.create()
647
648     def _upgrade_schema(self, old_version):
649         if old_version == 1:
650             # In schema version 2, we added indexes and primary keys
651             # pylint: disable=protected-access
652             table = self._query(self._db, 'users', OPTIONS_TABLE,
653                                 trans=False)._table
654             self._db.add_constraint(table.primary_key)
655             for index in table.indexes:
656                 self._db.add_index(index)
657             return 2
658         else:
659             raise NotImplementedError()
660
661
662 class TranStore(Store):
663
664     def __init__(self, path=None):
665         super(TranStore, self).__init__('transactions.db')
666
667     def _initialize_schema(self):
668         q = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
669                         trans=False)
670         q.create()
671
672     def _upgrade_schema(self, old_version):
673         if old_version == 1:
674             # In schema version 2, we added indexes and primary keys
675             # pylint: disable=protected-access
676             table = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
677                                 trans=False)._table
678             self._db.add_constraint(table.primary_key)
679             for index in table.indexes:
680                 self._db.add_index(index)
681             return 2
682         else:
683             raise NotImplementedError()
684
685
686 class SAML2SessionStore(Store):
687
688     def __init__(self, database_url):
689         super(SAML2SessionStore, self).__init__(database_url=database_url)
690         self.table = 'saml2_sessions'
691         # pylint: disable=protected-access
692         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
693         table.create(checkfirst=True)
694
695     def _get_unique_id_from_column(self, name, value):
696         """
697         The query is going to return only the column in the query.
698         Use this method to get the uuidval which can be used to fetch
699         the entire entry.
700
701         Returns None or the uuid of the first value found.
702         """
703         data = self.get_unique_data(self.table, name=name, value=value)
704         count = len(data)
705         if count == 0:
706             return None
707         elif count != 1:
708             raise ValueError("Multiple entries returned")
709         return data.keys()[0]
710
711     def remove_expired_sessions(self):
712         # pylint: disable=protected-access
713         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
714         sel = select([table.columns.uuid]). \
715             where(and_(table.c.name == 'expiration_time',
716                        table.c.value <= datetime.datetime.now()))
717         # pylint: disable=no-value-for-parameter
718         d = table.delete().where(table.c.uuid.in_(sel))
719         d.execute()
720
721     def get_data(self, idval=None, name=None, value=None):
722         return self.get_unique_data(self.table, idval, name, value)
723
724     def new_session(self, datum):
725         if 'supported_logout_mechs' in datum:
726             datum['supported_logout_mechs'] = ','.join(
727                 datum['supported_logout_mechs']
728             )
729         return self.new_unique_data(self.table, datum)
730
731     def get_session(self, session_id=None, request_id=None):
732         if session_id:
733             uuidval = self._get_unique_id_from_column('session_id', session_id)
734         elif request_id:
735             uuidval = self._get_unique_id_from_column('request_id', request_id)
736         else:
737             raise ValueError("Unable to find session")
738         if not uuidval:
739             return None, None
740         data = self.get_unique_data(self.table, uuidval=uuidval)
741         return uuidval, data[uuidval]
742
743     def get_user_sessions(self, user):
744         """
745         Return a list of all sessions for a given user.
746         """
747         rows = self.get_unique_data(self.table, name='user', value=user)
748
749         # We have a list of sessions for this user, now get the details
750         logged_in = []
751         for r in rows:
752             data = self.get_unique_data(self.table, uuidval=r)
753             data[r]['supported_logout_mechs'] = data[r].get(
754                 'supported_logout_mechs', '').split(',')
755             logged_in.append(data)
756
757         return logged_in
758
759     def update_session(self, datum):
760         self.save_unique_data(self.table, datum)
761
762     def remove_session(self, uuidval):
763         self.del_unique_data(self.table, uuidval)
764
765     def wipe_data(self):
766         self._reset_data(self.table)
767
768     def _initialize_schema(self):
769         q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
770                         trans=False)
771         q.create()
772
773     def _upgrade_schema(self, old_version):
774         if old_version == 1:
775             # In schema version 2, we added indexes and primary keys
776             # pylint: disable=protected-access
777             table = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
778                                 trans=False)._table
779             self._db.add_constraint(table.primary_key)
780             for index in table.indexes:
781                 self._db.add_index(index)
782             return 2
783         else:
784             raise NotImplementedError()