Close connections after creating the tables
[cascardo/ipsilon.git] / ipsilon / util / data.py
1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
2
3 import cherrypy
4 import datetime
5 from ipsilon.util.log import Log
6 from sqlalchemy import create_engine
7 from sqlalchemy import MetaData, Table, Column, Text
8 from sqlalchemy.pool import QueuePool, SingletonThreadPool
9 from sqlalchemy.schema import (PrimaryKeyConstraint, Index, AddConstraint,
10                                CreateIndex)
11 from sqlalchemy.sql import select, and_
12 import ConfigParser
13 import os
14 import uuid
15 import logging
16
17
18 CURRENT_SCHEMA_VERSION = 2
19 OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
20                  'primary_key': ('name', 'option'),
21                  'indexes': [('name',)]
22                  }
23 UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
24                      'primary_key': ('uuid', 'name'),
25                      'indexes': [('uuid',)]
26                      }
27
28
29 class DatabaseError(Exception):
30     pass
31
32
33 class BaseStore(Log):
34     # Some helper functions used for upgrades
35     def add_constraint(self, table):
36         raise NotImplementedError()
37
38     def add_index(self, index):
39         raise NotImplementedError()
40
41
42 class SqlStore(BaseStore):
43     __instances = {}
44
45     @classmethod
46     def get_connection(cls, name):
47         if name not in cls.__instances:
48             if cherrypy.config.get('db.conn.log', False):
49                 logging.debug('SqlStore new: %s', name)
50             cls.__instances[name] = SqlStore(name)
51         return cls.__instances[name]
52
53     def __init__(self, name):
54         self.db_conn_log = cherrypy.config.get('db.conn.log', False)
55         self.debug('SqlStore init: %s' % name)
56         self.name = name
57         engine_name = name
58         if '://' not in engine_name:
59             engine_name = 'sqlite:///' + engine_name
60         # This pool size is per configured database. The minimum needed,
61         #  determined by binary search, is 23. We're using 25 so we have a bit
62         #  more playroom, and then the overflow should make sure things don't
63         #  break when we suddenly need more.
64         pool_args = {'poolclass': QueuePool,
65                      'pool_size': 25,
66                      'max_overflow': 50}
67         if engine_name.startswith('sqlite://'):
68             # It's not possible to share connections for SQLite between
69             #  threads, so let's use the SingletonThreadPool for them
70             pool_args = {'poolclass': SingletonThreadPool}
71         self._dbengine = create_engine(engine_name, **pool_args)
72         self.is_readonly = False
73
74     def add_constraint(self, constraint):
75         if self._dbengine.dialect.name != 'sqlite':
76             # It is impossible to add constraints to a pre-existing table for
77             #  SQLite
78             # source: http://www.sqlite.org/omitted.html
79             create_constraint = AddConstraint(constraint, bind=self._dbengine)
80             create_constraint.execute()
81
82     def add_index(self, index):
83         add_index = CreateIndex(index, bind=self._dbengine)
84         add_index.execute()
85
86     def debug(self, fact):
87         if self.db_conn_log:
88             super(SqlStore, self).debug(fact)
89
90     def engine(self):
91         return self._dbengine
92
93     def connection(self):
94         self.debug('SqlStore connect: %s' % self.name)
95         conn = self._dbengine.connect()
96
97         def cleanup_connection():
98             self.debug('SqlStore cleanup: %s' % self.name)
99             conn.close()
100         cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
101         return conn
102
103
104 class SqlQuery(Log):
105
106     def __init__(self, db_obj, table, table_def, trans=True):
107         self._db = db_obj
108         self._con = self._db.connection()
109         self._trans = self._con.begin() if trans else None
110         self._table = self._get_table(table, table_def)
111
112     def _get_table(self, name, table_def):
113         if isinstance(table_def, list):
114             table_def = {'columns': table_def,
115                          'indexes': [],
116                          'primary_key': None}
117         table_creation = []
118         for col_name in table_def['columns']:
119             table_creation.append(Column(col_name, Text()))
120         if table_def['primary_key']:
121             table_creation.append(PrimaryKeyConstraint(
122                 *table_def['primary_key']))
123         for index in table_def['indexes']:
124             idx_name = 'idx_%s_%s' % (name, '_'.join(index))
125             table_creation.append(Index(idx_name, *index))
126         table = Table(name, MetaData(self._db.engine()), *table_creation)
127         return table
128
129     def _where(self, kvfilter):
130         where = None
131         if kvfilter is not None:
132             for k in kvfilter:
133                 w = self._table.columns[k] == kvfilter[k]
134                 if where is None:
135                     where = w
136                 else:
137                     where = where & w
138         return where
139
140     def _columns(self, columns=None):
141         cols = None
142         if columns is not None:
143             cols = []
144             for c in columns:
145                 cols.append(self._table.columns[c])
146         else:
147             cols = self._table.columns
148         return cols
149
150     def rollback(self):
151         self._trans.rollback()
152
153     def commit(self):
154         self._trans.commit()
155
156     def create(self):
157         self._table.create(checkfirst=True)
158
159     def drop(self):
160         self._table.drop(checkfirst=True)
161
162     def select(self, kvfilter=None, columns=None):
163         return self._con.execute(select(self._columns(columns),
164                                         self._where(kvfilter)))
165
166     def insert(self, values):
167         self._con.execute(self._table.insert(values))
168
169     def update(self, values, kvfilter):
170         self._con.execute(self._table.update(self._where(kvfilter), values))
171
172     def delete(self, kvfilter):
173         self._con.execute(self._table.delete(self._where(kvfilter)))
174
175
176 class FileStore(BaseStore):
177
178     def __init__(self, name):
179         self._filename = name
180         self.is_readonly = True
181         self._timestamp = None
182         self._config = None
183
184     def get_config(self):
185         try:
186             stat = os.stat(self._filename)
187         except OSError, e:
188             self.error("Unable to check config file %s: [%s]" % (
189                 self._filename, e))
190             self._config = None
191             raise
192         timestamp = stat.st_mtime
193         if self._config is None or timestamp > self._timestamp:
194             self._config = ConfigParser.RawConfigParser()
195             self._config.optionxform = str
196             self._config.read(self._filename)
197         return self._config
198
199     def add_constraint(self, table):
200         raise NotImplementedError()
201
202     def add_index(self, index):
203         raise NotImplementedError()
204
205
206 class FileQuery(Log):
207
208     def __init__(self, fstore, table, table_def, trans=True):
209         # We don't need indexes in a FileQuery, so drop that info
210         if isinstance(table_def, dict):
211             columns = table_def['columns']
212         else:
213             columns = table_def
214         self._fstore = fstore
215         self._config = fstore.get_config()
216         self._section = table
217         if len(columns) > 3 or columns[-1] != 'value':
218             raise ValueError('Unsupported configuration format')
219         self._columns = columns
220
221     def rollback(self):
222         return
223
224     def commit(self):
225         return
226
227     def create(self):
228         raise NotImplementedError
229
230     def drop(self):
231         raise NotImplementedError
232
233     def select(self, kvfilter=None, columns=None):
234         if self._section not in self._config.sections():
235             return []
236
237         opts = self._config.options(self._section)
238
239         prefix = None
240         prefix_ = ''
241         if self._columns[0] in kvfilter:
242             prefix = kvfilter[self._columns[0]]
243             prefix_ = prefix + ' '
244
245         name = None
246         if len(self._columns) == 3 and self._columns[1] in kvfilter:
247             name = kvfilter[self._columns[1]]
248
249         value = None
250         if self._columns[-1] in kvfilter:
251             value = kvfilter[self._columns[-1]]
252
253         res = []
254         for o in opts:
255             if len(self._columns) == 3:
256                 # 3 cols
257                 if prefix and not o.startswith(prefix_):
258                     continue
259
260                 col1, col2 = o.split(' ', 1)
261                 if name and col2 != name:
262                     continue
263
264                 col3 = self._config.get(self._section, o)
265                 if value and col3 != value:
266                     continue
267
268                 r = [col1, col2, col3]
269             else:
270                 # 2 cols
271                 if prefix and o != prefix:
272                     continue
273                 r = [o, self._config.get(self._section, o)]
274
275             if columns:
276                 s = []
277                 for c in columns:
278                     s.append(r[self._columns.index(c)])
279                 res.append(s)
280             else:
281                 res.append(r)
282
283         self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
284                                                  repr(kvfilter),
285                                                  repr(columns),
286                                                  repr(res)))
287         return res
288
289     def insert(self, values):
290         raise NotImplementedError
291
292     def update(self, values, kvfilter):
293         raise NotImplementedError
294
295     def delete(self, kvfilter):
296         raise NotImplementedError
297
298
299 class Store(Log):
300     _is_upgrade = False
301
302     def __init__(self, config_name=None, database_url=None):
303         if config_name is None and database_url is None:
304             raise ValueError('config_name or database_url must be provided')
305         if config_name:
306             if config_name not in cherrypy.config:
307                 raise NameError('Unknown database %s' % config_name)
308             name = cherrypy.config[config_name]
309         else:
310             name = database_url
311         if name.startswith('configfile://'):
312             _, filename = name.split('://')
313             self._db = FileStore(filename)
314             self._query = FileQuery
315         else:
316             self._db = SqlStore.get_connection(name)
317             self._query = SqlQuery
318
319         if not self._is_upgrade:
320             self._check_database()
321
322     def _code_schema_version(self):
323         # This function makes it possible for separate plugins to have
324         #  different schema versions. We default to the global schema
325         #  version.
326         return CURRENT_SCHEMA_VERSION
327
328     def _get_schema_version(self):
329         # We are storing multiple versions: one per class
330         # That way, we can support plugins with differing schema versions from
331         #  the main codebase, and even in the same database.
332         q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
333         q.create()
334         q._con.close()  # pylint: disable=protected-access
335         cls_name = self.__class__.__name__
336         current_version = self.load_options('dbinfo').get('%s_schema'
337                                                           % cls_name, {})
338         if 'version' in current_version:
339             return int(current_version['version'])
340         else:
341             # Also try the old table name.
342             # "scheme" was a typo, but we need to retain that now for compat
343             fallback_version = self.load_options('dbinfo').get('scheme',
344                                                                {})
345             if 'version' in fallback_version:
346                 # Explanation for this is in def upgrade_database(self)
347                 return -1
348             else:
349                 return None
350
351     def _check_database(self):
352         if self.is_readonly:
353             # If the database is readonly, we cannot do anything to the
354             #  schema. Let's just return, and assume people checked the
355             #  upgrade notes
356             return
357
358         current_version = self._get_schema_version()
359         if current_version is None:
360             self.error('Database initialization required! ' +
361                        'Please run ipsilon-upgrade-database')
362             raise DatabaseError('Database initialization required for %s' %
363                                 self.__class__.__name__)
364         if current_version != self._code_schema_version():
365             self.error('Database upgrade required! ' +
366                        'Please run ipsilon-upgrade-database')
367             raise DatabaseError('Database upgrade required for %s' %
368                                 self.__class__.__name__)
369
370     def _store_new_schema_version(self, new_version):
371         cls_name = self.__class__.__name__
372         self.save_options('dbinfo', '%s_schema' % cls_name,
373                           {'version': new_version})
374
375     def _initialize_schema(self):
376         raise NotImplementedError()
377
378     def _upgrade_schema(self, old_version):
379         # Datastores need to figure out what to do with bigger old_versions
380         #  themselves.
381         # They might implement downgrading if that's feasible, or just throw
382         #  NotImplementedError
383         # Should return the new schema version
384         raise NotImplementedError()
385
386     def upgrade_database(self):
387         # Do whatever is needed to get schema to current version
388         old_schema_version = self._get_schema_version()
389         if old_schema_version is None:
390             # Just initialize a new schema
391             self._initialize_schema()
392             self._store_new_schema_version(self._code_schema_version())
393         elif old_schema_version == -1:
394             # This is a special-case from 1.0: we only created tables at the
395             # first time they were actually used, but the upgrade code assumes
396             # that the tables exist. So let's fix this.
397             self._initialize_schema()
398             # The old version was schema version 1
399             self._store_new_schema_version(1)
400             self.upgrade_database()
401         elif old_schema_version != self._code_schema_version():
402             # Upgrade from old_schema_version to code_schema_version
403             self.debug('Upgrading from schema version %i' % old_schema_version)
404             new_version = self._upgrade_schema(old_schema_version)
405             if not new_version:
406                 error = ('Schema upgrade error: %s did not provide a ' +
407                          'new schema version number!' %
408                          self.__class__.__name__)
409                 self.error(error)
410                 raise Exception(error)
411             self._store_new_schema_version(new_version)
412             # Check if we are now up-to-date
413             self.upgrade_database()
414
415     @property
416     def is_readonly(self):
417         return self._db.is_readonly
418
419     def _row_to_dict_tree(self, data, row):
420         name = row[0]
421         if len(row) > 2:
422             if name not in data:
423                 data[name] = dict()
424             d2 = data[name]
425             self._row_to_dict_tree(d2, row[1:])
426         else:
427             value = row[1]
428             if name in data:
429                 if data[name] is list:
430                     data[name].append(value)
431                 else:
432                     v = data[name]
433                     data[name] = [v, value]
434             else:
435                 data[name] = value
436
437     def _rows_to_dict_tree(self, rows):
438         data = dict()
439         for r in rows:
440             self._row_to_dict_tree(data, r)
441         return data
442
443     def _load_data(self, table, columns, kvfilter=None):
444         rows = []
445         try:
446             q = self._query(self._db, table, columns, trans=False)
447             rows = q.select(kvfilter)
448         except Exception, e:  # pylint: disable=broad-except
449             self.error("Failed to load data for table %s: [%s]" % (table, e))
450         return self._rows_to_dict_tree(rows)
451
452     def load_config(self):
453         table = 'config'
454         columns = ['name', 'value']
455         return self._load_data(table, columns)
456
457     def load_options(self, table, name=None):
458         kvfilter = dict()
459         if name:
460             kvfilter['name'] = name
461         options = self._load_data(table, OPTIONS_TABLE, kvfilter)
462         if name and name in options:
463             return options[name]
464         return options
465
466     def save_options(self, table, name, options):
467         curvals = dict()
468         q = None
469         try:
470             q = self._query(self._db, table, OPTIONS_TABLE)
471             rows = q.select({'name': name}, ['option', 'value'])
472             for row in rows:
473                 curvals[row[0]] = row[1]
474
475             for opt in options:
476                 if opt in curvals:
477                     q.update({'value': options[opt]},
478                              {'name': name, 'option': opt})
479                 else:
480                     q.insert((name, opt, options[opt]))
481
482             q.commit()
483         except Exception, e:  # pylint: disable=broad-except
484             if q:
485                 q.rollback()
486             self.error("Failed to save options: [%s]" % e)
487             raise
488
489     def delete_options(self, table, name, options=None):
490         kvfilter = {'name': name}
491         q = None
492         try:
493             q = self._query(self._db, table, OPTIONS_TABLE)
494             if options is None:
495                 q.delete(kvfilter)
496             else:
497                 for opt in options:
498                     kvfilter['option'] = opt
499                     q.delete(kvfilter)
500             q.commit()
501         except Exception, e:  # pylint: disable=broad-except
502             if q:
503                 q.rollback()
504             self.error("Failed to delete from %s: [%s]" % (table, e))
505             raise
506
507     def new_unique_data(self, table, data):
508         newid = str(uuid.uuid4())
509         q = None
510         try:
511             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
512             for name in data:
513                 q.insert((newid, name, data[name]))
514             q.commit()
515         except Exception, e:  # pylint: disable=broad-except
516             if q:
517                 q.rollback()
518             self.error("Failed to store %s data: [%s]" % (table, e))
519             raise
520         return newid
521
522     def get_unique_data(self, table, uuidval=None, name=None, value=None):
523         kvfilter = dict()
524         if uuidval:
525             kvfilter['uuid'] = uuidval
526         if name:
527             kvfilter['name'] = name
528         if value:
529             kvfilter['value'] = value
530         return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
531
532     def save_unique_data(self, table, data):
533         q = None
534         try:
535             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
536             for uid in data:
537                 curvals = dict()
538                 rows = q.select({'uuid': uid}, ['name', 'value'])
539                 for r in rows:
540                     curvals[r[0]] = r[1]
541
542                 datum = data[uid]
543                 for name in datum:
544                     if name in curvals:
545                         if datum[name] is None:
546                             q.delete({'uuid': uid, 'name': name})
547                         else:
548                             q.update({'value': datum[name]},
549                                      {'uuid': uid, 'name': name})
550                     else:
551                         if datum[name] is not None:
552                             q.insert((uid, name, datum[name]))
553
554             q.commit()
555         except Exception, e:  # pylint: disable=broad-except
556             if q:
557                 q.rollback()
558             self.error("Failed to store data in %s: [%s]" % (table, e))
559             raise
560
561     def del_unique_data(self, table, uuidval):
562         kvfilter = {'uuid': uuidval}
563         try:
564             q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
565             q.delete(kvfilter)
566         except Exception, e:  # pylint: disable=broad-except
567             self.error("Failed to delete data from %s: [%s]" % (table, e))
568
569     def _reset_data(self, table):
570         q = None
571         try:
572             q = self._query(self._db, table, UNIQUE_DATA_TABLE)
573             q.drop()
574             q.create()
575             q.commit()
576         except Exception, e:  # pylint: disable=broad-except
577             if q:
578                 q.rollback()
579             self.error("Failed to erase all data from %s: [%s]" % (table, e))
580
581
582 class AdminStore(Store):
583
584     def __init__(self):
585         super(AdminStore, self).__init__('admin.config.db')
586
587     def get_data(self, plugin, idval=None, name=None, value=None):
588         return self.get_unique_data(plugin+"_data", idval, name, value)
589
590     def save_data(self, plugin, data):
591         return self.save_unique_data(plugin+"_data", data)
592
593     def new_datum(self, plugin, datum):
594         table = plugin+"_data"
595         return self.new_unique_data(table, datum)
596
597     def del_datum(self, plugin, idval):
598         table = plugin+"_data"
599         return self.del_unique_data(table, idval)
600
601     def wipe_data(self, plugin):
602         table = plugin+"_data"
603         self._reset_data(table)
604
605     def _initialize_schema(self):
606         for table in ['config',
607                       'info_config',
608                       'login_config',
609                       'provider_config']:
610             q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
611             q.create()
612             q._con.close()  # pylint: disable=protected-access
613
614     def _upgrade_schema(self, old_version):
615         if old_version == 1:
616             # In schema version 2, we added indexes and primary keys
617             for table in ['config',
618                           'info_config',
619                           'login_config',
620                           'provider_config']:
621                 # pylint: disable=protected-access
622                 table = self._query(self._db, table, OPTIONS_TABLE,
623                                     trans=False)._table
624                 self._db.add_constraint(table.primary_key)
625                 for index in table.indexes:
626                     self._db.add_index(index)
627             return 2
628         else:
629             raise NotImplementedError()
630
631     def create_plugin_data_table(self, plugin_name):
632         if not self.is_readonly:
633             table = plugin_name+'_data'
634             q = self._query(self._db, table, UNIQUE_DATA_TABLE,
635                             trans=False)
636             q.create()
637             q._con.close()  # pylint: disable=protected-access
638
639
640 class UserStore(Store):
641
642     def __init__(self, path=None):
643         super(UserStore, self).__init__('user.prefs.db')
644
645     def save_user_preferences(self, user, options):
646         self.save_options('users', user, options)
647
648     def load_user_preferences(self, user):
649         return self.load_options('users', user)
650
651     def save_plugin_data(self, plugin, user, options):
652         self.save_options(plugin+"_data", user, options)
653
654     def load_plugin_data(self, plugin, user):
655         return self.load_options(plugin+"_data", user)
656
657     def _initialize_schema(self):
658         q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
659         q.create()
660         q._con.close()  # pylint: disable=protected-access
661
662     def _upgrade_schema(self, old_version):
663         if old_version == 1:
664             # In schema version 2, we added indexes and primary keys
665             # pylint: disable=protected-access
666             table = self._query(self._db, 'users', OPTIONS_TABLE,
667                                 trans=False)._table
668             self._db.add_constraint(table.primary_key)
669             for index in table.indexes:
670                 self._db.add_index(index)
671             return 2
672         else:
673             raise NotImplementedError()
674
675
676 class TranStore(Store):
677
678     def __init__(self, path=None):
679         super(TranStore, self).__init__('transactions.db')
680
681     def _initialize_schema(self):
682         q = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
683                         trans=False)
684         q.create()
685         q._con.close()  # pylint: disable=protected-access
686
687     def _upgrade_schema(self, old_version):
688         if old_version == 1:
689             # In schema version 2, we added indexes and primary keys
690             # pylint: disable=protected-access
691             table = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
692                                 trans=False)._table
693             self._db.add_constraint(table.primary_key)
694             for index in table.indexes:
695                 self._db.add_index(index)
696             return 2
697         else:
698             raise NotImplementedError()
699
700
701 class SAML2SessionStore(Store):
702
703     def __init__(self, database_url):
704         super(SAML2SessionStore, self).__init__(database_url=database_url)
705         self.table = 'saml2_sessions'
706         # pylint: disable=protected-access
707         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
708         table.create(checkfirst=True)
709
710     def _get_unique_id_from_column(self, name, value):
711         """
712         The query is going to return only the column in the query.
713         Use this method to get the uuidval which can be used to fetch
714         the entire entry.
715
716         Returns None or the uuid of the first value found.
717         """
718         data = self.get_unique_data(self.table, name=name, value=value)
719         count = len(data)
720         if count == 0:
721             return None
722         elif count != 1:
723             raise ValueError("Multiple entries returned")
724         return data.keys()[0]
725
726     def remove_expired_sessions(self):
727         # pylint: disable=protected-access
728         table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
729         sel = select([table.columns.uuid]). \
730             where(and_(table.c.name == 'expiration_time',
731                        table.c.value <= datetime.datetime.now()))
732         # pylint: disable=no-value-for-parameter
733         d = table.delete().where(table.c.uuid.in_(sel))
734         d.execute()
735
736     def get_data(self, idval=None, name=None, value=None):
737         return self.get_unique_data(self.table, idval, name, value)
738
739     def new_session(self, datum):
740         if 'supported_logout_mechs' in datum:
741             datum['supported_logout_mechs'] = ','.join(
742                 datum['supported_logout_mechs']
743             )
744         return self.new_unique_data(self.table, datum)
745
746     def get_session(self, session_id=None, request_id=None):
747         if session_id:
748             uuidval = self._get_unique_id_from_column('session_id', session_id)
749         elif request_id:
750             uuidval = self._get_unique_id_from_column('request_id', request_id)
751         else:
752             raise ValueError("Unable to find session")
753         if not uuidval:
754             return None, None
755         data = self.get_unique_data(self.table, uuidval=uuidval)
756         return uuidval, data[uuidval]
757
758     def get_user_sessions(self, user):
759         """
760         Return a list of all sessions for a given user.
761         """
762         rows = self.get_unique_data(self.table, name='user', value=user)
763
764         # We have a list of sessions for this user, now get the details
765         logged_in = []
766         for r in rows:
767             data = self.get_unique_data(self.table, uuidval=r)
768             data[r]['supported_logout_mechs'] = data[r].get(
769                 'supported_logout_mechs', '').split(',')
770             logged_in.append(data)
771
772         return logged_in
773
774     def update_session(self, datum):
775         self.save_unique_data(self.table, datum)
776
777     def remove_session(self, uuidval):
778         self.del_unique_data(self.table, uuidval)
779
780     def wipe_data(self):
781         self._reset_data(self.table)
782
783     def _initialize_schema(self):
784         q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
785                         trans=False)
786         q.create()
787         q._con.close()  # pylint: disable=protected-access
788
789     def _upgrade_schema(self, old_version):
790         if old_version == 1:
791             # In schema version 2, we added indexes and primary keys
792             # pylint: disable=protected-access
793             table = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
794                                 trans=False)._table
795             self._db.add_constraint(table.primary_key)
796             for index in table.indexes:
797                 self._db.add_index(index)
798             return 2
799         else:
800             raise NotImplementedError()