fb1f646bf4efb5988cf9c32515eafe9d00233da1
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / sessions.py
1 # Copyright (C) 2015  Rob Crittenden <rcritten@redhat.com>
2 #
3 # see file 'COPYING' for use and warranty information
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 from ipsilon.util.log import Log
19
20
21 class SAMLSession(Log):
22     """
23     A SAML login session used to track login/logout state.
24
25        session_id - ID of the login session
26        provider_id - ID of the SP
27        session - the Login session object
28        logoutstate - dict containing logout state info
29        session_indexes - the IDs of any login session we've seen
30                          for this user
31
32     When a new session is seen for the same user any existing session
33     is thrown away. We keep the original session_id though and send
34     all that we've seen to the SP when performing a logout to ensure
35     that all sessions get logged out.
36
37     logout state is a dictionary containing (potentially)
38     these attributes:
39
40     relaystate - The relaystate from the Logout Request or Response
41     id         - The Logout request id that initiated the logout
42     request    - Dump of the initial logout request
43     """
44     def __init__(self, session_id, provider_id, session,
45                  logoutstate=None):
46
47         self.session_id = session_id
48         self.provider_id = provider_id
49         self.session = session
50         self.logoutstate = logoutstate
51         self.session_indexes = [session_id]
52
53     def set_logoutstate(self, relaystate, request_id, request=None):
54         self.logoutstate = dict(relaystate=relaystate,
55                                 id=request_id,
56                                 request=request)
57
58     def dump(self):
59         self.debug('session_id %s' % self.session_id)
60         self.debug('session_index %s' % self.session_indexes)
61         self.debug('provider_id %s' % self.provider_id)
62         self.debug('session %s' % self.session)
63         self.debug('logoutstate %s' % self.logoutstate)
64
65
66 class SAMLSessionsContainer(Log):
67     """
68     Store SAML session information.
69
70     The sessions are stored in two dicts which represent the state that
71     the session is in.
72
73     When a user logs in, add_session() is called and a new SAMLSession
74     created and added to the sessions dict, keyed on provider_id.
75
76     When a user logs out, the next login session is found and moved to
77     sessions_logging_out. remove_session() will look in both when trying
78     to remove a session.
79     """
80
81     def __init__(self):
82         self.sessions = dict()
83         self.sessions_logging_out = dict()
84
85     def add_session(self, session_id, provider_id, session):
86         """
87         Add a new session to the logged-in bucket.
88
89         Drop any existing sessions that might exist for this
90         provider. We have no control over the SP's so if it sends
91         us another login, accept it.
92
93         If an existing session exists drop it but keep a copy of
94         its session index. When we logout we send ALL session indexes
95         we've received to ensure that they are all logged out.
96         """
97         samlsession = SAMLSession(session_id, provider_id, session)
98
99         old_session = self.find_session_by_provider(provider_id)
100         if old_session is not None:
101             samlsession.session_indexes.extend(old_session.session_indexes)
102             self.debug("old session: %s" % old_session.session_indexes)
103             self.debug("new session: %s" % samlsession.session_indexes)
104             self.remove_session_by_provider(provider_id)
105         self.sessions[provider_id] = samlsession
106         self.dump()
107
108     def remove_session_by_provider(self, provider_id):
109         """
110         Remove all instances of this provider from either session
111         pool.
112         """
113         if provider_id in self.sessions:
114             self.sessions.pop(provider_id)
115         if provider_id in self.sessions_logging_out:
116             self.sessions_logging_out.pop(provider_id)
117
118     def find_session_by_provider(self, provider_id):
119         """
120         Return a given session from either pool.
121
122         Return None if no session for a provider is found.
123         """
124         if provider_id in self.sessions:
125             return self.sessions[provider_id]
126         if provider_id in self.sessions_logging_out:
127             return self.sessions_logging_out[provider_id]
128         return None
129
130     def start_logout(self, session):
131         """
132         Move a session into the logging_out state
133
134         No return value
135         """
136         if session.provider_id in self.sessions_logging_out:
137             return
138
139         session = self.sessions.pop(session.provider_id)
140
141         self.sessions_logging_out[session.provider_id] = session
142
143     def get_next_logout(self):
144         """
145         Get the next session in the logged-in state and move
146         it to the logging_out state.  Return the session that is
147         found.
148
149         Return None if no more sessions in login state.
150         """
151         try:
152             provider_id = self.sessions.keys()[0]
153         except IndexError:
154             return None
155
156         session = self.sessions.pop(provider_id)
157
158         if provider_id in self.sessions_logging_out:
159             self.sessions_logging_out.pop(provider_id)
160
161         self.sessions_logging_out[provider_id] = session
162
163         return session
164
165     def get_last_session(self):
166         if self.count() != 1:
167             raise ValueError('Not exactly one session left')
168
169         try:
170             provider_id = self.sessions_logging_out.keys()[0]
171         except IndexError:
172             return None
173
174         return self.sessions_logging_out.pop(provider_id)
175
176     def count(self):
177         """
178         Return number of active login/logging out sessions.
179         """
180         return len(self.sessions) + len(self.sessions_logging_out)
181
182     def dump(self):
183         count = 0
184         for s in self.sessions:
185             self.debug('Login Session: %d' % count)
186             session = self.sessions[s]
187             session.dump()
188             self.debug('-----------------------')
189             count += 1
190         for s in self.sessions_logging_out:
191             self.debug('Logging-out Session: %d' % count)
192             session = self.sessions_logging_out[s]
193             session.dump()
194             self.debug('-----------------------')
195             count += 1
196
197 if __name__ == '__main__':
198     provider1 = "http://127.0.0.10/saml2"
199     provider2 = "http://127.0.0.11/saml2"
200
201     saml_sessions = SAMLSessionsContainer()
202
203     try:
204         testsession = saml_sessions.get_last_session()
205     except ValueError:
206         assert(saml_sessions.count() == 0)
207
208     saml_sessions.add_session("_123456",
209                               provider1,
210                               "sessiondata")
211
212     saml_sessions.add_session("_789012",
213                               provider2,
214                               "sessiondata")
215
216     try:
217         testsession = saml_sessions.get_last_session()
218     except ValueError:
219         assert(saml_sessions.count() == 2)
220
221     testsession = saml_sessions.find_session_by_provider(provider1)
222     assert(testsession.provider_id == provider1)
223     assert(testsession.session_id == "_123456")
224     assert(testsession.session == "sessiondata")
225
226     # Test get_next_logout() by fetching both values out. Do some
227     # basic accounting to ensure we get both values eventually.
228     providers = [provider1, provider2]
229     testsession = saml_sessions.get_next_logout()
230     providers.remove(testsession.provider_id)  # should be one of them
231
232     testsession = saml_sessions.get_next_logout()
233     assert(testsession.provider_id == providers[0])  # should be the other
234
235     saml_sessions.start_logout(testsession)
236     saml_sessions.remove_session_by_provider(provider2)
237
238     assert(saml_sessions.count() == 1)
239
240     testsession = saml_sessions.get_last_session()
241     assert(testsession.provider_id == provider1)
242
243     saml_sessions.remove_session_by_provider(provider1)
244     assert(saml_sessions.count() == 0)