337a31d919549fd5761fa5bcfcfe1cb2aaa45b3c
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / provider.py
1 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
2 #
3 # see file 'COPYING' for use and warranty information
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 from ipsilon.providers.common import ProviderException
19 from ipsilon.tools.saml2metadata import SAML2_NAMEID_MAP
20 from ipsilon.util.log import Log
21 import lasso
22
23
24 class InvalidProviderId(ProviderException):
25
26     def __init__(self, code):
27         message = 'Invalid Provider ID: %s' % code
28         super(InvalidProviderId, self).__init__(message)
29         self._debug(message)
30
31
32 class NameIdNotAllowed(Exception):
33
34     def __init__(self, nid):
35         message = 'Name ID [%s] is not allowed' % nid
36         super(NameIdNotAllowed, self).__init__(message)
37         self.message = message
38
39     def __str__(self):
40         return repr(self.message)
41
42
43 class ServiceProvider(Log):
44
45     def __init__(self, config, provider_id):
46         self.cfg = config
47         data = self.cfg.get_data(name='id', value=provider_id)
48         if len(data) != 1:
49             raise InvalidProviderId('multiple matches')
50         idval = data.keys()[0]
51         data = self.cfg.get_data(idval=idval)
52         self._properties = data[idval]
53         self._staging = dict()
54
55     @property
56     def provider_id(self):
57         return self._properties['id']
58
59     @property
60     def name(self):
61         return self._properties['name']
62
63     @name.setter
64     def name(self, value):
65         self._staging['name'] = value
66
67     @property
68     def owner(self):
69         if 'owner' in self._properties:
70             return self._properties['owner']
71         else:
72             return ''
73
74     @owner.setter
75     def owner(self, value):
76         self._staging['owner'] = value
77
78     @property
79     def allowed_nameids(self):
80         if 'allowed nameids' in self._properties:
81             allowed = self._properties['allowed nameids']
82             return [x.strip() for x in allowed.split(',')]
83         else:
84             return self.cfg.default_allowed_nameids
85
86     @allowed_nameids.setter
87     def allowed_nameids(self, value):
88         if type(value) is not list:
89             raise ValueError("Must be a list")
90         self._staging['allowed nameids'] = ','.join(value)
91
92     @property
93     def default_nameid(self):
94         if 'default nameid' in self._properties:
95             return self._properties['default nameid']
96         else:
97             return self.cfg.default_nameid
98
99     @default_nameid.setter
100     def default_nameid(self, value):
101         self._staging['default nameid'] = value
102
103     def save_properties(self):
104         data = self.cfg.get_data(name='id', value=self.provider_id)
105         if len(data) != 1:
106             raise InvalidProviderId('Could not find SP data')
107         idval = data.keys()[0]
108         data = dict()
109         data[idval] = self._staging
110         self.cfg.save_data(data)
111         data = self.cfg.get_data(idval=idval)
112         self._properties = data[idval]
113         self._staging = dict()
114
115     def get_valid_nameid(self, nip):
116         self._debug('Requested NameId [%s]' % (nip.format,))
117         if nip.format is None:
118             return SAML2_NAMEID_MAP[self.default_nameid]
119         elif nip.format == lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED:
120             return SAML2_NAMEID_MAP[self.default_nameid]
121         else:
122             allowed = self.allowed_nameids
123             self._debug('Allowed NameIds %s' % (repr(allowed)))
124             for nameid in allowed:
125                 if nip.format == SAML2_NAMEID_MAP[nameid]:
126                     return nip.format
127         raise NameIdNotAllowed(nip.format)
128
129     def permanently_delete(self):
130         data = self.cfg.get_data(name='id', value=self.provider_id)
131         if len(data) != 1:
132             raise InvalidProviderId('Could not find SP data')
133         idval = data.keys()[0]
134         self.cfg.del_datum(idval)
135
136     def normalize_username(self, username):
137         if 'strip domain' in self._properties:
138             return username.split('@', 1)[0]
139         return username
140
141     def is_valid_nameid(self, value):
142         if value in SAML2_NAMEID_MAP:
143             return True
144         return False
145
146     def valid_nameids(self):
147         return SAML2_NAMEID_MAP.keys()
148
149
150 class ServiceProviderCreator(object):
151
152     def __init__(self, config):
153         self.cfg = config
154
155     def create_from_buffer(self, name, metabuf):
156         '''Test and add data'''
157
158         test = lasso.Server()
159         test.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP, metabuf)
160         newsps = test.get_providers()
161         if len(newsps) != 1:
162             raise InvalidProviderId("Metadata must contain one Provider")
163
164         spid = newsps.keys()[0]
165         data = self.cfg.get_data(name='id', value=spid)
166         if len(data) != 0:
167             raise InvalidProviderId("Provider Already Exists")
168         datum = {'id': spid, 'name': name, 'type': 'SP', 'metadata': metabuf}
169         self.cfg.new_datum(datum)
170
171         data = self.cfg.get_data(name='id', value=spid)
172         if len(data) != 1:
173             raise InvalidProviderId("Internal Error")
174         idval = data.keys()[0]
175         data = self.cfg.get_data(idval=idval)
176         sp = data[idval]
177         self.cfg.idp.add_provider(sp)
178
179         return ServiceProvider(self.cfg, spid)
180
181
182 class IdentityProvider(Log):
183     def __init__(self, config):
184         self.server = lasso.Server(config.idp_metadata_file,
185                                    config.idp_key_file,
186                                    None,
187                                    config.idp_certificate_file)
188         self.server.role = lasso.PROVIDER_ROLE_IDP
189
190     def add_provider(self, sp):
191         self.server.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP,
192                                           sp['metadata'])
193         self._debug('Added SP %s' % sp['name'])
194
195     def get_login_handler(self, dump=None):
196         if dump:
197             return lasso.Login.newFromDump(self.server, dump)
198         else:
199             return lasso.Login(self.server)
200
201     def get_providers(self):
202         return self.server.get_providers()