Implement urn:oasis:names:tc:SAML:2.0:nameid-format:transient
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / auth.py
old mode 100755 (executable)
new mode 100644 (file)
index cbfeaaa..71bfc9a
@@ -1,5 +1,3 @@
-#!/usr/bin/python
-#
 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
 #
 # see file 'COPYING' for use and warranty information
 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
 #
 # see file 'COPYING' for use and warranty information
@@ -22,11 +20,14 @@ from ipsilon.providers.common import AuthenticationError, InvalidRequest
 from ipsilon.providers.saml2.provider import ServiceProvider
 from ipsilon.providers.saml2.provider import InvalidProviderId
 from ipsilon.providers.saml2.provider import NameIdNotAllowed
 from ipsilon.providers.saml2.provider import ServiceProvider
 from ipsilon.providers.saml2.provider import InvalidProviderId
 from ipsilon.providers.saml2.provider import NameIdNotAllowed
+from ipsilon.providers.saml2.sessions import SAMLSessionsContainer
+from ipsilon.util.policy import Policy
 from ipsilon.util.user import UserSession
 from ipsilon.util.trans import Transaction
 import cherrypy
 import datetime
 import lasso
 from ipsilon.util.user import UserSession
 from ipsilon.util.trans import Transaction
 import cherrypy
 import datetime
 import lasso
+import uuid
 
 
 class UnknownProvider(ProviderException):
 
 
 class UnknownProvider(ProviderException):
@@ -127,7 +128,8 @@ class AuthenticateRequest(ProviderPageBase):
                     self.basepath, self.trans.get_GET_arg())
                 data = {'saml2_stage': 'auth',
                         'saml2_request': login.dump(),
                     self.basepath, self.trans.get_GET_arg())
                 data = {'saml2_stage': 'auth',
                         'saml2_request': login.dump(),
-                        'login_return': returl}
+                        'login_return': returl,
+                        'login_target': login.remoteProviderId}
                 self.trans.store(data)
                 redirect = '%s/login?%s' % (self.basepath,
                                             self.trans.get_GET_arg())
                 self.trans.store(data)
                 redirect = '%s/login?%s' % (self.basepath,
                                             self.trans.get_GET_arg())
@@ -184,8 +186,7 @@ class AuthenticateRequest(ProviderPageBase):
             # TODO map to something else ?
             nameid = provider.normalize_username(user.name)
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT:
             # TODO map to something else ?
             nameid = provider.normalize_username(user.name)
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT:
-            # TODO map to something else ?
-            nameid = provider.normalize_username(user.name)
+            nameid = '_' + uuid.uuid4().hex
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_KERBEROS:
             nameid = us.get_data('user', 'krb_principal_name')
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_EMAIL:
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_KERBEROS:
             nameid = us.get_data('user', 'krb_principal_name')
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_EMAIL:
@@ -201,30 +202,71 @@ class AuthenticateRequest(ProviderPageBase):
             raise AuthenticationError("Unavailable Name ID type",
                                       lasso.SAML2_STATUS_CODE_AUTHN_FAILED)
 
             raise AuthenticationError("Unavailable Name ID type",
                                       lasso.SAML2_STATUS_CODE_AUTHN_FAILED)
 
-        # TODO: filter user attributes as policy requires from 'usersession'
-        if not login.assertion.attributeStatement:
-            attrstat = lasso.Saml2AttributeStatement()
-            login.assertion.attributeStatement = [attrstat]
-        else:
-            attrstat = login.assertion.attributeStatement[0]
-        if not attrstat.attribute:
-            attrstat.attribute = ()
+        # Check attribute policy and perform mapping and filtering
+        policy = Policy(self.cfg.default_attribute_mapping,
+                        self.cfg.default_allowed_attributes)
+        userattrs = us.get_user_attrs()
+        mappedattrs, _ = policy.map_attributes(userattrs)
+        attributes = policy.filter_attributes(mappedattrs)
+
+        if '_groups' in attributes and 'groups' not in attributes:
+            attributes['groups'] = attributes['_groups']
+
+        self.debug("%s's attributes: %s" % (user.name, attributes))
+
+        # The saml-core-2.0-os specification section 2.7.3 requires
+        # the AttributeStatement element to be non-empty.
+        if attributes:
+            if not login.assertion.attributeStatement:
+                attrstat = lasso.Saml2AttributeStatement()
+                login.assertion.attributeStatement = [attrstat]
+            else:
+                attrstat = login.assertion.attributeStatement[0]
+            if not attrstat.attribute:
+                attrstat.attribute = ()
 
 
-        attributes = us.get_user_attrs()
         for key in attributes:
         for key in attributes:
-            attr = lasso.Saml2Attribute()
-            attr.name = key
-            attr.nameFormat = lasso.SAML2_ATTRIBUTE_NAME_FORMAT_BASIC
-            value = str(attributes[key]).encode('utf-8')
-            node = lasso.MiscTextNode.newWithString(value)
-            node.textChild = True
-            attrvalue = lasso.Saml2AttributeValue()
-            attrvalue.any = [node]
-            attr.attributeValue = [attrvalue]
-            attrstat.attribute = attrstat.attribute + (attr,)
+            # skip internal info
+            if key[0] == '_':
+                continue
+            values = attributes[key]
+            if isinstance(values, dict):
+                continue
+            if not isinstance(values, list):
+                values = [values]
+            for value in values:
+                attr = lasso.Saml2Attribute()
+                attr.name = key
+                attr.nameFormat = lasso.SAML2_ATTRIBUTE_NAME_FORMAT_BASIC
+                value = str(value).encode('utf-8')
+                self.debug('value %s' % value)
+                node = lasso.MiscTextNode.newWithString(value)
+                node.textChild = True
+                attrvalue = lasso.Saml2AttributeValue()
+                attrvalue.any = [node]
+                attr.attributeValue = [attrvalue]
+                attrstat.attribute = attrstat.attribute + (attr,)
 
         self.debug('Assertion: %s' % login.assertion.dump())
 
 
         self.debug('Assertion: %s' % login.assertion.dump())
 
+        saml_sessions = us.get_provider_data('saml2')
+        if saml_sessions is None:
+            saml_sessions = SAMLSessionsContainer()
+
+        session = saml_sessions.find_session_by_provider(
+            login.remoteProviderId)
+        if session:
+            # TODO: something...
+            self.debug('Login session for this user already exists!?')
+            session.dump()
+
+        lasso_session = lasso.Session()
+        lasso_session.addAssertion(login.remoteProviderId, login.assertion)
+        saml_sessions.add_session(login.assertion.id,
+                                  login.remoteProviderId,
+                                  lasso_session)
+        us.save_provider_data('saml2', saml_sessions)
+
     def saml2error(self, login, code, message):
         status = lasso.Samlp2Status()
         status.statusCode = lasso.Samlp2StatusCode()
     def saml2error(self, login, code, message):
         status = lasso.Samlp2Status()
         status.statusCode = lasso.Samlp2StatusCode()