pylint 1.4.3 version fixes
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / auth.py
old mode 100755 (executable)
new mode 100644 (file)
index 861ef96..9d2bb7d
@@ -1,5 +1,3 @@
-#!/usr/bin/python
-#
 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
 #
 # see file 'COPYING' for use and warranty information
 # Copyright (C) 2014  Simo Sorce <simo@redhat.com>
 #
 # see file 'COPYING' for use and warranty information
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 from ipsilon.providers.common import ProviderPageBase, ProviderException
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 from ipsilon.providers.common import ProviderPageBase, ProviderException
+from ipsilon.providers.common import AuthenticationError, InvalidRequest
 from ipsilon.providers.saml2.provider import ServiceProvider
 from ipsilon.providers.saml2.provider import InvalidProviderId
 from ipsilon.providers.saml2.provider import NameIdNotAllowed
 from ipsilon.providers.saml2.provider import ServiceProvider
 from ipsilon.providers.saml2.provider import InvalidProviderId
 from ipsilon.providers.saml2.provider import NameIdNotAllowed
+from ipsilon.providers.saml2.sessions import SAMLSessionsContainer
+from ipsilon.util.policy import Policy
 from ipsilon.util.user import UserSession
 from ipsilon.util.trans import Transaction
 import cherrypy
 import datetime
 import lasso
 from ipsilon.util.user import UserSession
 from ipsilon.util.trans import Transaction
 import cherrypy
 import datetime
 import lasso
-
-
-class AuthenticationError(ProviderException):
-
-    def __init__(self, message, code):
-        super(AuthenticationError, self).__init__(message)
-        self.code = code
-        self._debug('%s [%s]' % (message, code))
-
-
-class InvalidRequest(ProviderException):
-
-    def __init__(self, message):
-        super(InvalidRequest, self).__init__(message)
-        self._debug(message)
+import uuid
+import hashlib
 
 
 class UnknownProvider(ProviderException):
 
     def __init__(self, message):
         super(UnknownProvider, self).__init__(message)
 
 
 class UnknownProvider(ProviderException):
 
     def __init__(self, message):
         super(UnknownProvider, self).__init__(message)
-        self._debug(message)
+        self.debug(message)
 
 
 class AuthenticateRequest(ProviderPageBase):
 
 
 class AuthenticateRequest(ProviderPageBase):
@@ -107,7 +95,7 @@ class AuthenticateRequest(ProviderPageBase):
                                                  e, message)
             raise UnknownProvider(msg)
 
                                                  e, message)
             raise UnknownProvider(msg)
 
-        self._debug('SP %s requested authentication' % login.remoteProviderId)
+        self.debug('SP %s requested authentication' % login.remoteProviderId)
 
         return login
 
 
         return login
 
@@ -120,13 +108,13 @@ class AuthenticateRequest(ProviderPageBase):
         try:
             login = self._parse_request(request)
         except InvalidRequest, e:
         try:
             login = self._parse_request(request)
         except InvalidRequest, e:
-            self._debug(str(e))
+            self.debug(str(e))
             raise cherrypy.HTTPError(400, 'Invalid SAML request token')
         except UnknownProvider, e:
             raise cherrypy.HTTPError(400, 'Invalid SAML request token')
         except UnknownProvider, e:
-            self._debug(str(e))
+            self.debug(str(e))
             raise cherrypy.HTTPError(400, 'Unknown Service Provider')
         except Exception, e:  # pylint: disable=broad-except
             raise cherrypy.HTTPError(400, 'Unknown Service Provider')
         except Exception, e:  # pylint: disable=broad-except
-            self._debug(str(e))
+            self.debug(str(e))
             raise cherrypy.HTTPError(500)
 
         return login
             raise cherrypy.HTTPError(500)
 
         return login
@@ -141,7 +129,8 @@ class AuthenticateRequest(ProviderPageBase):
                     self.basepath, self.trans.get_GET_arg())
                 data = {'saml2_stage': 'auth',
                         'saml2_request': login.dump(),
                     self.basepath, self.trans.get_GET_arg())
                 data = {'saml2_stage': 'auth',
                         'saml2_request': login.dump(),
-                        'login_return': returl}
+                        'login_return': returl,
+                        'login_target': login.remoteProviderId}
                 self.trans.store(data)
                 redirect = '%s/login?%s' % (self.basepath,
                                             self.trans.get_GET_arg())
                 self.trans.store(data)
                 redirect = '%s/login?%s' % (self.basepath,
                                             self.trans.get_GET_arg())
@@ -195,17 +184,27 @@ class AuthenticateRequest(ProviderPageBase):
 
         nameid = None
         if nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_PERSISTENT:
 
         nameid = None
         if nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_PERSISTENT:
-            # TODO map to something else ?
-            nameid = provider.normalize_username(user.name)
+            idpsalt = self.cfg.idp_nameid_salt
+            if idpsalt is None:
+                raise AuthenticationError(
+                    "idp nameid salt is not set in configuration"
+                )
+            value = hashlib.sha512()
+            value.update(idpsalt)
+            value.update(login.remoteProviderId)
+            value.update(user.name)
+            nameid = '_' + value.hexdigest()
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT:
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT:
-            # TODO map to something else ?
-            nameid = provider.normalize_username(user.name)
+            nameid = '_' + uuid.uuid4().hex
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_KERBEROS:
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_KERBEROS:
-            nameid = us.get_data('user', 'krb_principal_name')
+            userattrs = us.get_user_attrs()
+            nameid = userattrs.get('gssapi_principal_name')
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_EMAIL:
             nameid = us.get_user().email
             if not nameid:
                 nameid = '%s@%s' % (user.name, self.cfg.default_email_domain)
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_EMAIL:
             nameid = us.get_user().email
             if not nameid:
                 nameid = '%s@%s' % (user.name, self.cfg.default_email_domain)
+        elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED:
+            nameid = provider.normalize_username(user.name)
 
         if nameid:
             login.assertion.subject.nameId.format = nameidfmt
 
         if nameid:
             login.assertion.subject.nameId.format = nameidfmt
@@ -215,30 +214,84 @@ class AuthenticateRequest(ProviderPageBase):
             raise AuthenticationError("Unavailable Name ID type",
                                       lasso.SAML2_STATUS_CODE_AUTHN_FAILED)
 
             raise AuthenticationError("Unavailable Name ID type",
                                       lasso.SAML2_STATUS_CODE_AUTHN_FAILED)
 
-        # TODO: filter user attributes as policy requires from 'usersession'
-        if not login.assertion.attributeStatement:
-            attrstat = lasso.Saml2AttributeStatement()
-            login.assertion.attributeStatement = [attrstat]
+        # Check attribute policy and perform mapping and filtering.
+        # If the SP has its own mapping or filtering policy use that
+        # instead of the global policy.
+        if (provider.attribute_mappings is not None and
+                len(provider.attribute_mappings) > 0):
+            attribute_mappings = provider.attribute_mappings
+        else:
+            attribute_mappings = self.cfg.default_attribute_mapping
+        if (provider.allowed_attributes is not None and
+                len(provider.allowed_attributes) > 0):
+            allowed_attributes = provider.allowed_attributes
         else:
         else:
-            attrstat = login.assertion.attributeStatement[0]
-        if not attrstat.attribute:
-            attrstat.attribute = ()
+            allowed_attributes = self.cfg.default_allowed_attributes
+        self.debug("Allowed attrs: %s" % allowed_attributes)
+        self.debug("Mapping: %s" % attribute_mappings)
+        policy = Policy(attribute_mappings, allowed_attributes)
+        userattrs = us.get_user_attrs()
+        mappedattrs, _ = policy.map_attributes(userattrs)
+        attributes = policy.filter_attributes(mappedattrs)
+
+        if '_groups' in attributes and 'groups' not in attributes:
+            attributes['groups'] = attributes['_groups']
+
+        self.debug("%s's attributes: %s" % (user.name, attributes))
+
+        # The saml-core-2.0-os specification section 2.7.3 requires
+        # the AttributeStatement element to be non-empty.
+        if attributes:
+            if not login.assertion.attributeStatement:
+                attrstat = lasso.Saml2AttributeStatement()
+                login.assertion.attributeStatement = [attrstat]
+            else:
+                attrstat = login.assertion.attributeStatement[0]
+            if not attrstat.attribute:
+                attrstat.attribute = ()
 
 
-        attributes = us.get_user_attrs()
         for key in attributes:
         for key in attributes:
-            attr = lasso.Saml2Attribute()
-            attr.name = key
-            attr.nameFormat = lasso.SAML2_ATTRIBUTE_NAME_FORMAT_BASIC
-            value = str(attributes[key]).encode('utf-8')
-            node = lasso.MiscTextNode.newWithString(value)
-            node.textChild = True
-            attrvalue = lasso.Saml2AttributeValue()
-            attrvalue.any = [node]
-            attr.attributeValue = [attrvalue]
-            attrstat.attribute = attrstat.attribute + (attr,)
+            # skip internal info
+            if key[0] == '_':
+                continue
+            values = attributes[key]
+            if isinstance(values, dict):
+                continue
+            if not isinstance(values, list):
+                values = [values]
+            for value in values:
+                attr = lasso.Saml2Attribute()
+                attr.name = key
+                attr.nameFormat = lasso.SAML2_ATTRIBUTE_NAME_FORMAT_BASIC
+                value = str(value).encode('utf-8')
+                self.debug('value %s' % value)
+                node = lasso.MiscTextNode.newWithString(value)
+                node.textChild = True
+                attrvalue = lasso.Saml2AttributeValue()
+                attrvalue.any = [node]
+                attr.attributeValue = [attrvalue]
+                attrstat.attribute = attrstat.attribute + (attr,)
 
         self.debug('Assertion: %s' % login.assertion.dump())
 
 
         self.debug('Assertion: %s' % login.assertion.dump())
 
+        saml_sessions = us.get_provider_data('saml2')
+        if saml_sessions is None:
+            saml_sessions = SAMLSessionsContainer()
+
+        session = saml_sessions.find_session_by_provider(
+            login.remoteProviderId)
+        if session:
+            # TODO: something...
+            self.debug('Login session for this user already exists!?')
+            session.dump()
+
+        lasso_session = lasso.Session()
+        lasso_session.addAssertion(login.remoteProviderId, login.assertion)
+        saml_sessions.add_session(login.assertion.id,
+                                  login.remoteProviderId,
+                                  lasso_session)
+        us.save_provider_data('saml2', saml_sessions)
+
     def saml2error(self, login, code, message):
         status = lasso.Samlp2Status()
         status.statusCode = lasso.Samlp2StatusCode()
     def saml2error(self, login, code, message):
         status = lasso.Samlp2Status()
         status.statusCode = lasso.Samlp2StatusCode()
@@ -253,7 +306,7 @@ class AuthenticateRequest(ProviderPageBase):
             raise cherrypy.HTTPError(501)
         elif login.protocolProfile == lasso.LOGIN_PROTOCOL_PROFILE_BRWS_POST:
             login.buildAuthnResponseMsg()
             raise cherrypy.HTTPError(501)
         elif login.protocolProfile == lasso.LOGIN_PROTOCOL_PROFILE_BRWS_POST:
             login.buildAuthnResponseMsg()
-            self._debug('POSTing back to SP [%s]' % (login.msgUrl))
+            self.debug('POSTing back to SP [%s]' % (login.msgUrl))
             context = {
                 "title": 'Redirecting back to the web application',
                 "action": login.msgUrl,
             context = {
                 "title": 'Redirecting back to the web application',
                 "action": login.msgUrl,
@@ -263,7 +316,6 @@ class AuthenticateRequest(ProviderPageBase):
                 ],
                 "submit": 'Return to application',
             }
                 ],
                 "submit": 'Return to application',
             }
-            # pylint: disable=star-args
             return self._template('saml2/post_response.html', **context)
 
         else:
             return self._template('saml2/post_response.html', **context)
 
         else: