-# Copyright (C) 2015 Rob Crittenden <rcritten@redhat.com>
-#
-# see file 'COPYING' for use and warranty information
-#
-# This program is free software; you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see <http://www.gnu.org/licenses/>.
+# Copyright (C) 2015 Ipsilon project Contributors, for license see COPYING
+from cherrypy import config as cherrypy_config
from ipsilon.util.log import Log
+from ipsilon.util.data import SAML2SessionStore
+import datetime
+from lasso import (
+ SAML2_METADATA_BINDING_SOAP,
+ SAML2_METADATA_BINDING_REDIRECT,
+)
+
+LOGGED_IN = 1
+INIT_LOGOUT = 2
+LOGGING_OUT = 4
+LOGGED_OUT = 8
class SAMLSession(Log):
"""
- A SAML login session used to track login/logout state.
+ A SAML login session.
+ uuidval - Unique ID stored in the database
session_id - ID of the login session
provider_id - ID of the SP
- session - the Login session object
- logoutstate - dict containing logout state info
-
- logout state is a dictionary containing (potentially)
- these attributes:
-
- relaystate - The relaystate from the Logout Request or Response
- id - The Logout request id that initiated the logout
- request - Dump of the initial logout request
+ user - the login name of the user that owns the session
+ login_session - the Login session object
+ logoutstate - an integer constant representing where in the
+ logout process this request is
+ relaystate - where the user will be redirected when logout is
+ complete
+ request_id - the logout request ID if initiated from IdP. The
+ logout response will include an InResponseTo value
+ which matches this.
+ logout_request - the Logout request object
+ expiration_time - the time the login session expires
+ supported_logout_mechs - logout mechanisms supported by this session
"""
- def __init__(self, session_id, provider_id, session,
- logoutstate=None):
+ def __init__(self, uuidval, session_id, provider_id, user,
+ login_session, logoutstate=None, relaystate=None,
+ logout_request=None, request_id=None,
+ expiration_time=None,
+ supported_logout_mechs=None):
+ self.uuidval = uuidval
self.session_id = session_id
self.provider_id = provider_id
- self.session = session
+ self.user = user
+ self.login_session = login_session
self.logoutstate = logoutstate
+ self.relaystate = relaystate
+ self.request_id = request_id
+ self.logout_request = logout_request
+ self.expiration_time = expiration_time
+ if supported_logout_mechs is None:
+ supported_logout_mechs = []
+ self.supported_logout_mechs = supported_logout_mechs
+
+ def set_logoutstate(self, relaystate=None, request=None, request_id=None):
+ """
+ Update attributes needed to determine the state of the session for
+ logout.
- def set_logoutstate(self, relaystate, request_id, request=None):
- self.logoutstate = dict(relaystate=relaystate,
- id=request_id,
- request=request)
+ The database is not updated when these are set. It is expected that
+ this is called prior to start_logout()
+ """
+ if relaystate:
+ self.relaystate = relaystate
+ if request:
+ self.logout_request = request
+ if request_id:
+ self.request_id = request_id
def dump(self):
self.debug('session_id %s' % self.session_id)
self.debug('provider_id %s' % self.provider_id)
- self.debug('session %s' % self.session)
+ self.debug('login session %s' % self.login_session)
self.debug('logoutstate %s' % self.logoutstate)
+ self.debug('logout mech %s' % self.supported_logout_mechs)
+
+ def convert(self):
+ """
+ Convert this object into something suitable to store in the
+ data backend.
+ """
+ data = dict()
+ data['session_id'] = self.session_id
+ data['provider_id'] = self.provider_id
+ data['user'] = self.user
+ data['login_session'] = self.login_session
+ data['logoutstate'] = self.logoutstate
+ data['relaystate'] = self.relaystate
+ data['logout_request'] = self.logout_request
+ data['request_id'] = self.request_id
+ data['expiration_time'] = self.expiration_time
+
+ return {self.uuidval: data}
-class SAMLSessionsContainer(Log):
+class SAMLSessionFactory(Log):
"""
- Store SAML session information.
+ Access SAML session information.
- The sessions are stored in two dicts which represent the state that
- the session is in.
+ The sessions are stored via the data backend.
When a user logs in, add_session() is called and a new SAMLSession
- created and added to the sessions dict, keyed on provider_id.
+ created and added to the table.
When a user logs out, the next login session is found and moved to
sessions_logging_out. remove_session() will look in both when trying
to remove a session.
- """
- def __init__(self):
- self.sessions = dict()
- self.sessions_logging_out = dict()
+ Returns a SAMLSession object representing the new session.
+ """
+ def __init__(self, database_url):
+ self._ss = SAML2SessionStore(database_url=database_url)
+ self.user = None
- def add_session(self, session_id, provider_id, session):
+ def _data_to_samlsession(self, uuidval, data):
"""
- Add a new session to the logged-in bucket.
-
- Drop any existing sessions that might exist for this
- provider. We have no control over the SP's so if it sends
- us another login, accept it.
+ Convert data from the data backend to a SAMLSession object.
"""
- samlsession = SAMLSession(session_id, provider_id, session)
-
- self.remove_session_by_provider(provider_id)
- self.sessions[provider_id] = samlsession
- self.dump()
-
- def remove_session_by_provider(self, provider_id):
+ return SAMLSession(uuidval,
+ data.get('session_id'),
+ data.get('provider_id'),
+ data.get('user'),
+ data.get('login_session'),
+ data.get('logoutstate'),
+ data.get('relaystate'),
+ data.get('logout_request'),
+ data.get('request_id'),
+ data.get('expiration_time'),
+ data.get('supported_logout_mechs'))
+
+ def add_session(self, session_id, provider_id, user, login_session,
+ request_id, supported_logout_mechs):
"""
- Remove all instances of this provider from either session
- pool.
+ Add a new login session to the table.
+
+ :param session_id: The login session ID
+ :param provider_id: The URL of the SP
+ :param user: The NameID username
+ :param login_session: The lasso Login session
+ :param request_id: The request ID of the Logout
+ :param supported_logout_mechs: A list of logout protocols supported
"""
- if provider_id in self.sessions:
- self.sessions.pop(provider_id)
- if provider_id in self.sessions_logging_out:
- self.sessions_logging_out.pop(provider_id)
+ self.user = user
- def find_session_by_provider(self, provider_id):
- """
- Return a given session from either pool.
+ timeout = cherrypy_config['tools.sessions.timeout']
+ t = datetime.timedelta(seconds=timeout * 60)
+ expiration_time = datetime.datetime.now() + t
- Return None if no session for a provider is found.
- """
- if provider_id in self.sessions:
- return self.sessions[provider_id]
- if provider_id in self.sessions_logging_out:
- return self.sessions_logging_out[provider_id]
- return None
+ data = {'session_id': session_id,
+ 'provider_id': provider_id,
+ 'user': user,
+ 'login_session': login_session,
+ 'logoutstate': LOGGED_IN,
+ 'expiration_time': expiration_time,
+ 'request_id': request_id,
+ 'supported_logout_mechs': supported_logout_mechs}
+
+ uuidval = self._ss.new_session(data)
+
+ return SAMLSession(uuidval, session_id, provider_id, user,
+ login_session, LOGGED_IN,
+ request_id=request_id,
+ expiration_time=expiration_time)
- def start_logout(self, session):
+ def get_session_by_id(self, session_id):
"""
- Move a session into the logging_out state
+ Retrieve a session by session ID
+ """
+ uuidval, data = self._ss.get_session(session_id=session_id)
+ if uuidval is None:
+ return None
- No return value
+ return self._data_to_samlsession(uuidval, data)
+
+ def get_session_id_by_provider_id(self, provider_id):
"""
- if session.provider_id in self.sessions_logging_out:
- return
+ Return a tuple of logged-in session IDs by provider_id
+ """
+ candidates = self._ss.get_user_sessions(self.user)
- session = self.sessions.pop(session.provider_id)
+ session_ids = []
+ for c in candidates:
+ key = c.keys()[0]
+ if c[key].get('provider_id') == provider_id:
+ samlsession = self._data_to_samlsession(key, c[key])
+ session_ids.append(samlsession.session_id.encode('utf-8'))
- self.sessions_logging_out[session.provider_id] = session
+ return tuple(session_ids)
- def get_next_logout(self):
+ def get_session_by_request_id(self, request_id):
"""
- Get the next session in the logged-in state and move
- it to the logging_out state. Return the session that is
- found.
-
- Return None if no more sessions in login state.
+ Retrieve a session by logout request ID
"""
- try:
- provider_id = self.sessions.keys()[0]
- except IndexError:
+ uuidval, data = self._ss.get_session(request_id=request_id)
+ if uuidval is None:
return None
- session = self.sessions.pop(provider_id)
-
- if provider_id in self.sessions_logging_out:
- self.sessions_logging_out.pop(provider_id)
+ return self._data_to_samlsession(uuidval, data)
- self.sessions_logging_out[provider_id] = session
+ def remove_session(self, samlsession):
+ return self._ss.remove_session(samlsession.uuidval)
- return session
+ def remove_session_by_session_id(self, session_id):
+ session = self.get_session_by_id(session_id)
+ return self._ss.remove_session(session.uuidval)
- def get_last_session(self):
- if self.count() != 1:
- raise ValueError('Not exactly one session left')
+ def start_logout(self, samlsession, relaystate=None, initial=True):
+ """
+ Move a session into the logging_out state
- try:
- provider_id = self.sessions_logging_out.keys()[0]
- except IndexError:
- return None
+ samlsession: the SAMLSession object to start logging out
+ relaystate: URL to redirect user to when logout is completed
+ initial: boolean to indicate if this session started logout.
+ Only the initial session's relaystate is used.
- return self.sessions_logging_out.pop(provider_id)
+ No return value
+ """
+ if initial:
+ samlsession.logoutstate = INIT_LOGOUT
+ else:
+ samlsession.logoutstate = LOGGING_OUT
+ if relaystate:
+ samlsession.relaystate = relaystate
+ datum = samlsession.convert()
+ self._ss.update_session(datum)
+
+ def get_next_logout(self, peek=False,
+ logout_mechs=None):
+ """
+ Get the next session in the logged-in state and move
+ it to the logging_out state. Return the session that is
+ found.
- def count(self):
+ :param peek: for IdP-initiated logout we can't remove the
+ session otherwise when the request comes back
+ in the user won't be seen as being logged-on.
+ :param logout_mechs: An ordered list of logout mechanisms
+ you're looking for. For each mechanism in order
+ loop through all sessions. If If no sessions of
+ this method are available then try the next mechanism
+ until exhausted. In that case None is returned.
+
+ Returns a tuple of (mechanism, session) or
+ (None, None) if no more sessions in LOGGED_IN state.
+ """
+ candidates = self._ss.get_user_sessions(self.user)
+ if logout_mechs is None:
+ logout_mechs = [SAML2_METADATA_BINDING_REDIRECT, ]
+
+ for mech in logout_mechs:
+ for c in candidates:
+ key = c.keys()[0]
+ if ((int(c[key].get('logoutstate', 0)) == LOGGED_IN) and
+ (mech in c[key].get('supported_logout_mechs'))):
+ samlsession = self._data_to_samlsession(key, c[key])
+ self.start_logout(samlsession, initial=False)
+ return (mech, samlsession)
+ return (None, None)
+
+ def get_initial_logout(self):
"""
- Return number of active login/logging out sessions.
+ Get the initial logout request.
+
+ Raises ValueError if no sessions in INIT_LOGOUT state.
"""
- return len(self.sessions) + len(self.sessions_logging_out)
+ candidates = self._ss.get_user_sessions(self.user)
+
+ # FIXME: what does it mean if there are multiple in init? We
+ # just return the first one for now. How do we know
+ # it's the "right" one if multiple logouts are started
+ # at the same time from different SPs?
+ for c in candidates:
+ key = c.keys()[0]
+ if int(c[key].get('logoutstate', 0)) == INIT_LOGOUT:
+ samlsession = self._data_to_samlsession(key, c[key])
+ return samlsession
+ raise ValueError()
+
+ def wipe_data(self):
+ self._ss.wipe_data()
def dump(self):
+ """
+ Dump all sessions to debug log
+ """
+ candidates = self._ss.get_user_sessions(self.user)
+
count = 0
- for s in self.sessions:
- self.debug('Login Session: %d' % count)
- session = self.sessions[s]
- session.dump()
- self.debug('-----------------------')
- count += 1
- for s in self.sessions_logging_out:
- self.debug('Logging-out Session: %d' % count)
- session = self.sessions_logging_out[s]
- session.dump()
- self.debug('-----------------------')
+ for c in candidates:
+ key = c.keys()[0]
+ samlsession = self._data_to_samlsession(key, c[key])
+ self.debug('session %d: %s' % (count, samlsession.convert()))
count += 1
if __name__ == '__main__':
provider1 = "http://127.0.0.10/saml2"
provider2 = "http://127.0.0.11/saml2"
- saml_sessions = SAMLSessionsContainer()
+ # temporary values to simulate cherrypy
+ cherrypy_config['tools.sessions.timeout'] = 60
+
+ factory = SAMLSessionFactory('/tmp/saml2sessions.sqlite')
+ factory.wipe_data()
- try:
- testsession = saml_sessions.get_last_session()
- except ValueError:
- assert(saml_sessions.count() == 0)
+ sess1 = factory.add_session('_123456', provider1, "admin",
+ "<Login/>", '_1234',
+ [SAML2_METADATA_BINDING_REDIRECT])
+ sess2 = factory.add_session('_789012', provider2, "testuser",
+ "<Login/>", '_7890',
+ [SAML2_METADATA_BINDING_SOAP,
+ SAML2_METADATA_BINDING_REDIRECT])
- saml_sessions.add_session("_123456",
- provider1,
- "sessiondata")
+ # Test finding sessions by provider
+ ids = factory.get_session_id_by_provider_id(provider2)
+ assert(len(ids) == 1)
- saml_sessions.add_session("_789012",
- provider2,
- "sessiondata")
+ sess3 = factory.add_session('_345678', provider2, "testuser",
+ "<Login/>", '_3456',
+ [SAML2_METADATA_BINDING_REDIRECT])
+ ids = factory.get_session_id_by_provider_id(provider2)
+ assert(len(ids) == 2)
- try:
- testsession = saml_sessions.get_last_session()
- except ValueError:
- assert(saml_sessions.count() == 2)
+ # Test finding sessions by session ID
+ test1 = factory.get_session_by_id('_123456')
+ assert(test1.user == 'admin')
+ assert(test1.provider_id == provider1)
- testsession = saml_sessions.find_session_by_provider(provider1)
- assert(testsession.provider_id == provider1)
- assert(testsession.session_id == "_123456")
- assert(testsession.session == "sessiondata")
+ # Log out and remove the first session
+ test1.set_logoutstate('http://www.example.com/idp')
+ factory.start_logout(test1, initial=True)
+ test1 = factory.get_session_by_id('_123456')
+ assert(test1.relaystate == 'http://www.example.com/idp')
- # Test get_next_logout() by fetching both values out. Do some
- # basic accounting to ensure we get both values eventually.
- providers = [provider1, provider2]
- testsession = saml_sessions.get_next_logout()
- providers.remove(testsession.provider_id) # should be one of them
+ factory.remove_session_by_session_id('_123456')
- testsession = saml_sessions.get_next_logout()
- assert(testsession.provider_id == providers[0]) # should be the other
+ # Make sure it is gone from the db
+ test1 = factory.get_session_by_id('_123456')
+ assert(test1 is None)
- saml_sessions.start_logout(testsession)
- saml_sessions.remove_session_by_provider(provider2)
+ test2 = factory.get_session_by_id('_789012')
+ factory.start_logout(test2, initial=True)
- assert(saml_sessions.count() == 1)
+ (lmech, test3) = factory.get_next_logout()
+ assert(test3.session_id == '_345678')
- testsession = saml_sessions.get_last_session()
- assert(testsession.provider_id == provider1)
+ test4 = factory.get_initial_logout()
+ assert(test4.session_id == '_789012')
- saml_sessions.remove_session_by_provider(provider1)
- assert(saml_sessions.count() == 0)
+ # Even though we've started logout, make sure we can still find
+ # all sessions for a provider.
+ ids = factory.get_session_id_by_provider_id(provider2)
+ assert(len(ids) == 2)