Add OpenID test suite
[cascardo/ipsilon.git] / tests / helpers / http.py
1 #!/usr/bin/python
2 #
3 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
4 #
5 # see file 'COPYING' for use and warranty information
6 #
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20
21 from lxml import html
22 import requests
23 import string
24 import urlparse
25 import json
26 from urllib import urlencode
27
28
29 class WrongPage(Exception):
30     pass
31
32
33 class PageTree(object):
34
35     def __init__(self, result):
36         self.result = result
37         self.text = result.text
38         self._tree = None
39
40     @property
41     def tree(self):
42         if self._tree is None:
43             self._tree = html.fromstring(self.text)
44         return self._tree
45
46     def first_value(self, rule):
47         result = self.tree.xpath(rule)
48         if type(result) is list:
49             if len(result) > 0:
50                 result = result[0]
51             else:
52                 result = None
53         return result
54
55     def all_values(self, rule):
56         result = self.tree.xpath(rule)
57         if type(result) is list:
58             return result
59         return [result]
60
61     def make_referer(self):
62         return self.result.url
63
64     def expected_value(self, rule, expected):
65         value = self.first_value(rule)
66         if value != expected:
67             raise ValueError("Expected [%s], got [%s]" % (expected, value))
68
69
70 class HttpSessions(object):
71
72     def __init__(self):
73         self.servers = dict()
74
75     def add_server(self, name, baseuri, user=None, pwd=None):
76         new = {'baseuri': baseuri,
77                'session': requests.Session()}
78         if user:
79             new['user'] = user
80         if pwd:
81             new['pwd'] = pwd
82         self.servers[name] = new
83
84     def get_session(self, url):
85         for srv in self.servers:
86             d = self.servers[srv]
87             if url.startswith(d['baseuri']):
88                 return d['session']
89
90         raise ValueError("Unknown URL: %s" % url)
91
92     def get(self, url, **kwargs):
93         session = self.get_session(url)
94         return session.get(url, allow_redirects=False, **kwargs)
95
96     def post(self, url, **kwargs):
97         session = self.get_session(url)
98         return session.post(url, allow_redirects=False, **kwargs)
99
100     def access(self, action, url, **kwargs):
101         action = string.lower(action)
102         if action == 'get':
103             return self.get(url, **kwargs)
104         elif action == 'post':
105             return self.post(url, **kwargs)
106         else:
107             raise ValueError("Unknown action type: [%s]" % action)
108
109     def new_url(self, referer, action):
110         if action.startswith('/'):
111             u = urlparse.urlparse(referer)
112             return '%s://%s%s' % (u.scheme, u.netloc, action)
113         return action
114
115     def get_form_data(self, page, form_id, input_fields):
116         form_selector = '//form'
117         if form_id:
118             form_selector += '[@id="%s"]' % form_id
119         values = []
120         action = page.first_value('%s/@action' % form_selector)
121         values.append(action)
122         method = page.first_value('%s/@method' % form_selector)
123         values.append(method)
124         for field in input_fields:
125             value = page.all_values('%s/input/@%s' % (form_selector,
126                                                       field))
127             values.append(value)
128         return values
129
130     def handle_login_form(self, idp, page):
131         if type(page) != PageTree:
132             raise TypeError("Expected PageTree object")
133
134         srv = self.servers[idp]
135
136         try:
137             results = self.get_form_data(page, "login_form", ["name", "value"])
138             action_url = results[0]
139             method = results[1]
140             names = results[2]
141             values = results[3]
142             if action_url is None:
143                 raise Exception
144         except Exception:  # pylint: disable=broad-except
145             raise WrongPage("Not a Login Form Page")
146
147         referer = page.make_referer()
148         headers = {'referer': referer}
149         payload = {}
150         for i in range(0, len(names)):
151             payload[names[i]] = values[i]
152
153         # replace known values
154         payload['login_name'] = srv['user']
155         payload['login_password'] = srv['pwd']
156
157         return [method, self.new_url(referer, action_url),
158                 {'headers': headers, 'data': payload}]
159
160     def handle_return_form(self, page):
161         if type(page) != PageTree:
162             raise TypeError("Expected PageTree object")
163
164         try:
165             results = self.get_form_data(page, "saml-response",
166                                          ["name", "value"])
167             action_url = results[0]
168             if action_url is None:
169                 raise Exception
170             method = results[1]
171             names = results[2]
172             values = results[3]
173         except Exception:  # pylint: disable=broad-except
174             raise WrongPage("Not a Return Form Page")
175
176         referer = page.make_referer()
177         headers = {'referer': referer}
178
179         payload = {}
180         for i in range(0, len(names)):
181             payload[names[i]] = values[i]
182
183         return [method, self.new_url(referer, action_url),
184                 {'headers': headers, 'data': payload}]
185
186     def handle_openid_form(self, page):
187         if type(page) != PageTree:
188             raise TypeError("Expected PageTree object")
189
190         if not page.first_value('//title/text()') == \
191                 'OpenID transaction in progress':
192             raise WrongPage('Not OpenID autosubmit form')
193
194         try:
195             results = self.get_form_data(page, None,
196                                          ["name", "value"])
197             action_url = results[0]
198             if action_url is None:
199                 raise Exception
200             method = results[1]
201             names = results[2]
202             values = results[3]
203         except Exception:  # pylint: disable=broad-except
204             raise WrongPage("Not OpenID autosubmit form")
205
206         referer = page.make_referer()
207         headers = {'referer': referer}
208
209         payload = {}
210         for i in range(0, len(names)):
211             payload[names[i]] = values[i]
212
213         return [method, self.new_url(referer, action_url),
214                 {'headers': headers, 'data': payload}]
215
216     def handle_openid_consent_form(self, page):
217         if type(page) != PageTree:
218             raise TypeError("Expected PageTree object")
219
220         try:
221             results = self.get_form_data(page, "consent_form",
222                                          ['name', 'value'])
223             action_url = results[0]
224             if action_url is None:
225                 raise Exception
226             method = results[1]
227             names = results[2]
228             values = results[3]
229         except Exception:  # pylint: disable=broad-except
230             raise WrongPage("Not an OpenID Consent Form Page")
231
232         referer = page.make_referer()
233         headers = {'referer': referer}
234
235         payload = {}
236         for i in range(0, len(names)):
237             payload[names[i]] = values[i]
238
239         # Replace known values
240         payload['decided_allow'] = 'Allow'
241
242         return [method, self.new_url(referer, action_url),
243                 {'headers': headers, 'data': payload}]
244
245     def fetch_page(self, idp, target_url, follow_redirect=True):
246         url = target_url
247         action = 'get'
248         args = {}
249
250         while True:
251             r = self.access(action, url, **args)  # pylint: disable=star-args
252             if r.status_code == 303 or r.status_code == 302:
253                 if not follow_redirect:
254                     return PageTree(r)
255                 url = r.headers['location']
256                 action = 'get'
257                 args = {}
258             elif r.status_code == 200:
259                 page = PageTree(r)
260
261                 try:
262                     (action, url, args) = self.handle_login_form(idp, page)
263                     continue
264                 except WrongPage:
265                     pass
266
267                 try:
268                     (action, url, args) = self.handle_return_form(page)
269                     continue
270                 except WrongPage:
271                     pass
272
273                 try:
274                     (action, url, args) = self.handle_openid_consent_form(page)
275                     continue
276                 except WrongPage:
277                     pass
278
279                 try:
280                     (action, url, args) = self.handle_openid_form(page)
281                     continue
282                 except WrongPage:
283                     pass
284
285                 # Either we got what we wanted, or we have to stop anyway
286                 return page
287             else:
288                 raise ValueError("Unhandled status (%d) on url %s" % (
289                                  r.status_code, url))
290
291     def auth_to_idp(self, idp):
292
293         srv = self.servers[idp]
294         target_url = '%s/%s/' % (srv['baseuri'], idp)
295
296         r = self.access('get', target_url)
297         if r.status_code != 200:
298             raise ValueError("Access to idp failed: %s" % repr(r))
299
300         page = PageTree(r)
301         page.expected_value('//div[@id="content"]/p/a/text()', 'Log In')
302         href = page.first_value('//div[@id="content"]/p/a/@href')
303         url = self.new_url(target_url, href)
304
305         page = self.fetch_page(idp, url)
306         page.expected_value('//div[@id="welcome"]/p/text()',
307                             'Welcome %s!' % srv['user'])
308
309     def logout_from_idp(self, idp):
310
311         srv = self.servers[idp]
312         target_url = '%s/%s/logout' % (srv['baseuri'], idp)
313
314         r = self.access('get', target_url)
315         if r.status_code != 200:
316             raise ValueError("Logout from idp failed: %s" % repr(r))
317
318     def get_sp_metadata(self, idp, sp):
319         idpsrv = self.servers[idp]
320         idpuri = idpsrv['baseuri']
321
322         spuri = self.servers[sp]['baseuri']
323
324         return (idpuri, requests.get('%s/saml2/metadata' % spuri))
325
326     def add_sp_metadata(self, idp, sp, rest=False):
327         expected_status = 200
328         idpsrv = self.servers[idp]
329         (idpuri, m) = self.get_sp_metadata(idp, sp)
330         url = '%s/%s/admin/providers/saml2/admin/new' % (idpuri, idp)
331         headers = {'referer': url}
332         if rest:
333             expected_status = 201
334             payload = {'metadata': m.content}
335             headers['content-type'] = 'application/x-www-form-urlencoded'
336             url = '%s/%s/rest/providers/saml2/SPS/%s' % (idpuri, idp, sp)
337             r = idpsrv['session'].post(url, headers=headers,
338                                        data=urlencode(payload))
339         else:
340             metafile = {'metafile': m.content}
341             payload = {'name': sp}
342             r = idpsrv['session'].post(url, headers=headers,
343                                        data=payload, files=metafile)
344         if r.status_code != expected_status:
345             raise ValueError('Failed to post SP data [%s]' % repr(r))
346
347         if not rest:
348             page = PageTree(r)
349             page.expected_value('//div[@class="alert alert-success"]/p/text()',
350                                 'SP Successfully added')
351
352     def set_sp_default_nameids(self, idp, sp, nameids):
353         """
354         nameids is a list of Name ID formats to enable
355         """
356         idpsrv = self.servers[idp]
357         idpuri = idpsrv['baseuri']
358         url = '%s/%s/admin/providers/saml2/admin/sp/%s' % (idpuri, idp, sp)
359         headers = {'referer': url}
360         headers['content-type'] = 'application/x-www-form-urlencoded'
361         payload = {'submit': 'Submit',
362                    'allowed_nameids': ', '.join(nameids)}
363         r = idpsrv['session'].post(url, headers=headers,
364                                    data=payload)
365         if r.status_code != 200:
366             raise ValueError('Failed to post SP data [%s]' % repr(r))
367
368     # pylint: disable=dangerous-default-value
369     def set_attributes_and_mapping(self, idp, mapping=[], attrs=[],
370                                    spname=None):
371         """
372         Set allowed attributes and mapping in the IDP or the SP. In the
373         case of the SP both allowed attributes and the mapping need to
374         be provided. An empty option for either means delete all values.
375
376         mapping is a list of list of rules of the form:
377            [['from-1', 'to-1'], ['from-2', 'from-2']]
378
379         ex. [['*', '*'], ['fullname', 'namefull']]
380
381         attrs is the list of attributes that will be allowed:
382            ['fullname', 'givenname', 'surname']
383         """
384         idpsrv = self.servers[idp]
385         idpuri = idpsrv['baseuri']
386         if spname:  # per-SP setting
387             url = '%s/%s/admin/providers/saml2/admin/sp/%s' % (
388                 idpuri, idp, spname)
389             mapname = 'Attribute Mapping'
390             attrname = 'Allowed Attributes'
391         else:  # global default
392             url = '%s/%s/admin/providers/saml2' % (idpuri, idp)
393             mapname = 'default attribute mapping'
394             attrname = 'default allowed attributes'
395
396         headers = {'referer': url}
397         headers['content-type'] = 'application/x-www-form-urlencoded'
398         payload = {'submit': 'Submit'}
399         count = 0
400         for m in mapping:
401             payload['%s %s-from' % (mapname, count)] = m[0]
402             payload['%s %s-to' % (mapname, count)] = m[1]
403             count += 1
404         count = 0
405         for attr in attrs:
406             payload['%s %s-name' % (attrname, count)] = attr
407             count += 1
408         r = idpsrv['session'].post(url, headers=headers,
409                                    data=payload)
410         if r.status_code != 200:
411             raise ValueError('Failed to post IDP data [%s]' % repr(r))
412
413     def fetch_rest_page(self, idpname, uri):
414         """
415         idpname - the name of the IDP to fetch the page from
416         uri - the URI of the page to retrieve
417
418         The URL for the request is built from known-information in
419         the session.
420
421         returns dict if successful
422         returns ValueError if the output is unparseable
423         """
424         baseurl = self.servers[idpname].get('baseuri')
425         page = self.fetch_page(
426             idpname,
427             '%s%s' % (baseurl, uri)
428         )
429         return json.loads(page.text)
430
431     def get_rest_sp(self, idpname, spname=None):
432         if spname is None:
433             uri = '/%s/rest/providers/saml2/SPS/' % idpname
434         else:
435             uri = '/%s/rest/providers/saml2/SPS/%s' % (idpname, spname)
436
437         return self.fetch_rest_page(idpname, uri)