Add SAML-specific session data for tracking login/logout sessions
authorRob Crittenden <rcritten@redhat.com>
Fri, 30 Jan 2015 15:03:03 +0000 (10:03 -0500)
committerSimo Sorce <simo@redhat.com>
Fri, 13 Feb 2015 22:50:11 +0000 (17:50 -0500)
https://fedorahosted.org/ipsilon/ticket/24

Signed-off-by: Rob Crittenden <rcritten@redhat.com>
Reviewed-by: Simo Sorce <simo@redhat.com>
ipsilon/providers/saml2/sessions.py [new file with mode: 0644]

diff --git a/ipsilon/providers/saml2/sessions.py b/ipsilon/providers/saml2/sessions.py
new file mode 100644 (file)
index 0000000..50b9a14
--- /dev/null
@@ -0,0 +1,226 @@
+# Copyright (C) 2015  Rob Crittenden <rcritten@redhat.com>
+#
+# see file 'COPYING' for use and warranty information
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from ipsilon.util.log import Log
+
+
+class SAMLSession(Log):
+    """
+    A SAML login session used to track login/logout state.
+
+       session_id - ID of the login session
+       provider_id - ID of the SP
+       session - the Login session object
+       logoutstate - dict containing logout state info
+
+    logout state is a dictionary containing (potentially)
+    these attributes:
+
+    relaystate - The relaystate from the Logout Request or Response
+    id         - The Logout request id that initiated the logout
+    request    - Dump of the initial logout request
+    """
+    def __init__(self, session_id, provider_id, session,
+                 logoutstate=None):
+
+        self.session_id = session_id
+        self.provider_id = provider_id
+        self.session = session
+        self.logoutstate = logoutstate
+
+    def set_logoutstate(self, relaystate, request_id, request=None):
+        self.logoutstate = dict(relaystate=relaystate,
+                                id=request_id,
+                                request=request)
+
+    def dump(self):
+        self.debug('session_id %s' % self.session_id)
+        self.debug('provider_id %s' % self.provider_id)
+        self.debug('session %s' % self.session)
+        self.debug('logoutstate %s' % self.logoutstate)
+
+
+class SAMLSessionsContainer(Log):
+    """
+    Store SAML session information.
+
+    The sessions are stored in two dicts which represent the state that
+    the session is in.
+
+    When a user logs in, add_session() is called and a new SAMLSession
+    created and added to the sessions dict, keyed on provider_id.
+
+    When a user logs out, the next login session is found and moved to
+    sessions_logging_out. remove_session() will look in both when trying
+    to remove a session.
+    """
+
+    def __init__(self):
+        self.sessions = dict()
+        self.sessions_logging_out = dict()
+
+    def add_session(self, session_id, provider_id, session):
+        """
+        Add a new session to the logged-in bucket.
+
+        Drop any existing sessions that might exist for this
+        provider. We have no control over the SP's so if it sends
+        us another login, accept it.
+        """
+        samlsession = SAMLSession(session_id, provider_id, session)
+
+        self.remove_session_by_provider(provider_id)
+        self.sessions[provider_id] = samlsession
+        self.dump()
+
+    def remove_session_by_provider(self, provider_id):
+        """
+        Remove all instances of this provider from either session
+        pool.
+        """
+        if provider_id in self.sessions:
+            self.sessions.pop(provider_id)
+        if provider_id in self.sessions_logging_out:
+            self.sessions_logging_out.pop(provider_id)
+
+    def find_session_by_provider(self, provider_id):
+        """
+        Return a given session from either pool.
+
+        Return None if no session for a provider is found.
+        """
+        if provider_id in self.sessions:
+            return self.sessions[provider_id]
+        if provider_id in self.sessions_logging_out:
+            return self.sessions_logging_out[provider_id]
+        return None
+
+    def start_logout(self, session):
+        """
+        Move a session into the logging_out state
+
+        No return value
+        """
+        if session.provider_id in self.sessions_logging_out:
+            return
+
+        session = self.sessions.pop(session.provider_id)
+
+        self.sessions_logging_out[session.provider_id] = session
+
+    def get_next_logout(self):
+        """
+        Get the next session in the logged-in state and move
+        it to the logging_out state.  Return the session that is
+        found.
+
+        Return None if no more sessions in login state.
+        """
+        try:
+            provider_id = self.sessions.keys()[0]
+        except IndexError:
+            return None
+
+        session = self.sessions.pop(provider_id)
+
+        if provider_id in self.sessions_logging_out:
+            self.sessions_logging_out.pop(provider_id)
+
+        self.sessions_logging_out[provider_id] = session
+
+        return session
+
+    def get_last_session(self):
+        if self.count() != 1:
+            raise ValueError('Not exactly one session left')
+
+        try:
+            provider_id = self.sessions_logging_out.keys()[0]
+        except IndexError:
+            return None
+
+        return self.sessions_logging_out.pop(provider_id)
+
+    def count(self):
+        """
+        Return number of active login/logging out sessions.
+        """
+        return len(self.sessions) + len(self.sessions_logging_out)
+
+    def dump(self):
+        count = 0
+        for s in self.sessions:
+            self.debug('Login Session: %d' % count)
+            session = self.sessions[s]
+            session.dump()
+            self.debug('-----------------------')
+            count += 1
+        for s in self.sessions_logging_out:
+            self.debug('Logging-out Session: %d' % count)
+            session = self.sessions_logging_out[s]
+            session.dump()
+            self.debug('-----------------------')
+            count += 1
+
+if __name__ == '__main__':
+    provider1 = "http://127.0.0.10/saml2"
+    provider2 = "http://127.0.0.11/saml2"
+
+    saml_sessions = SAMLSessionsContainer()
+
+    try:
+        testsession = saml_sessions.get_last_session()
+    except ValueError:
+        assert(saml_sessions.count() == 0)
+
+    saml_sessions.add_session("_123456",
+                              provider1,
+                              "sessiondata")
+
+    saml_sessions.add_session("_789012",
+                              provider2,
+                              "sessiondata")
+
+    try:
+        testsession = saml_sessions.get_last_session()
+    except ValueError:
+        assert(saml_sessions.count() == 2)
+
+    testsession = saml_sessions.find_session_by_provider(provider1)
+    assert(testsession.provider_id == provider1)
+    assert(testsession.session_id == "_123456")
+    assert(testsession.session == "sessiondata")
+
+    # Test get_next_logout() by fetching both values out. Do some
+    # basic accounting to ensure we get both values eventually.
+    providers = [provider1, provider2]
+    testsession = saml_sessions.get_next_logout()
+    providers.remove(testsession.provider_id)  # should be one of them
+
+    testsession = saml_sessions.get_next_logout()
+    assert(testsession.provider_id == providers[0])  # should be the other
+
+    saml_sessions.start_logout(testsession)
+    saml_sessions.remove_session_by_provider(provider2)
+
+    assert(saml_sessions.count() == 1)
+
+    testsession = saml_sessions.get_last_session()
+    assert(testsession.provider_id == provider1)
+
+    saml_sessions.remove_session_by_provider(provider1)
+    assert(saml_sessions.count() == 0)