3 # Copyright (C) 2014 Simo Sorce <simo@redhat.com>
5 # see file 'COPYING' for use and warranty information
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 # GNU General Public License for more details.
17 # You should have received a copy of the GNU General Public License
18 # along with this program. If not, see <http://www.gnu.org/licenses/>.
27 class WrongPage(Exception):
31 class PageTree(object):
33 def __init__(self, result):
35 self.text = result.text
40 if self._tree is None:
41 self._tree = html.fromstring(self.text)
44 def first_value(self, rule):
45 result = self.tree.xpath(rule)
46 if type(result) is list:
53 def all_values(self, rule):
54 result = self.tree.xpath(rule)
55 if type(result) is list:
59 def make_referer(self):
60 return self.result.url
62 def expected_value(self, rule, expected):
63 value = self.first_value(rule)
65 raise ValueError("Expected [%s], got [%s]" % (expected, value))
68 class HttpSessions(object):
73 def add_server(self, name, baseuri, user=None, pwd=None):
74 new = {'baseuri': baseuri,
75 'session': requests.Session()}
80 self.servers[name] = new
82 def get_session(self, url):
83 for srv in self.servers:
85 if url.startswith(d['baseuri']):
88 raise ValueError("Unknown URL: %s" % url)
90 def get(self, url, **kwargs):
91 session = self.get_session(url)
92 return session.get(url, allow_redirects=False, **kwargs)
94 def post(self, url, **kwargs):
95 session = self.get_session(url)
96 return session.post(url, allow_redirects=False, **kwargs)
98 def access(self, action, url, **kwargs):
99 action = string.lower(action)
101 return self.get(url, **kwargs)
102 elif action == 'post':
103 return self.post(url, **kwargs)
105 raise ValueError("Unknown action type: [%s]" % action)
107 def new_url(self, referer, action):
108 if action.startswith('/'):
109 u = urlparse.urlparse(referer)
110 return '%s://%s%s' % (u.scheme, u.netloc, action)
113 def get_form_data(self, page, form_id, input_fields):
115 action = page.first_value('//form[@id="%s"]/@action' % form_id)
116 values.append(action)
117 method = page.first_value('//form[@id="%s"]/@method' % form_id)
118 values.append(method)
119 for field in input_fields:
120 value = page.all_values('//form[@id="%s"]/input/@%s' % (form_id,
125 def handle_login_form(self, idp, page):
126 if type(page) != PageTree:
127 raise TypeError("Expected PageTree object")
129 srv = self.servers[idp]
132 results = self.get_form_data(page, "login_form", ["name", "value"])
133 action_url = results[0]
137 if action_url is None:
139 except Exception: # pylint: disable=broad-except
140 raise WrongPage("Not a Login Form Page")
142 referer = page.make_referer()
143 headers = {'referer': referer}
145 for i in range(0, len(names)):
146 payload[names[i]] = values[i]
148 # replace known values
149 payload['login_name'] = srv['user']
150 payload['login_password'] = srv['pwd']
152 return [method, self.new_url(referer, action_url),
153 {'headers': headers, 'data': payload}]
155 def handle_return_form(self, page):
156 if type(page) != PageTree:
157 raise TypeError("Expected PageTree object")
160 results = self.get_form_data(page, "saml-response",
162 action_url = results[0]
163 if action_url is None:
168 except Exception: # pylint: disable=broad-except
169 raise WrongPage("Not a Return Form Page")
171 referer = page.make_referer()
172 headers = {'referer': referer}
175 for i in range(0, len(names)):
176 payload[names[i]] = values[i]
178 return [method, self.new_url(referer, action_url),
179 {'headers': headers, 'data': payload}]
181 def fetch_page(self, idp, target_url):
187 r = self.access(action, url, **args) # pylint: disable=star-args
188 if r.status_code == 303:
189 url = r.headers['location']
192 elif r.status_code == 200:
196 (action, url, args) = self.handle_login_form(idp, page)
202 (action, url, args) = self.handle_return_form(page)
207 # Either we got what we wanted, or we have to stop anyway
210 raise ValueError("Unhandled status (%d) on url %s" % (
213 def auth_to_idp(self, idp):
215 srv = self.servers[idp]
216 target_url = '%s/%s/' % (srv['baseuri'], idp)
218 r = self.access('get', target_url)
219 if r.status_code != 200:
220 raise ValueError("Access to idp failed: %s" % repr(r))
223 page.expected_value('//div[@id="content"]/p/a/text()', 'Log In')
224 href = page.first_value('//div[@id="content"]/p/a/@href')
225 url = self.new_url(target_url, href)
227 page = self.fetch_page(idp, url)
228 page.expected_value('//div[@id="welcome"]/p/text()',
229 'Welcome %s!' % srv['user'])
231 def get_sp_metadata(self, idp, sp):
232 idpsrv = self.servers[idp]
233 idpuri = idpsrv['baseuri']
235 spuri = self.servers[sp]['baseuri']
237 return (idpuri, requests.get('%s/saml2/metadata' % spuri))
239 def add_sp_metadata(self, idp, sp):
240 idpsrv = self.servers[idp]
241 (idpuri, m) = self.get_sp_metadata(idp, sp)
242 url = '%s/%s/admin/providers/saml2/admin/new' % (idpuri, idp)
243 metafile = {'metafile': m.content}
244 headers = {'referer': url}
245 payload = {'name': sp}
246 r = idpsrv['session'].post(url, headers=headers,
247 data=payload, files=metafile)
248 if r.status_code != 200:
249 raise ValueError('Failed to post SP data [%s]' % repr(r))
252 page.expected_value('//div[@class="alert alert-success"]/p/text()',
253 'SP Successfully added')