Add SAML-specific session data for tracking login/logout sessions
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / sessions.py
1 # Copyright (C) 2015  Rob Crittenden <rcritten@redhat.com>
2 #
3 # see file 'COPYING' for use and warranty information
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 from ipsilon.util.log import Log
19
20
21 class SAMLSession(Log):
22     """
23     A SAML login session used to track login/logout state.
24
25        session_id - ID of the login session
26        provider_id - ID of the SP
27        session - the Login session object
28        logoutstate - dict containing logout state info
29
30     logout state is a dictionary containing (potentially)
31     these attributes:
32
33     relaystate - The relaystate from the Logout Request or Response
34     id         - The Logout request id that initiated the logout
35     request    - Dump of the initial logout request
36     """
37     def __init__(self, session_id, provider_id, session,
38                  logoutstate=None):
39
40         self.session_id = session_id
41         self.provider_id = provider_id
42         self.session = session
43         self.logoutstate = logoutstate
44
45     def set_logoutstate(self, relaystate, request_id, request=None):
46         self.logoutstate = dict(relaystate=relaystate,
47                                 id=request_id,
48                                 request=request)
49
50     def dump(self):
51         self.debug('session_id %s' % self.session_id)
52         self.debug('provider_id %s' % self.provider_id)
53         self.debug('session %s' % self.session)
54         self.debug('logoutstate %s' % self.logoutstate)
55
56
57 class SAMLSessionsContainer(Log):
58     """
59     Store SAML session information.
60
61     The sessions are stored in two dicts which represent the state that
62     the session is in.
63
64     When a user logs in, add_session() is called and a new SAMLSession
65     created and added to the sessions dict, keyed on provider_id.
66
67     When a user logs out, the next login session is found and moved to
68     sessions_logging_out. remove_session() will look in both when trying
69     to remove a session.
70     """
71
72     def __init__(self):
73         self.sessions = dict()
74         self.sessions_logging_out = dict()
75
76     def add_session(self, session_id, provider_id, session):
77         """
78         Add a new session to the logged-in bucket.
79
80         Drop any existing sessions that might exist for this
81         provider. We have no control over the SP's so if it sends
82         us another login, accept it.
83         """
84         samlsession = SAMLSession(session_id, provider_id, session)
85
86         self.remove_session_by_provider(provider_id)
87         self.sessions[provider_id] = samlsession
88         self.dump()
89
90     def remove_session_by_provider(self, provider_id):
91         """
92         Remove all instances of this provider from either session
93         pool.
94         """
95         if provider_id in self.sessions:
96             self.sessions.pop(provider_id)
97         if provider_id in self.sessions_logging_out:
98             self.sessions_logging_out.pop(provider_id)
99
100     def find_session_by_provider(self, provider_id):
101         """
102         Return a given session from either pool.
103
104         Return None if no session for a provider is found.
105         """
106         if provider_id in self.sessions:
107             return self.sessions[provider_id]
108         if provider_id in self.sessions_logging_out:
109             return self.sessions_logging_out[provider_id]
110         return None
111
112     def start_logout(self, session):
113         """
114         Move a session into the logging_out state
115
116         No return value
117         """
118         if session.provider_id in self.sessions_logging_out:
119             return
120
121         session = self.sessions.pop(session.provider_id)
122
123         self.sessions_logging_out[session.provider_id] = session
124
125     def get_next_logout(self):
126         """
127         Get the next session in the logged-in state and move
128         it to the logging_out state.  Return the session that is
129         found.
130
131         Return None if no more sessions in login state.
132         """
133         try:
134             provider_id = self.sessions.keys()[0]
135         except IndexError:
136             return None
137
138         session = self.sessions.pop(provider_id)
139
140         if provider_id in self.sessions_logging_out:
141             self.sessions_logging_out.pop(provider_id)
142
143         self.sessions_logging_out[provider_id] = session
144
145         return session
146
147     def get_last_session(self):
148         if self.count() != 1:
149             raise ValueError('Not exactly one session left')
150
151         try:
152             provider_id = self.sessions_logging_out.keys()[0]
153         except IndexError:
154             return None
155
156         return self.sessions_logging_out.pop(provider_id)
157
158     def count(self):
159         """
160         Return number of active login/logging out sessions.
161         """
162         return len(self.sessions) + len(self.sessions_logging_out)
163
164     def dump(self):
165         count = 0
166         for s in self.sessions:
167             self.debug('Login Session: %d' % count)
168             session = self.sessions[s]
169             session.dump()
170             self.debug('-----------------------')
171             count += 1
172         for s in self.sessions_logging_out:
173             self.debug('Logging-out Session: %d' % count)
174             session = self.sessions_logging_out[s]
175             session.dump()
176             self.debug('-----------------------')
177             count += 1
178
179 if __name__ == '__main__':
180     provider1 = "http://127.0.0.10/saml2"
181     provider2 = "http://127.0.0.11/saml2"
182
183     saml_sessions = SAMLSessionsContainer()
184
185     try:
186         testsession = saml_sessions.get_last_session()
187     except ValueError:
188         assert(saml_sessions.count() == 0)
189
190     saml_sessions.add_session("_123456",
191                               provider1,
192                               "sessiondata")
193
194     saml_sessions.add_session("_789012",
195                               provider2,
196                               "sessiondata")
197
198     try:
199         testsession = saml_sessions.get_last_session()
200     except ValueError:
201         assert(saml_sessions.count() == 2)
202
203     testsession = saml_sessions.find_session_by_provider(provider1)
204     assert(testsession.provider_id == provider1)
205     assert(testsession.session_id == "_123456")
206     assert(testsession.session == "sessiondata")
207
208     # Test get_next_logout() by fetching both values out. Do some
209     # basic accounting to ensure we get both values eventually.
210     providers = [provider1, provider2]
211     testsession = saml_sessions.get_next_logout()
212     providers.remove(testsession.provider_id)  # should be one of them
213
214     testsession = saml_sessions.get_next_logout()
215     assert(testsession.provider_id == providers[0])  # should be the other
216
217     saml_sessions.start_logout(testsession)
218     saml_sessions.remove_session_by_provider(provider2)
219
220     assert(saml_sessions.count() == 1)
221
222     testsession = saml_sessions.get_last_session()
223     assert(testsession.provider_id == provider1)
224
225     saml_sessions.remove_session_by_provider(provider1)
226     assert(saml_sessions.count() == 0)