Add SQL primary key and indexes
[cascardo/ipsilon.git] / ipsilon / util / data.py
1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
2
3 import cherrypy
4 import datetime
5 from ipsilon.util.log import Log
6 from sqlalchemy import create_engine
7 from sqlalchemy import MetaData, Table, Column, Text
8 from sqlalchemy.pool import QueuePool, SingletonThreadPool
9 from sqlalchemy.schema import PrimaryKeyConstraint, Index
10 from sqlalchemy.sql import select, and_
11 import ConfigParser
12 import os
13 import uuid
14 import logging
15
16
17 CURRENT_SCHEMA_VERSION = 2
18 OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
19                  'primary_key': ('name', 'option'),
20                  'indexes': [('name',)]
21                  }
22 UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
23                      'primary_key': ('uuid', 'name'),
24                      'indexes': [('uuid',)]
25                      }
26
27
28 class DatabaseError(Exception):
29     pass
30
31
32 class SqlStore(Log):
33     __instances = {}
34
35     @classmethod
36     def get_connection(cls, name):
37         if name not in cls.__instances:
38             if cherrypy.config.get('db.conn.log', False):
39                 logging.debug('SqlStore new: %s', name)
40             cls.__instances[name] = SqlStore(name)
41         return cls.__instances[name]
42
43     def __init__(self, name):
44         self.db_conn_log = cherrypy.config.get('db.conn.log', False)
45         self.debug('SqlStore init: %s' % name)
46         self.name = name
47         engine_name = name
48         if '://' not in engine_name:
49             engine_name = 'sqlite:///' + engine_name
50         # This pool size is per configured database. The minimum needed,
51         #  determined by binary search, is 23. We're using 25 so we have a bit
52         #  more playroom, and then the overflow should make sure things don't
53         #  break when we suddenly need more.
54         pool_args = {'poolclass': QueuePool,
55                      'pool_size': 25,
56                      'max_overflow': 50}
57         if engine_name.startswith('sqlite://'):
58             # It's not possible to share connections for SQLite between
59             #  threads, so let's use the SingletonThreadPool for them
60             pool_args = {'poolclass': SingletonThreadPool}
61         self._dbengine = create_engine(engine_name, **pool_args)
62         self.is_readonly = False
63
64     def debug(self, fact):
65         if self.db_conn_log:
66             super(SqlStore, self).debug(fact)
67
68     def engine(self):
69         return self._dbengine
70
71     def connection(self):
72         self.debug('SqlStore connect: %s' % self.name)
73         conn = self._dbengine.connect()
74
75         def cleanup_connection():
76             self.debug('SqlStore cleanup: %s' % self.name)
77             conn.close()
78         cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
79         return conn
80
81
82 class SqlQuery(Log):
83
84     def __init__(self, db_obj, table, table_def, trans=True):
85         self._db = db_obj
86         self._con = self._db.connection()
87         self._trans = self._con.begin() if trans else None
88         self._table = self._get_table(table, table_def)
89
90     def _get_table(self, name, table_def):
91         if isinstance(table_def, list):
92             table_def = {'columns': table_def,
93                          'indexes': [],
94                          'primary_key': None}
95         table_creation = []
96         for col_name in table_def['columns']:
97             table_creation.append(Column(col_name, Text()))
98         if table_def['primary_key']:
99             table_creation.append(PrimaryKeyConstraint(
100                 *table_def['primary_key']))
101         for index in table_def['indexes']:
102             idx_name = 'idx_%s_%s' % (name, '_'.join(index))
103             table_creation.append(Index(idx_name, *index))
104         table = Table(name, MetaData(self._db.engine()), *table_creation)
105         return table
106
107     def _where(self, kvfilter):
108         where = None
109         if kvfilter is not None:
110             for k in kvfilter:
111                 w = self._table.columns[k] == kvfilter[k]
112                 if where is None:
113                     where = w
114                 else:
115                     where = where & w
116         return where
117
118     def _columns(self, columns=None):
119         cols = None
120         if columns is not None:
121             cols = []
122             for c in columns:
123                 cols.append(self._table.columns[c])
124         else:
125             cols = self._table.columns
126         return cols
127
128     def rollback(self):
129         self._trans.rollback()
130
131     def commit(self):
132         self._trans.commit()
133
134     def create(self):
135         self._table.create(checkfirst=True)
136
137     def drop(self):
138         self._table.drop(checkfirst=True)
139
140     def select(self, kvfilter=None, columns=None):
141         return self._con.execute(select(self._columns(columns),
142                                         self._where(kvfilter)))
143
144     def insert(self, values):
145         self._con.execute(self._table.insert(values))
146
147     def update(self, values, kvfilter):
148         self._con.execute(self._table.update(self._where(kvfilter), values))
149
150     def delete(self, kvfilter):
151         self._con.execute(self._table.delete(self._where(kvfilter)))
152
153
154 class FileStore(Log):
155
156     def __init__(self, name):
157         self._filename = name
158         self.is_readonly = True
159         self._timestamp = None
160         self._config = None
161
162     def get_config(self):
163         try:
164             stat = os.stat(self._filename)
165         except OSError, e:
166             self.error("Unable to check config file %s: [%s]" % (
167                 self._filename, e))
168             self._config = None
169             raise
170         timestamp = stat.st_mtime
171         if self._config is None or timestamp > self._timestamp:
172             self._config = ConfigParser.RawConfigParser()
173             self._config.optionxform = str
174             self._config.read(self._filename)
175         return self._config
176
177
178 class FileQuery(Log):
179
180     def __init__(self, fstore, table, table_def, trans=True):
181         # We don't need indexes in a FileQuery, so drop that info
182         if isinstance(table_def, dict):
183             columns = table_def['columns']
184         else:
185             columns = table_def
186         self._fstore = fstore
187         self._config = fstore.get_config()
188         self._section = table
189         if len(columns) > 3 or columns[-1] != 'value':
190             raise ValueError('Unsupported configuration format')
191         self._columns = columns
192
193     def rollback(self):
194         return
195
196     def commit(self):
197         return
198
199     def create(self):
200         raise NotImplementedError
201
202     def drop(self):
203         raise NotImplementedError
204
205     def select(self, kvfilter=None, columns=None):
206         if self._section not in self._config.sections():
207             return []
208
209         opts = self._config.options(self._section)
210
211         prefix = None
212         prefix_ = ''
213         if self._columns[0] in kvfilter:
214             prefix = kvfilter[self._columns[0]]
215             prefix_ = prefix + ' '
216
217         name = None
218         if len(self._columns) == 3 and self._columns[1] in kvfilter:
219             name = kvfilter[self._columns[1]]
220
221         value = None
222         if self._columns[-1] in kvfilter:
223             value = kvfilter[self._columns[-1]]
224
225         res = []
226         for o in opts:
227             if len(self._columns) == 3:
228                 # 3 cols
229                 if prefix and not o.startswith(prefix_):
230                     continue
231
232                 col1, col2 = o.split(' ', 1)
233                 if name and col2 != name:
234                     continue
235
236                 col3 = self._config.get(self._section, o)
237                 if value and col3 != value:
238                     continue
239
240                 r = [col1, col2, col3]
241             else:
242                 # 2 cols
243                 if prefix and o != prefix:
244                     continue
245                 r = [o, self._config.get(self._section, o)]
246
247             if columns:
248                 s = []
249                 for c in columns:
250                     s.append(r[self._columns.index(c)])
251                 res.append(s)
252             else:
253                 res.append(r)
254
255         self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
256                                                  repr(kvfilter),
257                                                  repr(columns),
258                                                  repr(res)))
259         return res
260
261     def insert(self, values):
262         raise NotImplementedError
263
264     def update(self, values, kvfilter):
265         raise NotImplementedError
266
267     def delete(self, kvfilter):
268         raise NotImplementedError
269
270
271 class Store(Log):
272     _is_upgrade = False
273
274     def __init__(self, config_name=None, database_url=None):
275         if config_name is None and database_url is None:
276             raise ValueError('config_name or database_url must be provided')
277         if config_name:
278             if config_name not in cherrypy.config:
279                 raise NameError('Unknown database %s' % config_name)
280             name = cherrypy.config[config_name]
281         else:
282             name = database_url
283         if name.startswith('configfile://'):
284             _, filename = name.split('://')
285             self._db = FileStore(filename)
286             self._query = FileQuery
287         else:
288             self._db = SqlStore.get_connection(name)
289             self._query = SqlQuery
290
291         if not self._is_upgrade:
292             self._check_database()
293
294     def _code_schema_version(self):
295         # This function makes it possible for separate plugins to have
296         #  different schema versions. We default to the global schema
297         #  version.
298         return CURRENT_SCHEMA_VERSION
299
300     def _get_schema_version(self):
301         # We are storing multiple versions: one per class
302         # That way, we can support plugins with differing schema versions from
303         #  the main codebase, and even in the same database.
304         q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
305         q.create()
306         cls_name = self.__class__.__name__
307         current_version = self.load_options('dbinfo').get('%s_schema'
308                                                           % cls_name, {})
309         if 'version' in current_version:
310             return int(current_version['version'])
311         else:
312             # Also try the old table name.
313             # "scheme" was a typo, but we need to retain that now for compat
314             fallback_version = self.load_options('dbinfo').get('scheme',
315                                                                {})
316             if 'version' in fallback_version:
317                 return int(fallback_version['version'])
318             else:
319                 return None
320
321     def _check_database(self):
322         if self.is_readonly:
323             # If the database is readonly, we cannot do anything to the
324             #  schema. Let's just return, and assume people checked the
325             #  upgrade notes
326             return
327
328         current_version = self._get_schema_version()
329         if current_version is None:
330             self.error('Database initialization required! ' +
331                        'Please run ipsilon-upgrade-database')
332             raise DatabaseError('Database initialization required for %s' %
333                                 self.__class__.__name__)
334         if current_version != self._code_schema_version():
335             self.error('Database upgrade required! ' +
336                        'Please run ipsilon-upgrade-database')
337             raise DatabaseError('Database upgrade required for %s' %
338                                 self.__class__.__name__)
339
340     def _store_new_schema_version(self, new_version):
341         cls_name = self.__class__.__name__
342         self.save_options('dbinfo', '%s_schema' % cls_name,
343                           {'version': new_version})
344
345     def _initialize_schema(self):
346         raise NotImplementedError()
347
348     def _upgrade_schema(self, old_version):
349         # Datastores need to figure out what to do with bigger old_versions
350         #  themselves.
351         # They might implement downgrading if that's feasible, or just throw
352         #  NotImplementedError
353         raise NotImplementedError()
354
355     def upgrade_database(self):
356         # Do whatever is needed to get schema to current version
357         old_schema_version = self._get_schema_version()
358         if old_schema_version is None:
359             # Just initialize a new schema
360             self._initialize_schema()
361             self._store_new_schema_version(self._code_schema_version())
362         elif old_schema_version != self._code_schema_version():
363             # Upgrade from old_schema_version to code_schema_version
364             self._upgrade_schema(old_schema_version)
365             self._store_new_schema_version(self._code_schema_version())
366
367     @property
368     def is_readonly(self):
369         return self._db.is_readonly
370
371     def _row_to_dict_tree(self, data, row):
372         name = row[0]
373         if len(row) > 2:
374             if name not in data:
375                 data[name] = dict()
376             d2 = data[name]
377             self._row_to_dict_tree(d2, row[1:])
378         else:
379             value = row[1]
380             if name in data:
381                 if data[name] is list:
382                     data[name].append(value)
383                 else:
384                     v = data[name]
385                     data[name] = [v, value]
386             else:
387                 data[name] = value
388
389     def _rows_to_dict_tree(self, rows):
390         data = dict()
391         for r in rows:
392             self._row_to_dict_tree(data, r)
393         return data
394
395     def _load_data(self, table, columns, kvfilter=None):
396         rows = []
397         try:
398             q = self._query(self._db, table, columns, trans=False)
399             rows = q.select(kvfilter)
400         except Exception, e:  # pylint: disable=broad-except
401             self.error("Failed to load data for table %s: [%s]" % (table, e))
402         return self._rows_to_dict_tree(rows)
403
404     def load_config(self):
405         table = 'config'
406         columns = ['name', 'value']
407         return self._load_data(table, columns)
408
409     def load_options(self, table, name=None):
410         kvfilter = dict()
411         if name:
412             kvfilter['name'] = name
413         options = self._load_data(table, OPTIONS_TABLE, kvfilter)
414         if name and name in options:
415             return options[name]
416         return options
417
418     def save_options(self, table, name, options):
419         curvals = dict()
420         q = None
421         try:
422             q = self._query(self._db, table, OPTIONS_TABLE)
423             rows = q.select({'name': name}, ['option', 'value'])
424             for row in rows:
425                 curvals[row[0]] = row[1]
426
427             for opt in options:
428                 if opt in curvals:
429                     q.update({'value': options[opt]},
430                              {'name': name, 'option': opt})
431                 else:
432                     q.insert((name, opt, options[opt]))
433
434             q.commit()
435         except Exception, e:  # pylint: disable=broad-except
436             if q:
437                 q.rollback()
438             self.error("Failed to save options: [%s]" % e)
439             raise
440
441     def delete_options(self, table, name, options=None):
442         kvfilter = {'name': name}
443         q = None
444         try:
445             q = self._query(self._db, table, OPTIONS_TABLE)
446             if options is None:
447                 q.delete(kvfilter)
448             else:
449                 for opt in options:
450                     kvfilter['option'] = opt
451                     q.delete(kvfilter)
452             q.commit()
453         except Exception, e:  # pylint: disable=broad-except
454             if q:
455                 q.rollback()
456             self.error("Failed to delete from %s: [%s]" % (table, e))
457             raise
458
459     def new_unique_data(self, table, data):
460         newid = str(uuid.uuid4())
461         q = None
462         try:
463             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
464             for name in data:
465                 q.insert((newid, name, data[name]))
466             q.commit()
467         except Exception, e:  # pylint: disable=broad-except
468             if q:
469                 q.rollback()
470             self.error("Failed to store %s data: [%s]" % (table, e))
471             raise
472         return newid
473
474     def get_unique_data(self, table, uuidval=None, name=None, value=None):
475         kvfilter = dict()
476         if uuidval:
477             kvfilter['uuid'] = uuidval
478         if name:
479             kvfilter['name'] = name
480         if value:
481             kvfilter['value'] = value
482         return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
483
484     def save_unique_data(self, table, data):
485         q = None
486         try:
487             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
488             for uid in data:
489                 curvals = dict()
490                 rows = q.select({'uuid': uid}, ['name', 'value'])
491                 for r in rows:
492                     curvals[r[0]] = r[1]
493
494                 datum = data[uid]
495                 for name in datum:
496                     if name in curvals:
497                         if datum[name] is None:
498                             q.delete({'uuid': uid, 'name': name})
499                         else:
500                             q.update({'value': datum[name]},
501                                      {'uuid': uid, 'name': name})
502                     else:
503                         if datum[name] is not None:
504                             q.insert((uid, name, datum[name]))
505
506             q.commit()
507         except Exception, e:  # pylint: disable=broad-except
508             if q:
509                 q.rollback()
510             self.error("Failed to store data in %s: [%s]" % (table, e))
511             raise
512
513     def del_unique_data(self, table, uuidval):
514         kvfilter = {'uuid': uuidval}
515         try:
516             q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
517             q.delete(kvfilter)
518         except Exception, e:  # pylint: disable=broad-except
519             self.error("Failed to delete data from %s: [%s]" % (table, e))
520
521     def _reset_data(self, table):
522         q = None
523         try:
524             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
525             q.drop()
526             q.create()
527             q.commit()
528         except Exception, e:  # pylint: disable=broad-except
529             if q:
530                 q.rollback()
531             self.error("Failed to erase all data from %s: [%s]" % (table, e))
532
533
534 class AdminStore(Store):
535
536     def __init__(self):
537         super(AdminStore, self).__init__('admin.config.db')
538
539     def get_data(self, plugin, idval=None, name=None, value=None):
540         return self.get_unique_data(plugin+"_data", idval, name, value)
541
542     def save_data(self, plugin, data):
543         return self.save_unique_data(plugin+"_data", data)
544
545     def new_datum(self, plugin, datum):
546         table = plugin+"_data"
547         return self.new_unique_data(table, datum)
548
549     def del_datum(self, plugin, idval):
550         table = plugin+"_data"
551         return self.del_unique_data(table, idval)
552
553     def wipe_data(self, plugin):
554         table = plugin+"_data"
555         self._reset_data(table)
556
557     def _initialize_schema(self):
558         for table in ['config',
559                       'info_config',
560                       'login_config',
561                       'provider_config']:
562             q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
563             q.create()
564
565     def _upgrade_schema(self, old_version):
566         raise NotImplementedError()
567
568
569 class UserStore(Store):
570
571     def __init__(self, path=None):
572         super(UserStore, self).__init__('user.prefs.db')
573
574     def save_user_preferences(self, user, options):
575         self.save_options('users', user, options)
576
577     def load_user_preferences(self, user):
578         return self.load_options('users', user)
579
580     def save_plugin_data(self, plugin, user, options):
581         self.save_options(plugin+"_data", user, options)
582
583     def load_plugin_data(self, plugin, user):
584         return self.load_options(plugin+"_data", user)
585
586     def _initialize_schema(self):
587         q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
588         q.create()
589
590     def _upgrade_schema(self, old_version):
591         raise NotImplementedError()
592
593
594 class TranStore(Store):
595
596     def __init__(self, path=None):
597         super(TranStore, self).__init__('transactions.db')
598
599     def _initialize_schema(self):
600         q = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
601                         trans=False)
602         q.create()
603
604     def _upgrade_schema(self, old_version):
605         raise NotImplementedError()
606
607
608 class SAML2SessionStore(Store):
609
610     def __init__(self, database_url):
611         super(SAML2SessionStore, self).__init__(database_url=database_url)
612         self.table = 'saml2_sessions'
613         # pylint: disable=protected-access
614         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
615         table.create(checkfirst=True)
616
617     def _get_unique_id_from_column(self, name, value):
618         """
619         The query is going to return only the column in the query.
620         Use this method to get the uuidval which can be used to fetch
621         the entire entry.
622
623         Returns None or the uuid of the first value found.
624         """
625         data = self.get_unique_data(self.table, name=name, value=value)
626         count = len(data)
627         if count == 0:
628             return None
629         elif count != 1:
630             raise ValueError("Multiple entries returned")
631         return data.keys()[0]
632
633     def remove_expired_sessions(self):
634         # pylint: disable=protected-access
635         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
636         sel = select([table.columns.uuid]). \
637             where(and_(table.c.name == 'expiration_time',
638                        table.c.value <= datetime.datetime.now()))
639         # pylint: disable=no-value-for-parameter
640         d = table.delete().where(table.c.uuid.in_(sel))
641         d.execute()
642
643     def get_data(self, idval=None, name=None, value=None):
644         return self.get_unique_data(self.table, idval, name, value)
645
646     def new_session(self, datum):
647         if 'supported_logout_mechs' in datum:
648             datum['supported_logout_mechs'] = ','.join(
649                 datum['supported_logout_mechs']
650             )
651         return self.new_unique_data(self.table, datum)
652
653     def get_session(self, session_id=None, request_id=None):
654         if session_id:
655             uuidval = self._get_unique_id_from_column('session_id', session_id)
656         elif request_id:
657             uuidval = self._get_unique_id_from_column('request_id', request_id)
658         else:
659             raise ValueError("Unable to find session")
660         if not uuidval:
661             return None, None
662         data = self.get_unique_data(self.table, uuidval=uuidval)
663         return uuidval, data[uuidval]
664
665     def get_user_sessions(self, user):
666         """
667         Return a list of all sessions for a given user.
668         """
669         rows = self.get_unique_data(self.table, name='user', value=user)
670
671         # We have a list of sessions for this user, now get the details
672         logged_in = []
673         for r in rows:
674             data = self.get_unique_data(self.table, uuidval=r)
675             data[r]['supported_logout_mechs'] = data[r].get(
676                 'supported_logout_mechs', '').split(',')
677             logged_in.append(data)
678
679         return logged_in
680
681     def update_session(self, datum):
682         self.save_unique_data(self.table, datum)
683
684     def remove_session(self, uuidval):
685         self.del_unique_data(self.table, uuidval)
686
687     def wipe_data(self):
688         self._reset_data(self.table)
689
690     def _initialize_schema(self):
691         q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
692                         trans=False)
693         q.create()
694
695     def _upgrade_schema(self, old_version):
696         raise NotImplementedError()