Add SQL primary key and indexes
[cascardo/ipsilon.git] / ipsilon / util / data.py
index c0fe4ab..e92aae4 100644 (file)
@@ -6,6 +6,7 @@ from ipsilon.util.log import Log
 from sqlalchemy import create_engine
 from sqlalchemy import MetaData, Table, Column, Text
 from sqlalchemy.pool import QueuePool, SingletonThreadPool
+from sqlalchemy.schema import PrimaryKeyConstraint, Index
 from sqlalchemy.sql import select, and_
 import ConfigParser
 import os
@@ -13,9 +14,15 @@ import uuid
 import logging
 
 
-CURRENT_SCHEMA_VERSION = 1
-OPTIONS_COLUMNS = ['name', 'option', 'value']
-UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
+CURRENT_SCHEMA_VERSION = 2
+OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
+                 'primary_key': ('name', 'option'),
+                 'indexes': [('name',)]
+                 }
+UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
+                     'primary_key': ('uuid', 'name'),
+                     'indexes': [('uuid',)]
+                     }
 
 
 class DatabaseError(Exception):
@@ -74,16 +81,27 @@ class SqlStore(Log):
 
 class SqlQuery(Log):
 
-    def __init__(self, db_obj, table, columns, trans=True):
+    def __init__(self, db_obj, table, table_def, trans=True):
         self._db = db_obj
         self._con = self._db.connection()
         self._trans = self._con.begin() if trans else None
-        self._table = self._get_table(table, columns)
-
-    def _get_table(self, name, columns):
-        table = Table(name, MetaData(self._db.engine()))
-        for c in columns:
-            table.append_column(Column(c, Text()))
+        self._table = self._get_table(table, table_def)
+
+    def _get_table(self, name, table_def):
+        if isinstance(table_def, list):
+            table_def = {'columns': table_def,
+                         'indexes': [],
+                         'primary_key': None}
+        table_creation = []
+        for col_name in table_def['columns']:
+            table_creation.append(Column(col_name, Text()))
+        if table_def['primary_key']:
+            table_creation.append(PrimaryKeyConstraint(
+                *table_def['primary_key']))
+        for index in table_def['indexes']:
+            idx_name = 'idx_%s_%s' % (name, '_'.join(index))
+            table_creation.append(Index(idx_name, *index))
+        table = Table(name, MetaData(self._db.engine()), *table_creation)
         return table
 
     def _where(self, kvfilter):
@@ -159,7 +177,12 @@ class FileStore(Log):
 
 class FileQuery(Log):
 
-    def __init__(self, fstore, table, columns, trans=True):
+    def __init__(self, fstore, table, table_def, trans=True):
+        # We don't need indexes in a FileQuery, so drop that info
+        if isinstance(table_def, dict):
+            columns = table_def['columns']
+        else:
+            columns = table_def
         self._fstore = fstore
         self._config = fstore.get_config()
         self._section = table
@@ -278,7 +301,7 @@ class Store(Log):
         # We are storing multiple versions: one per class
         # That way, we can support plugins with differing schema versions from
         #  the main codebase, and even in the same database.
-        q = self._query(self._db, 'dbinfo', OPTIONS_COLUMNS, trans=False)
+        q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
         q.create()
         cls_name = self.__class__.__name__
         current_version = self.load_options('dbinfo').get('%s_schema'
@@ -387,7 +410,7 @@ class Store(Log):
         kvfilter = dict()
         if name:
             kvfilter['name'] = name
-        options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
+        options = self._load_data(table, OPTIONS_TABLE, kvfilter)
         if name and name in options:
             return options[name]
         return options
@@ -396,7 +419,7 @@ class Store(Log):
         curvals = dict()
         q = None
         try:
-            q = self._query(self._db, table, OPTIONS_COLUMNS)
+            q = self._query(self._db, table, OPTIONS_TABLE)
             rows = q.select({'name': name}, ['option', 'value'])
             for row in rows:
                 curvals[row[0]] = row[1]
@@ -419,7 +442,7 @@ class Store(Log):
         kvfilter = {'name': name}
         q = None
         try:
-            q = self._query(self._db, table, OPTIONS_COLUMNS)
+            q = self._query(self._db, table, OPTIONS_TABLE)
             if options is None:
                 q.delete(kvfilter)
             else:
@@ -437,7 +460,7 @@ class Store(Log):
         newid = str(uuid.uuid4())
         q = None
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE)
             for name in data:
                 q.insert((newid, name, data[name]))
             q.commit()
@@ -456,12 +479,12 @@ class Store(Log):
             kvfilter['name'] = name
         if value:
             kvfilter['value'] = value
-        return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
+        return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
 
     def save_unique_data(self, table, data):
         q = None
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE)
             for uid in data:
                 curvals = dict()
                 rows = q.select({'uuid': uid}, ['name', 'value'])
@@ -490,7 +513,7 @@ class Store(Log):
     def del_unique_data(self, table, uuidval):
         kvfilter = {'uuid': uuidval}
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
             q.delete(kvfilter)
         except Exception, e:  # pylint: disable=broad-except
             self.error("Failed to delete data from %s: [%s]" % (table, e))
@@ -498,7 +521,7 @@ class Store(Log):
     def _reset_data(self, table):
         q = None
         try:
-            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+            q = self._query(self._db, table, UNIQUE_DATA_TABLE)
             q.drop()
             q.create()
             q.commit()
@@ -536,7 +559,7 @@ class AdminStore(Store):
                       'info_config',
                       'login_config',
                       'provider_config']:
-            q = self._query(self._db, table, OPTIONS_COLUMNS, trans=False)
+            q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
             q.create()
 
     def _upgrade_schema(self, old_version):
@@ -561,7 +584,7 @@ class UserStore(Store):
         return self.load_options(plugin+"_data", user)
 
     def _initialize_schema(self):
-        q = self._query(self._db, 'users', OPTIONS_COLUMNS, trans=False)
+        q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
         q.create()
 
     def _upgrade_schema(self, old_version):
@@ -574,7 +597,7 @@ class TranStore(Store):
         super(TranStore, self).__init__('transactions.db')
 
     def _initialize_schema(self):
-        q = self._query(self._db, 'transactions', UNIQUE_DATA_COLUMNS,
+        q = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
                         trans=False)
         q.create()
 
@@ -588,7 +611,7 @@ class SAML2SessionStore(Store):
         super(SAML2SessionStore, self).__init__(database_url=database_url)
         self.table = 'saml2_sessions'
         # pylint: disable=protected-access
-        table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+        table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
         table.create(checkfirst=True)
 
     def _get_unique_id_from_column(self, name, value):
@@ -609,7 +632,7 @@ class SAML2SessionStore(Store):
 
     def remove_expired_sessions(self):
         # pylint: disable=protected-access
-        table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+        table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
         sel = select([table.columns.uuid]). \
             where(and_(table.c.name == 'expiration_time',
                        table.c.value <= datetime.datetime.now()))
@@ -665,7 +688,7 @@ class SAML2SessionStore(Store):
         self._reset_data(self.table)
 
     def _initialize_schema(self):
-        q = self._query(self._db, self.table, UNIQUE_DATA_COLUMNS,
+        q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
                         trans=False)
         q.create()