Move some exceptions into provider.common
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / auth.py
index 65d795d..cbfeaaa 100755 (executable)
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 from ipsilon.providers.common import ProviderPageBase, ProviderException
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 from ipsilon.providers.common import ProviderPageBase, ProviderException
+from ipsilon.providers.common import AuthenticationError, InvalidRequest
 from ipsilon.providers.saml2.provider import ServiceProvider
 from ipsilon.providers.saml2.provider import InvalidProviderId
 from ipsilon.providers.saml2.provider import NameIdNotAllowed
 from ipsilon.util.user import UserSession
 from ipsilon.providers.saml2.provider import ServiceProvider
 from ipsilon.providers.saml2.provider import InvalidProviderId
 from ipsilon.providers.saml2.provider import NameIdNotAllowed
 from ipsilon.util.user import UserSession
+from ipsilon.util.trans import Transaction
 import cherrypy
 import datetime
 import lasso
 
 
 import cherrypy
 import datetime
 import lasso
 
 
-class AuthenticationError(ProviderException):
-
-    def __init__(self, message, code):
-        super(AuthenticationError, self).__init__(message)
-        self.code = code
-        self._debug('%s [%s]' % (message, code))
-
-
-class InvalidRequest(ProviderException):
-
-    def __init__(self, message):
-        super(InvalidRequest, self).__init__(message)
-        self._debug(message)
-
-
 class UnknownProvider(ProviderException):
 
     def __init__(self, message):
 class UnknownProvider(ProviderException):
 
     def __init__(self, message):
@@ -53,9 +40,25 @@ class AuthenticateRequest(ProviderPageBase):
 
     def __init__(self, *args, **kwargs):
         super(AuthenticateRequest, self).__init__(*args, **kwargs)
 
     def __init__(self, *args, **kwargs):
         super(AuthenticateRequest, self).__init__(*args, **kwargs)
-        self.STAGE_INIT = 0
-        self.STAGE_AUTH = 1
-        self.stage = self.STAGE_INIT
+        self.stage = 'init'
+        self.trans = None
+
+    def _preop(self, *args, **kwargs):
+        try:
+            # generate a new id or get current one
+            self.trans = Transaction('saml2', **kwargs)
+            if self.trans.cookie.value != self.trans.provider:
+                self.debug('Invalid transaction, %s != %s' % (
+                           self.trans.cookie.value, self.trans.provider))
+        except Exception, e:  # pylint: disable=broad-except
+            self.debug('Transaction initialization failed: %s' % repr(e))
+            raise cherrypy.HTTPError(400, 'Invalid transaction id')
+
+    def pre_GET(self, *args, **kwargs):
+        self._preop(*args, **kwargs)
+
+    def pre_POST(self, *args, **kwargs):
+        self._preop(*args, **kwargs)
 
     def auth(self, login):
         try:
 
     def auth(self, login):
         try:
@@ -116,21 +119,28 @@ class AuthenticateRequest(ProviderPageBase):
 
     def saml2checks(self, login):
 
 
     def saml2checks(self, login):
 
-        session = UserSession()
-        user = session.get_user()
+        us = UserSession()
+        user = us.get_user()
         if user.is_anonymous:
         if user.is_anonymous:
-            if self.stage < self.STAGE_AUTH:
-                session.save_data('saml2', 'stage', self.STAGE_AUTH)
-                session.save_data('saml2', 'Request', login.dump())
-                session.save_data('login', 'Return',
-                                  '%s/saml2/SSO/Continue' % self.basepath)
-                raise cherrypy.HTTPRedirect('%s/login' % self.basepath)
+            if self.stage == 'init':
+                returl = '%s/saml2/SSO/Continue?%s' % (
+                    self.basepath, self.trans.get_GET_arg())
+                data = {'saml2_stage': 'auth',
+                        'saml2_request': login.dump(),
+                        'login_return': returl}
+                self.trans.store(data)
+                redirect = '%s/login?%s' % (self.basepath,
+                                            self.trans.get_GET_arg())
+                raise cherrypy.HTTPRedirect(redirect)
             else:
                 raise AuthenticationError(
                     "Unknown user", lasso.SAML2_STATUS_CODE_AUTHN_FAILED)
 
         self._audit("Logged in user: %s [%s]" % (user.name, user.fullname))
 
             else:
                 raise AuthenticationError(
                     "Unknown user", lasso.SAML2_STATUS_CODE_AUTHN_FAILED)
 
         self._audit("Logged in user: %s [%s]" % (user.name, user.fullname))
 
+        # We can wipe the transaction now, as this is the last step
+        self.trans.wipe()
+
         # TODO: check if this is the first time this user access this SP
         # If required by user prefs, ask user for consent once and then
         # record it
         # TODO: check if this is the first time this user access this SP
         # If required by user prefs, ask user for consent once and then
         # record it
@@ -157,9 +167,6 @@ class AuthenticateRequest(ProviderPageBase):
         authtime_notbefore = authtime - skew
         authtime_notafter = authtime + skew
 
         authtime_notbefore = authtime - skew
         authtime_notafter = authtime + skew
 
-        us = UserSession()
-        user = us.get_user()
-
         # TODO: get authentication type fnd name format from session
         # need to save which login manager authenticated and map it to a
         # saml2 authentication context
         # TODO: get authentication type fnd name format from session
         # need to save which login manager authenticated and map it to a
         # saml2 authentication context
@@ -174,10 +181,10 @@ class AuthenticateRequest(ProviderPageBase):
 
         nameid = None
         if nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_PERSISTENT:
 
         nameid = None
         if nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_PERSISTENT:
-            ## TODO map to something else ?
+            # TODO map to something else ?
             nameid = provider.normalize_username(user.name)
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT:
             nameid = provider.normalize_username(user.name)
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT:
-            ## TODO map to something else ?
+            # TODO map to something else ?
             nameid = provider.normalize_username(user.name)
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_KERBEROS:
             nameid = us.get_data('user', 'krb_principal_name')
             nameid = provider.normalize_username(user.name)
         elif nameidfmt == lasso.SAML2_NAME_IDENTIFIER_FORMAT_KERBEROS:
             nameid = us.get_data('user', 'krb_principal_name')
@@ -190,10 +197,33 @@ class AuthenticateRequest(ProviderPageBase):
             login.assertion.subject.nameId.format = nameidfmt
             login.assertion.subject.nameId.content = nameid
         else:
             login.assertion.subject.nameId.format = nameidfmt
             login.assertion.subject.nameId.content = nameid
         else:
+            self.trans.wipe()
             raise AuthenticationError("Unavailable Name ID type",
                                       lasso.SAML2_STATUS_CODE_AUTHN_FAILED)
 
             raise AuthenticationError("Unavailable Name ID type",
                                       lasso.SAML2_STATUS_CODE_AUTHN_FAILED)
 
-        # TODO: add user attributes as policy requires from 'usersession'
+        # TODO: filter user attributes as policy requires from 'usersession'
+        if not login.assertion.attributeStatement:
+            attrstat = lasso.Saml2AttributeStatement()
+            login.assertion.attributeStatement = [attrstat]
+        else:
+            attrstat = login.assertion.attributeStatement[0]
+        if not attrstat.attribute:
+            attrstat.attribute = ()
+
+        attributes = us.get_user_attrs()
+        for key in attributes:
+            attr = lasso.Saml2Attribute()
+            attr.name = key
+            attr.nameFormat = lasso.SAML2_ATTRIBUTE_NAME_FORMAT_BASIC
+            value = str(attributes[key]).encode('utf-8')
+            node = lasso.MiscTextNode.newWithString(value)
+            node.textChild = True
+            attrvalue = lasso.Saml2AttributeValue()
+            attrvalue.any = [node]
+            attr.attributeValue = [attrvalue]
+            attrstat.attribute = attrstat.attribute + (attr,)
+
+        self.debug('Assertion: %s' % login.assertion.dump())
 
     def saml2error(self, login, code, message):
         status = lasso.Samlp2Status()
 
     def saml2error(self, login, code, message):
         status = lasso.Samlp2Status()