Add support for logout over SOAP
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / provider.py
old mode 100755 (executable)
new mode 100644 (file)
index 58ffbfe..b70582e
@@ -1,26 +1,16 @@
-#!/usr/bin/python
-#
-# Copyright (C) 2014  Simo Sorce <simo@redhat.com>
-#
-# see file 'COPYING' for use and warranty information
-#
-# This program is free software; you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+# Copyright (C) 2014 Ipsilon project Contributors, for license see COPYING
 
 from ipsilon.providers.common import ProviderException
-from ipsilon.tools.saml2metadata import SAML2_NAMEID_MAP
+from ipsilon.util import config as pconfig
+from ipsilon.util.config import ConfigHelper
+from ipsilon.tools.saml2metadata import SAML2_NAMEID_MAP, NSMAP
 from ipsilon.util.log import Log
+from lxml import etree
 import lasso
+import re
+
+
+VALID_IN_NAME = r'[^\ a-zA-Z0-9]'
 
 
 class InvalidProviderId(ProviderException):
@@ -28,7 +18,7 @@ class InvalidProviderId(ProviderException):
     def __init__(self, code):
         message = 'Invalid Provider ID: %s' % code
         super(InvalidProviderId, self).__init__(message)
-        self._debug(message)
+        self.debug(message)
 
 
 class NameIdNotAllowed(Exception):
@@ -42,9 +32,15 @@ class NameIdNotAllowed(Exception):
         return repr(self.message)
 
 
-class ServiceProvider(Log):
+class ServiceProviderConfig(ConfigHelper):
+    def __init__(self):
+        super(ServiceProviderConfig, self).__init__()
+
+
+class ServiceProvider(ServiceProviderConfig):
 
     def __init__(self, config, provider_id):
+        super(ServiceProvider, self).__init__()
         self.cfg = config
         data = self.cfg.get_data(name='id', value=provider_id)
         if len(data) != 1:
@@ -53,6 +49,50 @@ class ServiceProvider(Log):
         data = self.cfg.get_data(idval=idval)
         self._properties = data[idval]
         self._staging = dict()
+        self.load_config()
+        self.logout_mechs = []
+        xmldoc = etree.XML(str(data[idval]['metadata']))
+        logout = xmldoc.xpath('//md:EntityDescriptor'
+                              '/md:SPSSODescriptor'
+                              '/md:SingleLogoutService',
+                              namespaces=NSMAP)
+        for service in logout:
+            self.logout_mechs.append(service.values()[0])
+
+    def load_config(self):
+        self.new_config(
+            self.provider_id,
+            pconfig.String(
+                'Name',
+                'A nickname used to easily identify the Service Provider.'
+                ' Only alphanumeric characters [A-Z,a-z,0-9] and spaces are'
+                '  accepted.',
+                self.name),
+            pconfig.Pick(
+                'Default NameID',
+                'Default NameID used by Service Providers.',
+                SAML2_NAMEID_MAP.keys(),
+                self.default_nameid),
+            pconfig.Choice(
+                'Allowed NameIDs',
+                'Allowed NameIDs for this Service Provider.',
+                SAML2_NAMEID_MAP.keys(),
+                self.allowed_nameids),
+            pconfig.String(
+                'User Owner',
+                'The user that owns this Service Provider',
+                self.owner),
+            pconfig.MappingList(
+                'Attribute Mapping',
+                'Defines how to map attributes before returning them to'
+                ' the SP. Setting this overrides the global values.',
+                self.attribute_mappings),
+            pconfig.ComplexList(
+                'Allowed Attributes',
+                'Defines a list of allowed attributes, applied after mapping.'
+                ' Setting this overrides the global values.',
+                self.allowed_attributes),
+        )
 
     @property
     def provider_id(self):
@@ -87,7 +127,7 @@ class ServiceProvider(Log):
 
     @allowed_nameids.setter
     def allowed_nameids(self, value):
-        if type(value) is not list:
+        if not isinstance(value, list):
             raise ValueError("Must be a list")
         self._staging['allowed nameids'] = ','.join(value)
 
@@ -102,6 +142,44 @@ class ServiceProvider(Log):
     def default_nameid(self, value):
         self._staging['default nameid'] = value
 
+    @property
+    def attribute_mappings(self):
+        if 'attribute mappings' in self._properties:
+            attr_map = pconfig.MappingList('temp', 'temp', None)
+            attr_map.import_value(str(self._properties['attribute mappings']))
+            return attr_map.get_value()
+        else:
+            return None
+
+    @attribute_mappings.setter
+    def attribute_mappings(self, attr_map):
+        if isinstance(attr_map, pconfig.MappingList):
+            value = attr_map.export_value()
+        else:
+            temp = pconfig.MappingList('temp', 'temp', None)
+            temp.set_value(attr_map)
+            value = temp.export_value()
+        self._staging['attribute mappings'] = value
+
+    @property
+    def allowed_attributes(self):
+        if 'allowed_attributes' in self._properties:
+            attr_map = pconfig.ComplexList('temp', 'temp', None)
+            attr_map.import_value(str(self._properties['allowed_attributes']))
+            return attr_map.get_value()
+        else:
+            return None
+
+    @allowed_attributes.setter
+    def allowed_attributes(self, attr_map):
+        if isinstance(attr_map, pconfig.ComplexList):
+            value = attr_map.export_value()
+        else:
+            temp = pconfig.ComplexList('temp', 'temp', None)
+            temp.set_value(attr_map)
+            value = temp.export_value()
+        self._staging['allowed_attributes'] = value
+
     def save_properties(self):
         data = self.cfg.get_data(name='id', value=self.provider_id)
         if len(data) != 1:
@@ -114,15 +192,21 @@ class ServiceProvider(Log):
         self._properties = data[idval]
         self._staging = dict()
 
+    def refresh_config(self):
+        """
+        Create a new config object for displaying in the UI based on
+        the current set of properties.
+        """
+        del self._config
+        self.load_config()
+
     def get_valid_nameid(self, nip):
-        self._debug('Requested NameId [%s]' % (nip.format,))
+        self.debug('Requested NameId [%s]' % (nip.format,))
         if nip.format is None:
             return SAML2_NAMEID_MAP[self.default_nameid]
-        elif nip.format == lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED:
-            return SAML2_NAMEID_MAP[self.default_nameid]
         else:
             allowed = self.allowed_nameids
-            self._debug('Allowed NameIds %s' % (repr(allowed)))
+            self.debug('Allowed NameIds %s' % (repr(allowed)))
             for nameid in allowed:
                 if nip.format == SAML2_NAMEID_MAP[nameid]:
                     return nip.format
@@ -140,6 +224,11 @@ class ServiceProvider(Log):
             return username.split('@', 1)[0]
         return username
 
+    def is_valid_name(self, value):
+        if re.search(VALID_IN_NAME, value):
+            return False
+        return True
+
     def is_valid_nameid(self, value):
         if value in SAML2_NAMEID_MAP:
             return True
@@ -157,6 +246,10 @@ class ServiceProviderCreator(object):
     def create_from_buffer(self, name, metabuf):
         '''Test and add data'''
 
+        if re.search(VALID_IN_NAME, name):
+            raise InvalidProviderId("Name must contain only "
+                                    "numbers and letters")
+
         test = lasso.Server()
         test.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP, metabuf)
         newsps = test.get_providers()
@@ -182,17 +275,18 @@ class ServiceProviderCreator(object):
 
 
 class IdentityProvider(Log):
-    def __init__(self, config):
+    def __init__(self, config, sessionfactory):
         self.server = lasso.Server(config.idp_metadata_file,
                                    config.idp_key_file,
                                    None,
                                    config.idp_certificate_file)
         self.server.role = lasso.PROVIDER_ROLE_IDP
+        self.sessionfactory = sessionfactory
 
     def add_provider(self, sp):
         self.server.addProviderFromBuffer(lasso.PROVIDER_ROLE_SP,
                                           sp['metadata'])
-        self._debug('Added SP %s' % sp['name'])
+        self.debug('Added SP %s' % sp['name'])
 
     def get_login_handler(self, dump=None):
         if dump:
@@ -202,3 +296,9 @@ class IdentityProvider(Log):
 
     def get_providers(self):
         return self.server.get_providers()
+
+    def get_logout_handler(self, dump=None):
+        if dump:
+            return lasso.Logout.newFromDump(self.server, dump)
+        else:
+            return lasso.Logout(self.server)