Add test for REST Service Provider GET and POST
[cascardo/ipsilon.git] / tests / helpers / http.py
1 #!/usr/bin/python
2 #
3 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
4 #
5 # see file 'COPYING' for use and warranty information
6 #
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20
21 from lxml import html
22 import requests
23 import string
24 import urlparse
25 import json
26 from urllib import urlencode
27
28
29 class WrongPage(Exception):
30     pass
31
32
33 class PageTree(object):
34
35     def __init__(self, result):
36         self.result = result
37         self.text = result.text
38         self._tree = None
39
40     @property
41     def tree(self):
42         if self._tree is None:
43             self._tree = html.fromstring(self.text)
44         return self._tree
45
46     def first_value(self, rule):
47         result = self.tree.xpath(rule)
48         if type(result) is list:
49             if len(result) > 0:
50                 result = result[0]
51             else:
52                 result = None
53         return result
54
55     def all_values(self, rule):
56         result = self.tree.xpath(rule)
57         if type(result) is list:
58             return result
59         return [result]
60
61     def make_referer(self):
62         return self.result.url
63
64     def expected_value(self, rule, expected):
65         value = self.first_value(rule)
66         if value != expected:
67             raise ValueError("Expected [%s], got [%s]" % (expected, value))
68
69
70 class HttpSessions(object):
71
72     def __init__(self):
73         self.servers = dict()
74
75     def add_server(self, name, baseuri, user=None, pwd=None):
76         new = {'baseuri': baseuri,
77                'session': requests.Session()}
78         if user:
79             new['user'] = user
80         if pwd:
81             new['pwd'] = pwd
82         self.servers[name] = new
83
84     def get_session(self, url):
85         for srv in self.servers:
86             d = self.servers[srv]
87             if url.startswith(d['baseuri']):
88                 return d['session']
89
90         raise ValueError("Unknown URL: %s" % url)
91
92     def get(self, url, **kwargs):
93         session = self.get_session(url)
94         return session.get(url, allow_redirects=False, **kwargs)
95
96     def post(self, url, **kwargs):
97         session = self.get_session(url)
98         return session.post(url, allow_redirects=False, **kwargs)
99
100     def access(self, action, url, **kwargs):
101         action = string.lower(action)
102         if action == 'get':
103             return self.get(url, **kwargs)
104         elif action == 'post':
105             return self.post(url, **kwargs)
106         else:
107             raise ValueError("Unknown action type: [%s]" % action)
108
109     def new_url(self, referer, action):
110         if action.startswith('/'):
111             u = urlparse.urlparse(referer)
112             return '%s://%s%s' % (u.scheme, u.netloc, action)
113         return action
114
115     def get_form_data(self, page, form_id, input_fields):
116         values = []
117         action = page.first_value('//form[@id="%s"]/@action' % form_id)
118         values.append(action)
119         method = page.first_value('//form[@id="%s"]/@method' % form_id)
120         values.append(method)
121         for field in input_fields:
122             value = page.all_values('//form[@id="%s"]/input/@%s' % (form_id,
123                                                                     field))
124             values.append(value)
125         return values
126
127     def handle_login_form(self, idp, page):
128         if type(page) != PageTree:
129             raise TypeError("Expected PageTree object")
130
131         srv = self.servers[idp]
132
133         try:
134             results = self.get_form_data(page, "login_form", ["name", "value"])
135             action_url = results[0]
136             method = results[1]
137             names = results[2]
138             values = results[3]
139             if action_url is None:
140                 raise Exception
141         except Exception:  # pylint: disable=broad-except
142             raise WrongPage("Not a Login Form Page")
143
144         referer = page.make_referer()
145         headers = {'referer': referer}
146         payload = {}
147         for i in range(0, len(names)):
148             payload[names[i]] = values[i]
149
150         # replace known values
151         payload['login_name'] = srv['user']
152         payload['login_password'] = srv['pwd']
153
154         return [method, self.new_url(referer, action_url),
155                 {'headers': headers, 'data': payload}]
156
157     def handle_return_form(self, page):
158         if type(page) != PageTree:
159             raise TypeError("Expected PageTree object")
160
161         try:
162             results = self.get_form_data(page, "saml-response",
163                                          ["name", "value"])
164             action_url = results[0]
165             if action_url is None:
166                 raise Exception
167             method = results[1]
168             names = results[2]
169             values = results[3]
170         except Exception:  # pylint: disable=broad-except
171             raise WrongPage("Not a Return Form Page")
172
173         referer = page.make_referer()
174         headers = {'referer': referer}
175
176         payload = {}
177         for i in range(0, len(names)):
178             payload[names[i]] = values[i]
179
180         return [method, self.new_url(referer, action_url),
181                 {'headers': headers, 'data': payload}]
182
183     def fetch_page(self, idp, target_url):
184         url = target_url
185         action = 'get'
186         args = {}
187
188         while True:
189             r = self.access(action, url, **args)  # pylint: disable=star-args
190             if r.status_code == 303:
191                 url = r.headers['location']
192                 action = 'get'
193                 args = {}
194             elif r.status_code == 200:
195                 page = PageTree(r)
196
197                 try:
198                     (action, url, args) = self.handle_login_form(idp, page)
199                     continue
200                 except WrongPage:
201                     pass
202
203                 try:
204                     (action, url, args) = self.handle_return_form(page)
205                     continue
206                 except WrongPage:
207                     pass
208
209                 # Either we got what we wanted, or we have to stop anyway
210                 return page
211             else:
212                 raise ValueError("Unhandled status (%d) on url %s" % (
213                                  r.status_code, url))
214
215     def auth_to_idp(self, idp):
216
217         srv = self.servers[idp]
218         target_url = '%s/%s/' % (srv['baseuri'], idp)
219
220         r = self.access('get', target_url)
221         if r.status_code != 200:
222             raise ValueError("Access to idp failed: %s" % repr(r))
223
224         page = PageTree(r)
225         page.expected_value('//div[@id="content"]/p/a/text()', 'Log In')
226         href = page.first_value('//div[@id="content"]/p/a/@href')
227         url = self.new_url(target_url, href)
228
229         page = self.fetch_page(idp, url)
230         page.expected_value('//div[@id="welcome"]/p/text()',
231                             'Welcome %s!' % srv['user'])
232
233     def get_sp_metadata(self, idp, sp):
234         idpsrv = self.servers[idp]
235         idpuri = idpsrv['baseuri']
236
237         spuri = self.servers[sp]['baseuri']
238
239         return (idpuri, requests.get('%s/saml2/metadata' % spuri))
240
241     def add_sp_metadata(self, idp, sp, rest=False):
242         expected_status = 200
243         idpsrv = self.servers[idp]
244         (idpuri, m) = self.get_sp_metadata(idp, sp)
245         url = '%s/%s/admin/providers/saml2/admin/new' % (idpuri, idp)
246         headers = {'referer': url}
247         if rest:
248             expected_status = 201
249             payload = {'metadata': m.content}
250             headers['content-type'] = 'application/x-www-form-urlencoded'
251             url = '%s/%s/rest/providers/saml2/SPS/%s' % (idpuri, idp, sp)
252             r = idpsrv['session'].post(url, headers=headers,
253                                        data=urlencode(payload))
254         else:
255             metafile = {'metafile': m.content}
256             payload = {'name': sp}
257             r = idpsrv['session'].post(url, headers=headers,
258                                        data=payload, files=metafile)
259         if r.status_code != expected_status:
260             raise ValueError('Failed to post SP data [%s]' % repr(r))
261
262         if not rest:
263             page = PageTree(r)
264             page.expected_value('//div[@class="alert alert-success"]/p/text()',
265                                 'SP Successfully added')
266
267     def fetch_rest_page(self, idpname, uri):
268         """
269         idpname - the name of the IDP to fetch the page from
270         uri - the URI of the page to retrieve
271
272         The URL for the request is built from known-information in
273         the session.
274
275         returns dict if successful
276         returns ValueError if the output is unparseable
277         """
278         baseurl = self.servers[idpname].get('baseuri')
279         page = self.fetch_page(
280             idpname,
281             '%s%s' % (baseurl, uri)
282         )
283         return json.loads(page.text)
284
285     def get_rest_sp(self, idpname, spname=None):
286         if spname is None:
287             uri = '/%s/rest/providers/saml2/SPS/' % idpname
288         else:
289             uri = '/%s/rest/providers/saml2/SPS/%s' % (idpname, spname)
290
291         return self.fetch_rest_page(idpname, uri)