f90519d5e6ee5007145b297b2995f7e6f33dde22
[cascardo/ipsilon.git] / ipsilon / util / data.py
1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
2
3 import cherrypy
4 from ipsilon.util.log import Log
5 from sqlalchemy import create_engine
6 from sqlalchemy import MetaData, Table, Column, Text
7 from sqlalchemy.pool import QueuePool, SingletonThreadPool
8 from sqlalchemy.sql import select
9 import ConfigParser
10 import os
11 import uuid
12 import logging
13
14
15 CURRENT_SCHEMA_VERSION = 1
16 OPTIONS_COLUMNS = ['name', 'option', 'value']
17 UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
18
19
20 class SqlStore(Log):
21     __instances = {}
22
23     @classmethod
24     def get_connection(cls, name):
25         if name not in cls.__instances.keys():
26             if cherrypy.config.get('db.conn.log', False):
27                 logging.debug('SqlStore new: %s', name)
28             cls.__instances[name] = SqlStore(name)
29         return cls.__instances[name]
30
31     def __init__(self, name):
32         self.db_conn_log = cherrypy.config.get('db.conn.log', False)
33         self.debug('SqlStore init: %s' % name)
34         self.name = name
35         engine_name = name
36         if '://' not in engine_name:
37             engine_name = 'sqlite:///' + engine_name
38         # This pool size is per configured database. The minimum needed,
39         #  determined by binary search, is 23. We're using 25 so we have a bit
40         #  more playroom, and then the overflow should make sure things don't
41         #  break when we suddenly need more.
42         pool_args = {'poolclass': QueuePool,
43                      'pool_size': 25,
44                      'max_overflow': 50}
45         if engine_name.startswith('sqlite://'):
46             # It's not possible to share connections for SQLite between
47             #  threads, so let's use the SingletonThreadPool for them
48             pool_args = {'poolclass': SingletonThreadPool}
49         self._dbengine = create_engine(engine_name, **pool_args)
50         self.is_readonly = False
51
52     def debug(self, fact):
53         if self.db_conn_log:
54             super(SqlStore, self).debug(fact)
55
56     def engine(self):
57         return self._dbengine
58
59     def connection(self):
60         self.debug('SqlStore connect: %s' % self.name)
61         conn = self._dbengine.connect()
62
63         def cleanup_connection():
64             self.debug('SqlStore cleanup: %s' % self.name)
65             conn.close()
66         cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
67         return conn
68
69
70 def SqlAutotable(f):
71     def at(self, *args, **kwargs):
72         self.create()
73         return f(self, *args, **kwargs)
74     return at
75
76
77 class SqlQuery(Log):
78
79     def __init__(self, db_obj, table, columns, trans=True):
80         self._db = db_obj
81         self._con = self._db.connection()
82         self._trans = self._con.begin() if trans else None
83         self._table = self._get_table(table, columns)
84
85     def _get_table(self, name, columns):
86         table = Table(name, MetaData(self._db.engine()))
87         for c in columns:
88             table.append_column(Column(c, Text()))
89         return table
90
91     def _where(self, kvfilter):
92         where = None
93         if kvfilter is not None:
94             for k in kvfilter:
95                 w = self._table.columns[k] == kvfilter[k]
96                 if where is None:
97                     where = w
98                 else:
99                     where = where & w
100         return where
101
102     def _columns(self, columns=None):
103         cols = None
104         if columns is not None:
105             cols = []
106             for c in columns:
107                 cols.append(self._table.columns[c])
108         else:
109             cols = self._table.columns
110         return cols
111
112     def rollback(self):
113         self._trans.rollback()
114
115     def commit(self):
116         self._trans.commit()
117
118     def create(self):
119         self._table.create(checkfirst=True)
120
121     def drop(self):
122         self._table.drop(checkfirst=True)
123
124     @SqlAutotable
125     def select(self, kvfilter=None, columns=None):
126         return self._con.execute(select(self._columns(columns),
127                                         self._where(kvfilter)))
128
129     @SqlAutotable
130     def insert(self, values):
131         self._con.execute(self._table.insert(values))
132
133     @SqlAutotable
134     def update(self, values, kvfilter):
135         self._con.execute(self._table.update(self._where(kvfilter), values))
136
137     @SqlAutotable
138     def delete(self, kvfilter):
139         self._con.execute(self._table.delete(self._where(kvfilter)))
140
141
142 class FileStore(Log):
143
144     def __init__(self, name):
145         self._filename = name
146         self.is_readonly = True
147         self._timestamp = None
148         self._config = None
149
150     def get_config(self):
151         try:
152             stat = os.stat(self._filename)
153         except OSError, e:
154             self.error("Unable to check config file %s: [%s]" % (
155                 self._filename, e))
156             self._config = None
157             raise
158         timestamp = stat.st_mtime
159         if self._config is None or timestamp > self._timestamp:
160             self._config = ConfigParser.RawConfigParser()
161             self._config.optionxform = str
162             self._config.read(self._filename)
163         return self._config
164
165
166 class FileQuery(Log):
167
168     def __init__(self, fstore, table, columns, trans=True):
169         self._fstore = fstore
170         self._config = fstore.get_config()
171         self._section = table
172         if len(columns) > 3 or columns[-1] != 'value':
173             raise ValueError('Unsupported configuration format')
174         self._columns = columns
175
176     def rollback(self):
177         return
178
179     def commit(self):
180         return
181
182     def create(self):
183         raise NotImplementedError
184
185     def drop(self):
186         raise NotImplementedError
187
188     def select(self, kvfilter=None, columns=None):
189         if self._section not in self._config.sections():
190             return []
191
192         opts = self._config.options(self._section)
193
194         prefix = None
195         prefix_ = ''
196         if self._columns[0] in kvfilter:
197             prefix = kvfilter[self._columns[0]]
198             prefix_ = prefix + ' '
199
200         name = None
201         if len(self._columns) == 3 and self._columns[1] in kvfilter:
202             name = kvfilter[self._columns[1]]
203
204         value = None
205         if self._columns[-1] in kvfilter:
206             value = kvfilter[self._columns[-1]]
207
208         res = []
209         for o in opts:
210             if len(self._columns) == 3:
211                 # 3 cols
212                 if prefix and not o.startswith(prefix_):
213                     continue
214
215                 col1, col2 = o.split(' ', 1)
216                 if name and col2 != name:
217                     continue
218
219                 col3 = self._config.get(self._section, o)
220                 if value and col3 != value:
221                     continue
222
223                 r = [col1, col2, col3]
224             else:
225                 # 2 cols
226                 if prefix and o != prefix:
227                     continue
228                 r = [o, self._config.get(self._section, o)]
229
230             if columns:
231                 s = []
232                 for c in columns:
233                     s.append(r[self._columns.index(c)])
234                 res.append(s)
235             else:
236                 res.append(r)
237
238         self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
239                                                  repr(kvfilter),
240                                                  repr(columns),
241                                                  repr(res)))
242         return res
243
244     def insert(self, values):
245         raise NotImplementedError
246
247     def update(self, values, kvfilter):
248         raise NotImplementedError
249
250     def delete(self, kvfilter):
251         raise NotImplementedError
252
253
254 class Store(Log):
255     def __init__(self, config_name=None, database_url=None):
256         if config_name is None and database_url is None:
257             raise ValueError('config_name or database_url must be provided')
258         if config_name:
259             if config_name not in cherrypy.config:
260                 raise NameError('Unknown database %s' % config_name)
261             name = cherrypy.config[config_name]
262         else:
263             name = database_url
264         if name.startswith('configfile://'):
265             _, filename = name.split('://')
266             self._db = FileStore(filename)
267             self._query = FileQuery
268         else:
269             self._db = SqlStore.get_connection(name)
270             self._query = SqlQuery
271         self._upgrade_database()
272
273     def _upgrade_database(self):
274         if self.is_readonly:
275             # If the database is readonly, we cannot do anything to the
276             #  schema. Let's just return, and assume people checked the
277             #  upgrade notes
278             return
279         current_version = self.load_options('dbinfo').get('scheme', None)
280         if current_version is None or 'version' not in current_version:
281             # No version stored, storing current version
282             self.save_options('dbinfo', 'scheme',
283                               {'version': CURRENT_SCHEMA_VERSION})
284             current_version = CURRENT_SCHEMA_VERSION
285         else:
286             current_version = int(current_version['version'])
287         if current_version != CURRENT_SCHEMA_VERSION:
288             self.debug('Upgrading database schema from %i to %i' % (
289                        current_version, CURRENT_SCHEMA_VERSION))
290             self._upgrade_database_from(current_version)
291
292     def _upgrade_database_from(self, old_schema_version):
293         # Insert code here to upgrade from old_schema_version to
294         #  CURRENT_SCHEMA_VERSION
295         raise Exception('Unable to upgrade database to current schema'
296                         ' version: version %i is unknown!' %
297                         old_schema_version)
298
299     @property
300     def is_readonly(self):
301         return self._db.is_readonly
302
303     def _row_to_dict_tree(self, data, row):
304         name = row[0]
305         if len(row) > 2:
306             if name not in data:
307                 data[name] = dict()
308             d2 = data[name]
309             self._row_to_dict_tree(d2, row[1:])
310         else:
311             value = row[1]
312             if name in data:
313                 if data[name] is list:
314                     data[name].append(value)
315                 else:
316                     v = data[name]
317                     data[name] = [v, value]
318             else:
319                 data[name] = value
320
321     def _rows_to_dict_tree(self, rows):
322         data = dict()
323         for r in rows:
324             self._row_to_dict_tree(data, r)
325         return data
326
327     def _load_data(self, table, columns, kvfilter=None):
328         rows = []
329         try:
330             q = self._query(self._db, table, columns, trans=False)
331             rows = q.select(kvfilter)
332         except Exception, e:  # pylint: disable=broad-except
333             self.error("Failed to load data for table %s: [%s]" % (table, e))
334         return self._rows_to_dict_tree(rows)
335
336     def load_config(self):
337         table = 'config'
338         columns = ['name', 'value']
339         return self._load_data(table, columns)
340
341     def load_options(self, table, name=None):
342         kvfilter = dict()
343         if name:
344             kvfilter['name'] = name
345         options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
346         if name and name in options:
347             return options[name]
348         return options
349
350     def save_options(self, table, name, options):
351         curvals = dict()
352         q = None
353         try:
354             q = self._query(self._db, table, OPTIONS_COLUMNS)
355             rows = q.select({'name': name}, ['option', 'value'])
356             for row in rows:
357                 curvals[row[0]] = row[1]
358
359             for opt in options:
360                 if opt in curvals:
361                     q.update({'value': options[opt]},
362                              {'name': name, 'option': opt})
363                 else:
364                     q.insert((name, opt, options[opt]))
365
366             q.commit()
367         except Exception, e:  # pylint: disable=broad-except
368             if q:
369                 q.rollback()
370             self.error("Failed to save options: [%s]" % e)
371             raise
372
373     def delete_options(self, table, name, options=None):
374         kvfilter = {'name': name}
375         q = None
376         try:
377             q = self._query(self._db, table, OPTIONS_COLUMNS)
378             if options is None:
379                 q.delete(kvfilter)
380             else:
381                 for opt in options:
382                     kvfilter['option'] = opt
383                     q.delete(kvfilter)
384             q.commit()
385         except Exception, e:  # pylint: disable=broad-except
386             if q:
387                 q.rollback()
388             self.error("Failed to delete from %s: [%s]" % (table, e))
389             raise
390
391     def new_unique_data(self, table, data):
392         newid = str(uuid.uuid4())
393         q = None
394         try:
395             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
396             for name in data:
397                 q.insert((newid, name, data[name]))
398             q.commit()
399         except Exception, e:  # pylint: disable=broad-except
400             if q:
401                 q.rollback()
402             self.error("Failed to store %s data: [%s]" % (table, e))
403             raise
404         return newid
405
406     def get_unique_data(self, table, uuidval=None, name=None, value=None):
407         kvfilter = dict()
408         if uuidval:
409             kvfilter['uuid'] = uuidval
410         if name:
411             kvfilter['name'] = name
412         if value:
413             kvfilter['value'] = value
414         return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
415
416     def save_unique_data(self, table, data):
417         q = None
418         try:
419             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
420             for uid in data:
421                 curvals = dict()
422                 rows = q.select({'uuid': uid}, ['name', 'value'])
423                 for r in rows:
424                     curvals[r[0]] = r[1]
425
426                 datum = data[uid]
427                 for name in datum:
428                     if name in curvals:
429                         if datum[name] is None:
430                             q.delete({'uuid': uid, 'name': name})
431                         else:
432                             q.update({'value': datum[name]},
433                                      {'uuid': uid, 'name': name})
434                     else:
435                         if datum[name] is not None:
436                             q.insert((uid, name, datum[name]))
437
438             q.commit()
439         except Exception, e:  # pylint: disable=broad-except
440             if q:
441                 q.rollback()
442             self.error("Failed to store data in %s: [%s]" % (table, e))
443             raise
444
445     def del_unique_data(self, table, uuidval):
446         kvfilter = {'uuid': uuidval}
447         try:
448             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
449             q.delete(kvfilter)
450         except Exception, e:  # pylint: disable=broad-except
451             self.error("Failed to delete data from %s: [%s]" % (table, e))
452
453     def _reset_data(self, table):
454         q = None
455         try:
456             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
457             q.drop()
458             q.create()
459             q.commit()
460         except Exception, e:  # pylint: disable=broad-except
461             if q:
462                 q.rollback()
463             self.error("Failed to erase all data from %s: [%s]" % (table, e))
464
465
466 class AdminStore(Store):
467
468     def __init__(self):
469         super(AdminStore, self).__init__('admin.config.db')
470
471     def get_data(self, plugin, idval=None, name=None, value=None):
472         return self.get_unique_data(plugin+"_data", idval, name, value)
473
474     def save_data(self, plugin, data):
475         return self.save_unique_data(plugin+"_data", data)
476
477     def new_datum(self, plugin, datum):
478         table = plugin+"_data"
479         return self.new_unique_data(table, datum)
480
481     def del_datum(self, plugin, idval):
482         table = plugin+"_data"
483         return self.del_unique_data(table, idval)
484
485     def wipe_data(self, plugin):
486         table = plugin+"_data"
487         self._reset_data(table)
488
489
490 class UserStore(Store):
491
492     def __init__(self, path=None):
493         super(UserStore, self).__init__('user.prefs.db')
494
495     def save_user_preferences(self, user, options):
496         self.save_options('users', user, options)
497
498     def load_user_preferences(self, user):
499         return self.load_options('users', user)
500
501     def save_plugin_data(self, plugin, user, options):
502         self.save_options(plugin+"_data", user, options)
503
504     def load_plugin_data(self, plugin, user):
505         return self.load_options(plugin+"_data", user)
506
507
508 class TranStore(Store):
509
510     def __init__(self, path=None):
511         super(TranStore, self).__init__('transactions.db')
512
513
514 class SAML2SessionStore(Store):
515
516     def __init__(self, path=None):
517         super(SAML2SessionStore, self).__init__('saml2.sessions.db')
518         self.table = 'sessions'
519
520     def _get_unique_id_from_column(self, name, value):
521         """
522         The query is going to return only the column in the query.
523         Use this method to get the uuidval which can be used to fetch
524         the entire entry.
525
526         Returns None or the uuid of the first value found.
527         """
528         data = self.get_unique_data(self.table, name=name, value=value)
529         count = len(data)
530         if count == 0:
531             return None
532         elif count != 1:
533             raise ValueError("Multiple entries returned")
534         return data.keys()[0]
535
536     def get_data(self, idval=None, name=None, value=None):
537         return self.get_unique_data(self.table, idval, name, value)
538
539     def new_session(self, datum):
540         return self.new_unique_data(self.table, datum)
541
542     def get_session(self, session_id=None, request_id=None):
543         if session_id:
544             uuidval = self._get_unique_id_from_column('session_id', session_id)
545         elif request_id:
546             uuidval = self._get_unique_id_from_column('request_id', request_id)
547         else:
548             raise ValueError("Unable to find session")
549         if not uuidval:
550             return None, None
551         data = self.get_unique_data(self.table, uuidval=uuidval)
552         return uuidval, data[uuidval]
553
554     def get_user_sessions(self, user):
555         """
556         Retrun a list of all sessions for a given user.
557         """
558         rows = self.get_unique_data(self.table, name='user', value=user)
559
560         # We have a list of sessions for this user, now get the details
561         logged_in = []
562         for r in rows:
563             data = self.get_unique_data(self.table, uuidval=r)
564             logged_in.append(data)
565
566         return logged_in
567
568     def update_session(self, datum):
569         self.save_unique_data(self.table, datum)
570
571     def remove_session(self, uuidval):
572         self.del_unique_data(self.table, uuidval)
573
574     def wipe_data(self):
575         self._reset_data(self.table)