d01bb6ed64520d0ef7321eb5a24e40c104aee7db
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / sessions.py
1 # Copyright (C) 2015 Ipsilon project Contributors, for license see COPYING
2
3 from ipsilon.util.log import Log
4 from ipsilon.util.data import SAML2SessionStore
5
6 LOGGED_IN = 1
7 INIT_LOGOUT = 2
8 LOGGING_OUT = 4
9 LOGGED_OUT = 8
10
11
12 class SAMLSession(Log):
13     """
14     A SAML login session.
15
16        uuidval - Unique ID stored in the database
17        session_id - ID of the login session
18        provider_id - ID of the SP
19        user - the login name of the user that owns the session
20        login_session - the Login session object
21        logoutstate - an integer constant representing where in the
22                      logout process this request is
23        relaystate - where the user will be redirected when logout is
24                     complete
25        request_id - the logout request ID if initiated from IdP. The
26                     logout response will include an InResponseTo value
27                     which matches this.
28        logout_request - the Logout request object
29     """
30     def __init__(self, uuidval, session_id, provider_id, user,
31                  login_session, logoutstate=None, relaystate=None,
32                  logout_request=None, request_id=None):
33
34         self.uuidval = uuidval
35         self.session_id = session_id
36         self.provider_id = provider_id
37         self.user = user
38         self.login_session = login_session
39         self.logoutstate = logoutstate
40         self.relaystate = relaystate
41         self.request_id = request_id
42         self.logout_request = logout_request
43
44     def set_logoutstate(self, relaystate=None, request=None, request_id=None):
45         """
46         Update attributes needed to determine the state of the session for
47         logout.
48
49         The database is not updated when these are set. It is expected that
50         this is called prior to start_logout()
51         """
52         if relaystate:
53             self.relaystate = relaystate
54         if request:
55             self.logout_request = request
56         if request_id:
57             self.request_id = request_id
58
59     def dump(self):
60         self.debug('session_id %s' % self.session_id)
61         self.debug('provider_id %s' % self.provider_id)
62         self.debug('login session %s' % self.login_session)
63         self.debug('logoutstate %s' % self.logoutstate)
64
65     def convert(self):
66         """
67         Convert this object into something suitable to store in the
68         data backend.
69         """
70         data = dict()
71         data['session_id'] = self.session_id
72         data['provider_id'] = self.provider_id
73         data['user'] = self.user
74         data['login_session'] = self.login_session
75         data['logoutstate'] = self.logoutstate
76         data['relaystate'] = self.relaystate
77         data['logout_request'] = self.logout_request
78         data['request_id'] = self.request_id
79
80         return {self.uuidval: data}
81
82
83 class SAMLSessionFactory(Log):
84     """
85     Access SAML session information.
86
87     The sessions are stored via the data backend.
88
89     When a user logs in, add_session() is called and a new SAMLSession
90     created and added to the table.
91
92     When a user logs out, the next login session is found and moved to
93     sessions_logging_out. remove_session() will look in both when trying
94     to remove a session.
95
96     Returns a SAMLSession object representing the new session.
97     """
98     def __init__(self):
99         self._ss = SAML2SessionStore()
100         self.user = None
101
102     def _data_to_samlsession(self, uuidval, data):
103         """
104         Convert data from the data backend to a SAMLSession object.
105         """
106         return SAMLSession(uuidval,
107                            data.get('session_id'),
108                            data.get('provider_id'),
109                            data.get('user'),
110                            data.get('login_session'),
111                            data.get('logoutstate'),
112                            data.get('relaystate'),
113                            data.get('logout_request'),
114                            data.get('request_id'))
115
116     def add_session(self, session_id, provider_id, user, login_session,
117                     request_id=None):
118         """
119         Add a new login session to the table.
120         """
121         self.user = user
122
123         data = {'session_id': session_id,
124                 'provider_id': provider_id,
125                 'user': user,
126                 'login_session': login_session,
127                 'logoutstate': LOGGED_IN}
128         if request_id:
129             data['request_id'] = request_id
130
131         uuidval = self._ss.new_session(data)
132
133         return SAMLSession(uuidval, session_id, provider_id, user,
134                            login_session, LOGGED_IN,
135                            request_id=request_id)
136
137     def get_session_by_id(self, session_id):
138         """
139         Retrieve a session by session ID
140         """
141         uuidval, data = self._ss.get_session(session_id=session_id)
142         if uuidval is None:
143             return None
144
145         return self._data_to_samlsession(uuidval, data)
146
147     def get_session_id_by_provider_id(self, provider_id):
148         """
149         Return a tuple of logged-in session IDs by provider_id
150         """
151         candidates = self._ss.get_user_sessions(self.user)
152
153         session_ids = []
154         for c in candidates:
155             key = c.keys()[0]
156             if c[key].get('provider_id') == provider_id:
157                 samlsession = self._data_to_samlsession(key, c[key])
158                 session_ids.append(samlsession.session_id.encode('utf-8'))
159
160         return tuple(session_ids)
161
162     def get_session_by_request_id(self, request_id):
163         """
164         Retrieve a session by logout request ID
165         """
166         uuidval, data = self._ss.get_session(request_id=request_id)
167         if uuidval is None:
168             return None
169
170         return self._data_to_samlsession(uuidval, data)
171
172     def remove_session(self, samlsession):
173         return self._ss.remove_session(samlsession.uuidval)
174
175     def remove_session_by_session_id(self, session_id):
176         session = self.get_session_by_id(session_id)
177         return self._ss.remove_session(session.uuidval)
178
179     def start_logout(self, samlsession, relaystate=None, initial=True):
180         """
181         Move a session into the logging_out state
182
183         samlsession: the SAMLSession object to start logging out
184         relaystate: URL to redirect user to when logout is completed
185         initial: boolean to indicate if this session started logout.
186                  Only the initial session's relaystate is used.
187
188         No return value
189         """
190         if initial:
191             samlsession.logoutstate = INIT_LOGOUT
192         else:
193             samlsession.logoutstate = LOGGING_OUT
194         if relaystate:
195             samlsession.relaystate = relaystate
196         datum = samlsession.convert()
197         self._ss.update_session(datum)
198
199     def get_next_logout(self, peek=False):
200         """
201         Get the next session in the logged-in state and move
202         it to the logging_out state.  Return the session that is
203         found.
204
205         :param peek: for IdP-initiated logout we can't remove the
206                      session otherwise when the request comes back
207                      in the user won't be seen as being logged-on.
208
209         Return None if no more sessions in LOGGED_IN state.
210         """
211         candidates = self._ss.get_user_sessions(self.user)
212
213         for c in candidates:
214             key = c.keys()[0]
215             if int(c[key].get('logoutstate', 0)) == LOGGED_IN:
216                 samlsession = self._data_to_samlsession(key, c[key])
217                 self.start_logout(samlsession, initial=False)
218                 return samlsession
219         return None
220
221     def get_initial_logout(self):
222         """
223         Get the initial logout request.
224
225         Return None if no sessions in INIT_LOGOUT state.
226         """
227         candidates = self._ss.get_user_sessions(self.user)
228
229         # FIXME: what does it mean if there are multiple in init? We
230         #        just return the first one for now. How do we know
231         #        it's the "right" one if multiple logouts are started
232         #        at the same time from different SPs?
233         for c in candidates:
234             key = c.keys()[0]
235             if int(c[key].get('logoutstate', 0)) == INIT_LOGOUT:
236                 samlsession = self._data_to_samlsession(key, c[key])
237                 return samlsession
238         return None
239
240     def wipe_data(self):
241         self._ss.wipe_data()
242
243     def dump(self):
244         """
245         Dump all sessions to debug log
246         """
247         candidates = self._ss.get_user_sessions(self.user)
248
249         count = 0
250         for c in candidates:
251             key = c.keys()[0]
252             samlsession = self._data_to_samlsession(key, c[key])
253             self.debug('session %d: %s' % (count, samlsession.convert()))
254             count += 1
255
256 if __name__ == '__main__':
257     import cherrypy
258
259     provider1 = "http://127.0.0.10/saml2"
260     provider2 = "http://127.0.0.11/saml2"
261
262     # temporary database location for testing
263     cherrypy.config['saml2.sessions.db'] = '/tmp/saml2sessions.sqlite'
264
265     factory = SAMLSessionFactory()
266     factory.wipe_data()
267
268     sess1 = factory.add_session('_123456', provider1, "admin", "<Login/>")
269     sess2 = factory.add_session('_789012', provider2, "testuser", "<Login/>")
270
271     # Test finding sessions by provider
272     ids = factory.get_session_id_by_provider_id(provider2)
273     assert(len(ids) == 1)
274
275     sess3 = factory.add_session('_345678', provider2, "testuser", "<Login/>")
276     ids = factory.get_session_id_by_provider_id(provider2)
277     assert(len(ids) == 2)
278
279     # Test finding sessions by session ID
280     test1 = factory.get_session_by_id('_123456')
281     assert(test1.user == 'admin')
282     assert(test1.provider_id == provider1)
283
284     # Log out and remove the first session
285     test1.set_logoutstate('http://www.example.com/idp')
286     factory.start_logout(test1, initial=True)
287     test1 = factory.get_session_by_id('_123456')
288     assert(test1.relaystate == 'http://www.example.com/idp')
289
290     factory.remove_session_by_session_id('_123456')
291
292     # Make sure it is gone from the db
293     test1 = factory.get_session_by_id('_123456')
294     assert(test1 is None)
295
296     test2 = factory.get_session_by_id('_789012')
297     factory.start_logout(test2, initial=True)
298
299     test3 = factory.get_next_logout()
300     assert(test3.session_id == '_345678')
301
302     test4 = factory.get_initial_logout()
303     assert(test4.session_id == '_789012')
304
305     # Even though we've started logout, make sure we can still find
306     # all sessions for a provider.
307     ids = factory.get_session_id_by_provider_id(provider2)
308     assert(len(ids) == 2)