Add support for storing SAML2 sessions
[cascardo/ipsilon.git] / ipsilon / util / data.py
index 0d1c2df..f90519d 100644 (file)
@@ -1,19 +1,4 @@
-# Copyright (C) 2013  Simo Sorce <simo@redhat.com>
-#
-# see file 'COPYING' for use and warranty information
-#
-# This program is free software; you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+# Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
 
 import cherrypy
 from ipsilon.util.log import Log
@@ -27,6 +12,7 @@ import uuid
 import logging
 
 
+CURRENT_SCHEMA_VERSION = 1
 OPTIONS_COLUMNS = ['name', 'option', 'value']
 UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
 
@@ -37,11 +23,13 @@ class SqlStore(Log):
     @classmethod
     def get_connection(cls, name):
         if name not in cls.__instances.keys():
-            logging.debug('SqlStore new: %s', name)
+            if cherrypy.config.get('db.conn.log', False):
+                logging.debug('SqlStore new: %s', name)
             cls.__instances[name] = SqlStore(name)
         return cls.__instances[name]
 
     def __init__(self, name):
+        self.db_conn_log = cherrypy.config.get('db.conn.log', False)
         self.debug('SqlStore init: %s' % name)
         self.name = name
         engine_name = name
@@ -58,10 +46,13 @@ class SqlStore(Log):
             # It's not possible to share connections for SQLite between
             #  threads, so let's use the SingletonThreadPool for them
             pool_args = {'poolclass': SingletonThreadPool}
-        # pylint: disable=star-args
         self._dbengine = create_engine(engine_name, **pool_args)
         self.is_readonly = False
 
+    def debug(self, fact):
+        if self.db_conn_log:
+            super(SqlStore, self).debug(fact)
+
     def engine(self):
         return self._dbengine
 
@@ -277,6 +268,33 @@ class Store(Log):
         else:
             self._db = SqlStore.get_connection(name)
             self._query = SqlQuery
+        self._upgrade_database()
+
+    def _upgrade_database(self):
+        if self.is_readonly:
+            # If the database is readonly, we cannot do anything to the
+            #  schema. Let's just return, and assume people checked the
+            #  upgrade notes
+            return
+        current_version = self.load_options('dbinfo').get('scheme', None)
+        if current_version is None or 'version' not in current_version:
+            # No version stored, storing current version
+            self.save_options('dbinfo', 'scheme',
+                              {'version': CURRENT_SCHEMA_VERSION})
+            current_version = CURRENT_SCHEMA_VERSION
+        else:
+            current_version = int(current_version['version'])
+        if current_version != CURRENT_SCHEMA_VERSION:
+            self.debug('Upgrading database schema from %i to %i' % (
+                       current_version, CURRENT_SCHEMA_VERSION))
+            self._upgrade_database_from(current_version)
+
+    def _upgrade_database_from(self, old_schema_version):
+        # Insert code here to upgrade from old_schema_version to
+        #  CURRENT_SCHEMA_VERSION
+        raise Exception('Unable to upgrade database to current schema'
+                        ' version: version %i is unknown!' %
+                        old_schema_version)
 
     @property
     def is_readonly(self):
@@ -491,3 +509,67 @@ class TranStore(Store):
 
     def __init__(self, path=None):
         super(TranStore, self).__init__('transactions.db')
+
+
+class SAML2SessionStore(Store):
+
+    def __init__(self, path=None):
+        super(SAML2SessionStore, self).__init__('saml2.sessions.db')
+        self.table = 'sessions'
+
+    def _get_unique_id_from_column(self, name, value):
+        """
+        The query is going to return only the column in the query.
+        Use this method to get the uuidval which can be used to fetch
+        the entire entry.
+
+        Returns None or the uuid of the first value found.
+        """
+        data = self.get_unique_data(self.table, name=name, value=value)
+        count = len(data)
+        if count == 0:
+            return None
+        elif count != 1:
+            raise ValueError("Multiple entries returned")
+        return data.keys()[0]
+
+    def get_data(self, idval=None, name=None, value=None):
+        return self.get_unique_data(self.table, idval, name, value)
+
+    def new_session(self, datum):
+        return self.new_unique_data(self.table, datum)
+
+    def get_session(self, session_id=None, request_id=None):
+        if session_id:
+            uuidval = self._get_unique_id_from_column('session_id', session_id)
+        elif request_id:
+            uuidval = self._get_unique_id_from_column('request_id', request_id)
+        else:
+            raise ValueError("Unable to find session")
+        if not uuidval:
+            return None, None
+        data = self.get_unique_data(self.table, uuidval=uuidval)
+        return uuidval, data[uuidval]
+
+    def get_user_sessions(self, user):
+        """
+        Retrun a list of all sessions for a given user.
+        """
+        rows = self.get_unique_data(self.table, name='user', value=user)
+
+        # We have a list of sessions for this user, now get the details
+        logged_in = []
+        for r in rows:
+            data = self.get_unique_data(self.table, uuidval=r)
+            logged_in.append(data)
+
+        return logged_in
+
+    def update_session(self, datum):
+        self.save_unique_data(self.table, datum)
+
+    def remove_session(self, uuidval):
+        self.del_unique_data(self.table, uuidval)
+
+    def wipe_data(self):
+        self._reset_data(self.table)