Use plugin-specific configuration, better expiration
[cascardo/ipsilon.git] / ipsilon / util / data.py
1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
2
3 import cherrypy
4 import datetime
5 from ipsilon.util.log import Log
6 from sqlalchemy import create_engine
7 from sqlalchemy import MetaData, Table, Column, Text
8 from sqlalchemy.pool import QueuePool, SingletonThreadPool
9 from sqlalchemy.sql import select, and_
10 import ConfigParser
11 import os
12 import uuid
13 import logging
14
15
16 CURRENT_SCHEMA_VERSION = 1
17 OPTIONS_COLUMNS = ['name', 'option', 'value']
18 UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
19
20
21 class SqlStore(Log):
22     __instances = {}
23
24     @classmethod
25     def get_connection(cls, name):
26         if name not in cls.__instances.keys():
27             if cherrypy.config.get('db.conn.log', False):
28                 logging.debug('SqlStore new: %s', name)
29             cls.__instances[name] = SqlStore(name)
30         return cls.__instances[name]
31
32     def __init__(self, name):
33         self.db_conn_log = cherrypy.config.get('db.conn.log', False)
34         self.debug('SqlStore init: %s' % name)
35         self.name = name
36         engine_name = name
37         if '://' not in engine_name:
38             engine_name = 'sqlite:///' + engine_name
39         # This pool size is per configured database. The minimum needed,
40         #  determined by binary search, is 23. We're using 25 so we have a bit
41         #  more playroom, and then the overflow should make sure things don't
42         #  break when we suddenly need more.
43         pool_args = {'poolclass': QueuePool,
44                      'pool_size': 25,
45                      'max_overflow': 50}
46         if engine_name.startswith('sqlite://'):
47             # It's not possible to share connections for SQLite between
48             #  threads, so let's use the SingletonThreadPool for them
49             pool_args = {'poolclass': SingletonThreadPool}
50         self._dbengine = create_engine(engine_name, **pool_args)
51         self.is_readonly = False
52
53     def debug(self, fact):
54         if self.db_conn_log:
55             super(SqlStore, self).debug(fact)
56
57     def engine(self):
58         return self._dbengine
59
60     def connection(self):
61         self.debug('SqlStore connect: %s' % self.name)
62         conn = self._dbengine.connect()
63
64         def cleanup_connection():
65             self.debug('SqlStore cleanup: %s' % self.name)
66             conn.close()
67         cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
68         return conn
69
70
71 def SqlAutotable(f):
72     def at(self, *args, **kwargs):
73         self.create()
74         return f(self, *args, **kwargs)
75     return at
76
77
78 class SqlQuery(Log):
79
80     def __init__(self, db_obj, table, columns, trans=True):
81         self._db = db_obj
82         self._con = self._db.connection()
83         self._trans = self._con.begin() if trans else None
84         self._table = self._get_table(table, columns)
85
86     def _get_table(self, name, columns):
87         table = Table(name, MetaData(self._db.engine()))
88         for c in columns:
89             table.append_column(Column(c, Text()))
90         return table
91
92     def _where(self, kvfilter):
93         where = None
94         if kvfilter is not None:
95             for k in kvfilter:
96                 w = self._table.columns[k] == kvfilter[k]
97                 if where is None:
98                     where = w
99                 else:
100                     where = where & w
101         return where
102
103     def _columns(self, columns=None):
104         cols = None
105         if columns is not None:
106             cols = []
107             for c in columns:
108                 cols.append(self._table.columns[c])
109         else:
110             cols = self._table.columns
111         return cols
112
113     def rollback(self):
114         self._trans.rollback()
115
116     def commit(self):
117         self._trans.commit()
118
119     def create(self):
120         self._table.create(checkfirst=True)
121
122     def drop(self):
123         self._table.drop(checkfirst=True)
124
125     @SqlAutotable
126     def select(self, kvfilter=None, columns=None):
127         return self._con.execute(select(self._columns(columns),
128                                         self._where(kvfilter)))
129
130     @SqlAutotable
131     def insert(self, values):
132         self._con.execute(self._table.insert(values))
133
134     @SqlAutotable
135     def update(self, values, kvfilter):
136         self._con.execute(self._table.update(self._where(kvfilter), values))
137
138     @SqlAutotable
139     def delete(self, kvfilter):
140         self._con.execute(self._table.delete(self._where(kvfilter)))
141
142
143 class FileStore(Log):
144
145     def __init__(self, name):
146         self._filename = name
147         self.is_readonly = True
148         self._timestamp = None
149         self._config = None
150
151     def get_config(self):
152         try:
153             stat = os.stat(self._filename)
154         except OSError, e:
155             self.error("Unable to check config file %s: [%s]" % (
156                 self._filename, e))
157             self._config = None
158             raise
159         timestamp = stat.st_mtime
160         if self._config is None or timestamp > self._timestamp:
161             self._config = ConfigParser.RawConfigParser()
162             self._config.optionxform = str
163             self._config.read(self._filename)
164         return self._config
165
166
167 class FileQuery(Log):
168
169     def __init__(self, fstore, table, columns, trans=True):
170         self._fstore = fstore
171         self._config = fstore.get_config()
172         self._section = table
173         if len(columns) > 3 or columns[-1] != 'value':
174             raise ValueError('Unsupported configuration format')
175         self._columns = columns
176
177     def rollback(self):
178         return
179
180     def commit(self):
181         return
182
183     def create(self):
184         raise NotImplementedError
185
186     def drop(self):
187         raise NotImplementedError
188
189     def select(self, kvfilter=None, columns=None):
190         if self._section not in self._config.sections():
191             return []
192
193         opts = self._config.options(self._section)
194
195         prefix = None
196         prefix_ = ''
197         if self._columns[0] in kvfilter:
198             prefix = kvfilter[self._columns[0]]
199             prefix_ = prefix + ' '
200
201         name = None
202         if len(self._columns) == 3 and self._columns[1] in kvfilter:
203             name = kvfilter[self._columns[1]]
204
205         value = None
206         if self._columns[-1] in kvfilter:
207             value = kvfilter[self._columns[-1]]
208
209         res = []
210         for o in opts:
211             if len(self._columns) == 3:
212                 # 3 cols
213                 if prefix and not o.startswith(prefix_):
214                     continue
215
216                 col1, col2 = o.split(' ', 1)
217                 if name and col2 != name:
218                     continue
219
220                 col3 = self._config.get(self._section, o)
221                 if value and col3 != value:
222                     continue
223
224                 r = [col1, col2, col3]
225             else:
226                 # 2 cols
227                 if prefix and o != prefix:
228                     continue
229                 r = [o, self._config.get(self._section, o)]
230
231             if columns:
232                 s = []
233                 for c in columns:
234                     s.append(r[self._columns.index(c)])
235                 res.append(s)
236             else:
237                 res.append(r)
238
239         self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
240                                                  repr(kvfilter),
241                                                  repr(columns),
242                                                  repr(res)))
243         return res
244
245     def insert(self, values):
246         raise NotImplementedError
247
248     def update(self, values, kvfilter):
249         raise NotImplementedError
250
251     def delete(self, kvfilter):
252         raise NotImplementedError
253
254
255 class Store(Log):
256     def __init__(self, config_name=None, database_url=None):
257         if config_name is None and database_url is None:
258             raise ValueError('config_name or database_url must be provided')
259         if config_name:
260             if config_name not in cherrypy.config:
261                 raise NameError('Unknown database %s' % config_name)
262             name = cherrypy.config[config_name]
263         else:
264             name = database_url
265         if name.startswith('configfile://'):
266             _, filename = name.split('://')
267             self._db = FileStore(filename)
268             self._query = FileQuery
269         else:
270             self._db = SqlStore.get_connection(name)
271             self._query = SqlQuery
272         self._upgrade_database()
273
274     def _upgrade_database(self):
275         if self.is_readonly:
276             # If the database is readonly, we cannot do anything to the
277             #  schema. Let's just return, and assume people checked the
278             #  upgrade notes
279             return
280         current_version = self.load_options('dbinfo').get('scheme', None)
281         if current_version is None or 'version' not in current_version:
282             # No version stored, storing current version
283             self.save_options('dbinfo', 'scheme',
284                               {'version': CURRENT_SCHEMA_VERSION})
285             current_version = CURRENT_SCHEMA_VERSION
286         else:
287             current_version = int(current_version['version'])
288         if current_version != CURRENT_SCHEMA_VERSION:
289             self.debug('Upgrading database schema from %i to %i' % (
290                        current_version, CURRENT_SCHEMA_VERSION))
291             self._upgrade_database_from(current_version)
292
293     def _upgrade_database_from(self, old_schema_version):
294         # Insert code here to upgrade from old_schema_version to
295         #  CURRENT_SCHEMA_VERSION
296         raise Exception('Unable to upgrade database to current schema'
297                         ' version: version %i is unknown!' %
298                         old_schema_version)
299
300     @property
301     def is_readonly(self):
302         return self._db.is_readonly
303
304     def _row_to_dict_tree(self, data, row):
305         name = row[0]
306         if len(row) > 2:
307             if name not in data:
308                 data[name] = dict()
309             d2 = data[name]
310             self._row_to_dict_tree(d2, row[1:])
311         else:
312             value = row[1]
313             if name in data:
314                 if data[name] is list:
315                     data[name].append(value)
316                 else:
317                     v = data[name]
318                     data[name] = [v, value]
319             else:
320                 data[name] = value
321
322     def _rows_to_dict_tree(self, rows):
323         data = dict()
324         for r in rows:
325             self._row_to_dict_tree(data, r)
326         return data
327
328     def _load_data(self, table, columns, kvfilter=None):
329         rows = []
330         try:
331             q = self._query(self._db, table, columns, trans=False)
332             rows = q.select(kvfilter)
333         except Exception, e:  # pylint: disable=broad-except
334             self.error("Failed to load data for table %s: [%s]" % (table, e))
335         return self._rows_to_dict_tree(rows)
336
337     def load_config(self):
338         table = 'config'
339         columns = ['name', 'value']
340         return self._load_data(table, columns)
341
342     def load_options(self, table, name=None):
343         kvfilter = dict()
344         if name:
345             kvfilter['name'] = name
346         options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
347         if name and name in options:
348             return options[name]
349         return options
350
351     def save_options(self, table, name, options):
352         curvals = dict()
353         q = None
354         try:
355             q = self._query(self._db, table, OPTIONS_COLUMNS)
356             rows = q.select({'name': name}, ['option', 'value'])
357             for row in rows:
358                 curvals[row[0]] = row[1]
359
360             for opt in options:
361                 if opt in curvals:
362                     q.update({'value': options[opt]},
363                              {'name': name, 'option': opt})
364                 else:
365                     q.insert((name, opt, options[opt]))
366
367             q.commit()
368         except Exception, e:  # pylint: disable=broad-except
369             if q:
370                 q.rollback()
371             self.error("Failed to save options: [%s]" % e)
372             raise
373
374     def delete_options(self, table, name, options=None):
375         kvfilter = {'name': name}
376         q = None
377         try:
378             q = self._query(self._db, table, OPTIONS_COLUMNS)
379             if options is None:
380                 q.delete(kvfilter)
381             else:
382                 for opt in options:
383                     kvfilter['option'] = opt
384                     q.delete(kvfilter)
385             q.commit()
386         except Exception, e:  # pylint: disable=broad-except
387             if q:
388                 q.rollback()
389             self.error("Failed to delete from %s: [%s]" % (table, e))
390             raise
391
392     def new_unique_data(self, table, data):
393         newid = str(uuid.uuid4())
394         q = None
395         try:
396             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
397             for name in data:
398                 q.insert((newid, name, data[name]))
399             q.commit()
400         except Exception, e:  # pylint: disable=broad-except
401             if q:
402                 q.rollback()
403             self.error("Failed to store %s data: [%s]" % (table, e))
404             raise
405         return newid
406
407     def get_unique_data(self, table, uuidval=None, name=None, value=None):
408         kvfilter = dict()
409         if uuidval:
410             kvfilter['uuid'] = uuidval
411         if name:
412             kvfilter['name'] = name
413         if value:
414             kvfilter['value'] = value
415         return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
416
417     def save_unique_data(self, table, data):
418         q = None
419         try:
420             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
421             for uid in data:
422                 curvals = dict()
423                 rows = q.select({'uuid': uid}, ['name', 'value'])
424                 for r in rows:
425                     curvals[r[0]] = r[1]
426
427                 datum = data[uid]
428                 for name in datum:
429                     if name in curvals:
430                         if datum[name] is None:
431                             q.delete({'uuid': uid, 'name': name})
432                         else:
433                             q.update({'value': datum[name]},
434                                      {'uuid': uid, 'name': name})
435                     else:
436                         if datum[name] is not None:
437                             q.insert((uid, name, datum[name]))
438
439             q.commit()
440         except Exception, e:  # pylint: disable=broad-except
441             if q:
442                 q.rollback()
443             self.error("Failed to store data in %s: [%s]" % (table, e))
444             raise
445
446     def del_unique_data(self, table, uuidval):
447         kvfilter = {'uuid': uuidval}
448         try:
449             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
450             q.delete(kvfilter)
451         except Exception, e:  # pylint: disable=broad-except
452             self.error("Failed to delete data from %s: [%s]" % (table, e))
453
454     def _reset_data(self, table):
455         q = None
456         try:
457             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
458             q.drop()
459             q.create()
460             q.commit()
461         except Exception, e:  # pylint: disable=broad-except
462             if q:
463                 q.rollback()
464             self.error("Failed to erase all data from %s: [%s]" % (table, e))
465
466
467 class AdminStore(Store):
468
469     def __init__(self):
470         super(AdminStore, self).__init__('admin.config.db')
471
472     def get_data(self, plugin, idval=None, name=None, value=None):
473         return self.get_unique_data(plugin+"_data", idval, name, value)
474
475     def save_data(self, plugin, data):
476         return self.save_unique_data(plugin+"_data", data)
477
478     def new_datum(self, plugin, datum):
479         table = plugin+"_data"
480         return self.new_unique_data(table, datum)
481
482     def del_datum(self, plugin, idval):
483         table = plugin+"_data"
484         return self.del_unique_data(table, idval)
485
486     def wipe_data(self, plugin):
487         table = plugin+"_data"
488         self._reset_data(table)
489
490
491 class UserStore(Store):
492
493     def __init__(self, path=None):
494         super(UserStore, self).__init__('user.prefs.db')
495
496     def save_user_preferences(self, user, options):
497         self.save_options('users', user, options)
498
499     def load_user_preferences(self, user):
500         return self.load_options('users', user)
501
502     def save_plugin_data(self, plugin, user, options):
503         self.save_options(plugin+"_data", user, options)
504
505     def load_plugin_data(self, plugin, user):
506         return self.load_options(plugin+"_data", user)
507
508
509 class TranStore(Store):
510
511     def __init__(self, path=None):
512         super(TranStore, self).__init__('transactions.db')
513
514
515 class SAML2SessionStore(Store):
516
517     def __init__(self, database_url):
518         super(SAML2SessionStore, self).__init__(database_url=database_url)
519         self.table = 'sessions'
520         # pylint: disable=protected-access
521         table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
522         table.create(checkfirst=True)
523
524     def _get_unique_id_from_column(self, name, value):
525         """
526         The query is going to return only the column in the query.
527         Use this method to get the uuidval which can be used to fetch
528         the entire entry.
529
530         Returns None or the uuid of the first value found.
531         """
532         data = self.get_unique_data(self.table, name=name, value=value)
533         count = len(data)
534         if count == 0:
535             return None
536         elif count != 1:
537             raise ValueError("Multiple entries returned")
538         return data.keys()[0]
539
540     def remove_expired_sessions(self):
541         # pylint: disable=protected-access
542         table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
543         sel = select([table.columns.uuid]). \
544             where(and_(table.c.name == 'expiration_time',
545                        table.c.value <= datetime.datetime.now()))
546         # pylint: disable=no-value-for-parameter
547         d = table.delete().where(table.c.uuid.in_(sel))
548         d.execute()
549
550     def get_data(self, idval=None, name=None, value=None):
551         return self.get_unique_data(self.table, idval, name, value)
552
553     def new_session(self, datum):
554         return self.new_unique_data(self.table, datum)
555
556     def get_session(self, session_id=None, request_id=None):
557         if session_id:
558             uuidval = self._get_unique_id_from_column('session_id', session_id)
559         elif request_id:
560             uuidval = self._get_unique_id_from_column('request_id', request_id)
561         else:
562             raise ValueError("Unable to find session")
563         if not uuidval:
564             return None, None
565         data = self.get_unique_data(self.table, uuidval=uuidval)
566         return uuidval, data[uuidval]
567
568     def get_user_sessions(self, user):
569         """
570         Retrun a list of all sessions for a given user.
571         """
572         rows = self.get_unique_data(self.table, name='user', value=user)
573
574         # We have a list of sessions for this user, now get the details
575         logged_in = []
576         for r in rows:
577             data = self.get_unique_data(self.table, uuidval=r)
578             logged_in.append(data)
579
580         return logged_in
581
582     def update_session(self, datum):
583         self.save_unique_data(self.table, datum)
584
585     def remove_session(self, uuidval):
586         self.del_unique_data(self.table, uuidval)
587
588     def wipe_data(self):
589         self._reset_data(self.table)