17363aa424ee5ad77b7f7550cae4c59ae0a2e888
[cascardo/ipsilon.git] / ipsilon / util / policy.py
1 # Copyright (C) 2015  Ipsilon project Contributors, for licensee see COPYING
2
3 from ipsilon.util.log import Log
4 import copy
5 import sys
6
7
8 class Policy(Log):
9
10     def __init__(self, mappings=None, allowed=None):
11         """ A Policy engine to filter attributes.
12         Mappings is a list of lists where the first value ia a list itself
13         and the second value is an attribute name or a list if the values
14         should go in a sub dictionary.
15         Note that mappings is a list and not a dictionary as this allows
16         to map the same original attribute to different resulting attributes
17         if wanted, by simply repeating the 'key list' with different values
18         or 'value lists'.
19
20             Example: [[['extras', 'shoes'], 'shoeNumber']]
21
22         A '*' can be used to allow any attribute.
23
24         The default mapping is [[['*'], '*']]
25         This copies all attributes without transformation.
26
27         Allowed is a list of allowed attributes.
28         Normally mapping should be called before filtering, this means
29         allowed attributes should name the mapped attributes.
30         Allowed attributes can be multi-element lists
31
32             Example: ['fullname', ['groups', 'domain users']]
33
34         Allowed is '*' by default.
35         """
36
37         self.mappings = None
38         if mappings:
39             if not isinstance(mappings, list):
40                 raise ValueError("Mappings should be a list not '%s'" %
41                                  type(mappings))
42             for el in mappings:
43                 if not isinstance(el, list):
44                     raise ValueError("Mappings must be lists, not '%s'" %
45                                      type(el))
46                 if len(el) != 2:
47                     raise ValueError("Mappings must contain 2 elements, "
48                                      "found %d" % len(el))
49                 if isinstance(el[0], list) and len(el[0]) > 2:
50                     raise ValueError("1st Mapping element can contain at "
51                                      "most 2 values, found %d" % len(el[0]))
52                 if isinstance(el[1], list) and len(el[1]) > 2:
53                     raise ValueError("2nd Mapping element can contain at "
54                                      "most 2 values, found %d" % len(el[1]))
55             self.mappings = mappings
56         else:
57             # default mapping, return all userdata and groups
58             # but ignore extras
59             self.mappings = [['*', '*']]
60
61         self.allowed = ['*']
62         if allowed:
63             if not isinstance(allowed, list):
64                 raise ValueError("Allowed should be a list not '%s'" %
65                                  type(allowed))
66             self.allowed = allowed
67
68     def map_attributes(self, attributes):
69
70         if not isinstance(attributes, dict):
71             raise ValueError("Attributes must be dictionary, not %s" %
72                              type(attributes))
73
74         not_mapped = copy.deepcopy(attributes)
75         mapped = dict()
76
77         for (key, value) in self.mappings:
78             if not isinstance(key, list):
79                 key = [key]
80             if len(key) == 2:
81                 prefix = key[0]
82                 name = key[1]
83             else:
84                 prefix = None
85                 name = key[0]
86
87             if not isinstance(value, list):
88                 value = [value]
89             if len(value) == 2:
90                 mapprefix = value[0]
91                 mapname = value[1]
92             else:
93                 mapprefix = None
94                 mapname = value[0]
95
96             if prefix:
97                 if prefix in attributes:
98                     attr = attributes[prefix]
99                 else:
100                     # '*' in a prefix matches nothing
101                     continue
102             else:
103                 attr = attributes
104
105             if name in attr:
106                 if isinstance(attr, list):
107                     if mapprefix:
108                         if mapprefix not in mapped:
109                             mapped[mapprefix] = list()
110                         mapped[mapprefix].append(mapname)
111                         if not_mapped:
112                             if prefix in not_mapped:
113                                 while name in not_mapped[prefix]:
114                                     not_mapped[prefix].remove(name)
115                     else:
116                         if mapname not in mapped:
117                             mapped[mapname] = list()
118                         mapped[mapname].append(attr[name])
119                         if not_mapped:
120                             if prefix in not_mapped:
121                                 del not_mapped[prefix]
122                 else:
123                     mapin = copy.deepcopy(attr[name])
124                     if mapname == '*':
125                         mapname = name
126                     if mapprefix:
127                         if mapprefix not in mapped:
128                             mapped[mapprefix] = dict()
129                         mapped[mapprefix].update({mapname: mapin})
130                     else:
131                         mapped.update({mapname: mapin})
132                     if not_mapped:
133                         if prefix:
134                             if prefix in not_mapped:
135                                 if name in not_mapped[prefix]:
136                                     del not_mapped[prefix][name]
137                         elif name in not_mapped:
138                             del not_mapped[name]
139             elif name == '*':
140                 mapin = copy.deepcopy(attr)
141                 # mapname is ignored if name == '*'
142                 if mapprefix:
143                     if mapprefix not in mapped:
144                         mapped[mapprefix] = mapin
145                     else:
146                         mapped[mapprefix].update(mapin)
147                 else:
148                     mapped.update(mapin)
149                 if not_mapped:
150                     if prefix in not_mapped:
151                         del not_mapped[prefix]
152                     else:
153                         not_mapped = None
154             else:
155                 continue
156
157         return mapped, not_mapped
158
159     def filter_attributes(self, attributes, whitelist=True):
160
161         filtered = dict()
162
163         for name in self.allowed:
164             if isinstance(name, list):
165                 key = name[0]
166                 value = name[1]
167                 if key in attributes:
168                     attr = attributes[key]
169                     if value == '*':
170                         filtered[key] = attr
171                     elif isinstance(attr, dict):
172                         if key not in filtered:
173                             filtered[key] = dict()
174                         if value in attr:
175                             filtered[key][value] = attr[value]
176                     elif isinstance(attr, list):
177                         if key not in filtered:
178                             filtered[key] = list()
179                         if value in attr:
180                             filtered[key].append(value)
181                     else:
182                         continue
183             else:
184                 if name in attributes:
185                     filtered[name] = attributes[name]
186                 elif name == '*':
187                     filtered = attributes
188
189         if whitelist:
190             allowed = filtered
191         else:
192             # filtered contains the blacklisted
193             allowed = copy.deepcopy(attributes)
194             for lvl1 in filtered:
195                 attr = filtered[lvl1]
196                 if isinstance(attr, dict):
197                     for lvl2 in attr:
198                         del allowed[lvl1][lvl2]
199                 elif isinstance(attr, list):
200                     for lvl2 in attr:
201                         allowed[lvl1].remove(lvl2)
202                 else:
203                     allowed[lvl1] = {}
204                 if len(allowed[lvl1]) == 0:
205                     del allowed[lvl1]
206
207         return allowed
208
209 # Unit tests
210 if __name__ == '__main__':
211
212     ret = 0
213
214     # Policy
215     t_attributes = {'onenameone': 'onevalueone',
216                     'onenametwo': 'onevaluetwo',
217                     'two': {'twonameone': 'twovalueone',
218                             'twonametwo': 'twovaluetwo'},
219                     'three': {'threenameone': 'threevalueone',
220                               'threenametwo': 'threevaluetwo'},
221                     'four': {'fournameone': 'fourvalueone',
222                              'fournametwo': 'fourvaluetwo'},
223                     'five': ['one', 'two', 'three'],
224                     'six': ['one', 'two', 'three']}
225
226     # test defaults first
227     p = Policy()
228
229     print 'Default attribute mapping'
230     m, n = p.map_attributes(t_attributes)
231     if m == t_attributes and n is None:
232         print 'SUCCESS'
233     else:
234         ret += 1
235         print 'FAIL: Expected %s\nObtained %s' % (t_attributes, m)
236
237     print 'Default attribute filtering'
238     f = p.filter_attributes(t_attributes)
239     if f == t_attributes:
240         print 'SUCCESS'
241     else:
242         ret += 1
243         print 'Expected %s\nObtained %s' % (t_attributes, f)
244
245     # test custom mappings and filters
246     t_mappings = [[['onenameone'], 'onemappedone'],
247                   [['onenametwo'], 'onemappedtwo'],
248                   [['two', '*'], '*'],
249                   [['three', 'threenameone'], 'threemappedone'],
250                   [['three', 'threenameone'], 'threemappedbis'],
251                   [['four', '*'], ['four', '*']],
252                   [['five'], 'listfive'],
253                   [['six', 'one'], ['six', 'mapone']]]
254
255     m_result = {'onemappedone': 'onevalueone',
256                 'onemappedtwo': 'onevaluetwo',
257                 'twonameone': 'twovalueone',
258                 'twonametwo': 'twovaluetwo',
259                 'threemappedone': 'threevalueone',
260                 'threemappedbis': 'threevalueone',
261                 'four': {'fournameone': 'fourvalueone',
262                          'fournametwo': 'fourvaluetwo'},
263                 'listfive': ['one', 'two', 'three'],
264                 'six': ['mapone']}
265
266     n_result = {'three': {'threenametwo': 'threevaluetwo'},
267                 'six': ['two', 'three']}
268
269     t_allowed = ['twonameone',
270                  ['four', 'fournametwo'],
271                  ['listfive', 'three'],
272                  ['six', '*']]
273
274     f_result = {'twonameone': 'twovalueone',
275                 'four': {'fournametwo': 'fourvaluetwo'},
276                 'listfive': ['three'],
277                 'six': ['mapone']}
278
279     p = Policy(t_mappings, t_allowed)
280
281     print 'Custom attribute mapping'
282     m, n = p.map_attributes(t_attributes)
283     if m == m_result and n == n_result:
284         print 'SUCCESS'
285     else:
286         ret += 1
287         print 'Expected %s\nObtained %s' % (m_result, m)
288
289     print 'Custom attribute filtering'
290     f = p.filter_attributes(m)
291     if f == f_result:
292         print 'SUCCESS'
293     else:
294         ret += 1
295         print 'Expected %s\nObtained %s' % (f_result, f)
296
297     t2_allowed = ['onemappedone', 'twonametwo', 'threemappedone',
298                   ['listfive', 'two']]
299
300     f2_result = {'onemappedtwo': 'onevaluetwo',
301                  'twonameone': 'twovalueone',
302                  'threemappedbis': 'threevalueone',
303                  'four': {'fournameone': 'fourvalueone',
304                           'fournametwo': 'fourvaluetwo'},
305                  'listfive': ['one', 'three'],
306                  'six': ['mapone']}
307
308     p = Policy(t_mappings, t2_allowed)
309
310     print 'Custom attribute filtering 2'
311     m, _ = p.map_attributes(t_attributes)
312     f = p.filter_attributes(m, whitelist=False)
313     if f == f2_result:
314         print 'SUCCESS'
315     else:
316         ret += 1
317         print 'Expected %s\nObtained %s' % (f2_result, f)
318
319     sys.exit(ret)