Providers can save properties back to the database
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / provider.py
1 #!/usr/bin/python
2 #
3 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
4 #
5 # see file 'COPYING' for use and warranty information
6 #
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20 from ipsilon.providers.common import ProviderException
21 import cherrypy
22 import lasso
23
24
25 NAMEID_MAP = {
26     'email': lasso.SAML2_NAME_IDENTIFIER_FORMAT_EMAIL,
27     'encrypted': lasso.SAML2_NAME_IDENTIFIER_FORMAT_ENCRYPTED,
28     'entity': lasso.SAML2_NAME_IDENTIFIER_FORMAT_ENTITY,
29     'kerberos': lasso.SAML2_NAME_IDENTIFIER_FORMAT_KERBEROS,
30     'persistent': lasso.SAML2_NAME_IDENTIFIER_FORMAT_PERSISTENT,
31     'transient': lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT,
32     'unspecified': lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED,
33     'windows': lasso.SAML2_NAME_IDENTIFIER_FORMAT_WINDOWS,
34     'x509': lasso.SAML2_NAME_IDENTIFIER_FORMAT_X509,
35 }
36
37
38 class InvalidProviderId(ProviderException):
39
40     def __init__(self, code):
41         message = 'Invalid Provider ID: %s' % code
42         super(InvalidProviderId, self).__init__(message)
43         self._debug(message)
44
45
46 class NameIdNotAllowed(Exception):
47
48     def __init__(self):
49         message = 'The specified Name ID is not allowed'
50         super(NameIdNotAllowed, self).__init__(message)
51         self.message = message
52
53     def __str__(self):
54         return repr(self.message)
55
56
57 class ServiceProvider(object):
58
59     def __init__(self, config, provider_id):
60         self.cfg = config
61         data = self.cfg.get_data(name='id', value=provider_id)
62         if len(data) != 1:
63             raise InvalidProviderId('multiple matches')
64         idval = data.keys()[0]
65         data = self.cfg.get_data(idval=idval)
66         self._properties = data[idval]
67         self._staging = dict()
68
69     @property
70     def provider_id(self):
71         return self._properties['id']
72
73     @property
74     def name(self):
75         return self._properties['name']
76
77     @name.setter
78     def name(self, value):
79         self._staging['name'] = value
80
81     @property
82     def owner(self):
83         if 'owner' in self._properties:
84             return self._properties['owner']
85         else:
86             return ''
87
88     @owner.setter
89     def owner(self, value):
90         self._staging['owner'] = value
91
92     @property
93     def allowed_nameids(self):
94         if 'allowed nameids' in self._properties:
95             allowed = self._properties['allowed nameids']
96             return [x.strip() for x in allowed.split(',')]
97         else:
98             return self.cfg.default_allowed_nameids
99
100     @allowed_nameids.setter
101     def allowed_nameids(self, value):
102         if type(value) is not list:
103             raise ValueError("Must be a list")
104         self._staging['allowed nameids'] = ','.join(value)
105
106     @property
107     def default_nameid(self):
108         if 'default nameid' in self._properties:
109             return self._properties['default nameid']
110         else:
111             return self.cfg.default_nameid
112
113     @default_nameid.setter
114     def default_nameid(self, value):
115         self._staging['default nameid'] = value
116
117     def save_properties(self):
118         data = self.cfg.get_data(name='id', value=self.provider_id)
119         if len(data) != 1:
120             raise InvalidProviderId('Could not find SP data')
121         idval = data.keys()[0]
122         data = dict()
123         data[idval] = self._staging
124         self.cfg.save_data(data)
125         data = self.cfg.get_data(idval=idval)
126         self._properties = data[idval]
127         self._staging = dict()
128
129     def get_valid_nameid(self, nip):
130         self._debug('Requested NameId [%s]' % (nip.format,))
131         if nip.format is None:
132             return NAMEID_MAP[self.default_nameid]
133         elif nip.format == lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED:
134             return NAMEID_MAP[self.default_nameid]
135         else:
136             allowed = self.allowed_nameids
137             self._debug('Allowed NameIds %s' % (repr(allowed)))
138             for nameid in allowed:
139                 if nip.format == NAMEID_MAP[nameid]:
140                     return nip.format
141         raise NameIdNotAllowed(nip.format)
142
143     def _debug(self, fact):
144         if cherrypy.config.get('debug', False):
145             cherrypy.log(fact)
146
147     def normalize_username(self, username):
148         if 'strip domain' in self._properties:
149             return username.split('@', 1)[0]
150         return username
151
152
153 class ServiceProviderCreator(object):
154
155     def __init__(self, config):
156         self.cfg = config
157
158     def create_from_buffer(self, name, metabuf):
159         '''Test and add data'''
160
161         test = lasso.Server()
162         test.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP, metabuf)
163         newsps = test.get_providers()
164         if len(newsps) != 1:
165             raise InvalidProviderId("Metadata must contain one Provider")
166
167         spid = newsps.keys()[0]
168         data = self.cfg.get_data(name='id', value=spid)
169         if len(data) != 0:
170             raise InvalidProviderId("Provider Already Exists")
171         datum = {'id': spid, 'name': name, 'type': 'SP', 'metadata': metabuf}
172         self.cfg.new_datum(datum)
173
174         data = self.cfg.get_data(name='id', value=spid)
175         if len(data) != 1:
176             raise InvalidProviderId("Internal Error")
177         idval = data.keys()[0]
178         data = self.cfg.get_data(idval=idval)
179         sp = data[idval]
180         self.cfg.idp.add_provider(sp)
181
182         return ServiceProvider(self.cfg, spid)
183
184
185 class IdentityProvider(object):
186     def __init__(self, config):
187         self.server = lasso.Server(config.idp_metadata_file,
188                                    config.idp_key_file,
189                                    None,
190                                    config.idp_certificate_file)
191         self.server.role = lasso.PROVIDER_ROLE_IDP
192
193     def add_provider(self, sp):
194         self.server.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP,
195                                           sp['metadata'])
196         self._debug('Added SP %s' % sp['name'])
197
198     def get_login_handler(self, dump=None):
199         if dump:
200             return lasso.Login.newFromDump(self.server, dump)
201         else:
202             return lasso.Login(self.server)
203
204     def get_providers(self):
205         return self.server.get_providers()
206
207     def _debug(self, fact):
208         if cherrypy.config.get('debug', False):
209             cherrypy.log(fact)