pam: use a pam object method instead of pam module function
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / provider.py
1 # Copyright (C) 2014 Ipsilon project Contributors, for license see COPYING
2
3 from ipsilon.providers.common import ProviderException
4 from ipsilon.util import config as pconfig
5 from ipsilon.util.config import ConfigHelper
6 from ipsilon.tools.saml2metadata import SAML2_NAMEID_MAP, NSMAP
7 from ipsilon.util.log import Log
8 from lxml import etree
9 import lasso
10 import re
11
12
13 VALID_IN_NAME = r'[^\ a-zA-Z0-9]'
14
15
16 class InvalidProviderId(ProviderException):
17
18     def __init__(self, code):
19         message = 'Invalid Provider ID: %s' % code
20         super(InvalidProviderId, self).__init__(message)
21         self.debug(message)
22
23
24 class NameIdNotAllowed(Exception):
25
26     def __init__(self, nid):
27         message = 'Name ID [%s] is not allowed' % nid
28         super(NameIdNotAllowed, self).__init__(message)
29         self.message = message
30
31     def __str__(self):
32         return repr(self.message)
33
34
35 class ServiceProviderConfig(ConfigHelper):
36     def __init__(self):
37         super(ServiceProviderConfig, self).__init__()
38
39
40 class ServiceProvider(ServiceProviderConfig):
41
42     def __init__(self, config, provider_id):
43         super(ServiceProvider, self).__init__()
44         self.cfg = config
45         data = self.cfg.get_data(name='id', value=provider_id)
46         if len(data) != 1:
47             raise InvalidProviderId('multiple matches')
48         idval = data.keys()[0]
49         data = self.cfg.get_data(idval=idval)
50         self._properties = data[idval]
51         self._staging = dict()
52         self.load_config()
53         self.logout_mechs = []
54         xmldoc = etree.XML(str(data[idval]['metadata']))
55         logout = xmldoc.xpath('//md:EntityDescriptor'
56                               '/md:SPSSODescriptor'
57                               '/md:SingleLogoutService',
58                               namespaces=NSMAP)
59         for service in logout:
60             self.logout_mechs.append(service.values()[0])
61
62     def load_config(self):
63         self.new_config(
64             self.provider_id,
65             pconfig.String(
66                 'Name',
67                 'A nickname used to easily identify the Service Provider.'
68                 ' Only alphanumeric characters [A-Z,a-z,0-9] and spaces are'
69                 '  accepted.',
70                 self.name),
71             pconfig.String(
72                 'Description',
73                 'A description of the SP to show on the Portal.',
74                 self.description),
75             pconfig.String(
76                 'Service Provider link',
77                 'A link to the Service Provider for the Portal.',
78                 self.splink),
79             pconfig.Condition(
80                 'Visible in Portal',
81                 'This SP is visible in the Portal.',
82                 self.visible),
83             pconfig.Image(
84                 'Image File',
85                 'Image to display for this SP in the Portal. Scale to '
86                 '100x200 for best results.',
87                 self.imagefile),
88             pconfig.Pick(
89                 'Default NameID',
90                 'Default NameID used by Service Providers.',
91                 SAML2_NAMEID_MAP.keys(),
92                 self.default_nameid),
93             pconfig.Choice(
94                 'Allowed NameIDs',
95                 'Allowed NameIDs for this Service Provider.',
96                 SAML2_NAMEID_MAP.keys(),
97                 self.allowed_nameids),
98             pconfig.String(
99                 'User Owner',
100                 'The user that owns this Service Provider',
101                 self.owner),
102             pconfig.MappingList(
103                 'Attribute Mapping',
104                 'Defines how to map attributes before returning them to'
105                 ' the SP. Setting this overrides the global values.',
106                 self.attribute_mappings),
107             pconfig.ComplexList(
108                 'Allowed Attributes',
109                 'Defines a list of allowed attributes, applied after mapping.'
110                 ' Setting this overrides the global values.',
111                 self.allowed_attributes),
112         )
113
114     @property
115     def provider_id(self):
116         return self._properties['id']
117
118     @property
119     def name(self):
120         return self._properties['name']
121
122     @name.setter
123     def name(self, value):
124         self._staging['name'] = value
125
126     @property
127     def description(self):
128         return self._properties.get('description', '')
129
130     @description.setter
131     def description(self, value):
132         self._staging['description'] = value
133
134     @property
135     def visible(self):
136         return self._properties.get('visible', True)
137
138     @visible.setter
139     def visible(self, value):
140         self._staging['visible'] = value
141
142     @property
143     def imagefile(self):
144         return self._properties.get('imagefile', '')
145
146     @imagefile.setter
147     def imagefile(self, value):
148         self._staging['imagefile'] = value
149
150     @property
151     def imageurl(self):
152         return pconfig.url_from_image(self._properties['imagefile'])
153
154     @property
155     def splink(self):
156         return self._properties.get('splink', '')
157
158     @splink.setter
159     def splink(self, value):
160         self._staging['splink'] = value
161
162     @property
163     def owner(self):
164         if 'owner' in self._properties:
165             return self._properties['owner']
166         else:
167             return ''
168
169     @owner.setter
170     def owner(self, value):
171         self._staging['owner'] = value
172
173     @property
174     def allowed_nameids(self):
175         if 'allowed nameids' in self._properties:
176             allowed = self._properties['allowed nameids']
177             return [x.strip() for x in allowed.split(',')]
178         else:
179             return self.cfg.default_allowed_nameids
180
181     @allowed_nameids.setter
182     def allowed_nameids(self, value):
183         if not isinstance(value, list):
184             raise ValueError("Must be a list")
185         self._staging['allowed nameids'] = ','.join(value)
186
187     @property
188     def default_nameid(self):
189         if 'default nameid' in self._properties:
190             return self._properties['default nameid']
191         else:
192             return self.cfg.default_nameid
193
194     @default_nameid.setter
195     def default_nameid(self, value):
196         self._staging['default nameid'] = value
197
198     @property
199     def attribute_mappings(self):
200         if 'attribute mappings' in self._properties:
201             attr_map = pconfig.MappingList('temp', 'temp', None)
202             attr_map.import_value(str(self._properties['attribute mappings']))
203             return attr_map.get_value()
204         else:
205             return None
206
207     @attribute_mappings.setter
208     def attribute_mappings(self, attr_map):
209         if isinstance(attr_map, pconfig.MappingList):
210             value = attr_map.export_value()
211         else:
212             temp = pconfig.MappingList('temp', 'temp', None)
213             temp.set_value(attr_map)
214             value = temp.export_value()
215         self._staging['attribute mappings'] = value
216
217     @property
218     def allowed_attributes(self):
219         if 'allowed_attributes' in self._properties:
220             attr_map = pconfig.ComplexList('temp', 'temp', None)
221             attr_map.import_value(str(self._properties['allowed_attributes']))
222             return attr_map.get_value()
223         else:
224             return None
225
226     @allowed_attributes.setter
227     def allowed_attributes(self, attr_map):
228         if isinstance(attr_map, pconfig.ComplexList):
229             value = attr_map.export_value()
230         else:
231             temp = pconfig.ComplexList('temp', 'temp', None)
232             temp.set_value(attr_map)
233             value = temp.export_value()
234         self._staging['allowed_attributes'] = value
235
236     def save_properties(self):
237         data = self.cfg.get_data(name='id', value=self.provider_id)
238         if len(data) != 1:
239             raise InvalidProviderId('Could not find SP data')
240         idval = data.keys()[0]
241         data = dict()
242         data[idval] = self._staging
243         self.cfg.save_data(data)
244         data = self.cfg.get_data(idval=idval)
245         self._properties = data[idval]
246         self._staging = dict()
247
248     def refresh_config(self):
249         """
250         Create a new config object for displaying in the UI based on
251         the current set of properties.
252         """
253         del self._config
254         self.load_config()
255
256     def get_valid_nameid(self, nip):
257         if nip is None or nip.format is None:
258             self.debug('No NameId requested, returning default [%s]'
259                        % SAML2_NAMEID_MAP[self.default_nameid])
260             return SAML2_NAMEID_MAP[self.default_nameid]
261         else:
262             self.debug('Requested NameId [%s]' % (nip.format,))
263             allowed = self.allowed_nameids
264             self.debug('Allowed NameIds %s' % (repr(allowed)))
265             for nameid in allowed:
266                 if nip.format == SAML2_NAMEID_MAP[nameid]:
267                     return nip.format
268         raise NameIdNotAllowed(nip.format)
269
270     def permanently_delete(self):
271         data = self.cfg.get_data(name='id', value=self.provider_id)
272         if len(data) != 1:
273             raise InvalidProviderId('Could not find SP data')
274         idval = data.keys()[0]
275         self.cfg.del_datum(idval)
276
277     def normalize_username(self, username):
278         if 'strip domain' in self._properties:
279             return username.split('@', 1)[0]
280         return username
281
282     def is_valid_name(self, value):
283         if re.search(VALID_IN_NAME, value):
284             return False
285         return True
286
287     def is_valid_nameid(self, value):
288         if value in SAML2_NAMEID_MAP:
289             return True
290         return False
291
292     def valid_nameids(self):
293         return SAML2_NAMEID_MAP.keys()
294
295
296 class ServiceProviderCreator(object):
297
298     def __init__(self, config):
299         self.cfg = config
300
301     def create_from_buffer(self, name, metabuf, description='',
302                            visible=True, imagefile='', splink=''):
303         '''Test and add data'''
304
305         if re.search(VALID_IN_NAME, name):
306             raise InvalidProviderId("Name must contain only "
307                                     "numbers and letters")
308
309         test = lasso.Server()
310         test.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP, metabuf)
311         newsps = test.get_providers()
312         if len(newsps) != 1:
313             raise InvalidProviderId("Metadata must contain one Provider")
314
315         spid = newsps.keys()[0]
316         data = self.cfg.get_data(name='id', value=spid)
317         if len(data) != 0:
318             raise InvalidProviderId("Provider Already Exists")
319         datum = {
320             'id': spid,
321             'name': name,
322             'type': 'SP',
323             'metadata': metabuf,
324             'description': description,
325             'visible': visible,
326             'imagefile': imagefile,
327             'splink': splink,
328         }
329         self.cfg.new_datum(datum)
330
331         data = self.cfg.get_data(name='id', value=spid)
332         if len(data) != 1:
333             raise InvalidProviderId("Internal Error")
334         idval = data.keys()[0]
335         data = self.cfg.get_data(idval=idval)
336         sp = data[idval]
337         self.cfg.idp.add_provider(sp)
338
339         return ServiceProvider(self.cfg, spid)
340
341
342 class IdentityProvider(Log):
343     def __init__(self, config, sessionfactory):
344         self.server = lasso.Server(config.idp_metadata_file,
345                                    config.idp_key_file,
346                                    None,
347                                    config.idp_certificate_file)
348         self.server.role = lasso.PROVIDER_ROLE_IDP
349         self.sessionfactory = sessionfactory
350
351     def add_provider(self, sp):
352         self.server.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP,
353                                           sp['metadata'])
354         self.debug('Added SP %s' % sp['name'])
355
356     def get_login_handler(self, dump=None):
357         if dump:
358             return lasso.Login.newFromDump(self.server, dump)
359         else:
360             return lasso.Login(self.server)
361
362     def get_providers(self):
363         return self.server.get_providers()
364
365     def get_logout_handler(self, dump=None):
366         if dump:
367             return lasso.Logout.newFromDump(self.server, dump)
368         else:
369             return lasso.Logout(self.server)