Add SQL primary key and indexes
authorPatrick Uiterwijk <puiterwijk@redhat.com>
Tue, 11 Aug 2015 09:52:13 +0000 (11:52 +0200)
committerPatrick Uiterwijk <puiterwijk@redhat.com>
Wed, 2 Sep 2015 15:17:52 +0000 (17:17 +0200)
Signed-off-by: Patrick Uiterwijk <puiterwijk@redhat.com>
Reviewed-by: Rob Crittenden <rcritten@redhat.com>
ipsilon/providers/openid/store.py
ipsilon/util/data.py
ipsilon/util/sessions.py

index e759bca..9b2bc4f 100644 (file)
@@ -1,6 +1,6 @@
 # Copyright (C) 2014 Ipsilon project Contributors, for license see COPYING
 
 # Copyright (C) 2014 Ipsilon project Contributors, for license see COPYING
 
-from ipsilon.util.data import Store, UNIQUE_DATA_COLUMNS
+from ipsilon.util.data import Store, UNIQUE_DATA_TABLE
 
 from openid import oidutil
 from openid.association import Association
 
 from openid import oidutil
 from openid.association import Association
@@ -79,7 +79,7 @@ class OpenIDStore(Store, OpenIDStoreInterface):
                 self.del_unique_data('association', iden)
 
     def _initialize_schema(self):
                 self.del_unique_data('association', iden)
 
     def _initialize_schema(self):
-        q = self._query(self._db, 'association', UNIQUE_DATA_COLUMNS,
+        q = self._query(self._db, 'association', UNIQUE_DATA_TABLE,
                         trans=False)
         q.create()
 
                         trans=False)
         q.create()
 
index c0fe4ab..e92aae4 100644 (file)
@@ -6,6 +6,7 @@ from ipsilon.util.log import Log
 from sqlalchemy import create_engine
 from sqlalchemy import MetaData, Table, Column, Text
 from sqlalchemy.pool import QueuePool, SingletonThreadPool
 from sqlalchemy import create_engine
 from sqlalchemy import MetaData, Table, Column, Text
 from sqlalchemy.pool import QueuePool, SingletonThreadPool
+from sqlalchemy.schema import PrimaryKeyConstraint, Index
 from sqlalchemy.sql import select, and_
 import ConfigParser
 import os
 from sqlalchemy.sql import select, and_
 import ConfigParser
 import os
@@ -13,9 +14,15 @@ import uuid
 import logging
 
 
 import logging
 
 
-CURRENT_SCHEMA_VERSION = 1
-OPTIONS_COLUMNS = ['name', 'option', 'value']
-UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
+CURRENT_SCHEMA_VERSION = 2
+OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
+                 'primary_key': ('name', 'option'),
+                 'indexes': [('name',)]
+                 }
+UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
+                     'primary_key': ('uuid', 'name'),
+                     'indexes': [('uuid',)]
+                     }
 
 
 class DatabaseError(Exception):
 
 
 class DatabaseError(Exception):
@@ -74,16 +81,27 @@ class SqlStore(Log):
 
 class SqlQuery(Log):
 
 
 class SqlQuery(Log):
 
-    def __init__(self, db_obj, table, columns, trans=True):
+    def __init__(self, db_obj, table, table_def, trans=True):
         self._db = db_obj
         self._con = self._db.connection()
         self._trans = self._con.begin() if trans else None
         self._db = db_obj
         self._con = self._db.connection()
         self._trans = self._con.begin() if trans else None
-        self._table = self._get_table(table, columns)
-
-    def _get_table(self, name, columns):
-        table = Table(name, MetaData(self._db.engine()))
-        for c in columns:
-            table.append_column(Column(c, Text()))
+        self._table = self._get_table(table, table_def)
+
+    def _get_table(self, name, table_def):
+        if isinstance(table_def, list):
+            table_def = {'columns': table_def,
+                         'indexes': [],
+                         'primary_key': None}
+        table_creation = []
+        for col_name in table_def['columns']:
+            table_creation.append(Column(col_name, Text()))
+        if table_def['primary_key']:
+            table_creation.append(PrimaryKeyConstraint(
+                *table_def['primary_key']))
+        for index in table_def['indexes']:
+            idx_name = 'idx_%s_%s' % (name, '_'.join(index))
+            table_creation.append(Index(idx_name, *index))
+        table = Table(name, MetaData(self._db.engine()), *table_creation)
         return table
 
     def _where(self, kvfilter):
         return table
 
     def _where(self, kvfilter):
@@ -159,7 +177,12 @@ class FileStore(Log):
 
 class FileQuery(Log):
 
 
 class FileQuery(Log):
 
-    def __init__(self, fstore, table, columns, trans=True):
+    def __init__(self, fstore, table, table_def, trans=True):
+        # We don't need indexes in a FileQuery, so drop that info
+        if isinstance(table_def, dict):
+            columns = table_def['columns']
+        else:
+            columns = table_def
         self._fstore = fstore
         self._config = fstore.get_config()
         self._section = table
         self._fstore = fstore
         self._config = fstore.get_config()
         self._section = table
@@ -278,7 +301,7 @@ class Store(Log):
         # We are storing multiple versions: one per class
         # That way, we can support plugins with differing schema versions from
         #  the main codebase, and even in the same database.
         # We are storing multiple versions: one per class
         # That way, we can support plugins with differing schema versions from
         #  the main codebase, and even in the same database.
-        q = self._query(self._db, 'dbinfo', OPTIONS_COLUMNS, trans=False)
+        q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
         q.create()
         cls_name = self.__class__.__name__
         current_version = self.load_options('dbinfo').get('%s_schema'
         q.create()
         cls_name = self.__class__.__name__
         current_version = self.load_options('dbinfo').get('%s_schema'
@@ -387,7 +410,7 @@ class Store(Log):
         kvfilter = dict()
         if name:
             kvfilter['name'] = name
         kvfilter = dict()
         if name:
             kvfilter['name'] = name
-        options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
+        options = self._load_data(table, OPTIONS_TABLE, kvfilter)
         if name and name in options:
             return options[name]
         return options
         if name and name in options:
             return options[name]
         return options
@@ -396,7 +419,7 @@ class Store(Log):
         curvals = dict()
         q = None
         try:
         curvals = dict()
         q = None
         try:
-            q = self._query(self._db, table, OPTIONS_COLUMNS)
+            q = self._query(self._db, table, OPTIONS_TABLE)
             rows = q.select({'name': name}, ['option', 'value'])
             for row in rows:
                 curvals[row[0]] = row[1]
             rows = q.select({'name': name}, ['option', 'value'])
             for row in rows:
                 curvals[row[0]] = row[1]
@@ -419,7 +442,7 @@ class Store(Log):
         kvfilter = {'name': name}
         q = None
         try:
         kvfilter = {'name': name}
         q = None
         try:
-            q = self._query(self._db, table, OPTIONS_COLUMNS)
+            q = self._query(self._db, table, OPTIONS_TABLE)
             if options is None:
                 q.delete(kvfilter)
             else:
             if options is None:
                 q.delete(kvfilter)
             else:
@@ -437,7 +460,7 @@ class Store(Log):
         newid = str(uuid.uuid4())
         q = None
         try:
         newid = str(uuid.uuid4())
         q = None
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE)
             for name in data:
                 q.insert((newid, name, data[name]))
             q.commit()
             for name in data:
                 q.insert((newid, name, data[name]))
             q.commit()
@@ -456,12 +479,12 @@ class Store(Log):
             kvfilter['name'] = name
         if value:
             kvfilter['value'] = value
             kvfilter['name'] = name
         if value:
             kvfilter['value'] = value
-        return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
+        return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
 
     def save_unique_data(self, table, data):
         q = None
         try:
 
     def save_unique_data(self, table, data):
         q = None
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE)
             for uid in data:
                 curvals = dict()
                 rows = q.select({'uuid': uid}, ['name', 'value'])
             for uid in data:
                 curvals = dict()
                 rows = q.select({'uuid': uid}, ['name', 'value'])
@@ -490,7 +513,7 @@ class Store(Log):
     def del_unique_data(self, table, uuidval):
         kvfilter = {'uuid': uuidval}
         try:
     def del_unique_data(self, table, uuidval):
         kvfilter = {'uuid': uuidval}
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
             q.delete(kvfilter)
         except Exception, e:  # pylint: disable=broad-except
             self.error("Failed to delete data from %s: [%s]" % (table, e))
             q.delete(kvfilter)
         except Exception, e:  # pylint: disable=broad-except
             self.error("Failed to delete data from %s: [%s]" % (table, e))
@@ -498,7 +521,7 @@ class Store(Log):
     def _reset_data(self, table):
         q = None
         try:
     def _reset_data(self, table):
         q = None
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE)
             q.drop()
             q.create()
             q.commit()
             q.drop()
             q.create()
             q.commit()
@@ -536,7 +559,7 @@ class AdminStore(Store):
                       'info_config',
                       'login_config',
                       'provider_config']:
                       'info_config',
                       'login_config',
                       'provider_config']:
-            q = self._query(self._db, table, OPTIONS_COLUMNS, trans=False)
+            q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
             q.create()
 
     def _upgrade_schema(self, old_version):
             q.create()
 
     def _upgrade_schema(self, old_version):
@@ -561,7 +584,7 @@ class UserStore(Store):
         return self.load_options(plugin+"_data", user)
 
     def _initialize_schema(self):
         return self.load_options(plugin+"_data", user)
 
     def _initialize_schema(self):
-        q = self._query(self._db, 'users', OPTIONS_COLUMNS, trans=False)
+        q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
         q.create()
 
     def _upgrade_schema(self, old_version):
         q.create()
 
     def _upgrade_schema(self, old_version):
@@ -574,7 +597,7 @@ class TranStore(Store):
         super(TranStore, self).__init__('transactions.db')
 
     def _initialize_schema(self):
         super(TranStore, self).__init__('transactions.db')
 
     def _initialize_schema(self):
-        q = self._query(self._db, 'transactions', UNIQUE_DATA_COLUMNS,
+        q = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
                         trans=False)
         q.create()
 
                         trans=False)
         q.create()
 
@@ -588,7 +611,7 @@ class SAML2SessionStore(Store):
         super(SAML2SessionStore, self).__init__(database_url=database_url)
         self.table = 'saml2_sessions'
         # pylint: disable=protected-access
         super(SAML2SessionStore, self).__init__(database_url=database_url)
         self.table = 'saml2_sessions'
         # pylint: disable=protected-access
-        table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+        table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
         table.create(checkfirst=True)
 
     def _get_unique_id_from_column(self, name, value):
         table.create(checkfirst=True)
 
     def _get_unique_id_from_column(self, name, value):
@@ -609,7 +632,7 @@ class SAML2SessionStore(Store):
 
     def remove_expired_sessions(self):
         # pylint: disable=protected-access
 
     def remove_expired_sessions(self):
         # pylint: disable=protected-access
-        table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+        table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
         sel = select([table.columns.uuid]). \
             where(and_(table.c.name == 'expiration_time',
                        table.c.value <= datetime.datetime.now()))
         sel = select([table.columns.uuid]). \
             where(and_(table.c.name == 'expiration_time',
                        table.c.value <= datetime.datetime.now()))
@@ -665,7 +688,7 @@ class SAML2SessionStore(Store):
         self._reset_data(self.table)
 
     def _initialize_schema(self):
         self._reset_data(self.table)
 
     def _initialize_schema(self):
-        q = self._query(self._db, self.table, UNIQUE_DATA_COLUMNS,
+        q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
                         trans=False)
         q.create()
 
                         trans=False)
         q.create()
 
index b870319..ef059d1 100644 (file)
@@ -10,12 +10,15 @@ except ImportError:
     import pickle
 
 
     import pickle
 
 
-SESSION_COLUMNS = ['id', 'data', 'expiration_time']
+SESSION_TABLE = {'columns': ['id', 'data', 'expiration_time'],
+                 'primary_key': ('id', ),
+                 'indexes': [('expiration_time',)]
+                 }
 
 
 class SessionStore(Store):
     def _initialize_schema(self):
 
 
 class SessionStore(Store):
     def _initialize_schema(self):
-        q = self._query(self._db, 'sessions', SESSION_COLUMNS,
+        q = self._query(self._db, 'sessions', SESSION_TABLE,
                         trans=False)
         q.create()
 
                         trans=False)
         q.create()
 
@@ -44,12 +47,12 @@ class SqlSession(Session):
         cls._db = cls._store._db
 
     def _exists(self):
         cls._db = cls._store._db
 
     def _exists(self):
-        q = SqlQuery(self._db, 'sessions', SESSION_COLUMNS)
+        q = SqlQuery(self._db, 'sessions', SESSION_TABLE)
         result = q.select({'id': self.id})
         return True if result.fetchone() else False
 
     def _load(self):
         result = q.select({'id': self.id})
         return True if result.fetchone() else False
 
     def _load(self):
-        q = SqlQuery(self._db, 'sessions', SESSION_COLUMNS)
+        q = SqlQuery(self._db, 'sessions', SESSION_TABLE)
         result = q.select({'id': self.id})
         r = result.fetchone()
         if r:
         result = q.select({'id': self.id})
         r = result.fetchone()
         if r:
@@ -59,7 +62,7 @@ class SqlSession(Session):
     def _save(self, expiration_time):
         q = None
         try:
     def _save(self, expiration_time):
         q = None
         try:
-            q = SqlQuery(self._db, 'sessions', SESSION_COLUMNS, trans=True)
+            q = SqlQuery(self._db, 'sessions', SESSION_TABLE, trans=True)
             q.delete({'id': self.id})
             data = pickle.dumps((self._data, expiration_time), self._proto)
             q.insert((self.id, base64.b64encode(data), expiration_time))
             q.delete({'id': self.id})
             data = pickle.dumps((self._data, expiration_time), self._proto)
             q.insert((self.id, base64.b64encode(data), expiration_time))
@@ -70,7 +73,7 @@ class SqlSession(Session):
             raise
 
     def _delete(self):
             raise
 
     def _delete(self):
-        q = SqlQuery(self._db, 'sessions', SESSION_COLUMNS)
+        q = SqlQuery(self._db, 'sessions', SESSION_TABLE)
         q.delete({'id': self.id})
 
     # copy what RamSession does for now
         q.delete({'id': self.id})
 
     # copy what RamSession does for now