1 # Copyright (C) 2015 Rob Crittenden <rcritten@redhat.com>
3 # see file 'COPYING' for use and warranty information
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18 from ipsilon.util.log import Log
21 class SAMLSession(Log):
23 A SAML login session used to track login/logout state.
25 session_id - ID of the login session
26 provider_id - ID of the SP
27 session - the Login session object
28 logoutstate - dict containing logout state info
29 session_indexes - the IDs of any login session we've seen
32 When a new session is seen for the same user any existing session
33 is thrown away. We keep the original session_id though and send
34 all that we've seen to the SP when performing a logout to ensure
35 that all sessions get logged out.
37 logout state is a dictionary containing (potentially)
40 relaystate - The relaystate from the Logout Request or Response
41 id - The Logout request id that initiated the logout
42 request - Dump of the initial logout request
44 def __init__(self, session_id, provider_id, session,
47 self.session_id = session_id
48 self.provider_id = provider_id
49 self.session = session
50 self.logoutstate = logoutstate
51 self.session_indexes = [session_id]
53 def set_logoutstate(self, relaystate, request_id, request=None):
54 self.logoutstate = dict(relaystate=relaystate,
59 self.debug('session_id %s' % self.session_id)
60 self.debug('session_index %s' % self.session_indexes)
61 self.debug('provider_id %s' % self.provider_id)
62 self.debug('session %s' % self.session)
63 self.debug('logoutstate %s' % self.logoutstate)
66 class SAMLSessionsContainer(Log):
68 Store SAML session information.
70 The sessions are stored in two dicts which represent the state that
73 When a user logs in, add_session() is called and a new SAMLSession
74 created and added to the sessions dict, keyed on provider_id.
76 When a user logs out, the next login session is found and moved to
77 sessions_logging_out. remove_session() will look in both when trying
82 self.sessions = dict()
83 self.sessions_logging_out = dict()
85 def add_session(self, session_id, provider_id, session):
87 Add a new session to the logged-in bucket.
89 Drop any existing sessions that might exist for this
90 provider. We have no control over the SP's so if it sends
91 us another login, accept it.
93 If an existing session exists drop it but keep a copy of
94 its session index. When we logout we send ALL session indexes
95 we've received to ensure that they are all logged out.
97 samlsession = SAMLSession(session_id, provider_id, session)
99 old_session = self.find_session_by_provider(provider_id)
100 if old_session is not None:
101 samlsession.session_indexes.extend(old_session.session_indexes)
102 self.debug("old session: %s" % old_session.session_indexes)
103 self.debug("new session: %s" % samlsession.session_indexes)
104 self.remove_session_by_provider(provider_id)
105 self.sessions[provider_id] = samlsession
108 def remove_session_by_provider(self, provider_id):
110 Remove all instances of this provider from either session
113 if provider_id in self.sessions:
114 self.sessions.pop(provider_id)
115 if provider_id in self.sessions_logging_out:
116 self.sessions_logging_out.pop(provider_id)
118 def find_session_by_provider(self, provider_id):
120 Return a given session from either pool.
122 Return None if no session for a provider is found.
124 if provider_id in self.sessions:
125 return self.sessions[provider_id]
126 if provider_id in self.sessions_logging_out:
127 return self.sessions_logging_out[provider_id]
130 def start_logout(self, session):
132 Move a session into the logging_out state
136 if session.provider_id in self.sessions_logging_out:
139 session = self.sessions.pop(session.provider_id)
141 self.sessions_logging_out[session.provider_id] = session
143 def get_next_logout(self, remove=True):
145 Get the next session in the logged-in state and move
146 it to the logging_out state. Return the session that is
149 :param remove: for IdP-initiated logout we can't remove the
150 session otherwise when the request comes back
151 in the user won't be seen as being logged-on.
153 Return None if no more sessions in login state.
156 provider_id = self.sessions.keys()[0]
161 session = self.sessions.pop(provider_id)
163 session = self.sessions.itervalues().next()
165 if provider_id in self.sessions_logging_out:
166 self.sessions_logging_out.pop(provider_id)
168 self.sessions_logging_out[provider_id] = session
172 def get_last_session(self):
173 if self.count() != 1:
174 raise ValueError('Not exactly one session left')
177 provider_id = self.sessions_logging_out.keys()[0]
181 return self.sessions_logging_out.pop(provider_id)
185 Return number of active login/logging out sessions.
187 return len(self.sessions) + len(self.sessions_logging_out)
191 for s in self.sessions:
192 self.debug('Login Session: %d' % count)
193 session = self.sessions[s]
195 self.debug('-----------------------')
197 for s in self.sessions_logging_out:
198 self.debug('Logging-out Session: %d' % count)
199 session = self.sessions_logging_out[s]
201 self.debug('-----------------------')
204 if __name__ == '__main__':
205 provider1 = "http://127.0.0.10/saml2"
206 provider2 = "http://127.0.0.11/saml2"
208 saml_sessions = SAMLSessionsContainer()
211 testsession = saml_sessions.get_last_session()
213 assert(saml_sessions.count() == 0)
215 saml_sessions.add_session("_123456",
219 saml_sessions.add_session("_789012",
224 testsession = saml_sessions.get_last_session()
226 assert(saml_sessions.count() == 2)
228 testsession = saml_sessions.find_session_by_provider(provider1)
229 assert(testsession.provider_id == provider1)
230 assert(testsession.session_id == "_123456")
231 assert(testsession.session == "sessiondata")
233 # Test get_next_logout() by fetching both values out. Do some
234 # basic accounting to ensure we get both values eventually.
235 providers = [provider1, provider2]
236 testsession = saml_sessions.get_next_logout()
237 providers.remove(testsession.provider_id) # should be one of them
239 testsession = saml_sessions.get_next_logout()
240 assert(testsession.provider_id == providers[0]) # should be the other
242 saml_sessions.start_logout(testsession)
243 saml_sessions.remove_session_by_provider(provider2)
245 assert(saml_sessions.count() == 1)
247 testsession = saml_sessions.get_last_session()
248 assert(testsession.provider_id == provider1)
250 saml_sessions.remove_session_by_provider(provider1)
251 assert(saml_sessions.count() == 0)