Assertion AttributeStatements must be non-empty
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / auth.py
old mode 100755 (executable)
new mode 100644 (file)
index 49f73a9..f5e8f0f
@@ -1,5 +1,3 @@
-#!/usr/bin/python
-#
 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
 #
 # see file 'COPYING' for use and warranty information
@@ -22,6 +20,8 @@ from ipsilon.providers.common import AuthenticationError, InvalidRequest
 from ipsilon.providers.saml2.provider import ServiceProvider
 from ipsilon.providers.saml2.provider import InvalidProviderId
 from ipsilon.providers.saml2.provider import NameIdNotAllowed
+from ipsilon.providers.saml2.sessions import SAMLSessionsContainer
+from ipsilon.util.policy import Policy
 from ipsilon.util.user import UserSession
 from ipsilon.util.trans import Transaction
 import cherrypy
@@ -202,29 +202,37 @@ class AuthenticateRequest(ProviderPageBase):
             raise AuthenticationError("Unavailable Name ID type",
                                       lasso.SAML2_STATUS_CODE_AUTHN_FAILED)
 
-        # TODO: filter user attributes as policy requires from 'usersession'
-        if not login.assertion.attributeStatement:
-            attrstat = lasso.Saml2AttributeStatement()
-            login.assertion.attributeStatement = [attrstat]
-        else:
-            attrstat = login.assertion.attributeStatement[0]
-        if not attrstat.attribute:
-            attrstat.attribute = ()
-
-        attributes = dict()
+        # Check attribute policy and perform mapping and filtering
+        policy = Policy(self.cfg.default_attribute_mapping,
+                        self.cfg.default_allowed_attributes)
         userattrs = us.get_user_attrs()
-        for key, value in userattrs.get('userdata', {}).iteritems():
-            if type(value) is str:
-                attributes[key] = value
-        if 'groups' in userattrs:
-            attributes['group'] = userattrs['groups']
-        for _, info in userattrs.get('extras', {}).iteritems():
-            for key, value in info.items():
-                attributes[key] = value
+        mappedattrs, _ = policy.map_attributes(userattrs)
+        attributes = policy.filter_attributes(mappedattrs)
+
+        if '_groups' in attributes and 'groups' not in attributes:
+            attributes['groups'] = attributes['_groups']
+
+        self.debug("%s's attributes: %s" % (user.name, attributes))
+
+        # The saml-core-2.0-os specification section 2.7.3 requires
+        # the AttributeStatement element to be non-empty.
+        if attributes:
+            if not login.assertion.attributeStatement:
+                attrstat = lasso.Saml2AttributeStatement()
+                login.assertion.attributeStatement = [attrstat]
+            else:
+                attrstat = login.assertion.attributeStatement[0]
+            if not attrstat.attribute:
+                attrstat.attribute = ()
 
         for key in attributes:
+            # skip internal info
+            if key[0] == '_':
+                continue
             values = attributes[key]
-            if type(values) is not list:
+            if isinstance(values, dict):
+                continue
+            if not isinstance(values, list):
                 values = [values]
             for value in values:
                 attr = lasso.Saml2Attribute()
@@ -241,6 +249,24 @@ class AuthenticateRequest(ProviderPageBase):
 
         self.debug('Assertion: %s' % login.assertion.dump())
 
+        saml_sessions = us.get_provider_data('saml2')
+        if saml_sessions is None:
+            saml_sessions = SAMLSessionsContainer()
+
+        session = saml_sessions.find_session_by_provider(
+            login.remoteProviderId)
+        if session:
+            # TODO: something...
+            self.debug('Login session for this user already exists!?')
+            session.dump()
+
+        lasso_session = lasso.Session()
+        lasso_session.addAssertion(login.remoteProviderId, login.assertion)
+        saml_sessions.add_session(login.assertion.id,
+                                  login.remoteProviderId,
+                                  lasso_session)
+        us.save_provider_data('saml2', saml_sessions)
+
     def saml2error(self, login, code, message):
         status = lasso.Samlp2Status()
         status.statusCode = lasso.Samlp2StatusCode()