Fix NameId exception
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / provider.py
1 #!/usr/bin/python
2 #
3 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
4 #
5 # see file 'COPYING' for use and warranty information
6 #
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20 from ipsilon.providers.common import ProviderException
21 from ipsilon.tools.saml2metadata import SAML2_NAMEID_MAP
22 import cherrypy
23 import lasso
24
25
26 class InvalidProviderId(ProviderException):
27
28     def __init__(self, code):
29         message = 'Invalid Provider ID: %s' % code
30         super(InvalidProviderId, self).__init__(message)
31         self._debug(message)
32
33
34 class NameIdNotAllowed(Exception):
35
36     def __init__(self, id):
37         message = 'Name ID [%s] is not allowed' % id
38         super(NameIdNotAllowed, self).__init__(message)
39         self.message = message
40
41     def __str__(self):
42         return repr(self.message)
43
44
45 class ServiceProvider(object):
46
47     def __init__(self, config, provider_id):
48         self.cfg = config
49         data = self.cfg.get_data(name='id', value=provider_id)
50         if len(data) != 1:
51             raise InvalidProviderId('multiple matches')
52         idval = data.keys()[0]
53         data = self.cfg.get_data(idval=idval)
54         self._properties = data[idval]
55         self._staging = dict()
56
57     @property
58     def provider_id(self):
59         return self._properties['id']
60
61     @property
62     def name(self):
63         return self._properties['name']
64
65     @name.setter
66     def name(self, value):
67         self._staging['name'] = value
68
69     @property
70     def owner(self):
71         if 'owner' in self._properties:
72             return self._properties['owner']
73         else:
74             return ''
75
76     @owner.setter
77     def owner(self, value):
78         self._staging['owner'] = value
79
80     @property
81     def allowed_nameids(self):
82         if 'allowed nameids' in self._properties:
83             allowed = self._properties['allowed nameids']
84             return [x.strip() for x in allowed.split(',')]
85         else:
86             return self.cfg.default_allowed_nameids
87
88     @allowed_nameids.setter
89     def allowed_nameids(self, value):
90         if type(value) is not list:
91             raise ValueError("Must be a list")
92         self._staging['allowed nameids'] = ','.join(value)
93
94     @property
95     def default_nameid(self):
96         if 'default nameid' in self._properties:
97             return self._properties['default nameid']
98         else:
99             return self.cfg.default_nameid
100
101     @default_nameid.setter
102     def default_nameid(self, value):
103         self._staging['default nameid'] = value
104
105     def save_properties(self):
106         data = self.cfg.get_data(name='id', value=self.provider_id)
107         if len(data) != 1:
108             raise InvalidProviderId('Could not find SP data')
109         idval = data.keys()[0]
110         data = dict()
111         data[idval] = self._staging
112         self.cfg.save_data(data)
113         data = self.cfg.get_data(idval=idval)
114         self._properties = data[idval]
115         self._staging = dict()
116
117     def get_valid_nameid(self, nip):
118         self._debug('Requested NameId [%s]' % (nip.format,))
119         if nip.format is None:
120             return SAML2_NAMEID_MAP[self.default_nameid]
121         elif nip.format == lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED:
122             return SAML2_NAMEID_MAP[self.default_nameid]
123         else:
124             allowed = self.allowed_nameids
125             self._debug('Allowed NameIds %s' % (repr(allowed)))
126             for nameid in allowed:
127                 if nip.format == SAML2_NAMEID_MAP[nameid]:
128                     return nip.format
129         raise NameIdNotAllowed(nip.format)
130
131     def permanently_delete(self):
132         data = self.cfg.get_data(name='id', value=self.provider_id)
133         if len(data) != 1:
134             raise InvalidProviderId('Could not find SP data')
135         idval = data.keys()[0]
136         self.cfg.del_datum(idval)
137
138     def _debug(self, fact):
139         if cherrypy.config.get('debug', False):
140             cherrypy.log(fact)
141
142     def normalize_username(self, username):
143         if 'strip domain' in self._properties:
144             return username.split('@', 1)[0]
145         return username
146
147     def is_valid_nameid(self, value):
148         if value in SAML2_NAMEID_MAP:
149             return True
150         return False
151
152     def valid_nameids(self):
153         return SAML2_NAMEID_MAP.keys()
154
155
156 class ServiceProviderCreator(object):
157
158     def __init__(self, config):
159         self.cfg = config
160
161     def create_from_buffer(self, name, metabuf):
162         '''Test and add data'''
163
164         test = lasso.Server()
165         test.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP, metabuf)
166         newsps = test.get_providers()
167         if len(newsps) != 1:
168             raise InvalidProviderId("Metadata must contain one Provider")
169
170         spid = newsps.keys()[0]
171         data = self.cfg.get_data(name='id', value=spid)
172         if len(data) != 0:
173             raise InvalidProviderId("Provider Already Exists")
174         datum = {'id': spid, 'name': name, 'type': 'SP', 'metadata': metabuf}
175         self.cfg.new_datum(datum)
176
177         data = self.cfg.get_data(name='id', value=spid)
178         if len(data) != 1:
179             raise InvalidProviderId("Internal Error")
180         idval = data.keys()[0]
181         data = self.cfg.get_data(idval=idval)
182         sp = data[idval]
183         self.cfg.idp.add_provider(sp)
184
185         return ServiceProvider(self.cfg, spid)
186
187
188 class IdentityProvider(object):
189     def __init__(self, config):
190         self.server = lasso.Server(config.idp_metadata_file,
191                                    config.idp_key_file,
192                                    None,
193                                    config.idp_certificate_file)
194         self.server.role = lasso.PROVIDER_ROLE_IDP
195
196     def add_provider(self, sp):
197         self.server.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP,
198                                           sp['metadata'])
199         self._debug('Added SP %s' % sp['name'])
200
201     def get_login_handler(self, dump=None):
202         if dump:
203             return lasso.Login.newFromDump(self.server, dump)
204         else:
205             return lasso.Login(self.server)
206
207     def get_providers(self):
208         return self.server.get_providers()
209
210     def _debug(self, fact):
211         if cherrypy.config.get('debug', False):
212             cherrypy.log(fact)