Close database sesssions
[cascardo/ipsilon.git] / ipsilon / util / data.py
1 # Copyright (C) 2013  Simo Sorce <simo@redhat.com>
2 #
3 # see file 'COPYING' for use and warranty information
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 import cherrypy
19 from ipsilon.util.log import Log
20 from sqlalchemy import create_engine
21 from sqlalchemy import MetaData, Table, Column, Text
22 from sqlalchemy.pool import QueuePool, SingletonThreadPool
23 from sqlalchemy.sql import select
24 import ConfigParser
25 import os
26 import uuid
27
28
29 OPTIONS_COLUMNS = ['name', 'option', 'value']
30 UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
31
32
33 class SqlStore(Log):
34     __instances = {}
35
36     @classmethod
37     def get_connection(cls, name):
38         if name not in cls.__instances.keys():
39             print 'SqlStore new: %s' % name
40             cls.__instances[name] = SqlStore(name)
41         return cls.__instances[name]
42
43     def __init__(self, name):
44         self.debug('SqlStore init: %s' % name)
45         self.name = name
46         engine_name = name
47         if '://' not in engine_name:
48             engine_name = 'sqlite:///' + engine_name
49         # This pool size is per configured database. The minimum needed,
50         #  determined by binary search, is 23. We're using 25 so we have a bit
51         #  more playroom, and then the overflow should make sure things don't
52         #  break when we suddenly need more.
53         pool_args = {'poolclass': QueuePool,
54                      'pool_size': 25,
55                      'max_overflow': 50}
56         if engine_name.startswith('sqlite://'):
57             # It's not possible to share connections for SQLite between
58             #  threads, so let's use the SingletonThreadPool for them
59             pool_args = {'poolclass': SingletonThreadPool}
60         # pylint: disable=star-args
61         self._dbengine = create_engine(engine_name, **pool_args)
62         self.is_readonly = False
63
64     def engine(self):
65         return self._dbengine
66
67     def connection(self):
68         self.debug('SqlStore connect: %s' % self.name)
69         conn = self._dbengine.connect()
70
71         def cleanup_connection():
72             self.debug('SqlStore cleanup: %s' % self.name)
73             conn.close()
74         cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
75         return conn
76
77
78 def SqlAutotable(f):
79     def at(self, *args, **kwargs):
80         self.create()
81         return f(self, *args, **kwargs)
82     return at
83
84
85 class SqlQuery(Log):
86
87     def __init__(self, db_obj, table, columns, trans=True):
88         self._db = db_obj
89         self._con = self._db.connection()
90         self._trans = self._con.begin() if trans else None
91         self._table = self._get_table(table, columns)
92
93     def _get_table(self, name, columns):
94         table = Table(name, MetaData(self._db.engine()))
95         for c in columns:
96             table.append_column(Column(c, Text()))
97         return table
98
99     def _where(self, kvfilter):
100         where = None
101         if kvfilter is not None:
102             for k in kvfilter:
103                 w = self._table.columns[k] == kvfilter[k]
104                 if where is None:
105                     where = w
106                 else:
107                     where = where & w
108         return where
109
110     def _columns(self, columns=None):
111         cols = None
112         if columns is not None:
113             cols = []
114             for c in columns:
115                 cols.append(self._table.columns[c])
116         else:
117             cols = self._table.columns
118         return cols
119
120     def rollback(self):
121         self._trans.rollback()
122
123     def commit(self):
124         self._trans.commit()
125
126     def create(self):
127         self._table.create(checkfirst=True)
128
129     def drop(self):
130         self._table.drop(checkfirst=True)
131
132     @SqlAutotable
133     def select(self, kvfilter=None, columns=None):
134         return self._con.execute(select(self._columns(columns),
135                                         self._where(kvfilter)))
136
137     @SqlAutotable
138     def insert(self, values):
139         self._con.execute(self._table.insert(values))
140
141     @SqlAutotable
142     def update(self, values, kvfilter):
143         self._con.execute(self._table.update(self._where(kvfilter), values))
144
145     @SqlAutotable
146     def delete(self, kvfilter):
147         self._con.execute(self._table.delete(self._where(kvfilter)))
148
149
150 class FileStore(Log):
151
152     def __init__(self, name):
153         self._filename = name
154         self.is_readonly = True
155         self._timestamp = None
156         self._config = None
157
158     def get_config(self):
159         try:
160             stat = os.stat(self._filename)
161         except OSError, e:
162             self.error("Unable to check config file %s: [%s]" % (
163                 self._filename, e))
164             self._config = None
165             raise
166         timestamp = stat.st_mtime
167         if self._config is None or timestamp > self._timestamp:
168             self._config = ConfigParser.RawConfigParser()
169             self._config.optionxform = str
170             self._config.read(self._filename)
171         return self._config
172
173
174 class FileQuery(Log):
175
176     def __init__(self, fstore, table, columns, trans=True):
177         self._fstore = fstore
178         self._config = fstore.get_config()
179         self._section = table
180         if len(columns) > 3 or columns[-1] != 'value':
181             raise ValueError('Unsupported configuration format')
182         self._columns = columns
183
184     def rollback(self):
185         return
186
187     def commit(self):
188         return
189
190     def create(self):
191         raise NotImplementedError
192
193     def drop(self):
194         raise NotImplementedError
195
196     def select(self, kvfilter=None, columns=None):
197         if self._section not in self._config.sections():
198             return []
199
200         opts = self._config.options(self._section)
201
202         prefix = None
203         prefix_ = ''
204         if self._columns[0] in kvfilter:
205             prefix = kvfilter[self._columns[0]]
206             prefix_ = prefix + ' '
207
208         name = None
209         if len(self._columns) == 3 and self._columns[1] in kvfilter:
210             name = kvfilter[self._columns[1]]
211
212         value = None
213         if self._columns[-1] in kvfilter:
214             value = kvfilter[self._columns[-1]]
215
216         res = []
217         for o in opts:
218             if len(self._columns) == 3:
219                 # 3 cols
220                 if prefix and not o.startswith(prefix_):
221                     continue
222
223                 col1, col2 = o.split(' ', 1)
224                 if name and col2 != name:
225                     continue
226
227                 col3 = self._config.get(self._section, o)
228                 if value and col3 != value:
229                     continue
230
231                 r = [col1, col2, col3]
232             else:
233                 # 2 cols
234                 if prefix and o != prefix:
235                     continue
236                 r = [o, self._config.get(self._section, o)]
237
238             if columns:
239                 s = []
240                 for c in columns:
241                     s.append(r[self._columns.index(c)])
242                 res.append(s)
243             else:
244                 res.append(r)
245
246         self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
247                                                  repr(kvfilter),
248                                                  repr(columns),
249                                                  repr(res)))
250         return res
251
252     def insert(self, values):
253         raise NotImplementedError
254
255     def update(self, values, kvfilter):
256         raise NotImplementedError
257
258     def delete(self, kvfilter):
259         raise NotImplementedError
260
261
262 class Store(Log):
263     def __init__(self, config_name=None, database_url=None):
264         if config_name is None and database_url is None:
265             raise ValueError('config_name or database_url must be provided')
266         if config_name:
267             if config_name not in cherrypy.config:
268                 raise NameError('Unknown database %s' % config_name)
269             name = cherrypy.config[config_name]
270         else:
271             name = database_url
272         if name.startswith('configfile://'):
273             _, filename = name.split('://')
274             self._db = FileStore(filename)
275             self._query = FileQuery
276         else:
277             self._db = SqlStore.get_connection(name)
278             self._query = SqlQuery
279
280     @property
281     def is_readonly(self):
282         return self._db.is_readonly
283
284     def _row_to_dict_tree(self, data, row):
285         name = row[0]
286         if len(row) > 2:
287             if name not in data:
288                 data[name] = dict()
289             d2 = data[name]
290             self._row_to_dict_tree(d2, row[1:])
291         else:
292             value = row[1]
293             if name in data:
294                 if data[name] is list:
295                     data[name].append(value)
296                 else:
297                     v = data[name]
298                     data[name] = [v, value]
299             else:
300                 data[name] = value
301
302     def _rows_to_dict_tree(self, rows):
303         data = dict()
304         for r in rows:
305             self._row_to_dict_tree(data, r)
306         return data
307
308     def _load_data(self, table, columns, kvfilter=None):
309         rows = []
310         try:
311             q = self._query(self._db, table, columns, trans=False)
312             rows = q.select(kvfilter)
313         except Exception, e:  # pylint: disable=broad-except
314             self.error("Failed to load data for table %s: [%s]" % (table, e))
315         return self._rows_to_dict_tree(rows)
316
317     def load_config(self):
318         table = 'config'
319         columns = ['name', 'value']
320         return self._load_data(table, columns)
321
322     def load_options(self, table, name=None):
323         kvfilter = dict()
324         if name:
325             kvfilter['name'] = name
326         options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
327         if name and name in options:
328             return options[name]
329         return options
330
331     def save_options(self, table, name, options):
332         curvals = dict()
333         q = None
334         try:
335             q = self._query(self._db, table, OPTIONS_COLUMNS)
336             rows = q.select({'name': name}, ['option', 'value'])
337             for row in rows:
338                 curvals[row[0]] = row[1]
339
340             for opt in options:
341                 if opt in curvals:
342                     q.update({'value': options[opt]},
343                              {'name': name, 'option': opt})
344                 else:
345                     q.insert((name, opt, options[opt]))
346
347             q.commit()
348         except Exception, e:  # pylint: disable=broad-except
349             if q:
350                 q.rollback()
351             self.error("Failed to save options: [%s]" % e)
352             raise
353
354     def delete_options(self, table, name, options=None):
355         kvfilter = {'name': name}
356         q = None
357         try:
358             q = self._query(self._db, table, OPTIONS_COLUMNS)
359             if options is None:
360                 q.delete(kvfilter)
361             else:
362                 for opt in options:
363                     kvfilter['option'] = opt
364                     q.delete(kvfilter)
365             q.commit()
366         except Exception, e:  # pylint: disable=broad-except
367             if q:
368                 q.rollback()
369             self.error("Failed to delete from %s: [%s]" % (table, e))
370             raise
371
372     def new_unique_data(self, table, data):
373         newid = str(uuid.uuid4())
374         q = None
375         try:
376             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
377             for name in data:
378                 q.insert((newid, name, data[name]))
379             q.commit()
380         except Exception, e:  # pylint: disable=broad-except
381             if q:
382                 q.rollback()
383             self.error("Failed to store %s data: [%s]" % (table, e))
384             raise
385         return newid
386
387     def get_unique_data(self, table, uuidval=None, name=None, value=None):
388         kvfilter = dict()
389         if uuidval:
390             kvfilter['uuid'] = uuidval
391         if name:
392             kvfilter['name'] = name
393         if value:
394             kvfilter['value'] = value
395         return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
396
397     def save_unique_data(self, table, data):
398         q = None
399         try:
400             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
401             for uid in data:
402                 curvals = dict()
403                 rows = q.select({'uuid': uid}, ['name', 'value'])
404                 for r in rows:
405                     curvals[r[0]] = r[1]
406
407                 datum = data[uid]
408                 for name in datum:
409                     if name in curvals:
410                         if datum[name] is None:
411                             q.delete({'uuid': uid, 'name': name})
412                         else:
413                             q.update({'value': datum[name]},
414                                      {'uuid': uid, 'name': name})
415                     else:
416                         if datum[name] is not None:
417                             q.insert((uid, name, datum[name]))
418
419             q.commit()
420         except Exception, e:  # pylint: disable=broad-except
421             if q:
422                 q.rollback()
423             self.error("Failed to store data in %s: [%s]" % (table, e))
424             raise
425
426     def del_unique_data(self, table, uuidval):
427         kvfilter = {'uuid': uuidval}
428         try:
429             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
430             q.delete(kvfilter)
431         except Exception, e:  # pylint: disable=broad-except
432             self.error("Failed to delete data from %s: [%s]" % (table, e))
433
434     def _reset_data(self, table):
435         q = None
436         try:
437             q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
438             q.drop()
439             q.create()
440             q.commit()
441         except Exception, e:  # pylint: disable=broad-except
442             if q:
443                 q.rollback()
444             self.error("Failed to erase all data from %s: [%s]" % (table, e))
445
446
447 class AdminStore(Store):
448
449     def __init__(self):
450         super(AdminStore, self).__init__('admin.config.db')
451
452     def get_data(self, plugin, idval=None, name=None, value=None):
453         return self.get_unique_data(plugin+"_data", idval, name, value)
454
455     def save_data(self, plugin, data):
456         return self.save_unique_data(plugin+"_data", data)
457
458     def new_datum(self, plugin, datum):
459         table = plugin+"_data"
460         return self.new_unique_data(table, datum)
461
462     def del_datum(self, plugin, idval):
463         table = plugin+"_data"
464         return self.del_unique_data(table, idval)
465
466     def wipe_data(self, plugin):
467         table = plugin+"_data"
468         self._reset_data(table)
469
470
471 class UserStore(Store):
472
473     def __init__(self, path=None):
474         super(UserStore, self).__init__('user.prefs.db')
475
476     def save_user_preferences(self, user, options):
477         self.save_options('users', user, options)
478
479     def load_user_preferences(self, user):
480         return self.load_options('users', user)
481
482     def save_plugin_data(self, plugin, user, options):
483         self.save_options(plugin+"_data", user, options)
484
485     def load_plugin_data(self, plugin, user):
486         return self.load_options(plugin+"_data", user)
487
488
489 class TranStore(Store):
490
491     def __init__(self, path=None):
492         super(TranStore, self).__init__('transactions.db')