When a new logout session is received, save old session ids
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / sessions.py
index 50b9a14..fb1f646 100644 (file)
@@ -26,6 +26,13 @@ class SAMLSession(Log):
        provider_id - ID of the SP
        session - the Login session object
        logoutstate - dict containing logout state info
+       session_indexes - the IDs of any login session we've seen
+                         for this user
+
+    When a new session is seen for the same user any existing session
+    is thrown away. We keep the original session_id though and send
+    all that we've seen to the SP when performing a logout to ensure
+    that all sessions get logged out.
 
     logout state is a dictionary containing (potentially)
     these attributes:
@@ -41,6 +48,7 @@ class SAMLSession(Log):
         self.provider_id = provider_id
         self.session = session
         self.logoutstate = logoutstate
+        self.session_indexes = [session_id]
 
     def set_logoutstate(self, relaystate, request_id, request=None):
         self.logoutstate = dict(relaystate=relaystate,
@@ -49,6 +57,7 @@ class SAMLSession(Log):
 
     def dump(self):
         self.debug('session_id %s' % self.session_id)
+        self.debug('session_index %s' % self.session_indexes)
         self.debug('provider_id %s' % self.provider_id)
         self.debug('session %s' % self.session)
         self.debug('logoutstate %s' % self.logoutstate)
@@ -80,10 +89,19 @@ class SAMLSessionsContainer(Log):
         Drop any existing sessions that might exist for this
         provider. We have no control over the SP's so if it sends
         us another login, accept it.
+
+        If an existing session exists drop it but keep a copy of
+        its session index. When we logout we send ALL session indexes
+        we've received to ensure that they are all logged out.
         """
         samlsession = SAMLSession(session_id, provider_id, session)
 
-        self.remove_session_by_provider(provider_id)
+        old_session = self.find_session_by_provider(provider_id)
+        if old_session is not None:
+            samlsession.session_indexes.extend(old_session.session_indexes)
+            self.debug("old session: %s" % old_session.session_indexes)
+            self.debug("new session: %s" % samlsession.session_indexes)
+            self.remove_session_by_provider(provider_id)
         self.sessions[provider_id] = samlsession
         self.dump()