Add database schema versioning
[cascardo/ipsilon.git] / ipsilon / util / data.py
1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
2
3 import cherrypy
4 from ipsilon.util.log import Log
5 from sqlalchemy import create_engine
6 from sqlalchemy import MetaData, Table, Column, Text
7 from sqlalchemy.pool import QueuePool, SingletonThreadPool
8 from sqlalchemy.sql import select
9 import ConfigParser
10 import os
11 import uuid
12 import logging
13
14
15 CURRENT_SCHEMA_VERSION = 1
16 OPTIONS_COLUMNS = ['name', 'option', 'value']
17 UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
18
19
20 class SqlStore(Log):
21     __instances = {}
22
23     @classmethod
24     def get_connection(cls, name):
25         if name not in cls.__instances.keys():
26             if cherrypy.config.get('db.conn.log', False):
27                 logging.debug('SqlStore new: %s', name)
28             cls.__instances[name] = SqlStore(name)
29         return cls.__instances[name]
30
31     def __init__(self, name):
32         self.db_conn_log = cherrypy.config.get('db.conn.log', False)
33         self.debug('SqlStore init: %s' % name)
34         self.name = name
35         engine_name = name
36         if '://' not in engine_name:
37             engine_name = 'sqlite:///' + engine_name
38         # This pool size is per configured database. The minimum needed,
39         #  determined by binary search, is 23. We're using 25 so we have a bit
40         #  more playroom, and then the overflow should make sure things don't
41         #  break when we suddenly need more.
42         pool_args = {'poolclass': QueuePool,
43                      'pool_size': 25,
44                      'max_overflow': 50}
45         if engine_name.startswith('sqlite://'):
46             # It's not possible to share connections for SQLite between
47             #  threads, so let's use the SingletonThreadPool for them
48             pool_args = {'poolclass': SingletonThreadPool}
49         self._dbengine = create_engine(engine_name, **pool_args)
50         self.is_readonly = False
51
52     def debug(self, fact):
53         if self.db_conn_log:
54             super(SqlStore, self).debug(fact)
55
56     def engine(self):
57         return self._dbengine
58
59     def connection(self):
60         self.debug('SqlStore connect: %s' % self.name)
61         conn = self._dbengine.connect()
62
63         def cleanup_connection():
64             self.debug('SqlStore cleanup: %s' % self.name)
65             conn.close()
66         cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
67         return conn
68
69
70 def SqlAutotable(f):
71     def at(self, *args, **kwargs):
72         self.create()
73         return f(self, *args, **kwargs)
74     return at
75
76
77 class SqlQuery(Log):
78
79     def __init__(self, db_obj, table, columns, trans=True):
80         self._db = db_obj
81         self._con = self._db.connection()
82         self._trans = self._con.begin() if trans else None
83         self._table = self._get_table(table, columns)
84
85     def _get_table(self, name, columns):
86         table = Table(name, MetaData(self._db.engine()))
87         for c in columns:
88             table.append_column(Column(c, Text()))
89         return table
90
91     def _where(self, kvfilter):
92         where = None
93         if kvfilter is not None:
94             for k in kvfilter:
95                 w = self._table.columns[k] == kvfilter[k]
96                 if where is None:
97                     where = w
98                 else:
99                     where = where & w
100         return where
101
102     def _columns(self, columns=None):
103         cols = None
104         if columns is not None:
105             cols = []
106             for c in columns:
107                 cols.append(self._table.columns[c])
108         else:
109             cols = self._table.columns
110         return cols
111
112     def rollback(self):
113         self._trans.rollback()
114
115     def commit(self):
116         self._trans.commit()
117
118     def create(self):
119         self._table.create(checkfirst=True)
120
121     def drop(self):
122         self._table.drop(checkfirst=True)
123
124     @SqlAutotable
125     def select(self, kvfilter=None, columns=None):
126         return self._con.execute(select(self._columns(columns),
127                                         self._where(kvfilter)))
128
129     @SqlAutotable
130     def insert(self, values):
131         self._con.execute(self._table.insert(values))
132
133     @SqlAutotable
134     def update(self, values, kvfilter):
135         self._con.execute(self._table.update(self._where(kvfilter), values))
136
137     @SqlAutotable
138     def delete(self, kvfilter):
139         self._con.execute(self._table.delete(self._where(kvfilter)))
140
141
142 class FileStore(Log):
143
144     def __init__(self, name):
145         self._filename = name
146         self.is_readonly = True
147         self._timestamp = None
148         self._config = None
149
150     def get_config(self):
151         try:
152             stat = os.stat(self._filename)
153         except OSError, e:
154             self.error("Unable to check config file %s: [%s]" % (
155                 self._filename, e))
156             self._config = None
157             raise
158         timestamp = stat.st_mtime
159         if self._config is None or timestamp > self._timestamp:
160             self._config = ConfigParser.RawConfigParser()
161             self._config.optionxform = str
162             self._config.read(self._filename)
163         return self._config
164
165
166 class FileQuery(Log):
167
168     def __init__(self, fstore, table, columns, trans=True):
169         self._fstore = fstore
170         self._config = fstore.get_config()
171         self._section = table
172         if len(columns) > 3 or columns[-1] != 'value':
173             raise ValueError('Unsupported configuration format')
174         self._columns = columns
175
176     def rollback(self):
177         return
178
179     def commit(self):
180         return
181
182     def create(self):
183         raise NotImplementedError
184
185     def drop(self):
186         raise NotImplementedError
187
188     def select(self, kvfilter=None, columns=None):
189         if self._section not in self._config.sections():
190             return []
191
192         opts = self._config.options(self._section)
193
194         prefix = None
195         prefix_ = ''
196         if self._columns[0] in kvfilter:
197             prefix = kvfilter[self._columns[0]]
198             prefix_ = prefix + ' '
199
200         name = None
201         if len(self._columns) == 3 and self._columns[1] in kvfilter:
202             name = kvfilter[self._columns[1]]
203
204         value = None
205         if self._columns[-1] in kvfilter:
206             value = kvfilter[self._columns[-1]]
207
208         res = []
209         for o in opts:
210             if len(self._columns) == 3:
211                 # 3 cols
212                 if prefix and not o.startswith(prefix_):
213                     continue
214
215                 col1, col2 = o.split(' ', 1)
216                 if name and col2 != name:
217                     continue
218
219                 col3 = self._config.get(self._section, o)
220                 if value and col3 != value:
221                     continue
222
223                 r = [col1, col2, col3]
224             else:
225                 # 2 cols
226                 if prefix and o != prefix:
227                     continue
228                 r = [o, self._config.get(self._section, o)]
229
230             if columns:
231                 s = []
232                 for c in columns:
233                     s.append(r[self._columns.index(c)])
234                 res.append(s)
235             else:
236                 res.append(r)
237
238         self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
239                                                  repr(kvfilter),
240                                                  repr(columns),
241                                                  repr(res)))
242         return res
243
244     def insert(self, values):
245         raise NotImplementedError
246
247     def update(self, values, kvfilter):
248         raise NotImplementedError
249
250     def delete(self, kvfilter):
251         raise NotImplementedError
252
253
254 class Store(Log):
255     def __init__(self, config_name=None, database_url=None):
256         if config_name is None and database_url is None:
257             raise ValueError('config_name or database_url must be provided')
258         if config_name:
259             if config_name not in cherrypy.config:
260                 raise NameError('Unknown database %s' % config_name)
261             name = cherrypy.config[config_name]
262         else:
263             name = database_url
264         if name.startswith('configfile://'):
265             _, filename = name.split('://')
266             self._db = FileStore(filename)
267             self._query = FileQuery
268         else:
269             self._db = SqlStore.get_connection(name)
270             self._query = SqlQuery
271         self._upgrade_database()
272
273     def _upgrade_database(self):
274         if self.is_readonly:
275             # If the database is readonly, we cannot do anything to the
276             #  schema. Let's just return, and assume people checked the
277             #  upgrade notes
278             return
279         current_version = self.load_options('dbinfo').get('scheme', None)
280         if current_version is None or 'version' not in current_version:
281             # No version stored, storing current version
282             self.save_options('dbinfo', 'scheme',
283                               {'version': CURRENT_SCHEMA_VERSION})
284             current_version = CURRENT_SCHEMA_VERSION
285         else:
286             current_version = int(current_version['version'])
287         if current_version != CURRENT_SCHEMA_VERSION:
288             self.debug('Upgrading database schema from %i to %i' % (
289                        current_version, CURRENT_SCHEMA_VERSION))
290             self._upgrade_database_from(current_version)
291
292     def _upgrade_database_from(self, old_schema_version):
293         # Insert code here to upgrade from old_schema_version to
294         #  CURRENT_SCHEMA_VERSION
295         raise Exception('Unable to upgrade database to current schema'
296                         ' version: version %i is unknown!' %
297                         old_schema_version)
298
299     @property
300     def is_readonly(self):
301         return self._db.is_readonly
302
303     def _row_to_dict_tree(self, data, row):
304         name = row[0]
305         if len(row) > 2:
306             if name not in data:
307                 data[name] = dict()
308             d2 = data[name]
309             self._row_to_dict_tree(d2, row[1:])
310         else:
311             value = row[1]
312             if name in data:
313                 if data[name] is list:
314                     data[name].append(value)
315                 else:
316                     v = data[name]
317                     data[name] = [v, value]
318             else:
319                 data[name] = value
320
321     def _rows_to_dict_tree(self, rows):
322         data = dict()
323         for r in rows:
324             self._row_to_dict_tree(data, r)
325         return data
326
327     def _load_data(self, table, columns, kvfilter=None):
328         rows = []
329         try:
330             q = self._query(self._db, table, columns, trans=False)
331             rows = q.select(kvfilter)
332         except Exception, e:  # pylint: disable=broad-except
333             self.error("Failed to load data for table %s: [%s]" % (table, e))
334         return self._rows_to_dict_tree(rows)
335
336     def load_config(self):
337         table = 'config'
338         columns = ['name', 'value']
339         return self._load_data(table, columns)
340
341     def load_options(self, table, name=None):
342         kvfilter = dict()
343         if name:
344             kvfilter['name'] = name
345         options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
346         if name and name in options:
347             return options[name]
348         return options
349
350     def save_options(self, table, name, options):
351         curvals = dict()
352         q = None
353         try:
354             q = self._query(self._db, table, OPTIONS_COLUMNS)
355             rows = q.select({'name': name}, ['option', 'value'])
356             for row in rows:
357                 curvals[row[0]] = row[1]
358
359             for opt in options:
360                 if opt in curvals:
361                     q.update({'value': options[opt]},
362                              {'name': name, 'option': opt})
363                 else:
364                     q.insert((name, opt, options[opt]))
365
366             q.commit()
367         except Exception, e:  # pylint: disable=broad-except
368             if q:
369                 q.rollback()
370             self.error("Failed to save options: [%s]" % e)
371             raise
372
373     def delete_options(self, table, name, options=None):
374         kvfilter = {'name': name}
375         q = None
376         try:
377             q = self._query(self._db, table, OPTIONS_COLUMNS)
378             if options is None:
379                 q.delete(kvfilter)
380             else:
381                 for opt in options:
382                     kvfilter['option'] = opt
383                     q.delete(kvfilter)
384             q.commit()
385         except Exception, e:  # pylint: disable=broad-except
386             if q:
387                 q.rollback()
388             self.error("Failed to delete from %s: [%s]" % (table, e))
389             raise
390
391     def new_unique_data(self, table, data):
392         newid = str(uuid.uuid4())
393         q = None
394         try:
395             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
396             for name in data:
397                 q.insert((newid, name, data[name]))
398             q.commit()
399         except Exception, e:  # pylint: disable=broad-except
400             if q:
401                 q.rollback()
402             self.error("Failed to store %s data: [%s]" % (table, e))
403             raise
404         return newid
405
406     def get_unique_data(self, table, uuidval=None, name=None, value=None):
407         kvfilter = dict()
408         if uuidval:
409             kvfilter['uuid'] = uuidval
410         if name:
411             kvfilter['name'] = name
412         if value:
413             kvfilter['value'] = value
414         return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
415
416     def save_unique_data(self, table, data):
417         q = None
418         try:
419             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
420             for uid in data:
421                 curvals = dict()
422                 rows = q.select({'uuid': uid}, ['name', 'value'])
423                 for r in rows:
424                     curvals[r[0]] = r[1]
425
426                 datum = data[uid]
427                 for name in datum:
428                     if name in curvals:
429                         if datum[name] is None:
430                             q.delete({'uuid': uid, 'name': name})
431                         else:
432                             q.update({'value': datum[name]},
433                                      {'uuid': uid, 'name': name})
434                     else:
435                         if datum[name] is not None:
436                             q.insert((uid, name, datum[name]))
437
438             q.commit()
439         except Exception, e:  # pylint: disable=broad-except
440             if q:
441                 q.rollback()
442             self.error("Failed to store data in %s: [%s]" % (table, e))
443             raise
444
445     def del_unique_data(self, table, uuidval):
446         kvfilter = {'uuid': uuidval}
447         try:
448             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
449             q.delete(kvfilter)
450         except Exception, e:  # pylint: disable=broad-except
451             self.error("Failed to delete data from %s: [%s]" % (table, e))
452
453     def _reset_data(self, table):
454         q = None
455         try:
456             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
457             q.drop()
458             q.create()
459             q.commit()
460         except Exception, e:  # pylint: disable=broad-except
461             if q:
462                 q.rollback()
463             self.error("Failed to erase all data from %s: [%s]" % (table, e))
464
465
466 class AdminStore(Store):
467
468     def __init__(self):
469         super(AdminStore, self).__init__('admin.config.db')
470
471     def get_data(self, plugin, idval=None, name=None, value=None):
472         return self.get_unique_data(plugin+"_data", idval, name, value)
473
474     def save_data(self, plugin, data):
475         return self.save_unique_data(plugin+"_data", data)
476
477     def new_datum(self, plugin, datum):
478         table = plugin+"_data"
479         return self.new_unique_data(table, datum)
480
481     def del_datum(self, plugin, idval):
482         table = plugin+"_data"
483         return self.del_unique_data(table, idval)
484
485     def wipe_data(self, plugin):
486         table = plugin+"_data"
487         self._reset_data(table)
488
489
490 class UserStore(Store):
491
492     def __init__(self, path=None):
493         super(UserStore, self).__init__('user.prefs.db')
494
495     def save_user_preferences(self, user, options):
496         self.save_options('users', user, options)
497
498     def load_user_preferences(self, user):
499         return self.load_options('users', user)
500
501     def save_plugin_data(self, plugin, user, options):
502         self.save_options(plugin+"_data", user, options)
503
504     def load_plugin_data(self, plugin, user):
505         return self.load_options(plugin+"_data", user)
506
507
508 class TranStore(Store):
509
510     def __init__(self, path=None):
511         super(TranStore, self).__init__('transactions.db')