Update Copyright header point to COPYING file
[cascardo/ipsilon.git] / ipsilon / util / policy.py
1 # Copyright (C) 2015 Ipsilon project Contributors, for license see COPYING
2
3 from ipsilon.util.log import Log
4 import copy
5 import sys
6
7
8 class Policy(Log):
9
10     def __init__(self, mappings=None, allowed=None):
11         """ A Policy engine to filter attributes.
12         Mappings is a list of lists where the first value ia a list itself
13         and the second value is an attribute name or a list if the values
14         should go in a sub dictionary.
15         Note that mappings is a list and not a dictionary as this allows
16         to map the same original attribute to different resulting attributes
17         if wanted, by simply repeating the 'key list' with different values
18         or 'value lists'.
19
20             Example: [[['extras', 'shoes'], 'shoeNumber']]
21
22         A '*' can be used to allow any attribute.
23
24         The default mapping is [[['*'], '*']]
25         This copies all attributes without transformation.
26
27         Allowed is a list of allowed attributes.
28         Normally mapping should be called before filtering, this means
29         allowed attributes should name the mapped attributes.
30         Allowed attributes can be multi-element lists
31
32             Example: ['fullname', ['groups', 'domain users']]
33
34         Allowed is '*' by default.
35         """
36
37         self.mappings = None
38         if mappings:
39             if not isinstance(mappings, list):
40                 raise ValueError("Mappings should be a list not '%s'" %
41                                  type(mappings))
42             for el in mappings:
43                 if not isinstance(el, list):
44                     raise ValueError("Mappings must be lists, not '%s'" %
45                                      type(el))
46                 if len(el) != 2:
47                     raise ValueError("Mappings must contain 2 elements, "
48                                      "found %d" % len(el))
49                 if isinstance(el[0], list) and len(el[0]) > 2:
50                     raise ValueError("1st Mapping element can contain at "
51                                      "most 2 values, found %d" % len(el[0]))
52                 if isinstance(el[1], list) and len(el[1]) > 2:
53                     raise ValueError("2nd Mapping element can contain at "
54                                      "most 2 values, found %d" % len(el[1]))
55             self.mappings = mappings
56         else:
57             # default mapping, return all userdata and groups
58             # but ignore extras
59             self.mappings = [['*', '*']]
60
61         self.allowed = ['*']
62         if allowed:
63             if not isinstance(allowed, list):
64                 raise ValueError("Allowed should be a list not '%s'" %
65                                  type(allowed))
66             self.allowed = allowed
67
68     def map_attributes(self, attributes, ignore_case=False):
69
70         if not isinstance(attributes, dict):
71             raise ValueError("Attributes must be dictionary, not %s" %
72                              type(attributes))
73
74         not_mapped = copy.deepcopy(attributes)
75         mapped = dict()
76
77         # If ignore_case is True,
78         # then PD translates case insensitively prefixes
79         PD = dict()
80         for k in attributes.keys():
81             if ignore_case:
82                 # note duplicates that differ only by case
83                 # will be lost here, beware!
84                 PD[k.lower()] = k
85             else:
86                 PD[k] = k
87
88         for (key, value) in self.mappings:
89             if not isinstance(key, list):
90                 key = [key]
91             if len(key) == 2:
92                 prefix = key[0]
93                 name = key[1]
94             else:
95                 prefix = None
96                 name = key[0]
97
98             if not isinstance(value, list):
99                 value = [value]
100             if len(value) == 2:
101                 mapprefix = value[0]
102                 mapname = value[1]
103             else:
104                 mapprefix = None
105                 mapname = value[0]
106
107             if ignore_case:
108                 if prefix:
109                     prefix = prefix.lower()
110                 name = name.lower()
111
112             if prefix:
113                 if prefix in PD:
114                     attr = attributes[PD[prefix]]
115                 else:
116                     # '*' in a prefix matches nothing
117                     continue
118
119                 # If ignore_case is True,
120                 # then ND translates case insensitively names
121                 ND = dict()
122                 if isinstance(attr, list):
123                     klist = attr
124                 else:
125                     klist = attr.keys()
126                 for k in klist:
127                     if ignore_case:
128                         # note duplicates that differ only by case
129                         # will be lost here, beware!
130                         ND[k.lower()] = k
131                     else:
132                         ND[k] = k
133             else:
134                 attr = attributes
135                 ND = PD
136
137             if name in ND and ND[name] in attr:
138                 if isinstance(attr, list):
139                     if mapprefix:
140                         if mapprefix not in mapped:
141                             mapped[mapprefix] = list()
142                         mapped[mapprefix].append(mapname)
143                         if not_mapped:
144                             if PD[prefix] in not_mapped:
145                                 while ND[name] in not_mapped[PD[prefix]]:
146                                     not_mapped[PD[prefix]].remove(ND[name])
147                     else:
148                         if mapname not in mapped:
149                             mapped[mapname] = list()
150                         mapped[mapname].append(attr[ND[name]])
151                         if not_mapped:
152                             if PD[prefix] in not_mapped:
153                                 del not_mapped[PD[prefix]]
154                 else:
155                     mapin = copy.deepcopy(attr[ND[name]])
156                     if mapname == '*':
157                         mapname = ND[name]
158                     if mapprefix:
159                         if mapprefix not in mapped:
160                             mapped[mapprefix] = dict()
161                         mapped[mapprefix].update({mapname: mapin})
162                     else:
163                         mapped.update({mapname: mapin})
164                     if not_mapped:
165                         if prefix:
166                             if PD[prefix] in not_mapped:
167                                 if ND[name] in not_mapped[PD[prefix]]:
168                                     del not_mapped[PD[prefix]][ND[name]]
169                         elif ND[name] in not_mapped:
170                             del not_mapped[ND[name]]
171             elif name == '*':
172                 mapin = copy.deepcopy(attr)
173                 # mapname is ignored if name == '*'
174                 if mapprefix:
175                     if mapprefix not in mapped:
176                         mapped[mapprefix] = mapin
177                     else:
178                         mapped[mapprefix].update(mapin)
179                 else:
180                     mapped.update(mapin)
181                 if not_mapped:
182                     if prefix and PD[prefix] in not_mapped:
183                         del not_mapped[PD[prefix]]
184                     else:
185                         not_mapped = None
186             else:
187                 continue
188
189         return mapped, not_mapped
190
191     def filter_attributes(self, attributes, whitelist=True):
192
193         filtered = dict()
194
195         for name in self.allowed:
196             if isinstance(name, list):
197                 key = name[0]
198                 value = name[1]
199                 if key in attributes:
200                     attr = attributes[key]
201                     if value == '*':
202                         filtered[key] = attr
203                     elif isinstance(attr, dict):
204                         if key not in filtered:
205                             filtered[key] = dict()
206                         if value in attr:
207                             filtered[key][value] = attr[value]
208                     elif isinstance(attr, list):
209                         if key not in filtered:
210                             filtered[key] = list()
211                         if value in attr:
212                             filtered[key].append(value)
213                     else:
214                         continue
215             else:
216                 if name in attributes:
217                     filtered[name] = attributes[name]
218                 elif name == '*':
219                     filtered = attributes
220
221         if whitelist:
222             allowed = filtered
223         else:
224             # filtered contains the blacklisted
225             allowed = copy.deepcopy(attributes)
226             for lvl1 in filtered:
227                 attr = filtered[lvl1]
228                 if isinstance(attr, dict):
229                     for lvl2 in attr:
230                         del allowed[lvl1][lvl2]
231                 elif isinstance(attr, list):
232                     for lvl2 in attr:
233                         allowed[lvl1].remove(lvl2)
234                 else:
235                     allowed[lvl1] = {}
236                 if len(allowed[lvl1]) == 0:
237                     del allowed[lvl1]
238
239         return allowed
240
241 # Unit tests
242 if __name__ == '__main__':
243
244     ret = 0
245
246     # Policy
247     t_attributes = {'onenameone': 'onevalueone',
248                     'onenametwo': 'onevaluetwo',
249                     'two': {'twonameone': 'twovalueone',
250                             'twonametwo': 'twovaluetwo'},
251                     'three': {'threenameone': 'threevalueone',
252                               'threenametwo': 'threevaluetwo'},
253                     'four': {'fournameone': 'fourvalueone',
254                              'fournametwo': 'fourvaluetwo'},
255                     'five': ['one', 'two', 'three'],
256                     'six': ['one', 'two', 'three']}
257
258     # test defaults first
259     p = Policy()
260
261     print 'Default attribute mapping'
262     m, n = p.map_attributes(t_attributes)
263     if m == t_attributes and n is None:
264         print 'SUCCESS'
265     else:
266         ret += 1
267         print 'FAIL: Expected %s\nObtained %s' % (t_attributes, m)
268
269     print 'Default attribute filtering'
270     f = p.filter_attributes(t_attributes)
271     if f == t_attributes:
272         print 'SUCCESS'
273     else:
274         ret += 1
275         print 'Expected %s\nObtained %s' % (t_attributes, f)
276
277     # test custom mappings and filters
278     t_mappings = [[['onenameone'], 'onemappedone'],
279                   [['onenametwo'], 'onemappedtwo'],
280                   [['two', '*'], '*'],
281                   [['three', 'threenameone'], 'threemappedone'],
282                   [['three', 'threenameone'], 'threemappedbis'],
283                   [['four', '*'], ['four', '*']],
284                   [['five'], 'listfive'],
285                   [['six', 'one'], ['six', 'mapone']]]
286
287     m_result = {'onemappedone': 'onevalueone',
288                 'onemappedtwo': 'onevaluetwo',
289                 'twonameone': 'twovalueone',
290                 'twonametwo': 'twovaluetwo',
291                 'threemappedone': 'threevalueone',
292                 'threemappedbis': 'threevalueone',
293                 'four': {'fournameone': 'fourvalueone',
294                          'fournametwo': 'fourvaluetwo'},
295                 'listfive': ['one', 'two', 'three'],
296                 'six': ['mapone']}
297
298     n_result = {'three': {'threenametwo': 'threevaluetwo'},
299                 'six': ['two', 'three']}
300
301     t_allowed = ['twonameone',
302                  ['four', 'fournametwo'],
303                  ['listfive', 'three'],
304                  ['six', '*']]
305
306     f_result = {'twonameone': 'twovalueone',
307                 'four': {'fournametwo': 'fourvaluetwo'},
308                 'listfive': ['three'],
309                 'six': ['mapone']}
310
311     p = Policy(t_mappings, t_allowed)
312
313     print 'Custom attribute mapping'
314     m, n = p.map_attributes(t_attributes)
315     if m == m_result and n == n_result:
316         print 'SUCCESS'
317     else:
318         ret += 1
319         print 'Expected %s\nObtained %s' % (m_result, m)
320
321     print 'Custom attribute filtering'
322     f = p.filter_attributes(m)
323     if f == f_result:
324         print 'SUCCESS'
325     else:
326         ret += 1
327         print 'Expected %s\nObtained %s' % (f_result, f)
328
329     t2_allowed = ['onemappedone', 'twonametwo', 'threemappedone',
330                   ['listfive', 'two']]
331
332     f2_result = {'onemappedtwo': 'onevaluetwo',
333                  'twonameone': 'twovalueone',
334                  'threemappedbis': 'threevalueone',
335                  'four': {'fournameone': 'fourvalueone',
336                           'fournametwo': 'fourvaluetwo'},
337                  'listfive': ['one', 'three'],
338                  'six': ['mapone']}
339
340     p = Policy(t_mappings, t2_allowed)
341
342     print 'Custom attribute filtering 2'
343     m, _ = p.map_attributes(t_attributes)
344     f = p.filter_attributes(m, whitelist=False)
345     if f == f2_result:
346         print 'SUCCESS'
347     else:
348         ret += 1
349         print 'Expected %s\nObtained %s' % (f2_result, f)
350
351     # Case Insensitive matching
352     tci_attributes = {'oneNameone': 'onevalueone',
353                       'onenamEtwo': 'onevaluetwo',
354                       'Two': {'twonameone': 'twovalueone',
355                               'twonameTwo': 'twovaluetwo'},
356                       'thrEE': {'threeNAMEone': 'threevalueone',
357                                 'thrEEnametwo': 'threevaluetwo'},
358                       'foUr': {'fournameone': 'fourvalueone',
359                                'fournametwo': 'fourvaluetwo'},
360                       'FIVE': ['one', 'two', 'three'],
361                       'six': ['ONE', 'two', 'three']}
362
363     tci_mappings = [[['onenameone'], 'onemappedone'],
364                     [['onenametwo'], 'onemappedtwo'],
365                     [['two', '*'], '*'],
366                     [['three', 'threenameone'], 'threemappedone'],
367                     [['three', 'threenameone'], 'threemappedbis'],
368                     [['four', '*'], ['Four', '*']],
369                     [['five'], 'listfive'],
370                     [['six', 'one'], ['six', 'mapone']]]
371
372     mci_result = {'onemappedone': 'onevalueone',
373                   'onemappedtwo': 'onevaluetwo',
374                   'twonameone': 'twovalueone',
375                   'twonameTwo': 'twovaluetwo',
376                   'threemappedone': 'threevalueone',
377                   'threemappedbis': 'threevalueone',
378                   'Four': {'fournameone': 'fourvalueone',
379                            'fournametwo': 'fourvaluetwo'},
380                   'listfive': ['one', 'two', 'three'],
381                   'six': ['mapone']}
382
383     nci_result = {'thrEE': {'thrEEnametwo': 'threevaluetwo'},
384                   'six': ['two', 'three']}
385
386     p = Policy(tci_mappings)
387     print 'Case insensitive attribute mapping'
388     m, n = p.map_attributes(tci_attributes, ignore_case=True)
389     if m == mci_result and n == nci_result:
390         print 'SUCCESS'
391     else:
392         ret += 1
393         print 'FAIL: Expected %s // %s\nObtained %s // %s' % \
394             (mci_result, nci_result, m, n)
395
396     sys.exit(ret)