1 # Copyright (C) 2015 Rob Crittenden <rcritten@redhat.com>
3 # see file 'COPYING' for use and warranty information
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18 from ipsilon.util.log import Log
21 class SAMLSession(Log):
23 A SAML login session used to track login/logout state.
25 session_id - ID of the login session
26 provider_id - ID of the SP
27 session - the Login session object
28 logoutstate - dict containing logout state info
29 session_indexes - the IDs of any login session we've seen
32 When a new session is seen for the same user any existing session
33 is thrown away. We keep the original session_id though and send
34 all that we've seen to the SP when performing a logout to ensure
35 that all sessions get logged out.
37 logout state is a dictionary containing (potentially)
40 relaystate - The relaystate from the Logout Request or Response
41 id - The Logout request id that initiated the logout
42 request - Dump of the initial logout request
44 def __init__(self, session_id, provider_id, session,
47 self.session_id = session_id
48 self.provider_id = provider_id
49 self.session = session
50 self.logoutstate = logoutstate
51 self.session_indexes = [session_id]
53 def set_logoutstate(self, relaystate, request_id, request=None):
54 self.logoutstate = dict(relaystate=relaystate,
59 self.debug('session_id %s' % self.session_id)
60 self.debug('session_index %s' % self.session_indexes)
61 self.debug('provider_id %s' % self.provider_id)
62 self.debug('session %s' % self.session)
63 self.debug('logoutstate %s' % self.logoutstate)
66 class SAMLSessionsContainer(Log):
68 Store SAML session information.
70 The sessions are stored in two dicts which represent the state that
73 When a user logs in, add_session() is called and a new SAMLSession
74 created and added to the sessions dict, keyed on provider_id.
76 When a user logs out, the next login session is found and moved to
77 sessions_logging_out. remove_session() will look in both when trying
82 self.sessions = dict()
83 self.sessions_logging_out = dict()
85 def add_session(self, session_id, provider_id, session):
87 Add a new session to the logged-in bucket.
89 Drop any existing sessions that might exist for this
90 provider. We have no control over the SP's so if it sends
91 us another login, accept it.
93 If an existing session exists drop it but keep a copy of
94 its session index. When we logout we send ALL session indexes
95 we've received to ensure that they are all logged out.
97 samlsession = SAMLSession(session_id, provider_id, session)
99 old_session = self.find_session_by_provider(provider_id)
100 if old_session is not None:
101 samlsession.session_indexes.extend(old_session.session_indexes)
102 self.debug("old session: %s" % old_session.session_indexes)
103 self.debug("new session: %s" % samlsession.session_indexes)
104 self.remove_session_by_provider(provider_id)
105 self.sessions[provider_id] = samlsession
108 def remove_session_by_provider(self, provider_id):
110 Remove all instances of this provider from either session
113 if provider_id in self.sessions:
114 self.sessions.pop(provider_id)
115 if provider_id in self.sessions_logging_out:
116 self.sessions_logging_out.pop(provider_id)
118 def find_session_by_provider(self, provider_id):
120 Return a given session from either pool.
122 Return None if no session for a provider is found.
124 if provider_id in self.sessions:
125 return self.sessions[provider_id]
126 if provider_id in self.sessions_logging_out:
127 return self.sessions_logging_out[provider_id]
130 def start_logout(self, session):
132 Move a session into the logging_out state
136 if session.provider_id in self.sessions_logging_out:
139 session = self.sessions.pop(session.provider_id)
141 self.sessions_logging_out[session.provider_id] = session
143 def get_next_logout(self):
145 Get the next session in the logged-in state and move
146 it to the logging_out state. Return the session that is
149 Return None if no more sessions in login state.
152 provider_id = self.sessions.keys()[0]
156 session = self.sessions.pop(provider_id)
158 if provider_id in self.sessions_logging_out:
159 self.sessions_logging_out.pop(provider_id)
161 self.sessions_logging_out[provider_id] = session
165 def get_last_session(self):
166 if self.count() != 1:
167 raise ValueError('Not exactly one session left')
170 provider_id = self.sessions_logging_out.keys()[0]
174 return self.sessions_logging_out.pop(provider_id)
178 Return number of active login/logging out sessions.
180 return len(self.sessions) + len(self.sessions_logging_out)
184 for s in self.sessions:
185 self.debug('Login Session: %d' % count)
186 session = self.sessions[s]
188 self.debug('-----------------------')
190 for s in self.sessions_logging_out:
191 self.debug('Logging-out Session: %d' % count)
192 session = self.sessions_logging_out[s]
194 self.debug('-----------------------')
197 if __name__ == '__main__':
198 provider1 = "http://127.0.0.10/saml2"
199 provider2 = "http://127.0.0.11/saml2"
201 saml_sessions = SAMLSessionsContainer()
204 testsession = saml_sessions.get_last_session()
206 assert(saml_sessions.count() == 0)
208 saml_sessions.add_session("_123456",
212 saml_sessions.add_session("_789012",
217 testsession = saml_sessions.get_last_session()
219 assert(saml_sessions.count() == 2)
221 testsession = saml_sessions.find_session_by_provider(provider1)
222 assert(testsession.provider_id == provider1)
223 assert(testsession.session_id == "_123456")
224 assert(testsession.session == "sessiondata")
226 # Test get_next_logout() by fetching both values out. Do some
227 # basic accounting to ensure we get both values eventually.
228 providers = [provider1, provider2]
229 testsession = saml_sessions.get_next_logout()
230 providers.remove(testsession.provider_id) # should be one of them
232 testsession = saml_sessions.get_next_logout()
233 assert(testsession.provider_id == providers[0]) # should be the other
235 saml_sessions.start_logout(testsession)
236 saml_sessions.remove_session_by_provider(provider2)
238 assert(saml_sessions.count() == 1)
240 testsession = saml_sessions.get_last_session()
241 assert(testsession.provider_id == provider1)
243 saml_sessions.remove_session_by_provider(provider1)
244 assert(saml_sessions.count() == 0)