Add support for expiration in Metadata
authorSimo Sorce <simo@redhat.com>
Mon, 19 Jan 2015 20:15:03 +0000 (15:15 -0500)
committerPatrick Uiterwijk <puiterwijk@redhat.com>
Thu, 29 Jan 2015 19:06:45 +0000 (20:06 +0100)
Signed-off-by: Simo Sorce <simo@redhat.com>
Reviewed-by: Patrick Uiterwijk <puiterwijk@redhat.com>
ipsilon/tools/saml2metadata.py

index 27eddb9..f918a44 100755 (executable)
@@ -17,6 +17,7 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+import datetime
 from ipsilon.tools.certs import Certificate
 from lxml import etree
 import lasso
@@ -58,6 +59,10 @@ IDP_ROLE = 'idp'
 SP_ROLE = 'sp'
 
 
+# Expire metadata weekly by default
+MIN_EXP_DEFAULT = 7
+
+
 def mdElement(_parent, _tag, **kwargs):
     tag = '{%s}%s' % (lasso.SAML2_METADATA_HREF, _tag)
     return etree.SubElement(_parent, tag, **kwargs)
@@ -70,11 +75,12 @@ def dsElement(_parent, _tag, **kwargs):
 
 class Metadata(object):
 
-    def __init__(self, role=None):
+    def __init__(self, role=None, expiration=None):
         self.root = etree.Element(EDESC, nsmap=NSMAP)
         self.entityid = None
         self.role = None
         self.set_role(role)
+        self.set_expiration(expiration)
 
     def set_entity_id(self, url):
         self.entityid = url
@@ -93,6 +99,21 @@ class Metadata(object):
         self.role.set('protocolSupportEnumeration', lasso.SAML2_PROTOCOL_HREF)
         return self.role
 
+    def set_expiration(self, exp):
+        if exp is None:
+            self.root.set('cacheDuration', "P%dD" % (MIN_EXP_DEFAULT))
+            return
+        elif isinstance(exp, datetime.date):
+            d = datetime.datetime.combine(exp, datetime.date.min.time())
+        elif isinstance(exp, datetime.datetime):
+            d = exp
+        elif isinstance(exp, datetime.timedelta):
+            d = datetime.datetime.now() + exp
+        else:
+            raise TypeError('Invalid expiration date type')
+
+        self.root.set('validUntil', d.isoformat())
+
     def add_cert(self, certdata, use):
         desc = mdElement(self.role, 'KeyDescriptor')
         desc.set('use', use)
@@ -118,11 +139,14 @@ class Metadata(object):
         nameidfmt = mdElement(self.role, 'NameIDFormat')
         nameidfmt.text = name_format
 
-    def output(self, path):
+    def output(self, path=None):
         data = etree.tostring(self.root, xml_declaration=True,
                               encoding='UTF-8', pretty_print=True)
-        with open(path, 'w') as f:
-            f.write(data)
+        if path is None:
+            return data
+        else:
+            with open(path, 'w') as f:
+                f.write(data)
 
 
 if __name__ == '__main__':