Update Copyright header point to COPYING file
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / provider.py
1 # Copyright (C) 2014 Ipsilon project Contributors, for license see COPYING
2
3 from ipsilon.providers.common import ProviderException
4 from ipsilon.util import config as pconfig
5 from ipsilon.util.config import ConfigHelper
6 from ipsilon.tools.saml2metadata import SAML2_NAMEID_MAP
7 from ipsilon.util.log import Log
8 import lasso
9 import re
10
11
12 VALID_IN_NAME = r'[^\ a-zA-Z0-9]'
13
14
15 class InvalidProviderId(ProviderException):
16
17     def __init__(self, code):
18         message = 'Invalid Provider ID: %s' % code
19         super(InvalidProviderId, self).__init__(message)
20         self.debug(message)
21
22
23 class NameIdNotAllowed(Exception):
24
25     def __init__(self, nid):
26         message = 'Name ID [%s] is not allowed' % nid
27         super(NameIdNotAllowed, self).__init__(message)
28         self.message = message
29
30     def __str__(self):
31         return repr(self.message)
32
33
34 class ServiceProviderConfig(ConfigHelper):
35     def __init__(self):
36         super(ServiceProviderConfig, self).__init__()
37
38
39 class ServiceProvider(ServiceProviderConfig):
40
41     def __init__(self, config, provider_id):
42         super(ServiceProvider, self).__init__()
43         self.cfg = config
44         data = self.cfg.get_data(name='id', value=provider_id)
45         if len(data) != 1:
46             raise InvalidProviderId('multiple matches')
47         idval = data.keys()[0]
48         data = self.cfg.get_data(idval=idval)
49         self._properties = data[idval]
50         self._staging = dict()
51         self.load_config()
52
53     def load_config(self):
54         self.new_config(
55             self.provider_id,
56             pconfig.String(
57                 'Name',
58                 'A nickname used to easily identify the Service Provider.'
59                 ' Only alphanumeric characters [A-Z,a-z,0-9] and spaces are'
60                 '  accepted.',
61                 self.name),
62             pconfig.Pick(
63                 'Default NameID',
64                 'Default NameID used by Service Providers.',
65                 SAML2_NAMEID_MAP.keys(),
66                 self.default_nameid),
67             pconfig.Choice(
68                 'Allowed NameIDs',
69                 'Allowed NameIDs for this Service Provider.',
70                 SAML2_NAMEID_MAP.keys(),
71                 self.allowed_nameids),
72             pconfig.String(
73                 'User Owner',
74                 'The user that owns this Service Provider',
75                 self.owner),
76             pconfig.MappingList(
77                 'Attribute Mapping',
78                 'Defines how to map attributes before returning them to'
79                 ' the SP. Setting this overrides the global values.',
80                 self.attribute_mappings),
81             pconfig.ComplexList(
82                 'Allowed Attributes',
83                 'Defines a list of allowed attributes, applied after mapping.'
84                 ' Setting this overrides the global values.',
85                 self.allowed_attributes),
86         )
87
88     @property
89     def provider_id(self):
90         return self._properties['id']
91
92     @property
93     def name(self):
94         return self._properties['name']
95
96     @name.setter
97     def name(self, value):
98         self._staging['name'] = value
99
100     @property
101     def owner(self):
102         if 'owner' in self._properties:
103             return self._properties['owner']
104         else:
105             return ''
106
107     @owner.setter
108     def owner(self, value):
109         self._staging['owner'] = value
110
111     @property
112     def allowed_nameids(self):
113         if 'allowed nameids' in self._properties:
114             allowed = self._properties['allowed nameids']
115             return [x.strip() for x in allowed.split(',')]
116         else:
117             return self.cfg.default_allowed_nameids
118
119     @allowed_nameids.setter
120     def allowed_nameids(self, value):
121         if not isinstance(value, list):
122             raise ValueError("Must be a list")
123         self._staging['allowed nameids'] = ','.join(value)
124
125     @property
126     def default_nameid(self):
127         if 'default nameid' in self._properties:
128             return self._properties['default nameid']
129         else:
130             return self.cfg.default_nameid
131
132     @default_nameid.setter
133     def default_nameid(self, value):
134         self._staging['default nameid'] = value
135
136     @property
137     def attribute_mappings(self):
138         if 'attribute mappings' in self._properties:
139             attr_map = pconfig.MappingList('temp', 'temp', None)
140             attr_map.import_value(str(self._properties['attribute mappings']))
141             return attr_map.get_value()
142         else:
143             return None
144
145     @attribute_mappings.setter
146     def attribute_mappings(self, attr_map):
147         if isinstance(attr_map, pconfig.MappingList):
148             value = attr_map.export_value()
149         else:
150             temp = pconfig.MappingList('temp', 'temp', None)
151             temp.set_value(attr_map)
152             value = temp.export_value()
153         self._staging['attribute mappings'] = value
154
155     @property
156     def allowed_attributes(self):
157         if 'allowed_attributes' in self._properties:
158             attr_map = pconfig.ComplexList('temp', 'temp', None)
159             attr_map.import_value(str(self._properties['allowed_attributes']))
160             return attr_map.get_value()
161         else:
162             return None
163
164     @allowed_attributes.setter
165     def allowed_attributes(self, attr_map):
166         if isinstance(attr_map, pconfig.ComplexList):
167             value = attr_map.export_value()
168         else:
169             temp = pconfig.ComplexList('temp', 'temp', None)
170             temp.set_value(attr_map)
171             value = temp.export_value()
172         self._staging['allowed_attributes'] = value
173
174     def save_properties(self):
175         data = self.cfg.get_data(name='id', value=self.provider_id)
176         if len(data) != 1:
177             raise InvalidProviderId('Could not find SP data')
178         idval = data.keys()[0]
179         data = dict()
180         data[idval] = self._staging
181         self.cfg.save_data(data)
182         data = self.cfg.get_data(idval=idval)
183         self._properties = data[idval]
184         self._staging = dict()
185
186     def refresh_config(self):
187         """
188         Create a new config object for displaying in the UI based on
189         the current set of properties.
190         """
191         del self._config
192         self.load_config()
193
194     def get_valid_nameid(self, nip):
195         self.debug('Requested NameId [%s]' % (nip.format,))
196         if nip.format is None:
197             return SAML2_NAMEID_MAP[self.default_nameid]
198         else:
199             allowed = self.allowed_nameids
200             self.debug('Allowed NameIds %s' % (repr(allowed)))
201             for nameid in allowed:
202                 if nip.format == SAML2_NAMEID_MAP[nameid]:
203                     return nip.format
204         raise NameIdNotAllowed(nip.format)
205
206     def permanently_delete(self):
207         data = self.cfg.get_data(name='id', value=self.provider_id)
208         if len(data) != 1:
209             raise InvalidProviderId('Could not find SP data')
210         idval = data.keys()[0]
211         self.cfg.del_datum(idval)
212
213     def normalize_username(self, username):
214         if 'strip domain' in self._properties:
215             return username.split('@', 1)[0]
216         return username
217
218     def is_valid_name(self, value):
219         if re.search(VALID_IN_NAME, value):
220             return False
221         return True
222
223     def is_valid_nameid(self, value):
224         if value in SAML2_NAMEID_MAP:
225             return True
226         return False
227
228     def valid_nameids(self):
229         return SAML2_NAMEID_MAP.keys()
230
231
232 class ServiceProviderCreator(object):
233
234     def __init__(self, config):
235         self.cfg = config
236
237     def create_from_buffer(self, name, metabuf):
238         '''Test and add data'''
239
240         if re.search(VALID_IN_NAME, name):
241             raise InvalidProviderId("Name must contain only "
242                                     "numbers and letters")
243
244         test = lasso.Server()
245         test.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP, metabuf)
246         newsps = test.get_providers()
247         if len(newsps) != 1:
248             raise InvalidProviderId("Metadata must contain one Provider")
249
250         spid = newsps.keys()[0]
251         data = self.cfg.get_data(name='id', value=spid)
252         if len(data) != 0:
253             raise InvalidProviderId("Provider Already Exists")
254         datum = {'id': spid, 'name': name, 'type': 'SP', 'metadata': metabuf}
255         self.cfg.new_datum(datum)
256
257         data = self.cfg.get_data(name='id', value=spid)
258         if len(data) != 1:
259             raise InvalidProviderId("Internal Error")
260         idval = data.keys()[0]
261         data = self.cfg.get_data(idval=idval)
262         sp = data[idval]
263         self.cfg.idp.add_provider(sp)
264
265         return ServiceProvider(self.cfg, spid)
266
267
268 class IdentityProvider(Log):
269     def __init__(self, config):
270         self.server = lasso.Server(config.idp_metadata_file,
271                                    config.idp_key_file,
272                                    None,
273                                    config.idp_certificate_file)
274         self.server.role = lasso.PROVIDER_ROLE_IDP
275
276     def add_provider(self, sp):
277         self.server.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP,
278                                           sp['metadata'])
279         self.debug('Added SP %s' % sp['name'])
280
281     def get_login_handler(self, dump=None):
282         if dump:
283             return lasso.Login.newFromDump(self.server, dump)
284         else:
285             return lasso.Login(self.server)
286
287     def get_providers(self):
288         return self.server.get_providers()
289
290     def get_logout_handler(self, dump=None):
291         if dump:
292             return lasso.Logout.newFromDump(self.server, dump)
293         else:
294             return lasso.Logout(self.server)