Validate SP names for admin pages and REST
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / provider.py
1 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
2 #
3 # see file 'COPYING' for use and warranty information
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 from ipsilon.providers.common import ProviderException
19 from ipsilon.tools.saml2metadata import SAML2_NAMEID_MAP
20 from ipsilon.util.log import Log
21 import lasso
22 import re
23
24
25 VALID_IN_NAME = r'[^\ a-zA-Z0-9]'
26
27
28 class InvalidProviderId(ProviderException):
29
30     def __init__(self, code):
31         message = 'Invalid Provider ID: %s' % code
32         super(InvalidProviderId, self).__init__(message)
33         self._debug(message)
34
35
36 class NameIdNotAllowed(Exception):
37
38     def __init__(self, nid):
39         message = 'Name ID [%s] is not allowed' % nid
40         super(NameIdNotAllowed, self).__init__(message)
41         self.message = message
42
43     def __str__(self):
44         return repr(self.message)
45
46
47 class ServiceProvider(Log):
48
49     def __init__(self, config, provider_id):
50         self.cfg = config
51         data = self.cfg.get_data(name='id', value=provider_id)
52         if len(data) != 1:
53             raise InvalidProviderId('multiple matches')
54         idval = data.keys()[0]
55         data = self.cfg.get_data(idval=idval)
56         self._properties = data[idval]
57         self._staging = dict()
58
59     @property
60     def provider_id(self):
61         return self._properties['id']
62
63     @property
64     def name(self):
65         return self._properties['name']
66
67     @name.setter
68     def name(self, value):
69         self._staging['name'] = value
70
71     @property
72     def owner(self):
73         if 'owner' in self._properties:
74             return self._properties['owner']
75         else:
76             return ''
77
78     @owner.setter
79     def owner(self, value):
80         self._staging['owner'] = value
81
82     @property
83     def allowed_nameids(self):
84         if 'allowed nameids' in self._properties:
85             allowed = self._properties['allowed nameids']
86             return [x.strip() for x in allowed.split(',')]
87         else:
88             return self.cfg.default_allowed_nameids
89
90     @allowed_nameids.setter
91     def allowed_nameids(self, value):
92         if type(value) is not list:
93             raise ValueError("Must be a list")
94         self._staging['allowed nameids'] = ','.join(value)
95
96     @property
97     def default_nameid(self):
98         if 'default nameid' in self._properties:
99             return self._properties['default nameid']
100         else:
101             return self.cfg.default_nameid
102
103     @default_nameid.setter
104     def default_nameid(self, value):
105         self._staging['default nameid'] = value
106
107     def save_properties(self):
108         data = self.cfg.get_data(name='id', value=self.provider_id)
109         if len(data) != 1:
110             raise InvalidProviderId('Could not find SP data')
111         idval = data.keys()[0]
112         data = dict()
113         data[idval] = self._staging
114         self.cfg.save_data(data)
115         data = self.cfg.get_data(idval=idval)
116         self._properties = data[idval]
117         self._staging = dict()
118
119     def get_valid_nameid(self, nip):
120         self._debug('Requested NameId [%s]' % (nip.format,))
121         if nip.format is None:
122             return SAML2_NAMEID_MAP[self.default_nameid]
123         else:
124             allowed = self.allowed_nameids
125             self._debug('Allowed NameIds %s' % (repr(allowed)))
126             for nameid in allowed:
127                 if nip.format == SAML2_NAMEID_MAP[nameid]:
128                     return nip.format
129         raise NameIdNotAllowed(nip.format)
130
131     def permanently_delete(self):
132         data = self.cfg.get_data(name='id', value=self.provider_id)
133         if len(data) != 1:
134             raise InvalidProviderId('Could not find SP data')
135         idval = data.keys()[0]
136         self.cfg.del_datum(idval)
137
138     def normalize_username(self, username):
139         if 'strip domain' in self._properties:
140             return username.split('@', 1)[0]
141         return username
142
143     def is_valid_name(self, value):
144         if re.search(VALID_IN_NAME, value):
145             return False
146         return True
147
148     def is_valid_nameid(self, value):
149         if value in SAML2_NAMEID_MAP:
150             return True
151         return False
152
153     def valid_nameids(self):
154         return SAML2_NAMEID_MAP.keys()
155
156
157 class ServiceProviderCreator(object):
158
159     def __init__(self, config):
160         self.cfg = config
161
162     def create_from_buffer(self, name, metabuf):
163         '''Test and add data'''
164
165         if re.search(VALID_IN_NAME, name):
166             raise InvalidProviderId("Name must contain only "
167                                     "numbers and letters")
168
169         test = lasso.Server()
170         test.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP, metabuf)
171         newsps = test.get_providers()
172         if len(newsps) != 1:
173             raise InvalidProviderId("Metadata must contain one Provider")
174
175         spid = newsps.keys()[0]
176         data = self.cfg.get_data(name='id', value=spid)
177         if len(data) != 0:
178             raise InvalidProviderId("Provider Already Exists")
179         datum = {'id': spid, 'name': name, 'type': 'SP', 'metadata': metabuf}
180         self.cfg.new_datum(datum)
181
182         data = self.cfg.get_data(name='id', value=spid)
183         if len(data) != 1:
184             raise InvalidProviderId("Internal Error")
185         idval = data.keys()[0]
186         data = self.cfg.get_data(idval=idval)
187         sp = data[idval]
188         self.cfg.idp.add_provider(sp)
189
190         return ServiceProvider(self.cfg, spid)
191
192
193 class IdentityProvider(Log):
194     def __init__(self, config):
195         self.server = lasso.Server(config.idp_metadata_file,
196                                    config.idp_key_file,
197                                    None,
198                                    config.idp_certificate_file)
199         self.server.role = lasso.PROVIDER_ROLE_IDP
200
201     def add_provider(self, sp):
202         self.server.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP,
203                                           sp['metadata'])
204         self._debug('Added SP %s' % sp['name'])
205
206     def get_login_handler(self, dump=None):
207         if dump:
208             return lasso.Login.newFromDump(self.server, dump)
209         else:
210             return lasso.Login(self.server)
211
212     def get_providers(self):
213         return self.server.get_providers()
214
215     def get_logout_handler(self, dump=None):
216         if dump:
217             return lasso.Logout.newFromDump(self.server, dump)
218         else:
219             return lasso.Logout(self.server)