Move accessory functions to a generic tools module
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / provider.py
index c738ac2..7d47363 100755 (executable)
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+from ipsilon.providers.common import ProviderException
+from ipsilon.tools.saml2metadata import SAML2_NAMEID_MAP
 import cherrypy
 import lasso
 
 
-NAMEID_MAP = {
-    'email': lasso.SAML2_NAME_IDENTIFIER_FORMAT_EMAIL,
-    'encrypted': lasso.SAML2_NAME_IDENTIFIER_FORMAT_ENCRYPTED,
-    'entity': lasso.SAML2_NAME_IDENTIFIER_FORMAT_ENTITY,
-    'kerberos': lasso.SAML2_NAME_IDENTIFIER_FORMAT_KERBEROS,
-    'persistent': lasso.SAML2_NAME_IDENTIFIER_FORMAT_PERSISTENT,
-    'transient': lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT,
-    'unspecified': lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED,
-    'windows': lasso.SAML2_NAME_IDENTIFIER_FORMAT_WINDOWS,
-    'x509': lasso.SAML2_NAME_IDENTIFIER_FORMAT_X509,
-}
+class InvalidProviderId(ProviderException):
 
-
-class InvalidProviderId(Exception):
-
-    def __init__(self, message):
-        msg = 'Invalid Provider ID: %s' % message
-        super(InvalidProviderId, self).__init__(msg)
-        self.message = msg
-
-    def __str__(self):
-        return repr(self.message)
+    def __init__(self, code):
+        message = 'Invalid Provider ID: %s' % code
+        super(InvalidProviderId, self).__init__(message)
+        self._debug(message)
 
 
 class NameIdNotAllowed(Exception):
@@ -66,6 +52,7 @@ class ServiceProvider(object):
         idval = data.keys()[0]
         data = self.cfg.get_data(idval=idval)
         self._properties = data[idval]
+        self._staging = dict()
 
     @property
     def provider_id(self):
@@ -75,13 +62,35 @@ class ServiceProvider(object):
     def name(self):
         return self._properties['name']
 
+    @name.setter
+    def name(self, value):
+        self._staging['name'] = value
+
     @property
-    def allowed_namedids(self):
-        if 'allowed nameid' in self._properties:
-            return self._properties['allowed nameid']
+    def owner(self):
+        if 'owner' in self._properties:
+            return self._properties['owner']
+        else:
+            return ''
+
+    @owner.setter
+    def owner(self, value):
+        self._staging['owner'] = value
+
+    @property
+    def allowed_nameids(self):
+        if 'allowed nameids' in self._properties:
+            allowed = self._properties['allowed nameids']
+            return [x.strip() for x in allowed.split(',')]
         else:
             return self.cfg.default_allowed_nameids
 
+    @allowed_nameids.setter
+    def allowed_nameids(self, value):
+        if type(value) is not list:
+            raise ValueError("Must be a list")
+        self._staging['allowed nameids'] = ','.join(value)
+
     @property
     def default_nameid(self):
         if 'default nameid' in self._properties:
@@ -89,19 +98,106 @@ class ServiceProvider(object):
         else:
             return self.cfg.default_nameid
 
+    @default_nameid.setter
+    def default_nameid(self, value):
+        self._staging['default nameid'] = value
+
+    def save_properties(self):
+        data = self.cfg.get_data(name='id', value=self.provider_id)
+        if len(data) != 1:
+            raise InvalidProviderId('Could not find SP data')
+        idval = data.keys()[0]
+        data = dict()
+        data[idval] = self._staging
+        self.cfg.save_data(data)
+        data = self.cfg.get_data(idval=idval)
+        self._properties = data[idval]
+        self._staging = dict()
+
     def get_valid_nameid(self, nip):
         self._debug('Requested NameId [%s]' % (nip.format,))
-        if nip.format == None:
-            return NAMEID_MAP[self.default_nameid]
+        if nip.format is None:
+            return SAML2_NAMEID_MAP[self.default_nameid]
         elif nip.format == lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED:
-            return NAMEID_MAP[self.default_nameid]
+            return SAML2_NAMEID_MAP[self.default_nameid]
         else:
-            allowed = self.allowed_namedids
+            allowed = self.allowed_nameids
             self._debug('Allowed NameIds %s' % (repr(allowed)))
             for nameid in allowed:
-                if nip.format == NAMEID_MAP[nameid]:
+                if nip.format == SAML2_NAMEID_MAP[nameid]:
                     return nip.format
-        raise NameIdNotAllowed()
+        raise NameIdNotAllowed(nip.format)
+
+    def permanently_delete(self):
+        data = self.cfg.get_data(name='id', value=self.provider_id)
+        if len(data) != 1:
+            raise InvalidProviderId('Could not find SP data')
+        idval = data.keys()[0]
+        self.cfg.del_datum(idval)
+
+    def _debug(self, fact):
+        if cherrypy.config.get('debug', False):
+            cherrypy.log(fact)
+
+    def normalize_username(self, username):
+        if 'strip domain' in self._properties:
+            return username.split('@', 1)[0]
+        return username
+
+
+class ServiceProviderCreator(object):
+
+    def __init__(self, config):
+        self.cfg = config
+
+    def create_from_buffer(self, name, metabuf):
+        '''Test and add data'''
+
+        test = lasso.Server()
+        test.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP, metabuf)
+        newsps = test.get_providers()
+        if len(newsps) != 1:
+            raise InvalidProviderId("Metadata must contain one Provider")
+
+        spid = newsps.keys()[0]
+        data = self.cfg.get_data(name='id', value=spid)
+        if len(data) != 0:
+            raise InvalidProviderId("Provider Already Exists")
+        datum = {'id': spid, 'name': name, 'type': 'SP', 'metadata': metabuf}
+        self.cfg.new_datum(datum)
+
+        data = self.cfg.get_data(name='id', value=spid)
+        if len(data) != 1:
+            raise InvalidProviderId("Internal Error")
+        idval = data.keys()[0]
+        data = self.cfg.get_data(idval=idval)
+        sp = data[idval]
+        self.cfg.idp.add_provider(sp)
+
+        return ServiceProvider(self.cfg, spid)
+
+
+class IdentityProvider(object):
+    def __init__(self, config):
+        self.server = lasso.Server(config.idp_metadata_file,
+                                   config.idp_key_file,
+                                   None,
+                                   config.idp_certificate_file)
+        self.server.role = lasso.PROVIDER_ROLE_IDP
+
+    def add_provider(self, sp):
+        self.server.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP,
+                                          sp['metadata'])
+        self._debug('Added SP %s' % sp['name'])
+
+    def get_login_handler(self, dump=None):
+        if dump:
+            return lasso.Login.newFromDump(self.server, dump)
+        else:
+            return lasso.Login(self.server)
+
+    def get_providers(self):
+        return self.server.get_providers()
 
     def _debug(self, fact):
         if cherrypy.config.get('debug', False):