2478e2a590093719953067ee040e9c2d3bff3f8c
[cascardo/ipsilon.git] / tests / helpers / http.py
1 #!/usr/bin/python
2 #
3 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
4 #
5 # see file 'COPYING' for use and warranty information
6 #
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20
21 from lxml import html
22 import requests
23 import string
24 import urlparse
25 import json
26 from urllib import urlencode
27
28
29 class WrongPage(Exception):
30     pass
31
32
33 class PageTree(object):
34
35     def __init__(self, result):
36         self.result = result
37         self.text = result.text
38         self._tree = None
39
40     @property
41     def tree(self):
42         if self._tree is None:
43             self._tree = html.fromstring(self.text)
44         return self._tree
45
46     def first_value(self, rule):
47         result = self.tree.xpath(rule)
48         if type(result) is list:
49             if len(result) > 0:
50                 result = result[0]
51             else:
52                 result = None
53         return result
54
55     def all_values(self, rule):
56         result = self.tree.xpath(rule)
57         if type(result) is list:
58             return result
59         return [result]
60
61     def make_referer(self):
62         return self.result.url
63
64     def expected_value(self, rule, expected):
65         value = self.first_value(rule)
66         if value != expected:
67             raise ValueError("Expected [%s], got [%s]" % (expected, value))
68
69
70 class HttpSessions(object):
71
72     def __init__(self):
73         self.servers = dict()
74
75     def add_server(self, name, baseuri, user=None, pwd=None):
76         new = {'baseuri': baseuri,
77                'session': requests.Session()}
78         if user:
79             new['user'] = user
80         if pwd:
81             new['pwd'] = pwd
82         self.servers[name] = new
83
84     def get_session(self, url):
85         for srv in self.servers:
86             d = self.servers[srv]
87             if url.startswith(d['baseuri']):
88                 return d['session']
89
90         raise ValueError("Unknown URL: %s" % url)
91
92     def get(self, url, **kwargs):
93         session = self.get_session(url)
94         return session.get(url, allow_redirects=False, **kwargs)
95
96     def post(self, url, **kwargs):
97         session = self.get_session(url)
98         return session.post(url, allow_redirects=False, **kwargs)
99
100     def access(self, action, url, **kwargs):
101         action = string.lower(action)
102         if action == 'get':
103             return self.get(url, **kwargs)
104         elif action == 'post':
105             return self.post(url, **kwargs)
106         else:
107             raise ValueError("Unknown action type: [%s]" % action)
108
109     def new_url(self, referer, action):
110         if action.startswith('/'):
111             u = urlparse.urlparse(referer)
112             return '%s://%s%s' % (u.scheme, u.netloc, action)
113         return action
114
115     def get_form_data(self, page, form_id, input_fields):
116         values = []
117         action = page.first_value('//form[@id="%s"]/@action' % form_id)
118         values.append(action)
119         method = page.first_value('//form[@id="%s"]/@method' % form_id)
120         values.append(method)
121         for field in input_fields:
122             value = page.all_values('//form[@id="%s"]/input/@%s' % (form_id,
123                                                                     field))
124             values.append(value)
125         return values
126
127     def handle_login_form(self, idp, page):
128         if type(page) != PageTree:
129             raise TypeError("Expected PageTree object")
130
131         srv = self.servers[idp]
132
133         try:
134             results = self.get_form_data(page, "login_form", ["name", "value"])
135             action_url = results[0]
136             method = results[1]
137             names = results[2]
138             values = results[3]
139             if action_url is None:
140                 raise Exception
141         except Exception:  # pylint: disable=broad-except
142             raise WrongPage("Not a Login Form Page")
143
144         referer = page.make_referer()
145         headers = {'referer': referer}
146         payload = {}
147         for i in range(0, len(names)):
148             payload[names[i]] = values[i]
149
150         # replace known values
151         payload['login_name'] = srv['user']
152         payload['login_password'] = srv['pwd']
153
154         return [method, self.new_url(referer, action_url),
155                 {'headers': headers, 'data': payload}]
156
157     def handle_return_form(self, page):
158         if type(page) != PageTree:
159             raise TypeError("Expected PageTree object")
160
161         try:
162             results = self.get_form_data(page, "saml-response",
163                                          ["name", "value"])
164             action_url = results[0]
165             if action_url is None:
166                 raise Exception
167             method = results[1]
168             names = results[2]
169             values = results[3]
170         except Exception:  # pylint: disable=broad-except
171             raise WrongPage("Not a Return Form Page")
172
173         referer = page.make_referer()
174         headers = {'referer': referer}
175
176         payload = {}
177         for i in range(0, len(names)):
178             payload[names[i]] = values[i]
179
180         return [method, self.new_url(referer, action_url),
181                 {'headers': headers, 'data': payload}]
182
183     def fetch_page(self, idp, target_url, follow_redirect=True):
184         url = target_url
185         action = 'get'
186         args = {}
187
188         while True:
189             r = self.access(action, url, **args)  # pylint: disable=star-args
190             if r.status_code == 303:
191                 if not follow_redirect:
192                     return PageTree(r)
193                 url = r.headers['location']
194                 action = 'get'
195                 args = {}
196             elif r.status_code == 200:
197                 page = PageTree(r)
198
199                 try:
200                     (action, url, args) = self.handle_login_form(idp, page)
201                     continue
202                 except WrongPage:
203                     pass
204
205                 try:
206                     (action, url, args) = self.handle_return_form(page)
207                     continue
208                 except WrongPage:
209                     pass
210
211                 # Either we got what we wanted, or we have to stop anyway
212                 return page
213             else:
214                 raise ValueError("Unhandled status (%d) on url %s" % (
215                                  r.status_code, url))
216
217     def auth_to_idp(self, idp):
218
219         srv = self.servers[idp]
220         target_url = '%s/%s/' % (srv['baseuri'], idp)
221
222         r = self.access('get', target_url)
223         if r.status_code != 200:
224             raise ValueError("Access to idp failed: %s" % repr(r))
225
226         page = PageTree(r)
227         page.expected_value('//div[@id="content"]/p/a/text()', 'Log In')
228         href = page.first_value('//div[@id="content"]/p/a/@href')
229         url = self.new_url(target_url, href)
230
231         page = self.fetch_page(idp, url)
232         page.expected_value('//div[@id="welcome"]/p/text()',
233                             'Welcome %s!' % srv['user'])
234
235     def logout_from_idp(self, idp):
236
237         srv = self.servers[idp]
238         target_url = '%s/%s/logout' % (srv['baseuri'], idp)
239
240         r = self.access('get', target_url)
241         if r.status_code != 200:
242             raise ValueError("Logout from idp failed: %s" % repr(r))
243
244     def get_sp_metadata(self, idp, sp):
245         idpsrv = self.servers[idp]
246         idpuri = idpsrv['baseuri']
247
248         spuri = self.servers[sp]['baseuri']
249
250         return (idpuri, requests.get('%s/saml2/metadata' % spuri))
251
252     def add_sp_metadata(self, idp, sp, rest=False):
253         expected_status = 200
254         idpsrv = self.servers[idp]
255         (idpuri, m) = self.get_sp_metadata(idp, sp)
256         url = '%s/%s/admin/providers/saml2/admin/new' % (idpuri, idp)
257         headers = {'referer': url}
258         if rest:
259             expected_status = 201
260             payload = {'metadata': m.content}
261             headers['content-type'] = 'application/x-www-form-urlencoded'
262             url = '%s/%s/rest/providers/saml2/SPS/%s' % (idpuri, idp, sp)
263             r = idpsrv['session'].post(url, headers=headers,
264                                        data=urlencode(payload))
265         else:
266             metafile = {'metafile': m.content}
267             payload = {'name': sp}
268             r = idpsrv['session'].post(url, headers=headers,
269                                        data=payload, files=metafile)
270         if r.status_code != expected_status:
271             raise ValueError('Failed to post SP data [%s]' % repr(r))
272
273         if not rest:
274             page = PageTree(r)
275             page.expected_value('//div[@class="alert alert-success"]/p/text()',
276                                 'SP Successfully added')
277
278     def set_sp_default_nameids(self, idp, sp, nameids):
279         """
280         nameids is a list of Name ID formats to enable
281         """
282         idpsrv = self.servers[idp]
283         idpuri = idpsrv['baseuri']
284         url = '%s/%s/admin/providers/saml2/admin/sp/%s' % (idpuri, idp, sp)
285         headers = {'referer': url}
286         headers['content-type'] = 'application/x-www-form-urlencoded'
287         payload = {'submit': 'Submit',
288                    'allowed_nameids': ', '.join(nameids)}
289         r = idpsrv['session'].post(url, headers=headers,
290                                    data=payload)
291         if r.status_code != 200:
292             raise ValueError('Failed to post SP data [%s]' % repr(r))
293
294     # pylint: disable=dangerous-default-value
295     def set_attributes_and_mapping(self, idp, mapping=[], attrs=[],
296                                    spname=None):
297         """
298         Set allowed attributes and mapping in the IDP or the SP. In the
299         case of the SP both allowed attributes and the mapping need to
300         be provided. An empty option for either means delete all values.
301
302         mapping is a list of list of rules of the form:
303            [['from-1', 'to-1'], ['from-2', 'from-2']]
304
305         ex. [['*', '*'], ['fullname', 'namefull']]
306
307         attrs is the list of attributes that will be allowed:
308            ['fullname', 'givenname', 'surname']
309         """
310         idpsrv = self.servers[idp]
311         idpuri = idpsrv['baseuri']
312         if spname:  # per-SP setting
313             url = '%s/%s/admin/providers/saml2/admin/sp/%s' % (
314                 idpuri, idp, spname)
315             mapname = 'Attribute Mapping'
316             attrname = 'Allowed Attributes'
317         else:  # global default
318             url = '%s/%s/admin/providers/saml2' % (idpuri, idp)
319             mapname = 'default attribute mapping'
320             attrname = 'default allowed attributes'
321
322         headers = {'referer': url}
323         headers['content-type'] = 'application/x-www-form-urlencoded'
324         payload = {'submit': 'Submit'}
325         count = 0
326         for m in mapping:
327             payload['%s %s-from' % (mapname, count)] = m[0]
328             payload['%s %s-to' % (mapname, count)] = m[1]
329             count += 1
330         count = 0
331         for attr in attrs:
332             payload['%s %s-name' % (attrname, count)] = attr
333             count += 1
334         r = idpsrv['session'].post(url, headers=headers,
335                                    data=payload)
336         if r.status_code != 200:
337             raise ValueError('Failed to post IDP data [%s]' % repr(r))
338
339     def fetch_rest_page(self, idpname, uri):
340         """
341         idpname - the name of the IDP to fetch the page from
342         uri - the URI of the page to retrieve
343
344         The URL for the request is built from known-information in
345         the session.
346
347         returns dict if successful
348         returns ValueError if the output is unparseable
349         """
350         baseurl = self.servers[idpname].get('baseuri')
351         page = self.fetch_page(
352             idpname,
353             '%s%s' % (baseurl, uri)
354         )
355         return json.loads(page.text)
356
357     def get_rest_sp(self, idpname, spname=None):
358         if spname is None:
359             uri = '/%s/rest/providers/saml2/SPS/' % idpname
360         else:
361             uri = '/%s/rest/providers/saml2/SPS/%s' % (idpname, spname)
362
363         return self.fetch_rest_page(idpname, uri)