Add Service Provider class
authorSimo Sorce <simo@redhat.com>
Wed, 26 Feb 2014 23:42:09 +0000 (18:42 -0500)
committerSimo Sorce <simo@redhat.com>
Thu, 27 Feb 2014 02:50:27 +0000 (21:50 -0500)
This class allows to represent a service provider and its associated policy

Signed-off-by: Simo Sorce <simo@redhat.com>
ipsilon/providers/saml2/provider.py [new file with mode: 0755]
ipsilon/providers/saml2idp.py
ipsilon/util/data.py
ipsilon/util/plugin.py

diff --git a/ipsilon/providers/saml2/provider.py b/ipsilon/providers/saml2/provider.py
new file mode 100755 (executable)
index 0000000..c738ac2
--- /dev/null
@@ -0,0 +1,108 @@
+#!/usr/bin/python
+#
+# Copyright (C) 2014  Simo Sorce <simo@redhat.com>
+#
+# see file 'COPYING' for use and warranty information
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+import cherrypy
+import lasso
+
+
+NAMEID_MAP = {
+    'email': lasso.SAML2_NAME_IDENTIFIER_FORMAT_EMAIL,
+    'encrypted': lasso.SAML2_NAME_IDENTIFIER_FORMAT_ENCRYPTED,
+    'entity': lasso.SAML2_NAME_IDENTIFIER_FORMAT_ENTITY,
+    'kerberos': lasso.SAML2_NAME_IDENTIFIER_FORMAT_KERBEROS,
+    'persistent': lasso.SAML2_NAME_IDENTIFIER_FORMAT_PERSISTENT,
+    'transient': lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT,
+    'unspecified': lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED,
+    'windows': lasso.SAML2_NAME_IDENTIFIER_FORMAT_WINDOWS,
+    'x509': lasso.SAML2_NAME_IDENTIFIER_FORMAT_X509,
+}
+
+
+class InvalidProviderId(Exception):
+
+    def __init__(self, message):
+        msg = 'Invalid Provider ID: %s' % message
+        super(InvalidProviderId, self).__init__(msg)
+        self.message = msg
+
+    def __str__(self):
+        return repr(self.message)
+
+
+class NameIdNotAllowed(Exception):
+
+    def __init__(self):
+        message = 'The specified Name ID is not allowed'
+        super(NameIdNotAllowed, self).__init__(message)
+        self.message = message
+
+    def __str__(self):
+        return repr(self.message)
+
+
+class ServiceProvider(object):
+
+    def __init__(self, config, provider_id):
+        self.cfg = config
+        data = self.cfg.get_data(name='id', value=provider_id)
+        if len(data) != 1:
+            raise InvalidProviderId('multiple matches')
+        idval = data.keys()[0]
+        data = self.cfg.get_data(idval=idval)
+        self._properties = data[idval]
+
+    @property
+    def provider_id(self):
+        return self._properties['id']
+
+    @property
+    def name(self):
+        return self._properties['name']
+
+    @property
+    def allowed_namedids(self):
+        if 'allowed nameid' in self._properties:
+            return self._properties['allowed nameid']
+        else:
+            return self.cfg.default_allowed_nameids
+
+    @property
+    def default_nameid(self):
+        if 'default nameid' in self._properties:
+            return self._properties['default nameid']
+        else:
+            return self.cfg.default_nameid
+
+    def get_valid_nameid(self, nip):
+        self._debug('Requested NameId [%s]' % (nip.format,))
+        if nip.format == None:
+            return NAMEID_MAP[self.default_nameid]
+        elif nip.format == lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED:
+            return NAMEID_MAP[self.default_nameid]
+        else:
+            allowed = self.allowed_namedids
+            self._debug('Allowed NameIds %s' % (repr(allowed)))
+            for nameid in allowed:
+                if nip.format == NAMEID_MAP[nameid]:
+                    return nip.format
+        raise NameIdNotAllowed()
+
+    def _debug(self, fact):
+        if cherrypy.config.get('debug', False):
+            cherrypy.log(fact)
index a22a1f4..3dda9e8 100755 (executable)
@@ -162,6 +162,16 @@ Provides SAML 2.0 authentication infrastructure. """
                 """ Allow authenticated users to register applications. """,
                 'boolean',
                 True
                 """ Allow authenticated users to register applications. """,
                 'boolean',
                 True
+            ],
+            'default allowed nameids': [
+                """Default Allowed NameIDs for Service Providers. """,
+                'list',
+                ['transient', 'email', 'kerberos', 'x509']
+            ],
+            'default nameid': [
+                """Default NameID used by Service Providers. """,
+                'string',
+                'email'
             ]
         }
 
             ]
         }
 
@@ -188,6 +198,14 @@ Provides SAML 2.0 authentication infrastructure. """
         return os.path.join(self.idp_storage_path,
                             self.get_config_value('idp key file'))
 
         return os.path.join(self.idp_storage_path,
                             self.get_config_value('idp key file'))
 
+    @property
+    def default_allowed_nameids(self):
+        return self.get_config_value('default allowed nameids')
+
+    @property
+    def default_nameid(self):
+        return self.get_config_value('default nameid')
+
     def get_tree(self, site):
         self.page = SAML2(site, self)
         return self.page
     def get_tree(self, site):
         self.page = SAML2(site, self)
         return self.page
index cbd3b49..2a55bb2 100755 (executable)
@@ -189,21 +189,43 @@ class Store(object):
             if con:
                 con.close()
 
             if con:
                 con.close()
 
-    def get_data(self, plugin):
+    def get_data(self, plugin, idval=None, name=None, value=None):
         con = None
         rows = []
         con = None
         rows = []
+        names = None
+        values = ()
+        if idval or name or value:
+            names = ""
+            if idval:
+                names += " id=?"
+                values = values + (idval,)
+            if name:
+                if len(names) != 0:
+                    names += " AND"
+                names += " name=?"
+                values = values + (name,)
+            if value:
+                if len(names) != 0:
+                    names += " AND"
+                names += " value=?"
+                values = values + (value,)
         try:
             con = sqlite3.connect(self._admin_dbname)
             cur = con.cursor()
             cur.execute("CREATE TABLE IF NOT EXISTS " +
                         plugin + "_data (id INTEGER, name TEXT, value TEXT)")
         try:
             con = sqlite3.connect(self._admin_dbname)
             cur = con.cursor()
             cur.execute("CREATE TABLE IF NOT EXISTS " +
                         plugin + "_data (id INTEGER, name TEXT, value TEXT)")
-            cur.execute("SELECT * FROM " + plugin + "_data")
+            if not names:
+                cur.execute("SELECT * FROM " + plugin + "_data")
+            else:
+                cur.execute("SELECT * FROM " + plugin + "_data WHERE" +
+                            names, values)
             rows = cur.fetchall()
             con.commit()
         except sqlite3.Error, e:
             if con:
                 con.rollback()
             cherrypy.log.error("Failed to load %s data: [%s]" % (plugin, e))
             rows = cur.fetchall()
             con.commit()
         except sqlite3.Error, e:
             if con:
                 con.rollback()
             cherrypy.log.error("Failed to load %s data: [%s]" % (plugin, e))
+            cherrypy.log.error(repr([names, values]))
         finally:
             if con:
                 con.close()
         finally:
             if con:
                 con.close()
index 16a086a..045cc75 100755 (executable)
@@ -138,8 +138,9 @@ class PluginObject(object):
             self._config = dict()
         self._config[option] = value
 
             self._config = dict()
         self._config[option] = value
 
-    def get_data(self):
-        return self._data.get_data(self.name)
+    def get_data(self, idval=None, name=None, value=None):
+        return self._data.get_data(self.name, idval=idval, name=name,
+                                   value=value)
 
     def save_data(self, data):
         self._data.save_data(self.name, data)
 
     def save_data(self, data):
         self._data.save_data(self.name, data)