d3cc1449b3f21c4e6395d5e7f3334058c6520f25
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / provider.py
1 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
2 #
3 # see file 'COPYING' for use and warranty information
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 from ipsilon.providers.common import ProviderException
19 from ipsilon.util import config as pconfig
20 from ipsilon.util.config import ConfigHelper
21 from ipsilon.tools.saml2metadata import SAML2_NAMEID_MAP
22 from ipsilon.util.log import Log
23 import lasso
24 import re
25
26
27 VALID_IN_NAME = r'[^\ a-zA-Z0-9]'
28
29
30 class InvalidProviderId(ProviderException):
31
32     def __init__(self, code):
33         message = 'Invalid Provider ID: %s' % code
34         super(InvalidProviderId, self).__init__(message)
35         self._debug(message)
36
37
38 class NameIdNotAllowed(Exception):
39
40     def __init__(self, nid):
41         message = 'Name ID [%s] is not allowed' % nid
42         super(NameIdNotAllowed, self).__init__(message)
43         self.message = message
44
45     def __str__(self):
46         return repr(self.message)
47
48
49 class ServiceProviderConfig(ConfigHelper):
50     def __init__(self):
51         super(ServiceProviderConfig, self).__init__()
52
53
54 class ServiceProvider(ServiceProviderConfig):
55
56     def __init__(self, config, provider_id):
57         super(ServiceProvider, self).__init__()
58         self.cfg = config
59         data = self.cfg.get_data(name='id', value=provider_id)
60         if len(data) != 1:
61             raise InvalidProviderId('multiple matches')
62         idval = data.keys()[0]
63         data = self.cfg.get_data(idval=idval)
64         self._properties = data[idval]
65         self._staging = dict()
66         self.load_config()
67
68     def load_config(self):
69         self.new_config(
70             self.provider_id,
71             pconfig.String(
72                 'Name',
73                 'A nickname used to easily identify the Service Provider.'
74                 ' Only alphanumeric characters [A-Z,a-z,0-9] and spaces are'
75                 '  accepted.',
76                 self.name),
77             pconfig.Pick(
78                 'Default NameID',
79                 'Default NameID used by Service Providers.',
80                 SAML2_NAMEID_MAP.keys(),
81                 self.default_nameid),
82             pconfig.Choice(
83                 'Allowed NameIDs',
84                 'Allowed NameIDs for this Service Provider.',
85                 SAML2_NAMEID_MAP.keys(),
86                 self.allowed_nameids),
87             pconfig.String(
88                 'User Owner',
89                 'The user that owns this Service Provider',
90                 self.owner),
91             pconfig.MappingList(
92                 'Attribute Mapping',
93                 'Defines how to map attributes before returning them to'
94                 ' the SP. Setting this overrides the global values.',
95                 self.attribute_mappings),
96             pconfig.ComplexList(
97                 'Allowed Attributes',
98                 'Defines a list of allowed attributes, applied after mapping.'
99                 ' Setting this overrides the global values.',
100                 self.allowed_attributes),
101         )
102
103     @property
104     def provider_id(self):
105         return self._properties['id']
106
107     @property
108     def name(self):
109         return self._properties['name']
110
111     @name.setter
112     def name(self, value):
113         self._staging['name'] = value
114
115     @property
116     def owner(self):
117         if 'owner' in self._properties:
118             return self._properties['owner']
119         else:
120             return ''
121
122     @owner.setter
123     def owner(self, value):
124         self._staging['owner'] = value
125
126     @property
127     def allowed_nameids(self):
128         if 'allowed nameids' in self._properties:
129             allowed = self._properties['allowed nameids']
130             return [x.strip() for x in allowed.split(',')]
131         else:
132             return self.cfg.default_allowed_nameids
133
134     @allowed_nameids.setter
135     def allowed_nameids(self, value):
136         if type(value) is not list:
137             raise ValueError("Must be a list")
138         self._staging['allowed nameids'] = ','.join(value)
139
140     @property
141     def default_nameid(self):
142         if 'default nameid' in self._properties:
143             return self._properties['default nameid']
144         else:
145             return self.cfg.default_nameid
146
147     @default_nameid.setter
148     def default_nameid(self, value):
149         self._staging['default nameid'] = value
150
151     @property
152     def attribute_mappings(self):
153         if 'attribute mappings' in self._properties:
154             attr_map = pconfig.MappingList('temp', 'temp', None)
155             attr_map.import_value(str(self._properties['attribute mappings']))
156             return attr_map.get_value()
157         else:
158             return None
159
160     @attribute_mappings.setter
161     def attribute_mappings(self, attr_map):
162         if isinstance(attr_map, pconfig.MappingList):
163             value = attr_map.export_value()
164         else:
165             temp = pconfig.MappingList('temp', 'temp', None)
166             temp.set_value(attr_map)
167             value = temp.export_value()
168         self._staging['attribute mappings'] = value
169
170     @property
171     def allowed_attributes(self):
172         if 'allowed_attributes' in self._properties:
173             attr_map = pconfig.ComplexList('temp', 'temp', None)
174             attr_map.import_value(str(self._properties['allowed_attributes']))
175             return attr_map.get_value()
176         else:
177             return None
178
179     @allowed_attributes.setter
180     def allowed_attributes(self, attr_map):
181         if isinstance(attr_map, pconfig.ComplexList):
182             value = attr_map.export_value()
183         else:
184             temp = pconfig.ComplexList('temp', 'temp', None)
185             temp.set_value(attr_map)
186             value = temp.export_value()
187         self._staging['allowed_attributes'] = value
188
189     def save_properties(self):
190         data = self.cfg.get_data(name='id', value=self.provider_id)
191         if len(data) != 1:
192             raise InvalidProviderId('Could not find SP data')
193         idval = data.keys()[0]
194         data = dict()
195         data[idval] = self._staging
196         self.cfg.save_data(data)
197         data = self.cfg.get_data(idval=idval)
198         self._properties = data[idval]
199         self._staging = dict()
200
201     def refresh_config(self):
202         """
203         Create a new config object for displaying in the UI based on
204         the current set of properties.
205         """
206         del self._config
207         self.load_config()
208
209     def get_valid_nameid(self, nip):
210         self._debug('Requested NameId [%s]' % (nip.format,))
211         if nip.format is None:
212             return SAML2_NAMEID_MAP[self.default_nameid]
213         else:
214             allowed = self.allowed_nameids
215             self._debug('Allowed NameIds %s' % (repr(allowed)))
216             for nameid in allowed:
217                 if nip.format == SAML2_NAMEID_MAP[nameid]:
218                     return nip.format
219         raise NameIdNotAllowed(nip.format)
220
221     def permanently_delete(self):
222         data = self.cfg.get_data(name='id', value=self.provider_id)
223         if len(data) != 1:
224             raise InvalidProviderId('Could not find SP data')
225         idval = data.keys()[0]
226         self.cfg.del_datum(idval)
227
228     def normalize_username(self, username):
229         if 'strip domain' in self._properties:
230             return username.split('@', 1)[0]
231         return username
232
233     def is_valid_name(self, value):
234         if re.search(VALID_IN_NAME, value):
235             return False
236         return True
237
238     def is_valid_nameid(self, value):
239         if value in SAML2_NAMEID_MAP:
240             return True
241         return False
242
243     def valid_nameids(self):
244         return SAML2_NAMEID_MAP.keys()
245
246
247 class ServiceProviderCreator(object):
248
249     def __init__(self, config):
250         self.cfg = config
251
252     def create_from_buffer(self, name, metabuf):
253         '''Test and add data'''
254
255         if re.search(VALID_IN_NAME, name):
256             raise InvalidProviderId("Name must contain only "
257                                     "numbers and letters")
258
259         test = lasso.Server()
260         test.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP, metabuf)
261         newsps = test.get_providers()
262         if len(newsps) != 1:
263             raise InvalidProviderId("Metadata must contain one Provider")
264
265         spid = newsps.keys()[0]
266         data = self.cfg.get_data(name='id', value=spid)
267         if len(data) != 0:
268             raise InvalidProviderId("Provider Already Exists")
269         datum = {'id': spid, 'name': name, 'type': 'SP', 'metadata': metabuf}
270         self.cfg.new_datum(datum)
271
272         data = self.cfg.get_data(name='id', value=spid)
273         if len(data) != 1:
274             raise InvalidProviderId("Internal Error")
275         idval = data.keys()[0]
276         data = self.cfg.get_data(idval=idval)
277         sp = data[idval]
278         self.cfg.idp.add_provider(sp)
279
280         return ServiceProvider(self.cfg, spid)
281
282
283 class IdentityProvider(Log):
284     def __init__(self, config):
285         self.server = lasso.Server(config.idp_metadata_file,
286                                    config.idp_key_file,
287                                    None,
288                                    config.idp_certificate_file)
289         self.server.role = lasso.PROVIDER_ROLE_IDP
290
291     def add_provider(self, sp):
292         self.server.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP,
293                                           sp['metadata'])
294         self._debug('Added SP %s' % sp['name'])
295
296     def get_login_handler(self, dump=None):
297         if dump:
298             return lasso.Login.newFromDump(self.server, dump)
299         else:
300             return lasso.Login(self.server)
301
302     def get_providers(self):
303         return self.server.get_providers()
304
305     def get_logout_handler(self, dump=None):
306         if dump:
307             return lasso.Logout.newFromDump(self.server, dump)
308         else:
309             return lasso.Logout(self.server)