Merge branch 'master' of git://git.kernel.org/pub/scm/linux/kernel/git/linville/wireless
[cascardo/linux.git] / net / ipv4 / fou.c
1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
5 #include <linux/ip.h>
6 #include <linux/udp.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
10 #include <net/gue.h>
11 #include <net/ip.h>
12 #include <net/protocol.h>
13 #include <net/udp.h>
14 #include <net/udp_tunnel.h>
15 #include <net/xfrm.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
18
19 static DEFINE_SPINLOCK(fou_lock);
20 static LIST_HEAD(fou_list);
21
22 struct fou {
23         struct socket *sock;
24         u8 protocol;
25         u16 port;
26         struct udp_offload udp_offloads;
27         struct list_head list;
28 };
29
30 struct fou_cfg {
31         u16 type;
32         u8 protocol;
33         struct udp_port_cfg udp_config;
34 };
35
36 static inline struct fou *fou_from_sock(struct sock *sk)
37 {
38         return sk->sk_user_data;
39 }
40
41 static int fou_udp_encap_recv_deliver(struct sk_buff *skb,
42                                       u8 protocol, size_t len)
43 {
44         struct iphdr *iph = ip_hdr(skb);
45
46         /* Remove 'len' bytes from the packet (UDP header and
47          * FOU header if present), modify the protocol to the one
48          * we found, and then call rcv_encap.
49          */
50         iph->tot_len = htons(ntohs(iph->tot_len) - len);
51         __skb_pull(skb, len);
52         skb_postpull_rcsum(skb, udp_hdr(skb), len);
53         skb_reset_transport_header(skb);
54
55         return -protocol;
56 }
57
58 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
59 {
60         struct fou *fou = fou_from_sock(sk);
61
62         if (!fou)
63                 return 1;
64
65         return fou_udp_encap_recv_deliver(skb, fou->protocol,
66                                           sizeof(struct udphdr));
67 }
68
69 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
70 {
71         struct fou *fou = fou_from_sock(sk);
72         size_t len;
73         struct guehdr *guehdr;
74         struct udphdr *uh;
75
76         if (!fou)
77                 return 1;
78
79         len = sizeof(struct udphdr) + sizeof(struct guehdr);
80         if (!pskb_may_pull(skb, len))
81                 goto drop;
82
83         uh = udp_hdr(skb);
84         guehdr = (struct guehdr *)&uh[1];
85
86         len += guehdr->hlen << 2;
87         if (!pskb_may_pull(skb, len))
88                 goto drop;
89
90         uh = udp_hdr(skb);
91         guehdr = (struct guehdr *)&uh[1];
92
93         if (guehdr->version != 0)
94                 goto drop;
95
96         if (guehdr->flags) {
97                 /* No support yet */
98                 goto drop;
99         }
100
101         return fou_udp_encap_recv_deliver(skb, guehdr->next_hdr, len);
102 drop:
103         kfree_skb(skb);
104         return 0;
105 }
106
107 static struct sk_buff **fou_gro_receive(struct sk_buff **head,
108                                         struct sk_buff *skb)
109 {
110         const struct net_offload *ops;
111         struct sk_buff **pp = NULL;
112         u8 proto = NAPI_GRO_CB(skb)->proto;
113         const struct net_offload **offloads;
114
115         rcu_read_lock();
116         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
117         ops = rcu_dereference(offloads[proto]);
118         if (!ops || !ops->callbacks.gro_receive)
119                 goto out_unlock;
120
121         pp = ops->callbacks.gro_receive(head, skb);
122
123 out_unlock:
124         rcu_read_unlock();
125
126         return pp;
127 }
128
129 static int fou_gro_complete(struct sk_buff *skb, int nhoff)
130 {
131         const struct net_offload *ops;
132         u8 proto = NAPI_GRO_CB(skb)->proto;
133         int err = -ENOSYS;
134         const struct net_offload **offloads;
135
136         rcu_read_lock();
137         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
138         ops = rcu_dereference(offloads[proto]);
139         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
140                 goto out_unlock;
141
142         err = ops->callbacks.gro_complete(skb, nhoff);
143
144 out_unlock:
145         rcu_read_unlock();
146
147         return err;
148 }
149
150 static struct sk_buff **gue_gro_receive(struct sk_buff **head,
151                                         struct sk_buff *skb)
152 {
153         const struct net_offload **offloads;
154         const struct net_offload *ops;
155         struct sk_buff **pp = NULL;
156         struct sk_buff *p;
157         u8 proto;
158         struct guehdr *guehdr;
159         unsigned int hlen, guehlen;
160         unsigned int off;
161         int flush = 1;
162
163         off = skb_gro_offset(skb);
164         hlen = off + sizeof(*guehdr);
165         guehdr = skb_gro_header_fast(skb, off);
166         if (skb_gro_header_hard(skb, hlen)) {
167                 guehdr = skb_gro_header_slow(skb, hlen, off);
168                 if (unlikely(!guehdr))
169                         goto out;
170         }
171
172         proto = guehdr->next_hdr;
173
174         rcu_read_lock();
175         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
176         ops = rcu_dereference(offloads[proto]);
177         if (WARN_ON(!ops || !ops->callbacks.gro_receive))
178                 goto out_unlock;
179
180         guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
181
182         hlen = off + guehlen;
183         if (skb_gro_header_hard(skb, hlen)) {
184                 guehdr = skb_gro_header_slow(skb, hlen, off);
185                 if (unlikely(!guehdr))
186                         goto out_unlock;
187         }
188
189         flush = 0;
190
191         for (p = *head; p; p = p->next) {
192                 const struct guehdr *guehdr2;
193
194                 if (!NAPI_GRO_CB(p)->same_flow)
195                         continue;
196
197                 guehdr2 = (struct guehdr *)(p->data + off);
198
199                 /* Compare base GUE header to be equal (covers
200                  * hlen, version, next_hdr, and flags.
201                  */
202                 if (guehdr->word != guehdr2->word) {
203                         NAPI_GRO_CB(p)->same_flow = 0;
204                         continue;
205                 }
206
207                 /* Compare optional fields are the same. */
208                 if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
209                                            guehdr->hlen << 2)) {
210                         NAPI_GRO_CB(p)->same_flow = 0;
211                         continue;
212                 }
213         }
214
215         skb_gro_pull(skb, guehlen);
216
217         /* Adjusted NAPI_GRO_CB(skb)->csum after skb_gro_pull()*/
218         skb_gro_postpull_rcsum(skb, guehdr, guehlen);
219
220         pp = ops->callbacks.gro_receive(head, skb);
221
222 out_unlock:
223         rcu_read_unlock();
224 out:
225         NAPI_GRO_CB(skb)->flush |= flush;
226
227         return pp;
228 }
229
230 static int gue_gro_complete(struct sk_buff *skb, int nhoff)
231 {
232         const struct net_offload **offloads;
233         struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
234         const struct net_offload *ops;
235         unsigned int guehlen;
236         u8 proto;
237         int err = -ENOENT;
238
239         proto = guehdr->next_hdr;
240
241         guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
242
243         rcu_read_lock();
244         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
245         ops = rcu_dereference(offloads[proto]);
246         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
247                 goto out_unlock;
248
249         err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
250
251 out_unlock:
252         rcu_read_unlock();
253         return err;
254 }
255
256 static int fou_add_to_port_list(struct fou *fou)
257 {
258         struct fou *fout;
259
260         spin_lock(&fou_lock);
261         list_for_each_entry(fout, &fou_list, list) {
262                 if (fou->port == fout->port) {
263                         spin_unlock(&fou_lock);
264                         return -EALREADY;
265                 }
266         }
267
268         list_add(&fou->list, &fou_list);
269         spin_unlock(&fou_lock);
270
271         return 0;
272 }
273
274 static void fou_release(struct fou *fou)
275 {
276         struct socket *sock = fou->sock;
277         struct sock *sk = sock->sk;
278
279         udp_del_offload(&fou->udp_offloads);
280
281         list_del(&fou->list);
282
283         /* Remove hooks into tunnel socket */
284         sk->sk_user_data = NULL;
285
286         sock_release(sock);
287
288         kfree(fou);
289 }
290
291 static int fou_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
292 {
293         udp_sk(sk)->encap_rcv = fou_udp_recv;
294         fou->protocol = cfg->protocol;
295         fou->udp_offloads.callbacks.gro_receive = fou_gro_receive;
296         fou->udp_offloads.callbacks.gro_complete = fou_gro_complete;
297         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
298         fou->udp_offloads.ipproto = cfg->protocol;
299
300         return 0;
301 }
302
303 static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
304 {
305         udp_sk(sk)->encap_rcv = gue_udp_recv;
306         fou->udp_offloads.callbacks.gro_receive = gue_gro_receive;
307         fou->udp_offloads.callbacks.gro_complete = gue_gro_complete;
308         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
309
310         return 0;
311 }
312
313 static int fou_create(struct net *net, struct fou_cfg *cfg,
314                       struct socket **sockp)
315 {
316         struct fou *fou = NULL;
317         int err;
318         struct socket *sock = NULL;
319         struct sock *sk;
320
321         /* Open UDP socket */
322         err = udp_sock_create(net, &cfg->udp_config, &sock);
323         if (err < 0)
324                 goto error;
325
326         /* Allocate FOU port structure */
327         fou = kzalloc(sizeof(*fou), GFP_KERNEL);
328         if (!fou) {
329                 err = -ENOMEM;
330                 goto error;
331         }
332
333         sk = sock->sk;
334
335         fou->port = cfg->udp_config.local_udp_port;
336
337         /* Initial for fou type */
338         switch (cfg->type) {
339         case FOU_ENCAP_DIRECT:
340                 err = fou_encap_init(sk, fou, cfg);
341                 if (err)
342                         goto error;
343                 break;
344         case FOU_ENCAP_GUE:
345                 err = gue_encap_init(sk, fou, cfg);
346                 if (err)
347                         goto error;
348                 break;
349         default:
350                 err = -EINVAL;
351                 goto error;
352         }
353
354         udp_sk(sk)->encap_type = 1;
355         udp_encap_enable();
356
357         sk->sk_user_data = fou;
358         fou->sock = sock;
359
360         udp_set_convert_csum(sk, true);
361
362         sk->sk_allocation = GFP_ATOMIC;
363
364         if (cfg->udp_config.family == AF_INET) {
365                 err = udp_add_offload(&fou->udp_offloads);
366                 if (err)
367                         goto error;
368         }
369
370         err = fou_add_to_port_list(fou);
371         if (err)
372                 goto error;
373
374         if (sockp)
375                 *sockp = sock;
376
377         return 0;
378
379 error:
380         kfree(fou);
381         if (sock)
382                 sock_release(sock);
383
384         return err;
385 }
386
387 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
388 {
389         struct fou *fou;
390         u16 port = cfg->udp_config.local_udp_port;
391         int err = -EINVAL;
392
393         spin_lock(&fou_lock);
394         list_for_each_entry(fou, &fou_list, list) {
395                 if (fou->port == port) {
396                         udp_del_offload(&fou->udp_offloads);
397                         fou_release(fou);
398                         err = 0;
399                         break;
400                 }
401         }
402         spin_unlock(&fou_lock);
403
404         return err;
405 }
406
407 static struct genl_family fou_nl_family = {
408         .id             = GENL_ID_GENERATE,
409         .hdrsize        = 0,
410         .name           = FOU_GENL_NAME,
411         .version        = FOU_GENL_VERSION,
412         .maxattr        = FOU_ATTR_MAX,
413         .netnsok        = true,
414 };
415
416 static struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
417         [FOU_ATTR_PORT] = { .type = NLA_U16, },
418         [FOU_ATTR_AF] = { .type = NLA_U8, },
419         [FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
420         [FOU_ATTR_TYPE] = { .type = NLA_U8, },
421 };
422
423 static int parse_nl_config(struct genl_info *info,
424                            struct fou_cfg *cfg)
425 {
426         memset(cfg, 0, sizeof(*cfg));
427
428         cfg->udp_config.family = AF_INET;
429
430         if (info->attrs[FOU_ATTR_AF]) {
431                 u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
432
433                 if (family != AF_INET && family != AF_INET6)
434                         return -EINVAL;
435
436                 cfg->udp_config.family = family;
437         }
438
439         if (info->attrs[FOU_ATTR_PORT]) {
440                 u16 port = nla_get_u16(info->attrs[FOU_ATTR_PORT]);
441
442                 cfg->udp_config.local_udp_port = port;
443         }
444
445         if (info->attrs[FOU_ATTR_IPPROTO])
446                 cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
447
448         if (info->attrs[FOU_ATTR_TYPE])
449                 cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
450
451         return 0;
452 }
453
454 static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
455 {
456         struct fou_cfg cfg;
457         int err;
458
459         err = parse_nl_config(info, &cfg);
460         if (err)
461                 return err;
462
463         return fou_create(&init_net, &cfg, NULL);
464 }
465
466 static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
467 {
468         struct fou_cfg cfg;
469
470         parse_nl_config(info, &cfg);
471
472         return fou_destroy(&init_net, &cfg);
473 }
474
475 static const struct genl_ops fou_nl_ops[] = {
476         {
477                 .cmd = FOU_CMD_ADD,
478                 .doit = fou_nl_cmd_add_port,
479                 .policy = fou_nl_policy,
480                 .flags = GENL_ADMIN_PERM,
481         },
482         {
483                 .cmd = FOU_CMD_DEL,
484                 .doit = fou_nl_cmd_rm_port,
485                 .policy = fou_nl_policy,
486                 .flags = GENL_ADMIN_PERM,
487         },
488 };
489
490 static int __init fou_init(void)
491 {
492         int ret;
493
494         ret = genl_register_family_with_ops(&fou_nl_family,
495                                             fou_nl_ops);
496
497         return ret;
498 }
499
500 static void __exit fou_fini(void)
501 {
502         struct fou *fou, *next;
503
504         genl_unregister_family(&fou_nl_family);
505
506         /* Close all the FOU sockets */
507
508         spin_lock(&fou_lock);
509         list_for_each_entry_safe(fou, next, &fou_list, list)
510                 fou_release(fou);
511         spin_unlock(&fou_lock);
512 }
513
514 module_init(fou_init);
515 module_exit(fou_fini);
516 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
517 MODULE_LICENSE("GPL");