pam: use a pam object method instead of pam module function
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / sessions.py
index 1000a87..d3ed7e2 100644 (file)
@@ -4,6 +4,10 @@ from cherrypy import config as cherrypy_config
 from ipsilon.util.log import Log
 from ipsilon.util.data import SAML2SessionStore
 import datetime
+from lasso import (
+    SAML2_METADATA_BINDING_SOAP,
+    SAML2_METADATA_BINDING_REDIRECT,
+)
 
 LOGGED_IN = 1
 INIT_LOGOUT = 2
@@ -29,11 +33,13 @@ class SAMLSession(Log):
                     which matches this.
        logout_request - the Logout request object
        expiration_time - the time the login session expires
+       supported_logout_mechs - logout mechanisms supported by this session
     """
     def __init__(self, uuidval, session_id, provider_id, user,
                  login_session, logoutstate=None, relaystate=None,
                  logout_request=None, request_id=None,
-                 expiration_time=None):
+                 expiration_time=None,
+                 supported_logout_mechs=None):
 
         self.uuidval = uuidval
         self.session_id = session_id
@@ -45,6 +51,9 @@ class SAMLSession(Log):
         self.request_id = request_id
         self.logout_request = logout_request
         self.expiration_time = expiration_time
+        if supported_logout_mechs is None:
+            supported_logout_mechs = []
+        self.supported_logout_mechs = supported_logout_mechs
 
     def set_logoutstate(self, relaystate=None, request=None, request_id=None):
         """
@@ -66,6 +75,7 @@ class SAMLSession(Log):
         self.debug('provider_id %s' % self.provider_id)
         self.debug('login session %s' % self.login_session)
         self.debug('logoutstate %s' % self.logoutstate)
+        self.debug('logout mech %s' % self.supported_logout_mechs)
 
     def convert(self):
         """
@@ -118,12 +128,20 @@ class SAMLSessionFactory(Log):
                            data.get('relaystate'),
                            data.get('logout_request'),
                            data.get('request_id'),
-                           data.get('expiration_time'))
+                           data.get('expiration_time'),
+                           data.get('supported_logout_mechs'))
 
     def add_session(self, session_id, provider_id, user, login_session,
-                    request_id=None):
+                    request_id, supported_logout_mechs):
         """
         Add a new login session to the table.
+
+        :param session_id: The login session ID
+        :param provider_id: The URL of the SP
+        :param user: The NameID username
+        :param login_session: The lasso Login session
+        :param request_id: The request ID of the Logout
+        :param supported_logout_mechs: A list of logout protocols supported
         """
         self.user = user
 
@@ -136,9 +154,9 @@ class SAMLSessionFactory(Log):
                 'user': user,
                 'login_session': login_session,
                 'logoutstate': LOGGED_IN,
-                'expiration_time': expiration_time}
-        if request_id:
-            data['request_id'] = request_id
+                'expiration_time': expiration_time,
+                'request_id': request_id,
+                'supported_logout_mechs': supported_logout_mechs}
 
         uuidval = self._ss.new_session(data)
 
@@ -209,7 +227,8 @@ class SAMLSessionFactory(Log):
         datum = samlsession.convert()
         self._ss.update_session(datum)
 
-    def get_next_logout(self, peek=False):
+    def get_next_logout(self, peek=False,
+                        logout_mechs=None):
         """
         Get the next session in the logged-in state and move
         it to the logging_out state.  Return the session that is
@@ -218,24 +237,34 @@ class SAMLSessionFactory(Log):
         :param peek: for IdP-initiated logout we can't remove the
                      session otherwise when the request comes back
                      in the user won't be seen as being logged-on.
-
-        Return None if no more sessions in LOGGED_IN state.
+        :param logout_mechs: An ordered list of logout mechanisms
+                     you're looking for. For each mechanism in order
+                     loop through all sessions. If If no sessions of
+                     this method are available then try the next mechanism
+                     until exhausted. In that case None is returned.
+
+        Returns a tuple of (mechanism, session) or
+        (None, None) if no more sessions in LOGGED_IN state.
         """
         candidates = self._ss.get_user_sessions(self.user)
-
-        for c in candidates:
-            key = c.keys()[0]
-            if int(c[key].get('logoutstate', 0)) == LOGGED_IN:
-                samlsession = self._data_to_samlsession(key, c[key])
-                self.start_logout(samlsession, initial=False)
-                return samlsession
-        return None
+        if logout_mechs is None:
+            logout_mechs = [SAML2_METADATA_BINDING_REDIRECT, ]
+
+        for mech in logout_mechs:
+            for c in candidates:
+                key = c.keys()[0]
+                if ((int(c[key].get('logoutstate', 0)) == LOGGED_IN) and
+                        (mech in c[key].get('supported_logout_mechs'))):
+                    samlsession = self._data_to_samlsession(key, c[key])
+                    self.start_logout(samlsession, initial=False)
+                    return (mech, samlsession)
+        return (None, None)
 
     def get_initial_logout(self):
         """
         Get the initial logout request.
 
-        Return None if no sessions in INIT_LOGOUT state.
+        Raises ValueError if no sessions in INIT_LOGOUT state.
         """
         candidates = self._ss.get_user_sessions(self.user)
 
@@ -248,7 +277,7 @@ class SAMLSessionFactory(Log):
             if int(c[key].get('logoutstate', 0)) == INIT_LOGOUT:
                 samlsession = self._data_to_samlsession(key, c[key])
                 return samlsession
-        return None
+        raise ValueError()
 
     def wipe_data(self):
         self._ss.wipe_data()
@@ -276,14 +305,21 @@ if __name__ == '__main__':
     factory = SAMLSessionFactory('/tmp/saml2sessions.sqlite')
     factory.wipe_data()
 
-    sess1 = factory.add_session('_123456', provider1, "admin", "<Login/>")
-    sess2 = factory.add_session('_789012', provider2, "testuser", "<Login/>")
+    sess1 = factory.add_session('_123456', provider1, "admin",
+                                "<Login/>", '_1234',
+                                [SAML2_METADATA_BINDING_REDIRECT])
+    sess2 = factory.add_session('_789012', provider2, "testuser",
+                                "<Login/>", '_7890',
+                                [SAML2_METADATA_BINDING_SOAP,
+                                 SAML2_METADATA_BINDING_REDIRECT])
 
     # Test finding sessions by provider
     ids = factory.get_session_id_by_provider_id(provider2)
     assert(len(ids) == 1)
 
-    sess3 = factory.add_session('_345678', provider2, "testuser", "<Login/>")
+    sess3 = factory.add_session('_345678', provider2, "testuser",
+                                "<Login/>", '_3456',
+                                [SAML2_METADATA_BINDING_REDIRECT])
     ids = factory.get_session_id_by_provider_id(provider2)
     assert(len(ids) == 2)
 
@@ -307,7 +343,7 @@ if __name__ == '__main__':
     test2 = factory.get_session_by_id('_789012')
     factory.start_logout(test2, initial=True)
 
-    test3 = factory.get_next_logout()
+    (lmech, test3) = factory.get_next_logout()
     assert(test3.session_id == '_345678')
 
     test4 = factory.get_initial_logout()