Remove expired SAML2 sessions
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / sessions.py
1 # Copyright (C) 2015 Ipsilon project Contributors, for license see COPYING
2
3 from cherrypy import config as cherrypy_config
4 from ipsilon.util.log import Log
5 from ipsilon.util.data import SAML2SessionStore
6 import datetime
7
8 LOGGED_IN = 1
9 INIT_LOGOUT = 2
10 LOGGING_OUT = 4
11 LOGGED_OUT = 8
12
13
14 def expire_sessions():
15     """
16     Find all expired sessions and remove them. This is executed as a
17     background cherrypy task.
18     """
19     ss = SAML2SessionStore()
20     data = ss.get_data()
21     now = datetime.datetime.now()
22     for idval in data:
23         r = data[idval]
24         exp = r.get('expiration_time', None)
25         if exp is not None:
26             exp = datetime.datetime.strptime(exp, '%Y-%m-%d %H:%M:%S.%f')
27             if exp < now:
28                 ss.remove_session(idval)
29
30
31 class SAMLSession(Log):
32     """
33     A SAML login session.
34
35        uuidval - Unique ID stored in the database
36        session_id - ID of the login session
37        provider_id - ID of the SP
38        user - the login name of the user that owns the session
39        login_session - the Login session object
40        logoutstate - an integer constant representing where in the
41                      logout process this request is
42        relaystate - where the user will be redirected when logout is
43                     complete
44        request_id - the logout request ID if initiated from IdP. The
45                     logout response will include an InResponseTo value
46                     which matches this.
47        logout_request - the Logout request object
48        expiration_time - the time the login session expires
49     """
50     def __init__(self, uuidval, session_id, provider_id, user,
51                  login_session, logoutstate=None, relaystate=None,
52                  logout_request=None, request_id=None,
53                  expiration_time=None):
54
55         self.uuidval = uuidval
56         self.session_id = session_id
57         self.provider_id = provider_id
58         self.user = user
59         self.login_session = login_session
60         self.logoutstate = logoutstate
61         self.relaystate = relaystate
62         self.request_id = request_id
63         self.logout_request = logout_request
64         self.expiration_time = expiration_time
65
66     def set_logoutstate(self, relaystate=None, request=None, request_id=None):
67         """
68         Update attributes needed to determine the state of the session for
69         logout.
70
71         The database is not updated when these are set. It is expected that
72         this is called prior to start_logout()
73         """
74         if relaystate:
75             self.relaystate = relaystate
76         if request:
77             self.logout_request = request
78         if request_id:
79             self.request_id = request_id
80
81     def dump(self):
82         self.debug('session_id %s' % self.session_id)
83         self.debug('provider_id %s' % self.provider_id)
84         self.debug('login session %s' % self.login_session)
85         self.debug('logoutstate %s' % self.logoutstate)
86
87     def convert(self):
88         """
89         Convert this object into something suitable to store in the
90         data backend.
91         """
92         data = dict()
93         data['session_id'] = self.session_id
94         data['provider_id'] = self.provider_id
95         data['user'] = self.user
96         data['login_session'] = self.login_session
97         data['logoutstate'] = self.logoutstate
98         data['relaystate'] = self.relaystate
99         data['logout_request'] = self.logout_request
100         data['request_id'] = self.request_id
101         data['expiration_time'] = self.expiration_time
102
103         return {self.uuidval: data}
104
105
106 class SAMLSessionFactory(Log):
107     """
108     Access SAML session information.
109
110     The sessions are stored via the data backend.
111
112     When a user logs in, add_session() is called and a new SAMLSession
113     created and added to the table.
114
115     When a user logs out, the next login session is found and moved to
116     sessions_logging_out. remove_session() will look in both when trying
117     to remove a session.
118
119     Returns a SAMLSession object representing the new session.
120     """
121     def __init__(self):
122         self._ss = SAML2SessionStore()
123         self.user = None
124
125     def _data_to_samlsession(self, uuidval, data):
126         """
127         Convert data from the data backend to a SAMLSession object.
128         """
129         return SAMLSession(uuidval,
130                            data.get('session_id'),
131                            data.get('provider_id'),
132                            data.get('user'),
133                            data.get('login_session'),
134                            data.get('logoutstate'),
135                            data.get('relaystate'),
136                            data.get('logout_request'),
137                            data.get('request_id'),
138                            data.get('expiration_time'))
139
140     def add_session(self, session_id, provider_id, user, login_session,
141                     request_id=None):
142         """
143         Add a new login session to the table.
144         """
145         self.user = user
146
147         timeout = cherrypy_config['tools.sessions.timeout']
148         t = datetime.timedelta(seconds=timeout * 60)
149         expiration_time = datetime.datetime.now() + t
150
151         data = {'session_id': session_id,
152                 'provider_id': provider_id,
153                 'user': user,
154                 'login_session': login_session,
155                 'logoutstate': LOGGED_IN,
156                 'expiration_time': expiration_time}
157         if request_id:
158             data['request_id'] = request_id
159
160         uuidval = self._ss.new_session(data)
161
162         return SAMLSession(uuidval, session_id, provider_id, user,
163                            login_session, LOGGED_IN,
164                            request_id=request_id,
165                            expiration_time=expiration_time)
166
167     def get_session_by_id(self, session_id):
168         """
169         Retrieve a session by session ID
170         """
171         uuidval, data = self._ss.get_session(session_id=session_id)
172         if uuidval is None:
173             return None
174
175         return self._data_to_samlsession(uuidval, data)
176
177     def get_session_id_by_provider_id(self, provider_id):
178         """
179         Return a tuple of logged-in session IDs by provider_id
180         """
181         candidates = self._ss.get_user_sessions(self.user)
182
183         session_ids = []
184         for c in candidates:
185             key = c.keys()[0]
186             if c[key].get('provider_id') == provider_id:
187                 samlsession = self._data_to_samlsession(key, c[key])
188                 session_ids.append(samlsession.session_id.encode('utf-8'))
189
190         return tuple(session_ids)
191
192     def get_session_by_request_id(self, request_id):
193         """
194         Retrieve a session by logout request ID
195         """
196         uuidval, data = self._ss.get_session(request_id=request_id)
197         if uuidval is None:
198             return None
199
200         return self._data_to_samlsession(uuidval, data)
201
202     def remove_session(self, samlsession):
203         return self._ss.remove_session(samlsession.uuidval)
204
205     def remove_session_by_session_id(self, session_id):
206         session = self.get_session_by_id(session_id)
207         return self._ss.remove_session(session.uuidval)
208
209     def start_logout(self, samlsession, relaystate=None, initial=True):
210         """
211         Move a session into the logging_out state
212
213         samlsession: the SAMLSession object to start logging out
214         relaystate: URL to redirect user to when logout is completed
215         initial: boolean to indicate if this session started logout.
216                  Only the initial session's relaystate is used.
217
218         No return value
219         """
220         if initial:
221             samlsession.logoutstate = INIT_LOGOUT
222         else:
223             samlsession.logoutstate = LOGGING_OUT
224         if relaystate:
225             samlsession.relaystate = relaystate
226         datum = samlsession.convert()
227         self._ss.update_session(datum)
228
229     def get_next_logout(self, peek=False):
230         """
231         Get the next session in the logged-in state and move
232         it to the logging_out state.  Return the session that is
233         found.
234
235         :param peek: for IdP-initiated logout we can't remove the
236                      session otherwise when the request comes back
237                      in the user won't be seen as being logged-on.
238
239         Return None if no more sessions in LOGGED_IN state.
240         """
241         candidates = self._ss.get_user_sessions(self.user)
242
243         for c in candidates:
244             key = c.keys()[0]
245             if int(c[key].get('logoutstate', 0)) == LOGGED_IN:
246                 samlsession = self._data_to_samlsession(key, c[key])
247                 self.start_logout(samlsession, initial=False)
248                 return samlsession
249         return None
250
251     def get_initial_logout(self):
252         """
253         Get the initial logout request.
254
255         Return None if no sessions in INIT_LOGOUT state.
256         """
257         candidates = self._ss.get_user_sessions(self.user)
258
259         # FIXME: what does it mean if there are multiple in init? We
260         #        just return the first one for now. How do we know
261         #        it's the "right" one if multiple logouts are started
262         #        at the same time from different SPs?
263         for c in candidates:
264             key = c.keys()[0]
265             if int(c[key].get('logoutstate', 0)) == INIT_LOGOUT:
266                 samlsession = self._data_to_samlsession(key, c[key])
267                 return samlsession
268         return None
269
270     def wipe_data(self):
271         self._ss.wipe_data()
272
273     def dump(self):
274         """
275         Dump all sessions to debug log
276         """
277         candidates = self._ss.get_user_sessions(self.user)
278
279         count = 0
280         for c in candidates:
281             key = c.keys()[0]
282             samlsession = self._data_to_samlsession(key, c[key])
283             self.debug('session %d: %s' % (count, samlsession.convert()))
284             count += 1
285
286 if __name__ == '__main__':
287     provider1 = "http://127.0.0.10/saml2"
288     provider2 = "http://127.0.0.11/saml2"
289
290     # temporary values to simulate cherrypy
291     cherrypy_config['saml2.sessions.db'] = '/tmp/saml2sessions.sqlite'
292     cherrypy_config['tools.sessions.timeout'] = 60
293
294     factory = SAMLSessionFactory()
295     factory.wipe_data()
296
297     sess1 = factory.add_session('_123456', provider1, "admin", "<Login/>")
298     sess2 = factory.add_session('_789012', provider2, "testuser", "<Login/>")
299
300     # Test finding sessions by provider
301     ids = factory.get_session_id_by_provider_id(provider2)
302     assert(len(ids) == 1)
303
304     sess3 = factory.add_session('_345678', provider2, "testuser", "<Login/>")
305     ids = factory.get_session_id_by_provider_id(provider2)
306     assert(len(ids) == 2)
307
308     # Test finding sessions by session ID
309     test1 = factory.get_session_by_id('_123456')
310     assert(test1.user == 'admin')
311     assert(test1.provider_id == provider1)
312
313     # Log out and remove the first session
314     test1.set_logoutstate('http://www.example.com/idp')
315     factory.start_logout(test1, initial=True)
316     test1 = factory.get_session_by_id('_123456')
317     assert(test1.relaystate == 'http://www.example.com/idp')
318
319     factory.remove_session_by_session_id('_123456')
320
321     # Make sure it is gone from the db
322     test1 = factory.get_session_by_id('_123456')
323     assert(test1 is None)
324
325     test2 = factory.get_session_by_id('_789012')
326     factory.start_logout(test2, initial=True)
327
328     test3 = factory.get_next_logout()
329     assert(test3.session_id == '_345678')
330
331     test4 = factory.get_initial_logout()
332     assert(test4.session_id == '_789012')
333
334     # Even though we've started logout, make sure we can still find
335     # all sessions for a provider.
336     ids = factory.get_session_id_by_provider_id(provider2)
337     assert(len(ids) == 2)