Drop usage of self._debug and use self.debug instead
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / sessions.py
1 # Copyright (C) 2015  Rob Crittenden <rcritten@redhat.com>
2 #
3 # see file 'COPYING' for use and warranty information
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 from ipsilon.util.log import Log
19
20
21 class SAMLSession(Log):
22     """
23     A SAML login session used to track login/logout state.
24
25        session_id - ID of the login session
26        provider_id - ID of the SP
27        session - the Login session object
28        logoutstate - dict containing logout state info
29        session_indexes - the IDs of any login session we've seen
30                          for this user
31
32     When a new session is seen for the same user any existing session
33     is thrown away. We keep the original session_id though and send
34     all that we've seen to the SP when performing a logout to ensure
35     that all sessions get logged out.
36
37     logout state is a dictionary containing (potentially)
38     these attributes:
39
40     relaystate - The relaystate from the Logout Request or Response
41     id         - The Logout request id that initiated the logout
42     request    - Dump of the initial logout request
43     """
44     def __init__(self, session_id, provider_id, session,
45                  logoutstate=None):
46
47         self.session_id = session_id
48         self.provider_id = provider_id
49         self.session = session
50         self.logoutstate = logoutstate
51         self.session_indexes = [session_id]
52
53     def set_logoutstate(self, relaystate, request_id, request=None):
54         self.logoutstate = dict(relaystate=relaystate,
55                                 id=request_id,
56                                 request=request)
57
58     def dump(self):
59         self.debug('session_id %s' % self.session_id)
60         self.debug('session_index %s' % self.session_indexes)
61         self.debug('provider_id %s' % self.provider_id)
62         self.debug('session %s' % self.session)
63         self.debug('logoutstate %s' % self.logoutstate)
64
65
66 class SAMLSessionsContainer(Log):
67     """
68     Store SAML session information.
69
70     The sessions are stored in two dicts which represent the state that
71     the session is in.
72
73     When a user logs in, add_session() is called and a new SAMLSession
74     created and added to the sessions dict, keyed on provider_id.
75
76     When a user logs out, the next login session is found and moved to
77     sessions_logging_out. remove_session() will look in both when trying
78     to remove a session.
79     """
80
81     def __init__(self):
82         self.sessions = dict()
83         self.sessions_logging_out = dict()
84
85     def add_session(self, session_id, provider_id, session):
86         """
87         Add a new session to the logged-in bucket.
88
89         Drop any existing sessions that might exist for this
90         provider. We have no control over the SP's so if it sends
91         us another login, accept it.
92
93         If an existing session exists drop it but keep a copy of
94         its session index. When we logout we send ALL session indexes
95         we've received to ensure that they are all logged out.
96         """
97         samlsession = SAMLSession(session_id, provider_id, session)
98
99         old_session = self.find_session_by_provider(provider_id)
100         if old_session is not None:
101             samlsession.session_indexes.extend(old_session.session_indexes)
102             self.debug("old session: %s" % old_session.session_indexes)
103             self.debug("new session: %s" % samlsession.session_indexes)
104             self.remove_session_by_provider(provider_id)
105         self.sessions[provider_id] = samlsession
106         self.dump()
107
108     def remove_session_by_provider(self, provider_id):
109         """
110         Remove all instances of this provider from either session
111         pool.
112         """
113         if provider_id in self.sessions:
114             self.sessions.pop(provider_id)
115         if provider_id in self.sessions_logging_out:
116             self.sessions_logging_out.pop(provider_id)
117
118     def find_session_by_provider(self, provider_id):
119         """
120         Return a given session from either pool.
121
122         Return None if no session for a provider is found.
123         """
124         if provider_id in self.sessions:
125             return self.sessions[provider_id]
126         if provider_id in self.sessions_logging_out:
127             return self.sessions_logging_out[provider_id]
128         return None
129
130     def start_logout(self, session):
131         """
132         Move a session into the logging_out state
133
134         No return value
135         """
136         if session.provider_id in self.sessions_logging_out:
137             return
138
139         session = self.sessions.pop(session.provider_id)
140
141         self.sessions_logging_out[session.provider_id] = session
142
143     def get_next_logout(self, remove=True):
144         """
145         Get the next session in the logged-in state and move
146         it to the logging_out state.  Return the session that is
147         found.
148
149         :param remove: for IdP-initiated logout we can't remove the
150                        session otherwise when the request comes back
151                        in the user won't be seen as being logged-on.
152
153         Return None if no more sessions in login state.
154         """
155         try:
156             provider_id = self.sessions.keys()[0]
157         except IndexError:
158             return None
159
160         if remove:
161             session = self.sessions.pop(provider_id)
162         else:
163             session = self.sessions.itervalues().next()
164
165         if provider_id in self.sessions_logging_out:
166             self.sessions_logging_out.pop(provider_id)
167
168         self.sessions_logging_out[provider_id] = session
169
170         return session
171
172     def get_last_session(self):
173         if self.count() != 1:
174             raise ValueError('Not exactly one session left')
175
176         try:
177             provider_id = self.sessions_logging_out.keys()[0]
178         except IndexError:
179             return None
180
181         return self.sessions_logging_out.pop(provider_id)
182
183     def count(self):
184         """
185         Return number of active login/logging out sessions.
186         """
187         return len(self.sessions) + len(self.sessions_logging_out)
188
189     def dump(self):
190         count = 0
191         for s in self.sessions:
192             self.debug('Login Session: %d' % count)
193             session = self.sessions[s]
194             session.dump()
195             self.debug('-----------------------')
196             count += 1
197         for s in self.sessions_logging_out:
198             self.debug('Logging-out Session: %d' % count)
199             session = self.sessions_logging_out[s]
200             session.dump()
201             self.debug('-----------------------')
202             count += 1
203
204 if __name__ == '__main__':
205     provider1 = "http://127.0.0.10/saml2"
206     provider2 = "http://127.0.0.11/saml2"
207
208     saml_sessions = SAMLSessionsContainer()
209
210     try:
211         testsession = saml_sessions.get_last_session()
212     except ValueError:
213         assert(saml_sessions.count() == 0)
214
215     saml_sessions.add_session("_123456",
216                               provider1,
217                               "sessiondata")
218
219     saml_sessions.add_session("_789012",
220                               provider2,
221                               "sessiondata")
222
223     try:
224         testsession = saml_sessions.get_last_session()
225     except ValueError:
226         assert(saml_sessions.count() == 2)
227
228     testsession = saml_sessions.find_session_by_provider(provider1)
229     assert(testsession.provider_id == provider1)
230     assert(testsession.session_id == "_123456")
231     assert(testsession.session == "sessiondata")
232
233     # Test get_next_logout() by fetching both values out. Do some
234     # basic accounting to ensure we get both values eventually.
235     providers = [provider1, provider2]
236     testsession = saml_sessions.get_next_logout()
237     providers.remove(testsession.provider_id)  # should be one of them
238
239     testsession = saml_sessions.get_next_logout()
240     assert(testsession.provider_id == providers[0])  # should be the other
241
242     saml_sessions.start_logout(testsession)
243     saml_sessions.remove_session_by_provider(provider2)
244
245     assert(saml_sessions.count() == 1)
246
247     testsession = saml_sessions.get_last_session()
248     assert(testsession.provider_id == provider1)
249
250     saml_sessions.remove_session_by_provider(provider1)
251     assert(saml_sessions.count() == 0)