Use new Log class everywhere
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / provider.py
index 73ff005..58ffbfe 100755 (executable)
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 from ipsilon.providers.common import ProviderException
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 from ipsilon.providers.common import ProviderException
-import cherrypy
+from ipsilon.tools.saml2metadata import SAML2_NAMEID_MAP
+from ipsilon.util.log import Log
 import lasso
 
 
 import lasso
 
 
-NAMEID_MAP = {
-    'email': lasso.SAML2_NAME_IDENTIFIER_FORMAT_EMAIL,
-    'encrypted': lasso.SAML2_NAME_IDENTIFIER_FORMAT_ENCRYPTED,
-    'entity': lasso.SAML2_NAME_IDENTIFIER_FORMAT_ENTITY,
-    'kerberos': lasso.SAML2_NAME_IDENTIFIER_FORMAT_KERBEROS,
-    'persistent': lasso.SAML2_NAME_IDENTIFIER_FORMAT_PERSISTENT,
-    'transient': lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT,
-    'unspecified': lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED,
-    'windows': lasso.SAML2_NAME_IDENTIFIER_FORMAT_WINDOWS,
-    'x509': lasso.SAML2_NAME_IDENTIFIER_FORMAT_X509,
-}
-
-
 class InvalidProviderId(ProviderException):
 
     def __init__(self, code):
 class InvalidProviderId(ProviderException):
 
     def __init__(self, code):
@@ -45,8 +33,8 @@ class InvalidProviderId(ProviderException):
 
 class NameIdNotAllowed(Exception):
 
 
 class NameIdNotAllowed(Exception):
 
-    def __init__(self):
-        message = 'The specified Name ID is not allowed'
+    def __init__(self, nid):
+        message = 'Name ID [%s] is not allowed' % nid
         super(NameIdNotAllowed, self).__init__(message)
         self.message = message
 
         super(NameIdNotAllowed, self).__init__(message)
         self.message = message
 
@@ -54,7 +42,7 @@ class NameIdNotAllowed(Exception):
         return repr(self.message)
 
 
         return repr(self.message)
 
 
-class ServiceProvider(object):
+class ServiceProvider(Log):
 
     def __init__(self, config, provider_id):
         self.cfg = config
 
     def __init__(self, config, provider_id):
         self.cfg = config
@@ -129,14 +117,14 @@ class ServiceProvider(object):
     def get_valid_nameid(self, nip):
         self._debug('Requested NameId [%s]' % (nip.format,))
         if nip.format is None:
     def get_valid_nameid(self, nip):
         self._debug('Requested NameId [%s]' % (nip.format,))
         if nip.format is None:
-            return NAMEID_MAP[self.default_nameid]
+            return SAML2_NAMEID_MAP[self.default_nameid]
         elif nip.format == lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED:
         elif nip.format == lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED:
-            return NAMEID_MAP[self.default_nameid]
+            return SAML2_NAMEID_MAP[self.default_nameid]
         else:
             allowed = self.allowed_nameids
             self._debug('Allowed NameIds %s' % (repr(allowed)))
             for nameid in allowed:
         else:
             allowed = self.allowed_nameids
             self._debug('Allowed NameIds %s' % (repr(allowed)))
             for nameid in allowed:
-                if nip.format == NAMEID_MAP[nameid]:
+                if nip.format == SAML2_NAMEID_MAP[nameid]:
                     return nip.format
         raise NameIdNotAllowed(nip.format)
 
                     return nip.format
         raise NameIdNotAllowed(nip.format)
 
@@ -147,15 +135,19 @@ class ServiceProvider(object):
         idval = data.keys()[0]
         self.cfg.del_datum(idval)
 
         idval = data.keys()[0]
         self.cfg.del_datum(idval)
 
-    def _debug(self, fact):
-        if cherrypy.config.get('debug', False):
-            cherrypy.log(fact)
-
     def normalize_username(self, username):
         if 'strip domain' in self._properties:
             return username.split('@', 1)[0]
         return username
 
     def normalize_username(self, username):
         if 'strip domain' in self._properties:
             return username.split('@', 1)[0]
         return username
 
+    def is_valid_nameid(self, value):
+        if value in SAML2_NAMEID_MAP:
+            return True
+        return False
+
+    def valid_nameids(self):
+        return SAML2_NAMEID_MAP.keys()
+
 
 class ServiceProviderCreator(object):
 
 
 class ServiceProviderCreator(object):
 
@@ -189,7 +181,7 @@ class ServiceProviderCreator(object):
         return ServiceProvider(self.cfg, spid)
 
 
         return ServiceProvider(self.cfg, spid)
 
 
-class IdentityProvider(object):
+class IdentityProvider(Log):
     def __init__(self, config):
         self.server = lasso.Server(config.idp_metadata_file,
                                    config.idp_key_file,
     def __init__(self, config):
         self.server = lasso.Server(config.idp_metadata_file,
                                    config.idp_key_file,
@@ -210,7 +202,3 @@ class IdentityProvider(object):
 
     def get_providers(self):
         return self.server.get_providers()
 
     def get_providers(self):
         return self.server.get_providers()
-
-    def _debug(self, fact):
-        if cherrypy.config.get('debug', False):
-            cherrypy.log(fact)