Add test for per-SP allowed and mapping attributes
[cascardo/ipsilon.git] / tests / helpers / http.py
1 #!/usr/bin/python
2 #
3 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
4 #
5 # see file 'COPYING' for use and warranty information
6 #
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20
21 from lxml import html
22 import requests
23 import string
24 import urlparse
25 import json
26 from urllib import urlencode
27
28
29 class WrongPage(Exception):
30     pass
31
32
33 class PageTree(object):
34
35     def __init__(self, result):
36         self.result = result
37         self.text = result.text
38         self._tree = None
39
40     @property
41     def tree(self):
42         if self._tree is None:
43             self._tree = html.fromstring(self.text)
44         return self._tree
45
46     def first_value(self, rule):
47         result = self.tree.xpath(rule)
48         if type(result) is list:
49             if len(result) > 0:
50                 result = result[0]
51             else:
52                 result = None
53         return result
54
55     def all_values(self, rule):
56         result = self.tree.xpath(rule)
57         if type(result) is list:
58             return result
59         return [result]
60
61     def make_referer(self):
62         return self.result.url
63
64     def expected_value(self, rule, expected):
65         value = self.first_value(rule)
66         if value != expected:
67             raise ValueError("Expected [%s], got [%s]" % (expected, value))
68
69
70 class HttpSessions(object):
71
72     def __init__(self):
73         self.servers = dict()
74
75     def add_server(self, name, baseuri, user=None, pwd=None):
76         new = {'baseuri': baseuri,
77                'session': requests.Session()}
78         if user:
79             new['user'] = user
80         if pwd:
81             new['pwd'] = pwd
82         self.servers[name] = new
83
84     def get_session(self, url):
85         for srv in self.servers:
86             d = self.servers[srv]
87             if url.startswith(d['baseuri']):
88                 return d['session']
89
90         raise ValueError("Unknown URL: %s" % url)
91
92     def get(self, url, **kwargs):
93         session = self.get_session(url)
94         return session.get(url, allow_redirects=False, **kwargs)
95
96     def post(self, url, **kwargs):
97         session = self.get_session(url)
98         return session.post(url, allow_redirects=False, **kwargs)
99
100     def access(self, action, url, **kwargs):
101         action = string.lower(action)
102         if action == 'get':
103             return self.get(url, **kwargs)
104         elif action == 'post':
105             return self.post(url, **kwargs)
106         else:
107             raise ValueError("Unknown action type: [%s]" % action)
108
109     def new_url(self, referer, action):
110         if action.startswith('/'):
111             u = urlparse.urlparse(referer)
112             return '%s://%s%s' % (u.scheme, u.netloc, action)
113         return action
114
115     def get_form_data(self, page, form_id, input_fields):
116         values = []
117         action = page.first_value('//form[@id="%s"]/@action' % form_id)
118         values.append(action)
119         method = page.first_value('//form[@id="%s"]/@method' % form_id)
120         values.append(method)
121         for field in input_fields:
122             value = page.all_values('//form[@id="%s"]/input/@%s' % (form_id,
123                                                                     field))
124             values.append(value)
125         return values
126
127     def handle_login_form(self, idp, page):
128         if type(page) != PageTree:
129             raise TypeError("Expected PageTree object")
130
131         srv = self.servers[idp]
132
133         try:
134             results = self.get_form_data(page, "login_form", ["name", "value"])
135             action_url = results[0]
136             method = results[1]
137             names = results[2]
138             values = results[3]
139             if action_url is None:
140                 raise Exception
141         except Exception:  # pylint: disable=broad-except
142             raise WrongPage("Not a Login Form Page")
143
144         referer = page.make_referer()
145         headers = {'referer': referer}
146         payload = {}
147         for i in range(0, len(names)):
148             payload[names[i]] = values[i]
149
150         # replace known values
151         payload['login_name'] = srv['user']
152         payload['login_password'] = srv['pwd']
153
154         return [method, self.new_url(referer, action_url),
155                 {'headers': headers, 'data': payload}]
156
157     def handle_return_form(self, page):
158         if type(page) != PageTree:
159             raise TypeError("Expected PageTree object")
160
161         try:
162             results = self.get_form_data(page, "saml-response",
163                                          ["name", "value"])
164             action_url = results[0]
165             if action_url is None:
166                 raise Exception
167             method = results[1]
168             names = results[2]
169             values = results[3]
170         except Exception:  # pylint: disable=broad-except
171             raise WrongPage("Not a Return Form Page")
172
173         referer = page.make_referer()
174         headers = {'referer': referer}
175
176         payload = {}
177         for i in range(0, len(names)):
178             payload[names[i]] = values[i]
179
180         return [method, self.new_url(referer, action_url),
181                 {'headers': headers, 'data': payload}]
182
183     def fetch_page(self, idp, target_url, follow_redirect=True):
184         url = target_url
185         action = 'get'
186         args = {}
187
188         while True:
189             r = self.access(action, url, **args)  # pylint: disable=star-args
190             if r.status_code == 303:
191                 if not follow_redirect:
192                     return PageTree(r)
193                 url = r.headers['location']
194                 action = 'get'
195                 args = {}
196             elif r.status_code == 200:
197                 page = PageTree(r)
198
199                 try:
200                     (action, url, args) = self.handle_login_form(idp, page)
201                     continue
202                 except WrongPage:
203                     pass
204
205                 try:
206                     (action, url, args) = self.handle_return_form(page)
207                     continue
208                 except WrongPage:
209                     pass
210
211                 # Either we got what we wanted, or we have to stop anyway
212                 return page
213             else:
214                 raise ValueError("Unhandled status (%d) on url %s" % (
215                                  r.status_code, url))
216
217     def auth_to_idp(self, idp):
218
219         srv = self.servers[idp]
220         target_url = '%s/%s/' % (srv['baseuri'], idp)
221
222         r = self.access('get', target_url)
223         if r.status_code != 200:
224             raise ValueError("Access to idp failed: %s" % repr(r))
225
226         page = PageTree(r)
227         page.expected_value('//div[@id="content"]/p/a/text()', 'Log In')
228         href = page.first_value('//div[@id="content"]/p/a/@href')
229         url = self.new_url(target_url, href)
230
231         page = self.fetch_page(idp, url)
232         page.expected_value('//div[@id="welcome"]/p/text()',
233                             'Welcome %s!' % srv['user'])
234
235     def get_sp_metadata(self, idp, sp):
236         idpsrv = self.servers[idp]
237         idpuri = idpsrv['baseuri']
238
239         spuri = self.servers[sp]['baseuri']
240
241         return (idpuri, requests.get('%s/saml2/metadata' % spuri))
242
243     def add_sp_metadata(self, idp, sp, rest=False):
244         expected_status = 200
245         idpsrv = self.servers[idp]
246         (idpuri, m) = self.get_sp_metadata(idp, sp)
247         url = '%s/%s/admin/providers/saml2/admin/new' % (idpuri, idp)
248         headers = {'referer': url}
249         if rest:
250             expected_status = 201
251             payload = {'metadata': m.content}
252             headers['content-type'] = 'application/x-www-form-urlencoded'
253             url = '%s/%s/rest/providers/saml2/SPS/%s' % (idpuri, idp, sp)
254             r = idpsrv['session'].post(url, headers=headers,
255                                        data=urlencode(payload))
256         else:
257             metafile = {'metafile': m.content}
258             payload = {'name': sp}
259             r = idpsrv['session'].post(url, headers=headers,
260                                        data=payload, files=metafile)
261         if r.status_code != expected_status:
262             raise ValueError('Failed to post SP data [%s]' % repr(r))
263
264         if not rest:
265             page = PageTree(r)
266             page.expected_value('//div[@class="alert alert-success"]/p/text()',
267                                 'SP Successfully added')
268
269     def set_sp_default_nameids(self, idp, sp, nameids):
270         """
271         nameids is a list of Name ID formats to enable
272         """
273         idpsrv = self.servers[idp]
274         idpuri = idpsrv['baseuri']
275         url = '%s/%s/admin/providers/saml2/admin/sp/%s' % (idpuri, idp, sp)
276         headers = {'referer': url}
277         headers['content-type'] = 'application/x-www-form-urlencoded'
278         payload = {'submit': 'Submit',
279                    'allowed_nameids': ', '.join(nameids)}
280         r = idpsrv['session'].post(url, headers=headers,
281                                    data=payload)
282         if r.status_code != 200:
283             raise ValueError('Failed to post SP data [%s]' % repr(r))
284
285     # pylint: disable=dangerous-default-value
286     def set_attributes_and_mapping(self, idp, mapping=[], attrs=[],
287                                    spname=None):
288         """
289         Set allowed attributes and mapping in the IDP or the SP. In the
290         case of the SP both allowed attributes and the mapping need to
291         be provided. An empty option for either means delete all values.
292
293         mapping is a list of list of rules of the form:
294            [['from-1', 'to-1'], ['from-2', 'from-2']]
295
296         ex. [['*', '*'], ['fullname', 'namefull']]
297
298         attrs is the list of attributes that will be allowed:
299            ['fullname', 'givenname', 'surname']
300         """
301         idpsrv = self.servers[idp]
302         idpuri = idpsrv['baseuri']
303         if spname:  # per-SP setting
304             url = '%s/%s/admin/providers/saml2/admin/sp/%s' % (
305                 idpuri, idp, spname)
306             mapname = 'Attribute Mapping'
307             attrname = 'Allowed Attributes'
308         else:  # global default
309             url = '%s/%s/admin/providers/saml2' % (idpuri, idp)
310             mapname = 'default attribute mapping'
311             attrname = 'default allowed attributes'
312
313         headers = {'referer': url}
314         headers['content-type'] = 'application/x-www-form-urlencoded'
315         payload = {'submit': 'Submit'}
316         count = 0
317         for m in mapping:
318             payload['%s %s-from' % (mapname, count)] = m[0]
319             payload['%s %s-to' % (mapname, count)] = m[1]
320             count += 1
321         count = 0
322         for attr in attrs:
323             payload['%s %s-name' % (attrname, count)] = attr
324             count += 1
325         r = idpsrv['session'].post(url, headers=headers,
326                                    data=payload)
327         if r.status_code != 200:
328             raise ValueError('Failed to post IDP data [%s]' % repr(r))
329
330     def fetch_rest_page(self, idpname, uri):
331         """
332         idpname - the name of the IDP to fetch the page from
333         uri - the URI of the page to retrieve
334
335         The URL for the request is built from known-information in
336         the session.
337
338         returns dict if successful
339         returns ValueError if the output is unparseable
340         """
341         baseurl = self.servers[idpname].get('baseuri')
342         page = self.fetch_page(
343             idpname,
344             '%s%s' % (baseurl, uri)
345         )
346         return json.loads(page.text)
347
348     def get_rest_sp(self, idpname, spname=None):
349         if spname is None:
350             uri = '/%s/rest/providers/saml2/SPS/' % idpname
351         else:
352             uri = '/%s/rest/providers/saml2/SPS/%s' % (idpname, spname)
353
354         return self.fetch_rest_page(idpname, uri)