Use plugin-specific configuration, better expiration
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / sessions.py
1 # Copyright (C) 2015 Ipsilon project Contributors, for license see COPYING
2
3 from cherrypy import config as cherrypy_config
4 from ipsilon.util.log import Log
5 from ipsilon.util.data import SAML2SessionStore
6 import datetime
7
8 LOGGED_IN = 1
9 INIT_LOGOUT = 2
10 LOGGING_OUT = 4
11 LOGGED_OUT = 8
12
13
14 class SAMLSession(Log):
15     """
16     A SAML login session.
17
18        uuidval - Unique ID stored in the database
19        session_id - ID of the login session
20        provider_id - ID of the SP
21        user - the login name of the user that owns the session
22        login_session - the Login session object
23        logoutstate - an integer constant representing where in the
24                      logout process this request is
25        relaystate - where the user will be redirected when logout is
26                     complete
27        request_id - the logout request ID if initiated from IdP. The
28                     logout response will include an InResponseTo value
29                     which matches this.
30        logout_request - the Logout request object
31        expiration_time - the time the login session expires
32     """
33     def __init__(self, uuidval, session_id, provider_id, user,
34                  login_session, logoutstate=None, relaystate=None,
35                  logout_request=None, request_id=None,
36                  expiration_time=None):
37
38         self.uuidval = uuidval
39         self.session_id = session_id
40         self.provider_id = provider_id
41         self.user = user
42         self.login_session = login_session
43         self.logoutstate = logoutstate
44         self.relaystate = relaystate
45         self.request_id = request_id
46         self.logout_request = logout_request
47         self.expiration_time = expiration_time
48
49     def set_logoutstate(self, relaystate=None, request=None, request_id=None):
50         """
51         Update attributes needed to determine the state of the session for
52         logout.
53
54         The database is not updated when these are set. It is expected that
55         this is called prior to start_logout()
56         """
57         if relaystate:
58             self.relaystate = relaystate
59         if request:
60             self.logout_request = request
61         if request_id:
62             self.request_id = request_id
63
64     def dump(self):
65         self.debug('session_id %s' % self.session_id)
66         self.debug('provider_id %s' % self.provider_id)
67         self.debug('login session %s' % self.login_session)
68         self.debug('logoutstate %s' % self.logoutstate)
69
70     def convert(self):
71         """
72         Convert this object into something suitable to store in the
73         data backend.
74         """
75         data = dict()
76         data['session_id'] = self.session_id
77         data['provider_id'] = self.provider_id
78         data['user'] = self.user
79         data['login_session'] = self.login_session
80         data['logoutstate'] = self.logoutstate
81         data['relaystate'] = self.relaystate
82         data['logout_request'] = self.logout_request
83         data['request_id'] = self.request_id
84         data['expiration_time'] = self.expiration_time
85
86         return {self.uuidval: data}
87
88
89 class SAMLSessionFactory(Log):
90     """
91     Access SAML session information.
92
93     The sessions are stored via the data backend.
94
95     When a user logs in, add_session() is called and a new SAMLSession
96     created and added to the table.
97
98     When a user logs out, the next login session is found and moved to
99     sessions_logging_out. remove_session() will look in both when trying
100     to remove a session.
101
102     Returns a SAMLSession object representing the new session.
103     """
104     def __init__(self, database_url):
105         self._ss = SAML2SessionStore(database_url=database_url)
106         self.user = None
107
108     def _data_to_samlsession(self, uuidval, data):
109         """
110         Convert data from the data backend to a SAMLSession object.
111         """
112         return SAMLSession(uuidval,
113                            data.get('session_id'),
114                            data.get('provider_id'),
115                            data.get('user'),
116                            data.get('login_session'),
117                            data.get('logoutstate'),
118                            data.get('relaystate'),
119                            data.get('logout_request'),
120                            data.get('request_id'),
121                            data.get('expiration_time'))
122
123     def add_session(self, session_id, provider_id, user, login_session,
124                     request_id=None):
125         """
126         Add a new login session to the table.
127         """
128         self.user = user
129
130         timeout = cherrypy_config['tools.sessions.timeout']
131         t = datetime.timedelta(seconds=timeout * 60)
132         expiration_time = datetime.datetime.now() + t
133
134         data = {'session_id': session_id,
135                 'provider_id': provider_id,
136                 'user': user,
137                 'login_session': login_session,
138                 'logoutstate': LOGGED_IN,
139                 'expiration_time': expiration_time}
140         if request_id:
141             data['request_id'] = request_id
142
143         uuidval = self._ss.new_session(data)
144
145         return SAMLSession(uuidval, session_id, provider_id, user,
146                            login_session, LOGGED_IN,
147                            request_id=request_id,
148                            expiration_time=expiration_time)
149
150     def get_session_by_id(self, session_id):
151         """
152         Retrieve a session by session ID
153         """
154         uuidval, data = self._ss.get_session(session_id=session_id)
155         if uuidval is None:
156             return None
157
158         return self._data_to_samlsession(uuidval, data)
159
160     def get_session_id_by_provider_id(self, provider_id):
161         """
162         Return a tuple of logged-in session IDs by provider_id
163         """
164         candidates = self._ss.get_user_sessions(self.user)
165
166         session_ids = []
167         for c in candidates:
168             key = c.keys()[0]
169             if c[key].get('provider_id') == provider_id:
170                 samlsession = self._data_to_samlsession(key, c[key])
171                 session_ids.append(samlsession.session_id.encode('utf-8'))
172
173         return tuple(session_ids)
174
175     def get_session_by_request_id(self, request_id):
176         """
177         Retrieve a session by logout request ID
178         """
179         uuidval, data = self._ss.get_session(request_id=request_id)
180         if uuidval is None:
181             return None
182
183         return self._data_to_samlsession(uuidval, data)
184
185     def remove_session(self, samlsession):
186         return self._ss.remove_session(samlsession.uuidval)
187
188     def remove_session_by_session_id(self, session_id):
189         session = self.get_session_by_id(session_id)
190         return self._ss.remove_session(session.uuidval)
191
192     def start_logout(self, samlsession, relaystate=None, initial=True):
193         """
194         Move a session into the logging_out state
195
196         samlsession: the SAMLSession object to start logging out
197         relaystate: URL to redirect user to when logout is completed
198         initial: boolean to indicate if this session started logout.
199                  Only the initial session's relaystate is used.
200
201         No return value
202         """
203         if initial:
204             samlsession.logoutstate = INIT_LOGOUT
205         else:
206             samlsession.logoutstate = LOGGING_OUT
207         if relaystate:
208             samlsession.relaystate = relaystate
209         datum = samlsession.convert()
210         self._ss.update_session(datum)
211
212     def get_next_logout(self, peek=False):
213         """
214         Get the next session in the logged-in state and move
215         it to the logging_out state.  Return the session that is
216         found.
217
218         :param peek: for IdP-initiated logout we can't remove the
219                      session otherwise when the request comes back
220                      in the user won't be seen as being logged-on.
221
222         Return None if no more sessions in LOGGED_IN state.
223         """
224         candidates = self._ss.get_user_sessions(self.user)
225
226         for c in candidates:
227             key = c.keys()[0]
228             if int(c[key].get('logoutstate', 0)) == LOGGED_IN:
229                 samlsession = self._data_to_samlsession(key, c[key])
230                 self.start_logout(samlsession, initial=False)
231                 return samlsession
232         return None
233
234     def get_initial_logout(self):
235         """
236         Get the initial logout request.
237
238         Return None if no sessions in INIT_LOGOUT state.
239         """
240         candidates = self._ss.get_user_sessions(self.user)
241
242         # FIXME: what does it mean if there are multiple in init? We
243         #        just return the first one for now. How do we know
244         #        it's the "right" one if multiple logouts are started
245         #        at the same time from different SPs?
246         for c in candidates:
247             key = c.keys()[0]
248             if int(c[key].get('logoutstate', 0)) == INIT_LOGOUT:
249                 samlsession = self._data_to_samlsession(key, c[key])
250                 return samlsession
251         return None
252
253     def wipe_data(self):
254         self._ss.wipe_data()
255
256     def dump(self):
257         """
258         Dump all sessions to debug log
259         """
260         candidates = self._ss.get_user_sessions(self.user)
261
262         count = 0
263         for c in candidates:
264             key = c.keys()[0]
265             samlsession = self._data_to_samlsession(key, c[key])
266             self.debug('session %d: %s' % (count, samlsession.convert()))
267             count += 1
268
269 if __name__ == '__main__':
270     provider1 = "http://127.0.0.10/saml2"
271     provider2 = "http://127.0.0.11/saml2"
272
273     # temporary values to simulate cherrypy
274     cherrypy_config['tools.sessions.timeout'] = 60
275
276     factory = SAMLSessionFactory('/tmp/saml2sessions.sqlite')
277     factory.wipe_data()
278
279     sess1 = factory.add_session('_123456', provider1, "admin", "<Login/>")
280     sess2 = factory.add_session('_789012', provider2, "testuser", "<Login/>")
281
282     # Test finding sessions by provider
283     ids = factory.get_session_id_by_provider_id(provider2)
284     assert(len(ids) == 1)
285
286     sess3 = factory.add_session('_345678', provider2, "testuser", "<Login/>")
287     ids = factory.get_session_id_by_provider_id(provider2)
288     assert(len(ids) == 2)
289
290     # Test finding sessions by session ID
291     test1 = factory.get_session_by_id('_123456')
292     assert(test1.user == 'admin')
293     assert(test1.provider_id == provider1)
294
295     # Log out and remove the first session
296     test1.set_logoutstate('http://www.example.com/idp')
297     factory.start_logout(test1, initial=True)
298     test1 = factory.get_session_by_id('_123456')
299     assert(test1.relaystate == 'http://www.example.com/idp')
300
301     factory.remove_session_by_session_id('_123456')
302
303     # Make sure it is gone from the db
304     test1 = factory.get_session_by_id('_123456')
305     assert(test1 is None)
306
307     test2 = factory.get_session_by_id('_789012')
308     factory.start_logout(test2, initial=True)
309
310     test3 = factory.get_next_logout()
311     assert(test3.session_id == '_345678')
312
313     test4 = factory.get_initial_logout()
314     assert(test4.session_id == '_789012')
315
316     # Even though we've started logout, make sure we can still find
317     # all sessions for a provider.
318     ids = factory.get_session_id_by_provider_id(provider2)
319     assert(len(ids) == 2)