3 # Copyright (C) 2014 Simo Sorce <simo@redhat.com>
5 # see file 'COPYING' for use and warranty information
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 # GNU General Public License for more details.
17 # You should have received a copy of the GNU General Public License
18 # along with this program. If not, see <http://www.gnu.org/licenses/>.
26 from urllib import urlencode
29 class WrongPage(Exception):
33 class PageTree(object):
35 def __init__(self, result):
37 self.text = result.text
42 if self._tree is None:
43 self._tree = html.fromstring(self.text)
46 def first_value(self, rule):
47 result = self.tree.xpath(rule)
48 if type(result) is list:
55 def all_values(self, rule):
56 result = self.tree.xpath(rule)
57 if type(result) is list:
61 def make_referer(self):
62 return self.result.url
64 def expected_value(self, rule, expected):
65 value = self.first_value(rule)
67 raise ValueError("Expected [%s], got [%s]" % (expected, value))
70 class HttpSessions(object):
75 def add_server(self, name, baseuri, user=None, pwd=None):
76 new = {'baseuri': baseuri,
77 'session': requests.Session()}
82 self.servers[name] = new
84 def get_session(self, url):
85 for srv in self.servers:
87 if url.startswith(d['baseuri']):
90 raise ValueError("Unknown URL: %s" % url)
92 def get(self, url, **kwargs):
93 session = self.get_session(url)
94 return session.get(url, allow_redirects=False, **kwargs)
96 def post(self, url, **kwargs):
97 session = self.get_session(url)
98 return session.post(url, allow_redirects=False, **kwargs)
100 def access(self, action, url, **kwargs):
101 action = string.lower(action)
103 return self.get(url, **kwargs)
104 elif action == 'post':
105 return self.post(url, **kwargs)
107 raise ValueError("Unknown action type: [%s]" % action)
109 def new_url(self, referer, action):
110 if action.startswith('/'):
111 u = urlparse.urlparse(referer)
112 return '%s://%s%s' % (u.scheme, u.netloc, action)
115 def get_form_data(self, page, form_id, input_fields):
116 form_selector = '//form'
118 form_selector += '[@id="%s"]' % form_id
120 action = page.first_value('%s/@action' % form_selector)
121 values.append(action)
122 method = page.first_value('%s/@method' % form_selector)
123 values.append(method)
124 for field in input_fields:
125 value = page.all_values('%s/input/@%s' % (form_selector,
130 def handle_login_form(self, idp, page):
131 if type(page) != PageTree:
132 raise TypeError("Expected PageTree object")
134 srv = self.servers[idp]
137 results = self.get_form_data(page, "login_form", ["name", "value"])
138 action_url = results[0]
142 if action_url is None:
144 except Exception: # pylint: disable=broad-except
145 raise WrongPage("Not a Login Form Page")
147 referer = page.make_referer()
148 headers = {'referer': referer}
150 for i in range(0, len(names)):
151 payload[names[i]] = values[i]
153 # replace known values
154 payload['login_name'] = srv['user']
155 payload['login_password'] = srv['pwd']
157 return [method, self.new_url(referer, action_url),
158 {'headers': headers, 'data': payload}]
160 def handle_return_form(self, page):
161 if type(page) != PageTree:
162 raise TypeError("Expected PageTree object")
165 results = self.get_form_data(page, "saml-response",
167 action_url = results[0]
168 if action_url is None:
173 except Exception: # pylint: disable=broad-except
174 raise WrongPage("Not a Return Form Page")
176 referer = page.make_referer()
177 headers = {'referer': referer}
180 for i in range(0, len(names)):
181 payload[names[i]] = values[i]
183 return [method, self.new_url(referer, action_url),
184 {'headers': headers, 'data': payload}]
186 def handle_openid_form(self, page):
187 if type(page) != PageTree:
188 raise TypeError("Expected PageTree object")
190 if not page.first_value('//title/text()') == \
191 'OpenID transaction in progress':
192 raise WrongPage('Not OpenID autosubmit form')
195 results = self.get_form_data(page, None,
197 action_url = results[0]
198 if action_url is None:
203 except Exception: # pylint: disable=broad-except
204 raise WrongPage("Not OpenID autosubmit form")
206 referer = page.make_referer()
207 headers = {'referer': referer}
210 for i in range(0, len(names)):
211 payload[names[i]] = values[i]
213 return [method, self.new_url(referer, action_url),
214 {'headers': headers, 'data': payload}]
216 def handle_openid_consent_form(self, page):
217 if type(page) != PageTree:
218 raise TypeError("Expected PageTree object")
221 results = self.get_form_data(page, "consent_form",
223 action_url = results[0]
224 if action_url is None:
229 except Exception: # pylint: disable=broad-except
230 raise WrongPage("Not an OpenID Consent Form Page")
232 referer = page.make_referer()
233 headers = {'referer': referer}
236 for i in range(0, len(names)):
237 payload[names[i]] = values[i]
239 # Replace known values
240 payload['decided_allow'] = 'Allow'
242 return [method, self.new_url(referer, action_url),
243 {'headers': headers, 'data': payload}]
245 def fetch_page(self, idp, target_url, follow_redirect=True):
251 r = self.access(action, url, **args) # pylint: disable=star-args
252 if r.status_code == 303 or r.status_code == 302:
253 if not follow_redirect:
255 url = r.headers['location']
258 elif r.status_code == 200:
262 (action, url, args) = self.handle_login_form(idp, page)
268 (action, url, args) = self.handle_return_form(page)
274 (action, url, args) = self.handle_openid_consent_form(page)
280 (action, url, args) = self.handle_openid_form(page)
285 # Either we got what we wanted, or we have to stop anyway
288 raise ValueError("Unhandled status (%d) on url %s" % (
291 def auth_to_idp(self, idp):
293 srv = self.servers[idp]
294 target_url = '%s/%s/' % (srv['baseuri'], idp)
296 r = self.access('get', target_url)
297 if r.status_code != 200:
298 raise ValueError("Access to idp failed: %s" % repr(r))
301 page.expected_value('//div[@id="content"]/p/a/text()', 'Log In')
302 href = page.first_value('//div[@id="content"]/p/a/@href')
303 url = self.new_url(target_url, href)
305 page = self.fetch_page(idp, url)
306 page.expected_value('//div[@id="welcome"]/p/text()',
307 'Welcome %s!' % srv['user'])
309 def logout_from_idp(self, idp):
311 srv = self.servers[idp]
312 target_url = '%s/%s/logout' % (srv['baseuri'], idp)
314 r = self.access('get', target_url)
315 if r.status_code != 200:
316 raise ValueError("Logout from idp failed: %s" % repr(r))
318 def get_sp_metadata(self, idp, sp):
319 idpsrv = self.servers[idp]
320 idpuri = idpsrv['baseuri']
322 spuri = self.servers[sp]['baseuri']
324 return (idpuri, requests.get('%s/saml2/metadata' % spuri))
326 def add_sp_metadata(self, idp, sp, rest=False):
327 expected_status = 200
328 idpsrv = self.servers[idp]
329 (idpuri, m) = self.get_sp_metadata(idp, sp)
330 url = '%s/%s/admin/providers/saml2/admin/new' % (idpuri, idp)
331 headers = {'referer': url}
333 expected_status = 201
334 payload = {'metadata': m.content}
335 headers['content-type'] = 'application/x-www-form-urlencoded'
336 url = '%s/%s/rest/providers/saml2/SPS/%s' % (idpuri, idp, sp)
337 r = idpsrv['session'].post(url, headers=headers,
338 data=urlencode(payload))
340 metafile = {'metafile': m.content}
341 payload = {'name': sp}
342 r = idpsrv['session'].post(url, headers=headers,
343 data=payload, files=metafile)
344 if r.status_code != expected_status:
345 raise ValueError('Failed to post SP data [%s]' % repr(r))
349 page.expected_value('//div[@class="alert alert-success"]/p/text()',
350 'SP Successfully added')
352 def set_sp_default_nameids(self, idp, sp, nameids):
354 nameids is a list of Name ID formats to enable
356 idpsrv = self.servers[idp]
357 idpuri = idpsrv['baseuri']
358 url = '%s/%s/admin/providers/saml2/admin/sp/%s' % (idpuri, idp, sp)
359 headers = {'referer': url}
360 headers['content-type'] = 'application/x-www-form-urlencoded'
361 payload = {'submit': 'Submit',
362 'allowed_nameids': ', '.join(nameids)}
363 r = idpsrv['session'].post(url, headers=headers,
365 if r.status_code != 200:
366 raise ValueError('Failed to post SP data [%s]' % repr(r))
368 # pylint: disable=dangerous-default-value
369 def set_attributes_and_mapping(self, idp, mapping=[], attrs=[],
372 Set allowed attributes and mapping in the IDP or the SP. In the
373 case of the SP both allowed attributes and the mapping need to
374 be provided. An empty option for either means delete all values.
376 mapping is a list of list of rules of the form:
377 [['from-1', 'to-1'], ['from-2', 'from-2']]
379 ex. [['*', '*'], ['fullname', 'namefull']]
381 attrs is the list of attributes that will be allowed:
382 ['fullname', 'givenname', 'surname']
384 idpsrv = self.servers[idp]
385 idpuri = idpsrv['baseuri']
386 if spname: # per-SP setting
387 url = '%s/%s/admin/providers/saml2/admin/sp/%s' % (
389 mapname = 'Attribute Mapping'
390 attrname = 'Allowed Attributes'
391 else: # global default
392 url = '%s/%s/admin/providers/saml2' % (idpuri, idp)
393 mapname = 'default attribute mapping'
394 attrname = 'default allowed attributes'
396 headers = {'referer': url}
397 headers['content-type'] = 'application/x-www-form-urlencoded'
398 payload = {'submit': 'Submit'}
401 payload['%s %s-from' % (mapname, count)] = m[0]
402 payload['%s %s-to' % (mapname, count)] = m[1]
406 payload['%s %s-name' % (attrname, count)] = attr
408 r = idpsrv['session'].post(url, headers=headers,
410 if r.status_code != 200:
411 raise ValueError('Failed to post IDP data [%s]' % repr(r))
413 def fetch_rest_page(self, idpname, uri):
415 idpname - the name of the IDP to fetch the page from
416 uri - the URI of the page to retrieve
418 The URL for the request is built from known-information in
421 returns dict if successful
422 returns ValueError if the output is unparseable
424 baseurl = self.servers[idpname].get('baseuri')
425 page = self.fetch_page(
427 '%s%s' % (baseurl, uri)
429 return json.loads(page.text)
431 def get_rest_sp(self, idpname, spname=None):
433 uri = '/%s/rest/providers/saml2/SPS/' % idpname
435 uri = '/%s/rest/providers/saml2/SPS/%s' % (idpname, spname)
437 return self.fetch_rest_page(idpname, uri)