net: sched: rcu'ify cls_bpf
[cascardo/linux.git] / net / sched / cls_bpf.c
1 /*
2  * Berkeley Packet Filter based traffic classifier
3  *
4  * Might be used to classify traffic through flexible, user-defined and
5  * possibly JIT-ed BPF filters for traffic control as an alternative to
6  * ematches.
7  *
8  * (C) 2013 Daniel Borkmann <dborkman@redhat.com>
9  *
10  * This program is free software; you can redistribute it and/or modify
11  * it under the terms of the GNU General Public License version 2 as
12  * published by the Free Software Foundation.
13  */
14
15 #include <linux/module.h>
16 #include <linux/types.h>
17 #include <linux/skbuff.h>
18 #include <linux/filter.h>
19 #include <net/rtnetlink.h>
20 #include <net/pkt_cls.h>
21 #include <net/sock.h>
22
23 MODULE_LICENSE("GPL");
24 MODULE_AUTHOR("Daniel Borkmann <dborkman@redhat.com>");
25 MODULE_DESCRIPTION("TC BPF based classifier");
26
27 struct cls_bpf_head {
28         struct list_head plist;
29         u32 hgen;
30         struct rcu_head rcu;
31 };
32
33 struct cls_bpf_prog {
34         struct bpf_prog *filter;
35         struct sock_filter *bpf_ops;
36         struct tcf_exts exts;
37         struct tcf_result res;
38         struct list_head link;
39         u32 handle;
40         u16 bpf_len;
41         struct tcf_proto *tp;
42         struct rcu_head rcu;
43 };
44
45 static const struct nla_policy bpf_policy[TCA_BPF_MAX + 1] = {
46         [TCA_BPF_CLASSID]       = { .type = NLA_U32 },
47         [TCA_BPF_OPS_LEN]       = { .type = NLA_U16 },
48         [TCA_BPF_OPS]           = { .type = NLA_BINARY,
49                                     .len = sizeof(struct sock_filter) * BPF_MAXINSNS },
50 };
51
52 static int cls_bpf_classify(struct sk_buff *skb, const struct tcf_proto *tp,
53                             struct tcf_result *res)
54 {
55         struct cls_bpf_head *head = rcu_dereference(tp->root);
56         struct cls_bpf_prog *prog;
57         int ret;
58
59         list_for_each_entry_rcu(prog, &head->plist, link) {
60                 int filter_res = BPF_PROG_RUN(prog->filter, skb);
61
62                 if (filter_res == 0)
63                         continue;
64
65                 *res = prog->res;
66                 if (filter_res != -1)
67                         res->classid = filter_res;
68
69                 ret = tcf_exts_exec(skb, &prog->exts, res);
70                 if (ret < 0)
71                         continue;
72
73                 return ret;
74         }
75
76         return -1;
77 }
78
79 static int cls_bpf_init(struct tcf_proto *tp)
80 {
81         struct cls_bpf_head *head;
82
83         head = kzalloc(sizeof(*head), GFP_KERNEL);
84         if (head == NULL)
85                 return -ENOBUFS;
86
87         INIT_LIST_HEAD_RCU(&head->plist);
88         rcu_assign_pointer(tp->root, head);
89
90         return 0;
91 }
92
93 static void cls_bpf_delete_prog(struct tcf_proto *tp, struct cls_bpf_prog *prog)
94 {
95         tcf_unbind_filter(tp, &prog->res);
96         tcf_exts_destroy(tp, &prog->exts);
97
98         bpf_prog_destroy(prog->filter);
99
100         kfree(prog->bpf_ops);
101         kfree(prog);
102 }
103
104 static void __cls_bpf_delete_prog(struct rcu_head *rcu)
105 {
106         struct cls_bpf_prog *prog = container_of(rcu, struct cls_bpf_prog, rcu);
107
108         cls_bpf_delete_prog(prog->tp, prog);
109 }
110
111 static int cls_bpf_delete(struct tcf_proto *tp, unsigned long arg)
112 {
113         struct cls_bpf_head *head = rtnl_dereference(tp->root);
114         struct cls_bpf_prog *prog, *todel = (struct cls_bpf_prog *) arg;
115
116         list_for_each_entry(prog, &head->plist, link) {
117                 if (prog == todel) {
118                         list_del_rcu(&prog->link);
119                         call_rcu(&prog->rcu, __cls_bpf_delete_prog);
120                         return 0;
121                 }
122         }
123
124         return -ENOENT;
125 }
126
127 static void cls_bpf_destroy(struct tcf_proto *tp)
128 {
129         struct cls_bpf_head *head = rtnl_dereference(tp->root);
130         struct cls_bpf_prog *prog, *tmp;
131
132         list_for_each_entry_safe(prog, tmp, &head->plist, link) {
133                 list_del_rcu(&prog->link);
134                 call_rcu(&prog->rcu, __cls_bpf_delete_prog);
135         }
136
137         RCU_INIT_POINTER(tp->root, NULL);
138         kfree_rcu(head, rcu);
139 }
140
141 static unsigned long cls_bpf_get(struct tcf_proto *tp, u32 handle)
142 {
143         struct cls_bpf_head *head = rtnl_dereference(tp->root);
144         struct cls_bpf_prog *prog;
145         unsigned long ret = 0UL;
146
147         if (head == NULL)
148                 return 0UL;
149
150         list_for_each_entry_rcu(prog, &head->plist, link) {
151                 if (prog->handle == handle) {
152                         ret = (unsigned long) prog;
153                         break;
154                 }
155         }
156
157         return ret;
158 }
159
160 static void cls_bpf_put(struct tcf_proto *tp, unsigned long f)
161 {
162 }
163
164 static int cls_bpf_modify_existing(struct net *net, struct tcf_proto *tp,
165                                    struct cls_bpf_prog *prog,
166                                    unsigned long base, struct nlattr **tb,
167                                    struct nlattr *est, bool ovr)
168 {
169         struct sock_filter *bpf_ops;
170         struct tcf_exts exts;
171         struct sock_fprog_kern tmp;
172         struct bpf_prog *fp;
173         u16 bpf_size, bpf_len;
174         u32 classid;
175         int ret;
176
177         if (!tb[TCA_BPF_OPS_LEN] || !tb[TCA_BPF_OPS] || !tb[TCA_BPF_CLASSID])
178                 return -EINVAL;
179
180         tcf_exts_init(&exts, TCA_BPF_ACT, TCA_BPF_POLICE);
181         ret = tcf_exts_validate(net, tp, tb, est, &exts, ovr);
182         if (ret < 0)
183                 return ret;
184
185         classid = nla_get_u32(tb[TCA_BPF_CLASSID]);
186         bpf_len = nla_get_u16(tb[TCA_BPF_OPS_LEN]);
187         if (bpf_len > BPF_MAXINSNS || bpf_len == 0) {
188                 ret = -EINVAL;
189                 goto errout;
190         }
191
192         bpf_size = bpf_len * sizeof(*bpf_ops);
193         bpf_ops = kzalloc(bpf_size, GFP_KERNEL);
194         if (bpf_ops == NULL) {
195                 ret = -ENOMEM;
196                 goto errout;
197         }
198
199         memcpy(bpf_ops, nla_data(tb[TCA_BPF_OPS]), bpf_size);
200
201         tmp.len = bpf_len;
202         tmp.filter = bpf_ops;
203
204         ret = bpf_prog_create(&fp, &tmp);
205         if (ret)
206                 goto errout_free;
207
208         prog->bpf_len = bpf_len;
209         prog->bpf_ops = bpf_ops;
210         prog->filter = fp;
211         prog->res.classid = classid;
212
213         tcf_bind_filter(tp, &prog->res, base);
214         tcf_exts_change(tp, &prog->exts, &exts);
215
216         return 0;
217 errout_free:
218         kfree(bpf_ops);
219 errout:
220         tcf_exts_destroy(tp, &exts);
221         return ret;
222 }
223
224 static u32 cls_bpf_grab_new_handle(struct tcf_proto *tp,
225                                    struct cls_bpf_head *head)
226 {
227         unsigned int i = 0x80000000;
228
229         do {
230                 if (++head->hgen == 0x7FFFFFFF)
231                         head->hgen = 1;
232         } while (--i > 0 && cls_bpf_get(tp, head->hgen));
233         if (i == 0)
234                 pr_err("Insufficient number of handles\n");
235
236         return i;
237 }
238
239 static int cls_bpf_change(struct net *net, struct sk_buff *in_skb,
240                           struct tcf_proto *tp, unsigned long base,
241                           u32 handle, struct nlattr **tca,
242                           unsigned long *arg, bool ovr)
243 {
244         struct cls_bpf_head *head = rtnl_dereference(tp->root);
245         struct cls_bpf_prog *oldprog = (struct cls_bpf_prog *) *arg;
246         struct nlattr *tb[TCA_BPF_MAX + 1];
247         struct cls_bpf_prog *prog;
248         int ret;
249
250         if (tca[TCA_OPTIONS] == NULL)
251                 return -EINVAL;
252
253         ret = nla_parse_nested(tb, TCA_BPF_MAX, tca[TCA_OPTIONS], bpf_policy);
254         if (ret < 0)
255                 return ret;
256
257         prog = kzalloc(sizeof(*prog), GFP_KERNEL);
258         if (!prog)
259                 return -ENOBUFS;
260
261         tcf_exts_init(&prog->exts, TCA_BPF_ACT, TCA_BPF_POLICE);
262
263         if (oldprog) {
264                 if (handle && oldprog->handle != handle) {
265                         ret = -EINVAL;
266                         goto errout;
267                 }
268         }
269
270         if (handle == 0)
271                 prog->handle = cls_bpf_grab_new_handle(tp, head);
272         else
273                 prog->handle = handle;
274         if (prog->handle == 0) {
275                 ret = -EINVAL;
276                 goto errout;
277         }
278
279         ret = cls_bpf_modify_existing(net, tp, prog, base, tb, tca[TCA_RATE], ovr);
280         if (ret < 0)
281                 goto errout;
282
283         if (oldprog) {
284                 list_replace_rcu(&prog->link, &oldprog->link);
285                 call_rcu(&oldprog->rcu, __cls_bpf_delete_prog);
286         } else {
287                 list_add_rcu(&prog->link, &head->plist);
288         }
289
290         *arg = (unsigned long) prog;
291         return 0;
292 errout:
293         kfree(prog);
294
295         return ret;
296 }
297
298 static int cls_bpf_dump(struct net *net, struct tcf_proto *tp, unsigned long fh,
299                         struct sk_buff *skb, struct tcmsg *tm)
300 {
301         struct cls_bpf_prog *prog = (struct cls_bpf_prog *) fh;
302         struct nlattr *nest, *nla;
303
304         if (prog == NULL)
305                 return skb->len;
306
307         tm->tcm_handle = prog->handle;
308
309         nest = nla_nest_start(skb, TCA_OPTIONS);
310         if (nest == NULL)
311                 goto nla_put_failure;
312
313         if (nla_put_u32(skb, TCA_BPF_CLASSID, prog->res.classid))
314                 goto nla_put_failure;
315         if (nla_put_u16(skb, TCA_BPF_OPS_LEN, prog->bpf_len))
316                 goto nla_put_failure;
317
318         nla = nla_reserve(skb, TCA_BPF_OPS, prog->bpf_len *
319                           sizeof(struct sock_filter));
320         if (nla == NULL)
321                 goto nla_put_failure;
322
323         memcpy(nla_data(nla), prog->bpf_ops, nla_len(nla));
324
325         if (tcf_exts_dump(skb, &prog->exts) < 0)
326                 goto nla_put_failure;
327
328         nla_nest_end(skb, nest);
329
330         if (tcf_exts_dump_stats(skb, &prog->exts) < 0)
331                 goto nla_put_failure;
332
333         return skb->len;
334
335 nla_put_failure:
336         nla_nest_cancel(skb, nest);
337         return -1;
338 }
339
340 static void cls_bpf_walk(struct tcf_proto *tp, struct tcf_walker *arg)
341 {
342         struct cls_bpf_head *head = rtnl_dereference(tp->root);
343         struct cls_bpf_prog *prog;
344
345         list_for_each_entry_rcu(prog, &head->plist, link) {
346                 if (arg->count < arg->skip)
347                         goto skip;
348                 if (arg->fn(tp, (unsigned long) prog, arg) < 0) {
349                         arg->stop = 1;
350                         break;
351                 }
352 skip:
353                 arg->count++;
354         }
355 }
356
357 static struct tcf_proto_ops cls_bpf_ops __read_mostly = {
358         .kind           =       "bpf",
359         .owner          =       THIS_MODULE,
360         .classify       =       cls_bpf_classify,
361         .init           =       cls_bpf_init,
362         .destroy        =       cls_bpf_destroy,
363         .get            =       cls_bpf_get,
364         .put            =       cls_bpf_put,
365         .change         =       cls_bpf_change,
366         .delete         =       cls_bpf_delete,
367         .walk           =       cls_bpf_walk,
368         .dump           =       cls_bpf_dump,
369 };
370
371 static int __init cls_bpf_init_mod(void)
372 {
373         return register_tcf_proto_ops(&cls_bpf_ops);
374 }
375
376 static void __exit cls_bpf_exit_mod(void)
377 {
378         unregister_tcf_proto_ops(&cls_bpf_ops);
379 }
380
381 module_init(cls_bpf_init_mod);
382 module_exit(cls_bpf_exit_mod);