Change SAML2 sessions backend to use Store API
authorRob Crittenden <rcritten@redhat.com>
Tue, 21 Apr 2015 13:35:25 +0000 (09:35 -0400)
committerPatrick Uiterwijk <puiterwijk@redhat.com>
Mon, 11 May 2015 22:39:10 +0000 (00:39 +0200)
The basic session API remains the same, just replace
the calls to pull data out of the user session to
instead pull from the database.

The per-session logout state is now a constant rather than
being a member of either the logged_in or logging_out
dictionaries.

https://fedorahosted.org/ipsilon/ticket/90

Signed-off-by: Rob Crittenden <rcritten@redhat.com>
Reviewed-by: Patrick Uiterwijk <puiterwijk@redhat.com>
ipsilon/providers/saml2/sessions.py

index c9cfd9c..d01bb6e 100644 (file)
 # Copyright (C) 2015 Ipsilon project Contributors, for license see COPYING
 
 from ipsilon.util.log import Log
+from ipsilon.util.data import SAML2SessionStore
+
+LOGGED_IN = 1
+INIT_LOGOUT = 2
+LOGGING_OUT = 4
+LOGGED_OUT = 8
 
 
 class SAMLSession(Log):
     """
-    A SAML login session used to track login/logout state.
+    A SAML login session.
 
+       uuidval - Unique ID stored in the database
        session_id - ID of the login session
        provider_id - ID of the SP
-       session - the Login session object
-       logoutstate - dict containing logout state info
-       session_indexes - the IDs of any login session we've seen
-                         for this user
-
-    When a new session is seen for the same user any existing session
-    is thrown away. We keep the original session_id though and send
-    all that we've seen to the SP when performing a logout to ensure
-    that all sessions get logged out.
-
-    logout state is a dictionary containing (potentially)
-    these attributes:
-
-    relaystate - The relaystate from the Logout Request or Response
-    id         - The Logout request id that initiated the logout
-    request    - Dump of the initial logout request
+       user - the login name of the user that owns the session
+       login_session - the Login session object
+       logoutstate - an integer constant representing where in the
+                     logout process this request is
+       relaystate - where the user will be redirected when logout is
+                    complete
+       request_id - the logout request ID if initiated from IdP. The
+                    logout response will include an InResponseTo value
+                    which matches this.
+       logout_request - the Logout request object
     """
-    def __init__(self, session_id, provider_id, session,
-                 logoutstate=None):
+    def __init__(self, uuidval, session_id, provider_id, user,
+                 login_session, logoutstate=None, relaystate=None,
+                 logout_request=None, request_id=None):
 
+        self.uuidval = uuidval
         self.session_id = session_id
         self.provider_id = provider_id
-        self.session = session
+        self.user = user
+        self.login_session = login_session
         self.logoutstate = logoutstate
-        self.session_indexes = [session_id]
+        self.relaystate = relaystate
+        self.request_id = request_id
+        self.logout_request = logout_request
+
+    def set_logoutstate(self, relaystate=None, request=None, request_id=None):
+        """
+        Update attributes needed to determine the state of the session for
+        logout.
 
-    def set_logoutstate(self, relaystate, request_id, request=None):
-        self.logoutstate = dict(relaystate=relaystate,
-                                id=request_id,
-                                request=request)
+        The database is not updated when these are set. It is expected that
+        this is called prior to start_logout()
+        """
+        if relaystate:
+            self.relaystate = relaystate
+        if request:
+            self.logout_request = request
+        if request_id:
+            self.request_id = request_id
 
     def dump(self):
         self.debug('session_id %s' % self.session_id)
-        self.debug('session_index %s' % self.session_indexes)
         self.debug('provider_id %s' % self.provider_id)
-        self.debug('session %s' % self.session)
+        self.debug('login session %s' % self.login_session)
         self.debug('logoutstate %s' % self.logoutstate)
 
+    def convert(self):
+        """
+        Convert this object into something suitable to store in the
+        data backend.
+        """
+        data = dict()
+        data['session_id'] = self.session_id
+        data['provider_id'] = self.provider_id
+        data['user'] = self.user
+        data['login_session'] = self.login_session
+        data['logoutstate'] = self.logoutstate
+        data['relaystate'] = self.relaystate
+        data['logout_request'] = self.logout_request
+        data['request_id'] = self.request_id
+
+        return {self.uuidval: data}
+
 
-class SAMLSessionsContainer(Log):
+class SAMLSessionFactory(Log):
     """
-    Store SAML session information.
+    Access SAML session information.
 
-    The sessions are stored in two dicts which represent the state that
-    the session is in.
+    The sessions are stored via the data backend.
 
     When a user logs in, add_session() is called and a new SAMLSession
-    created and added to the sessions dict, keyed on provider_id.
+    created and added to the table.
 
     When a user logs out, the next login session is found and moved to
     sessions_logging_out. remove_session() will look in both when trying
     to remove a session.
-    """
 
+    Returns a SAMLSession object representing the new session.
+    """
     def __init__(self):
-        self.sessions = dict()
-        self.sessions_logging_out = dict()
+        self._ss = SAML2SessionStore()
+        self.user = None
 
-    def add_session(self, session_id, provider_id, session):
+    def _data_to_samlsession(self, uuidval, data):
+        """
+        Convert data from the data backend to a SAMLSession object.
+        """
+        return SAMLSession(uuidval,
+                           data.get('session_id'),
+                           data.get('provider_id'),
+                           data.get('user'),
+                           data.get('login_session'),
+                           data.get('logoutstate'),
+                           data.get('relaystate'),
+                           data.get('logout_request'),
+                           data.get('request_id'))
+
+    def add_session(self, session_id, provider_id, user, login_session,
+                    request_id=None):
         """
-        Add a new session to the logged-in bucket.
+        Add a new login session to the table.
+        """
+        self.user = user
 
-        Drop any existing sessions that might exist for this
-        provider. We have no control over the SP's so if it sends
-        us another login, accept it.
+        data = {'session_id': session_id,
+                'provider_id': provider_id,
+                'user': user,
+                'login_session': login_session,
+                'logoutstate': LOGGED_IN}
+        if request_id:
+            data['request_id'] = request_id
 
-        If an existing session exists drop it but keep a copy of
-        its session index. When we logout we send ALL session indexes
-        we've received to ensure that they are all logged out.
-        """
-        samlsession = SAMLSession(session_id, provider_id, session)
+        uuidval = self._ss.new_session(data)
 
-        old_session = self.find_session_by_provider(provider_id)
-        if old_session is not None:
-            samlsession.session_indexes.extend(old_session.session_indexes)
-            self.debug("old session: %s" % old_session.session_indexes)
-            self.debug("new session: %s" % samlsession.session_indexes)
-            self.remove_session_by_provider(provider_id)
-        self.sessions[provider_id] = samlsession
-        self.dump()
+        return SAMLSession(uuidval, session_id, provider_id, user,
+                           login_session, LOGGED_IN,
+                           request_id=request_id)
 
-    def remove_session_by_provider(self, provider_id):
+    def get_session_by_id(self, session_id):
         """
-        Remove all instances of this provider from either session
-        pool.
+        Retrieve a session by session ID
         """
-        if provider_id in self.sessions:
-            self.sessions.pop(provider_id)
-        if provider_id in self.sessions_logging_out:
-            self.sessions_logging_out.pop(provider_id)
+        uuidval, data = self._ss.get_session(session_id=session_id)
+        if uuidval is None:
+            return None
 
-    def find_session_by_provider(self, provider_id):
+        return self._data_to_samlsession(uuidval, data)
+
+    def get_session_id_by_provider_id(self, provider_id):
+        """
+        Return a tuple of logged-in session IDs by provider_id
         """
-        Return a given session from either pool.
+        candidates = self._ss.get_user_sessions(self.user)
+
+        session_ids = []
+        for c in candidates:
+            key = c.keys()[0]
+            if c[key].get('provider_id') == provider_id:
+                samlsession = self._data_to_samlsession(key, c[key])
+                session_ids.append(samlsession.session_id.encode('utf-8'))
 
-        Return None if no session for a provider is found.
+        return tuple(session_ids)
+
+    def get_session_by_request_id(self, request_id):
         """
-        if provider_id in self.sessions:
-            return self.sessions[provider_id]
-        if provider_id in self.sessions_logging_out:
-            return self.sessions_logging_out[provider_id]
-        return None
+        Retrieve a session by logout request ID
+        """
+        uuidval, data = self._ss.get_session(request_id=request_id)
+        if uuidval is None:
+            return None
+
+        return self._data_to_samlsession(uuidval, data)
 
-    def start_logout(self, session):
+    def remove_session(self, samlsession):
+        return self._ss.remove_session(samlsession.uuidval)
+
+    def remove_session_by_session_id(self, session_id):
+        session = self.get_session_by_id(session_id)
+        return self._ss.remove_session(session.uuidval)
+
+    def start_logout(self, samlsession, relaystate=None, initial=True):
         """
         Move a session into the logging_out state
 
+        samlsession: the SAMLSession object to start logging out
+        relaystate: URL to redirect user to when logout is completed
+        initial: boolean to indicate if this session started logout.
+                 Only the initial session's relaystate is used.
+
         No return value
         """
-        if session.provider_id in self.sessions_logging_out:
-            return
-
-        session = self.sessions.pop(session.provider_id)
-
-        self.sessions_logging_out[session.provider_id] = session
+        if initial:
+            samlsession.logoutstate = INIT_LOGOUT
+        else:
+            samlsession.logoutstate = LOGGING_OUT
+        if relaystate:
+            samlsession.relaystate = relaystate
+        datum = samlsession.convert()
+        self._ss.update_session(datum)
 
-    def get_next_logout(self, remove=True):
+    def get_next_logout(self, peek=False):
         """
         Get the next session in the logged-in state and move
         it to the logging_out state.  Return the session that is
         found.
 
-        :param remove: for IdP-initiated logout we can't remove the
-                       session otherwise when the request comes back
-                       in the user won't be seen as being logged-on.
+        :param peek: for IdP-initiated logout we can't remove the
+                     session otherwise when the request comes back
+                     in the user won't be seen as being logged-on.
 
-        Return None if no more sessions in login state.
+        Return None if no more sessions in LOGGED_IN state.
         """
-        try:
-            provider_id = self.sessions.keys()[0]
-        except IndexError:
-            return None
-
-        if remove:
-            session = self.sessions.pop(provider_id)
-        else:
-            session = self.sessions.itervalues().next()
-
-        if provider_id in self.sessions_logging_out:
-            self.sessions_logging_out.pop(provider_id)
-
-        self.sessions_logging_out[provider_id] = session
-
-        return session
+        candidates = self._ss.get_user_sessions(self.user)
+
+        for c in candidates:
+            key = c.keys()[0]
+            if int(c[key].get('logoutstate', 0)) == LOGGED_IN:
+                samlsession = self._data_to_samlsession(key, c[key])
+                self.start_logout(samlsession, initial=False)
+                return samlsession
+        return None
 
-    def get_last_session(self):
-        if self.count() != 1:
-            raise ValueError('Not exactly one session left')
+    def get_initial_logout(self):
+        """
+        Get the initial logout request.
 
-        try:
-            provider_id = self.sessions_logging_out.keys()[0]
-        except IndexError:
-            return None
+        Return None if no sessions in INIT_LOGOUT state.
+        """
+        candidates = self._ss.get_user_sessions(self.user)
+
+        # FIXME: what does it mean if there are multiple in init? We
+        #        just return the first one for now. How do we know
+        #        it's the "right" one if multiple logouts are started
+        #        at the same time from different SPs?
+        for c in candidates:
+            key = c.keys()[0]
+            if int(c[key].get('logoutstate', 0)) == INIT_LOGOUT:
+                samlsession = self._data_to_samlsession(key, c[key])
+                return samlsession
+        return None
 
-        return self.sessions_logging_out.pop(provider_id)
+    def wipe_data(self):
+        self._ss.wipe_data()
 
-    def count(self):
+    def dump(self):
         """
-        Return number of active login/logging out sessions.
+        Dump all sessions to debug log
         """
-        return len(self.sessions) + len(self.sessions_logging_out)
+        candidates = self._ss.get_user_sessions(self.user)
 
-    def dump(self):
         count = 0
-        for s in self.sessions:
-            self.debug('Login Session: %d' % count)
-            session = self.sessions[s]
-            session.dump()
-            self.debug('-----------------------')
-            count += 1
-        for s in self.sessions_logging_out:
-            self.debug('Logging-out Session: %d' % count)
-            session = self.sessions_logging_out[s]
-            session.dump()
-            self.debug('-----------------------')
+        for c in candidates:
+            key = c.keys()[0]
+            samlsession = self._data_to_samlsession(key, c[key])
+            self.debug('session %d: %s' % (count, samlsession.convert()))
             count += 1
 
 if __name__ == '__main__':
+    import cherrypy
+
     provider1 = "http://127.0.0.10/saml2"
     provider2 = "http://127.0.0.11/saml2"
 
-    saml_sessions = SAMLSessionsContainer()
+    # temporary database location for testing
+    cherrypy.config['saml2.sessions.db'] = '/tmp/saml2sessions.sqlite'
+
+    factory = SAMLSessionFactory()
+    factory.wipe_data()
 
-    try:
-        testsession = saml_sessions.get_last_session()
-    except ValueError:
-        assert(saml_sessions.count() == 0)
+    sess1 = factory.add_session('_123456', provider1, "admin", "<Login/>")
+    sess2 = factory.add_session('_789012', provider2, "testuser", "<Login/>")
 
-    saml_sessions.add_session("_123456",
-                              provider1,
-                              "sessiondata")
+    # Test finding sessions by provider
+    ids = factory.get_session_id_by_provider_id(provider2)
+    assert(len(ids) == 1)
 
-    saml_sessions.add_session("_789012",
-                              provider2,
-                              "sessiondata")
+    sess3 = factory.add_session('_345678', provider2, "testuser", "<Login/>")
+    ids = factory.get_session_id_by_provider_id(provider2)
+    assert(len(ids) == 2)
 
-    try:
-        testsession = saml_sessions.get_last_session()
-    except ValueError:
-        assert(saml_sessions.count() == 2)
+    # Test finding sessions by session ID
+    test1 = factory.get_session_by_id('_123456')
+    assert(test1.user == 'admin')
+    assert(test1.provider_id == provider1)
 
-    testsession = saml_sessions.find_session_by_provider(provider1)
-    assert(testsession.provider_id == provider1)
-    assert(testsession.session_id == "_123456")
-    assert(testsession.session == "sessiondata")
+    # Log out and remove the first session
+    test1.set_logoutstate('http://www.example.com/idp')
+    factory.start_logout(test1, initial=True)
+    test1 = factory.get_session_by_id('_123456')
+    assert(test1.relaystate == 'http://www.example.com/idp')
 
-    # Test get_next_logout() by fetching both values out. Do some
-    # basic accounting to ensure we get both values eventually.
-    providers = [provider1, provider2]
-    testsession = saml_sessions.get_next_logout()
-    providers.remove(testsession.provider_id)  # should be one of them
+    factory.remove_session_by_session_id('_123456')
 
-    testsession = saml_sessions.get_next_logout()
-    assert(testsession.provider_id == providers[0])  # should be the other
+    # Make sure it is gone from the db
+    test1 = factory.get_session_by_id('_123456')
+    assert(test1 is None)
 
-    saml_sessions.start_logout(testsession)
-    saml_sessions.remove_session_by_provider(provider2)
+    test2 = factory.get_session_by_id('_789012')
+    factory.start_logout(test2, initial=True)
 
-    assert(saml_sessions.count() == 1)
+    test3 = factory.get_next_logout()
+    assert(test3.session_id == '_345678')
 
-    testsession = saml_sessions.get_last_session()
-    assert(testsession.provider_id == provider1)
+    test4 = factory.get_initial_logout()
+    assert(test4.session_id == '_789012')
 
-    saml_sessions.remove_session_by_provider(provider1)
-    assert(saml_sessions.count() == 0)
+    # Even though we've started logout, make sure we can still find
+    # all sessions for a provider.
+    ids = factory.get_session_id_by_provider_id(provider2)
+    assert(len(ids) == 2)