Merge remote-tracking branch 'asoc/topic/mtk' into asoc-next
[cascardo/linux.git] / net / netfilter / ipset / ip_set_list_set.c
1 /* Copyright (C) 2008-2013 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
2  *
3  * This program is free software; you can redistribute it and/or modify
4  * it under the terms of the GNU General Public License version 2 as
5  * published by the Free Software Foundation.
6  */
7
8 /* Kernel module implementing an IP set type: the list:set type */
9
10 #include <linux/module.h>
11 #include <linux/ip.h>
12 #include <linux/rculist.h>
13 #include <linux/skbuff.h>
14 #include <linux/errno.h>
15
16 #include <linux/netfilter/ipset/ip_set.h>
17 #include <linux/netfilter/ipset/ip_set_list.h>
18
19 #define IPSET_TYPE_REV_MIN      0
20 /*                              1    Counters support added */
21 /*                              2    Comments support added */
22 #define IPSET_TYPE_REV_MAX      3 /* skbinfo support added */
23
24 MODULE_LICENSE("GPL");
25 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
26 IP_SET_MODULE_DESC("list:set", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX);
27 MODULE_ALIAS("ip_set_list:set");
28
29 /* Member elements  */
30 struct set_elem {
31         struct rcu_head rcu;
32         struct list_head list;
33         struct ip_set *set;     /* Sigh, in order to cleanup reference */
34         ip_set_id_t id;
35 } __aligned(__alignof__(u64));
36
37 struct set_adt_elem {
38         ip_set_id_t id;
39         ip_set_id_t refid;
40         int before;
41 };
42
43 /* Type structure */
44 struct list_set {
45         u32 size;               /* size of set list array */
46         struct timer_list gc;   /* garbage collection */
47         struct net *net;        /* namespace */
48         struct list_head members; /* the set members */
49 };
50
51 static int
52 list_set_ktest(struct ip_set *set, const struct sk_buff *skb,
53                const struct xt_action_param *par,
54                struct ip_set_adt_opt *opt, const struct ip_set_ext *ext)
55 {
56         struct list_set *map = set->data;
57         struct set_elem *e;
58         u32 cmdflags = opt->cmdflags;
59         int ret;
60
61         /* Don't lookup sub-counters at all */
62         opt->cmdflags &= ~IPSET_FLAG_MATCH_COUNTERS;
63         if (opt->cmdflags & IPSET_FLAG_SKIP_SUBCOUNTER_UPDATE)
64                 opt->cmdflags &= ~IPSET_FLAG_SKIP_COUNTER_UPDATE;
65         list_for_each_entry_rcu(e, &map->members, list) {
66                 if (SET_WITH_TIMEOUT(set) &&
67                     ip_set_timeout_expired(ext_timeout(e, set)))
68                         continue;
69                 ret = ip_set_test(e->id, skb, par, opt);
70                 if (ret > 0) {
71                         if (SET_WITH_COUNTER(set))
72                                 ip_set_update_counter(ext_counter(e, set),
73                                                       ext, &opt->ext,
74                                                       cmdflags);
75                         if (SET_WITH_SKBINFO(set))
76                                 ip_set_get_skbinfo(ext_skbinfo(e, set),
77                                                    ext, &opt->ext,
78                                                    cmdflags);
79                         return ret;
80                 }
81         }
82         return 0;
83 }
84
85 static int
86 list_set_kadd(struct ip_set *set, const struct sk_buff *skb,
87               const struct xt_action_param *par,
88               struct ip_set_adt_opt *opt, const struct ip_set_ext *ext)
89 {
90         struct list_set *map = set->data;
91         struct set_elem *e;
92         int ret;
93
94         list_for_each_entry(e, &map->members, list) {
95                 if (SET_WITH_TIMEOUT(set) &&
96                     ip_set_timeout_expired(ext_timeout(e, set)))
97                         continue;
98                 ret = ip_set_add(e->id, skb, par, opt);
99                 if (ret == 0)
100                         return ret;
101         }
102         return 0;
103 }
104
105 static int
106 list_set_kdel(struct ip_set *set, const struct sk_buff *skb,
107               const struct xt_action_param *par,
108               struct ip_set_adt_opt *opt, const struct ip_set_ext *ext)
109 {
110         struct list_set *map = set->data;
111         struct set_elem *e;
112         int ret;
113
114         list_for_each_entry(e, &map->members, list) {
115                 if (SET_WITH_TIMEOUT(set) &&
116                     ip_set_timeout_expired(ext_timeout(e, set)))
117                         continue;
118                 ret = ip_set_del(e->id, skb, par, opt);
119                 if (ret == 0)
120                         return ret;
121         }
122         return 0;
123 }
124
125 static int
126 list_set_kadt(struct ip_set *set, const struct sk_buff *skb,
127               const struct xt_action_param *par,
128               enum ipset_adt adt, struct ip_set_adt_opt *opt)
129 {
130         struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set);
131         int ret = -EINVAL;
132
133         rcu_read_lock();
134         switch (adt) {
135         case IPSET_TEST:
136                 ret = list_set_ktest(set, skb, par, opt, &ext);
137                 break;
138         case IPSET_ADD:
139                 ret = list_set_kadd(set, skb, par, opt, &ext);
140                 break;
141         case IPSET_DEL:
142                 ret = list_set_kdel(set, skb, par, opt, &ext);
143                 break;
144         default:
145                 break;
146         }
147         rcu_read_unlock();
148
149         return ret;
150 }
151
152 /* Userspace interfaces: we are protected by the nfnl mutex */
153
154 static void
155 __list_set_del_rcu(struct rcu_head * rcu)
156 {
157         struct set_elem *e = container_of(rcu, struct set_elem, rcu);
158         struct ip_set *set = e->set;
159         struct list_set *map = set->data;
160
161         ip_set_put_byindex(map->net, e->id);
162         ip_set_ext_destroy(set, e);
163         kfree(e);
164 }
165
166 static inline void
167 list_set_del(struct ip_set *set, struct set_elem *e)
168 {
169         list_del_rcu(&e->list);
170         call_rcu(&e->rcu, __list_set_del_rcu);
171 }
172
173 static inline void
174 list_set_replace(struct set_elem *e, struct set_elem *old)
175 {
176         list_replace_rcu(&old->list, &e->list);
177         call_rcu(&old->rcu, __list_set_del_rcu);
178 }
179
180 static void
181 set_cleanup_entries(struct ip_set *set)
182 {
183         struct list_set *map = set->data;
184         struct set_elem *e, *n;
185
186         list_for_each_entry_safe(e, n, &map->members, list)
187                 if (ip_set_timeout_expired(ext_timeout(e, set)))
188                         list_set_del(set, e);
189 }
190
191 static int
192 list_set_utest(struct ip_set *set, void *value, const struct ip_set_ext *ext,
193                struct ip_set_ext *mext, u32 flags)
194 {
195         struct list_set *map = set->data;
196         struct set_adt_elem *d = value;
197         struct set_elem *e, *next, *prev = NULL;
198         int ret;
199
200         list_for_each_entry(e, &map->members, list) {
201                 if (SET_WITH_TIMEOUT(set) &&
202                     ip_set_timeout_expired(ext_timeout(e, set)))
203                         continue;
204                 else if (e->id != d->id) {
205                         prev = e;
206                         continue;
207                 }
208
209                 if (d->before == 0) {
210                         ret = 1;
211                 } else if (d->before > 0) {
212                         next = list_next_entry(e, list);
213                         ret = !list_is_last(&e->list, &map->members) &&
214                               next->id == d->refid;
215                 } else {
216                         ret = prev && prev->id == d->refid;
217                 }
218                 return ret;
219         }
220         return 0;
221 }
222
223 static void
224 list_set_init_extensions(struct ip_set *set, const struct ip_set_ext *ext,
225                          struct set_elem *e)
226 {
227         if (SET_WITH_COUNTER(set))
228                 ip_set_init_counter(ext_counter(e, set), ext);
229         if (SET_WITH_COMMENT(set))
230                 ip_set_init_comment(ext_comment(e, set), ext);
231         if (SET_WITH_SKBINFO(set))
232                 ip_set_init_skbinfo(ext_skbinfo(e, set), ext);
233         /* Update timeout last */
234         if (SET_WITH_TIMEOUT(set))
235                 ip_set_timeout_set(ext_timeout(e, set), ext->timeout);
236 }
237
238 static int
239 list_set_uadd(struct ip_set *set, void *value, const struct ip_set_ext *ext,
240               struct ip_set_ext *mext, u32 flags)
241 {
242         struct list_set *map = set->data;
243         struct set_adt_elem *d = value;
244         struct set_elem *e, *n, *prev, *next;
245         bool flag_exist = flags & IPSET_FLAG_EXIST;
246
247         /* Find where to add the new entry */
248         n = prev = next = NULL;
249         list_for_each_entry(e, &map->members, list) {
250                 if (SET_WITH_TIMEOUT(set) &&
251                     ip_set_timeout_expired(ext_timeout(e, set)))
252                         continue;
253                 else if (d->id == e->id)
254                         n = e;
255                 else if (d->before == 0 || e->id != d->refid)
256                         continue;
257                 else if (d->before > 0)
258                         next = e;
259                 else
260                         prev = e;
261         }
262         /* Re-add already existing element */
263         if (n) {
264                 if ((d->before > 0 && !next) ||
265                     (d->before < 0 && !prev))
266                         return -IPSET_ERR_REF_EXIST;
267                 if (!flag_exist)
268                         return -IPSET_ERR_EXIST;
269                 /* Update extensions */
270                 ip_set_ext_destroy(set, n);
271                 list_set_init_extensions(set, ext, n);
272
273                 /* Set is already added to the list */
274                 ip_set_put_byindex(map->net, d->id);
275                 return 0;
276         }
277         /* Add new entry */
278         if (d->before == 0) {
279                 /* Append  */
280                 n = list_empty(&map->members) ? NULL :
281                     list_last_entry(&map->members, struct set_elem, list);
282         } else if (d->before > 0) {
283                 /* Insert after next element */
284                 if (!list_is_last(&next->list, &map->members))
285                         n = list_next_entry(next, list);
286         } else {
287                 /* Insert before prev element */
288                 if (prev->list.prev != &map->members)
289                         n = list_prev_entry(prev, list);
290         }
291         /* Can we replace a timed out entry? */
292         if (n &&
293             !(SET_WITH_TIMEOUT(set) &&
294               ip_set_timeout_expired(ext_timeout(n, set))))
295                 n =  NULL;
296
297         e = kzalloc(set->dsize, GFP_ATOMIC);
298         if (!e)
299                 return -ENOMEM;
300         e->id = d->id;
301         e->set = set;
302         INIT_LIST_HEAD(&e->list);
303         list_set_init_extensions(set, ext, e);
304         if (n)
305                 list_set_replace(e, n);
306         else if (next)
307                 list_add_tail_rcu(&e->list, &next->list);
308         else if (prev)
309                 list_add_rcu(&e->list, &prev->list);
310         else
311                 list_add_tail_rcu(&e->list, &map->members);
312
313         return 0;
314 }
315
316 static int
317 list_set_udel(struct ip_set *set, void *value, const struct ip_set_ext *ext,
318               struct ip_set_ext *mext, u32 flags)
319 {
320         struct list_set *map = set->data;
321         struct set_adt_elem *d = value;
322         struct set_elem *e, *next, *prev = NULL;
323
324         list_for_each_entry(e, &map->members, list) {
325                 if (SET_WITH_TIMEOUT(set) &&
326                     ip_set_timeout_expired(ext_timeout(e, set)))
327                         continue;
328                 else if (e->id != d->id) {
329                         prev = e;
330                         continue;
331                 }
332
333                 if (d->before > 0) {
334                         next = list_next_entry(e, list);
335                         if (list_is_last(&e->list, &map->members) ||
336                             next->id != d->refid)
337                                 return -IPSET_ERR_REF_EXIST;
338                 } else if (d->before < 0) {
339                         if (!prev || prev->id != d->refid)
340                                 return -IPSET_ERR_REF_EXIST;
341                 }
342                 list_set_del(set, e);
343                 return 0;
344         }
345         return d->before != 0 ? -IPSET_ERR_REF_EXIST : -IPSET_ERR_EXIST;
346 }
347
348 static int
349 list_set_uadt(struct ip_set *set, struct nlattr *tb[],
350               enum ipset_adt adt, u32 *lineno, u32 flags, bool retried)
351 {
352         struct list_set *map = set->data;
353         ipset_adtfn adtfn = set->variant->adt[adt];
354         struct set_adt_elem e = { .refid = IPSET_INVALID_ID };
355         struct ip_set_ext ext = IP_SET_INIT_UEXT(set);
356         struct ip_set *s;
357         int ret = 0;
358
359         if (tb[IPSET_ATTR_LINENO])
360                 *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]);
361
362         if (unlikely(!tb[IPSET_ATTR_NAME] ||
363                      !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS)))
364                 return -IPSET_ERR_PROTOCOL;
365
366         ret = ip_set_get_extensions(set, tb, &ext);
367         if (ret)
368                 return ret;
369         e.id = ip_set_get_byname(map->net, nla_data(tb[IPSET_ATTR_NAME]), &s);
370         if (e.id == IPSET_INVALID_ID)
371                 return -IPSET_ERR_NAME;
372         /* "Loop detection" */
373         if (s->type->features & IPSET_TYPE_NAME) {
374                 ret = -IPSET_ERR_LOOP;
375                 goto finish;
376         }
377
378         if (tb[IPSET_ATTR_CADT_FLAGS]) {
379                 u32 f = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
380
381                 e.before = f & IPSET_FLAG_BEFORE;
382         }
383
384         if (e.before && !tb[IPSET_ATTR_NAMEREF]) {
385                 ret = -IPSET_ERR_BEFORE;
386                 goto finish;
387         }
388
389         if (tb[IPSET_ATTR_NAMEREF]) {
390                 e.refid = ip_set_get_byname(map->net,
391                                             nla_data(tb[IPSET_ATTR_NAMEREF]),
392                                             &s);
393                 if (e.refid == IPSET_INVALID_ID) {
394                         ret = -IPSET_ERR_NAMEREF;
395                         goto finish;
396                 }
397                 if (!e.before)
398                         e.before = -1;
399         }
400         if (adt != IPSET_TEST && SET_WITH_TIMEOUT(set))
401                 set_cleanup_entries(set);
402
403         ret = adtfn(set, &e, &ext, &ext, flags);
404
405 finish:
406         if (e.refid != IPSET_INVALID_ID)
407                 ip_set_put_byindex(map->net, e.refid);
408         if (adt != IPSET_ADD || ret)
409                 ip_set_put_byindex(map->net, e.id);
410
411         return ip_set_eexist(ret, flags) ? 0 : ret;
412 }
413
414 static void
415 list_set_flush(struct ip_set *set)
416 {
417         struct list_set *map = set->data;
418         struct set_elem *e, *n;
419
420         list_for_each_entry_safe(e, n, &map->members, list)
421                 list_set_del(set, e);
422 }
423
424 static void
425 list_set_destroy(struct ip_set *set)
426 {
427         struct list_set *map = set->data;
428         struct set_elem *e, *n;
429
430         if (SET_WITH_TIMEOUT(set))
431                 del_timer_sync(&map->gc);
432
433         list_for_each_entry_safe(e, n, &map->members, list) {
434                 list_del(&e->list);
435                 ip_set_put_byindex(map->net, e->id);
436                 ip_set_ext_destroy(set, e);
437                 kfree(e);
438         }
439         kfree(map);
440
441         set->data = NULL;
442 }
443
444 static int
445 list_set_head(struct ip_set *set, struct sk_buff *skb)
446 {
447         const struct list_set *map = set->data;
448         struct nlattr *nested;
449         struct set_elem *e;
450         u32 n = 0;
451
452         rcu_read_lock();
453         list_for_each_entry_rcu(e, &map->members, list)
454                 n++;
455         rcu_read_unlock();
456
457         nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
458         if (!nested)
459                 goto nla_put_failure;
460         if (nla_put_net32(skb, IPSET_ATTR_SIZE, htonl(map->size)) ||
461             nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref)) ||
462             nla_put_net32(skb, IPSET_ATTR_MEMSIZE,
463                           htonl(sizeof(*map) + n * set->dsize)))
464                 goto nla_put_failure;
465         if (unlikely(ip_set_put_flags(skb, set)))
466                 goto nla_put_failure;
467         ipset_nest_end(skb, nested);
468
469         return 0;
470 nla_put_failure:
471         return -EMSGSIZE;
472 }
473
474 static int
475 list_set_list(const struct ip_set *set,
476               struct sk_buff *skb, struct netlink_callback *cb)
477 {
478         const struct list_set *map = set->data;
479         struct nlattr *atd, *nested;
480         u32 i = 0, first = cb->args[IPSET_CB_ARG0];
481         struct set_elem *e;
482         int ret = 0;
483
484         atd = ipset_nest_start(skb, IPSET_ATTR_ADT);
485         if (!atd)
486                 return -EMSGSIZE;
487
488         rcu_read_lock();
489         list_for_each_entry_rcu(e, &map->members, list) {
490                 if (i < first ||
491                     (SET_WITH_TIMEOUT(set) &&
492                      ip_set_timeout_expired(ext_timeout(e, set)))) {
493                         i++;
494                         continue;
495                 }
496                 nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
497                 if (!nested)
498                         goto nla_put_failure;
499                 if (nla_put_string(skb, IPSET_ATTR_NAME,
500                                    ip_set_name_byindex(map->net, e->id)))
501                         goto nla_put_failure;
502                 if (ip_set_put_extensions(skb, set, e, true))
503                         goto nla_put_failure;
504                 ipset_nest_end(skb, nested);
505                 i++;
506         }
507
508         ipset_nest_end(skb, atd);
509         /* Set listing finished */
510         cb->args[IPSET_CB_ARG0] = 0;
511         goto out;
512
513 nla_put_failure:
514         nla_nest_cancel(skb, nested);
515         if (unlikely(i == first)) {
516                 nla_nest_cancel(skb, atd);
517                 cb->args[IPSET_CB_ARG0] = 0;
518                 ret = -EMSGSIZE;
519         } else {
520                 cb->args[IPSET_CB_ARG0] = i;
521         }
522         ipset_nest_end(skb, atd);
523 out:
524         rcu_read_unlock();
525         return ret;
526 }
527
528 static bool
529 list_set_same_set(const struct ip_set *a, const struct ip_set *b)
530 {
531         const struct list_set *x = a->data;
532         const struct list_set *y = b->data;
533
534         return x->size == y->size &&
535                a->timeout == b->timeout &&
536                a->extensions == b->extensions;
537 }
538
539 static const struct ip_set_type_variant set_variant = {
540         .kadt   = list_set_kadt,
541         .uadt   = list_set_uadt,
542         .adt    = {
543                 [IPSET_ADD] = list_set_uadd,
544                 [IPSET_DEL] = list_set_udel,
545                 [IPSET_TEST] = list_set_utest,
546         },
547         .destroy = list_set_destroy,
548         .flush  = list_set_flush,
549         .head   = list_set_head,
550         .list   = list_set_list,
551         .same_set = list_set_same_set,
552 };
553
554 static void
555 list_set_gc(unsigned long ul_set)
556 {
557         struct ip_set *set = (struct ip_set *)ul_set;
558         struct list_set *map = set->data;
559
560         spin_lock_bh(&set->lock);
561         set_cleanup_entries(set);
562         spin_unlock_bh(&set->lock);
563
564         map->gc.expires = jiffies + IPSET_GC_PERIOD(set->timeout) * HZ;
565         add_timer(&map->gc);
566 }
567
568 static void
569 list_set_gc_init(struct ip_set *set, void (*gc)(unsigned long ul_set))
570 {
571         struct list_set *map = set->data;
572
573         init_timer(&map->gc);
574         map->gc.data = (unsigned long)set;
575         map->gc.function = gc;
576         map->gc.expires = jiffies + IPSET_GC_PERIOD(set->timeout) * HZ;
577         add_timer(&map->gc);
578 }
579
580 /* Create list:set type of sets */
581
582 static bool
583 init_list_set(struct net *net, struct ip_set *set, u32 size)
584 {
585         struct list_set *map;
586
587         map = kzalloc(sizeof(*map), GFP_KERNEL);
588         if (!map)
589                 return false;
590
591         map->size = size;
592         map->net = net;
593         INIT_LIST_HEAD(&map->members);
594         set->data = map;
595
596         return true;
597 }
598
599 static int
600 list_set_create(struct net *net, struct ip_set *set, struct nlattr *tb[],
601                 u32 flags)
602 {
603         u32 size = IP_SET_LIST_DEFAULT_SIZE;
604
605         if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_SIZE) ||
606                      !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) ||
607                      !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS)))
608                 return -IPSET_ERR_PROTOCOL;
609
610         if (tb[IPSET_ATTR_SIZE])
611                 size = ip_set_get_h32(tb[IPSET_ATTR_SIZE]);
612         if (size < IP_SET_LIST_MIN_SIZE)
613                 size = IP_SET_LIST_MIN_SIZE;
614
615         set->variant = &set_variant;
616         set->dsize = ip_set_elem_len(set, tb, sizeof(struct set_elem),
617                                      __alignof__(struct set_elem));
618         if (!init_list_set(net, set, size))
619                 return -ENOMEM;
620         if (tb[IPSET_ATTR_TIMEOUT]) {
621                 set->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
622                 list_set_gc_init(set, list_set_gc);
623         }
624         return 0;
625 }
626
627 static struct ip_set_type list_set_type __read_mostly = {
628         .name           = "list:set",
629         .protocol       = IPSET_PROTOCOL,
630         .features       = IPSET_TYPE_NAME | IPSET_DUMP_LAST,
631         .dimension      = IPSET_DIM_ONE,
632         .family         = NFPROTO_UNSPEC,
633         .revision_min   = IPSET_TYPE_REV_MIN,
634         .revision_max   = IPSET_TYPE_REV_MAX,
635         .create         = list_set_create,
636         .create_policy  = {
637                 [IPSET_ATTR_SIZE]       = { .type = NLA_U32 },
638                 [IPSET_ATTR_TIMEOUT]    = { .type = NLA_U32 },
639                 [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 },
640         },
641         .adt_policy     = {
642                 [IPSET_ATTR_NAME]       = { .type = NLA_STRING,
643                                             .len = IPSET_MAXNAMELEN },
644                 [IPSET_ATTR_NAMEREF]    = { .type = NLA_STRING,
645                                             .len = IPSET_MAXNAMELEN },
646                 [IPSET_ATTR_TIMEOUT]    = { .type = NLA_U32 },
647                 [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
648                 [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 },
649                 [IPSET_ATTR_BYTES]      = { .type = NLA_U64 },
650                 [IPSET_ATTR_PACKETS]    = { .type = NLA_U64 },
651                 [IPSET_ATTR_COMMENT]    = { .type = NLA_NUL_STRING,
652                                             .len  = IPSET_MAX_COMMENT_SIZE },
653                 [IPSET_ATTR_SKBMARK]    = { .type = NLA_U64 },
654                 [IPSET_ATTR_SKBPRIO]    = { .type = NLA_U32 },
655                 [IPSET_ATTR_SKBQUEUE]   = { .type = NLA_U16 },
656         },
657         .me             = THIS_MODULE,
658 };
659
660 static int __init
661 list_set_init(void)
662 {
663         return ip_set_type_register(&list_set_type);
664 }
665
666 static void __exit
667 list_set_fini(void)
668 {
669         rcu_barrier();
670         ip_set_type_unregister(&list_set_type);
671 }
672
673 module_init(list_set_init);
674 module_exit(list_set_fini);