Merge tag 'pinctrl-v3.19-1' of git://git.kernel.org/pub/scm/linux/kernel/git/linusw...
[cascardo/linux.git] / net / ipv4 / fou.c
1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
5 #include <linux/ip.h>
6 #include <linux/udp.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
10 #include <net/gue.h>
11 #include <net/ip.h>
12 #include <net/protocol.h>
13 #include <net/udp.h>
14 #include <net/udp_tunnel.h>
15 #include <net/xfrm.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
18
19 static DEFINE_SPINLOCK(fou_lock);
20 static LIST_HEAD(fou_list);
21
22 struct fou {
23         struct socket *sock;
24         u8 protocol;
25         u16 port;
26         struct udp_offload udp_offloads;
27         struct list_head list;
28 };
29
30 struct fou_cfg {
31         u16 type;
32         u8 protocol;
33         struct udp_port_cfg udp_config;
34 };
35
36 static inline struct fou *fou_from_sock(struct sock *sk)
37 {
38         return sk->sk_user_data;
39 }
40
41 static int fou_udp_encap_recv_deliver(struct sk_buff *skb,
42                                       u8 protocol, size_t len)
43 {
44         struct iphdr *iph = ip_hdr(skb);
45
46         /* Remove 'len' bytes from the packet (UDP header and
47          * FOU header if present), modify the protocol to the one
48          * we found, and then call rcv_encap.
49          */
50         iph->tot_len = htons(ntohs(iph->tot_len) - len);
51         __skb_pull(skb, len);
52         skb_postpull_rcsum(skb, udp_hdr(skb), len);
53         skb_reset_transport_header(skb);
54
55         return -protocol;
56 }
57
58 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
59 {
60         struct fou *fou = fou_from_sock(sk);
61
62         if (!fou)
63                 return 1;
64
65         return fou_udp_encap_recv_deliver(skb, fou->protocol,
66                                           sizeof(struct udphdr));
67 }
68
69 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
70 {
71         struct fou *fou = fou_from_sock(sk);
72         size_t len;
73         struct guehdr *guehdr;
74         struct udphdr *uh;
75
76         if (!fou)
77                 return 1;
78
79         len = sizeof(struct udphdr) + sizeof(struct guehdr);
80         if (!pskb_may_pull(skb, len))
81                 goto drop;
82
83         uh = udp_hdr(skb);
84         guehdr = (struct guehdr *)&uh[1];
85
86         len += guehdr->hlen << 2;
87         if (!pskb_may_pull(skb, len))
88                 goto drop;
89
90         uh = udp_hdr(skb);
91         guehdr = (struct guehdr *)&uh[1];
92
93         if (guehdr->version != 0)
94                 goto drop;
95
96         if (guehdr->flags) {
97                 /* No support yet */
98                 goto drop;
99         }
100
101         return fou_udp_encap_recv_deliver(skb, guehdr->next_hdr, len);
102 drop:
103         kfree_skb(skb);
104         return 0;
105 }
106
107 static struct sk_buff **fou_gro_receive(struct sk_buff **head,
108                                         struct sk_buff *skb)
109 {
110         const struct net_offload *ops;
111         struct sk_buff **pp = NULL;
112         u8 proto = NAPI_GRO_CB(skb)->proto;
113         const struct net_offload **offloads;
114
115         rcu_read_lock();
116         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
117         ops = rcu_dereference(offloads[proto]);
118         if (!ops || !ops->callbacks.gro_receive)
119                 goto out_unlock;
120
121         pp = ops->callbacks.gro_receive(head, skb);
122
123 out_unlock:
124         rcu_read_unlock();
125
126         return pp;
127 }
128
129 static int fou_gro_complete(struct sk_buff *skb, int nhoff)
130 {
131         const struct net_offload *ops;
132         u8 proto = NAPI_GRO_CB(skb)->proto;
133         int err = -ENOSYS;
134         const struct net_offload **offloads;
135
136         udp_tunnel_gro_complete(skb, nhoff);
137
138         rcu_read_lock();
139         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
140         ops = rcu_dereference(offloads[proto]);
141         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
142                 goto out_unlock;
143
144         err = ops->callbacks.gro_complete(skb, nhoff);
145
146 out_unlock:
147         rcu_read_unlock();
148
149         return err;
150 }
151
152 static struct sk_buff **gue_gro_receive(struct sk_buff **head,
153                                         struct sk_buff *skb)
154 {
155         const struct net_offload **offloads;
156         const struct net_offload *ops;
157         struct sk_buff **pp = NULL;
158         struct sk_buff *p;
159         u8 proto;
160         struct guehdr *guehdr;
161         unsigned int hlen, guehlen;
162         unsigned int off;
163         int flush = 1;
164
165         off = skb_gro_offset(skb);
166         hlen = off + sizeof(*guehdr);
167         guehdr = skb_gro_header_fast(skb, off);
168         if (skb_gro_header_hard(skb, hlen)) {
169                 guehdr = skb_gro_header_slow(skb, hlen, off);
170                 if (unlikely(!guehdr))
171                         goto out;
172         }
173
174         proto = guehdr->next_hdr;
175
176         rcu_read_lock();
177         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
178         ops = rcu_dereference(offloads[proto]);
179         if (WARN_ON(!ops || !ops->callbacks.gro_receive))
180                 goto out_unlock;
181
182         guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
183
184         hlen = off + guehlen;
185         if (skb_gro_header_hard(skb, hlen)) {
186                 guehdr = skb_gro_header_slow(skb, hlen, off);
187                 if (unlikely(!guehdr))
188                         goto out_unlock;
189         }
190
191         flush = 0;
192
193         for (p = *head; p; p = p->next) {
194                 const struct guehdr *guehdr2;
195
196                 if (!NAPI_GRO_CB(p)->same_flow)
197                         continue;
198
199                 guehdr2 = (struct guehdr *)(p->data + off);
200
201                 /* Compare base GUE header to be equal (covers
202                  * hlen, version, next_hdr, and flags.
203                  */
204                 if (guehdr->word != guehdr2->word) {
205                         NAPI_GRO_CB(p)->same_flow = 0;
206                         continue;
207                 }
208
209                 /* Compare optional fields are the same. */
210                 if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
211                                            guehdr->hlen << 2)) {
212                         NAPI_GRO_CB(p)->same_flow = 0;
213                         continue;
214                 }
215         }
216
217         skb_gro_pull(skb, guehlen);
218
219         /* Adjusted NAPI_GRO_CB(skb)->csum after skb_gro_pull()*/
220         skb_gro_postpull_rcsum(skb, guehdr, guehlen);
221
222         pp = ops->callbacks.gro_receive(head, skb);
223
224 out_unlock:
225         rcu_read_unlock();
226 out:
227         NAPI_GRO_CB(skb)->flush |= flush;
228
229         return pp;
230 }
231
232 static int gue_gro_complete(struct sk_buff *skb, int nhoff)
233 {
234         const struct net_offload **offloads;
235         struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
236         const struct net_offload *ops;
237         unsigned int guehlen;
238         u8 proto;
239         int err = -ENOENT;
240
241         proto = guehdr->next_hdr;
242
243         guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
244
245         rcu_read_lock();
246         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
247         ops = rcu_dereference(offloads[proto]);
248         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
249                 goto out_unlock;
250
251         err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
252
253 out_unlock:
254         rcu_read_unlock();
255         return err;
256 }
257
258 static int fou_add_to_port_list(struct fou *fou)
259 {
260         struct fou *fout;
261
262         spin_lock(&fou_lock);
263         list_for_each_entry(fout, &fou_list, list) {
264                 if (fou->port == fout->port) {
265                         spin_unlock(&fou_lock);
266                         return -EALREADY;
267                 }
268         }
269
270         list_add(&fou->list, &fou_list);
271         spin_unlock(&fou_lock);
272
273         return 0;
274 }
275
276 static void fou_release(struct fou *fou)
277 {
278         struct socket *sock = fou->sock;
279         struct sock *sk = sock->sk;
280
281         udp_del_offload(&fou->udp_offloads);
282
283         list_del(&fou->list);
284
285         /* Remove hooks into tunnel socket */
286         sk->sk_user_data = NULL;
287
288         sock_release(sock);
289
290         kfree(fou);
291 }
292
293 static int fou_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
294 {
295         udp_sk(sk)->encap_rcv = fou_udp_recv;
296         fou->protocol = cfg->protocol;
297         fou->udp_offloads.callbacks.gro_receive = fou_gro_receive;
298         fou->udp_offloads.callbacks.gro_complete = fou_gro_complete;
299         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
300         fou->udp_offloads.ipproto = cfg->protocol;
301
302         return 0;
303 }
304
305 static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
306 {
307         udp_sk(sk)->encap_rcv = gue_udp_recv;
308         fou->udp_offloads.callbacks.gro_receive = gue_gro_receive;
309         fou->udp_offloads.callbacks.gro_complete = gue_gro_complete;
310         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
311
312         return 0;
313 }
314
315 static int fou_create(struct net *net, struct fou_cfg *cfg,
316                       struct socket **sockp)
317 {
318         struct fou *fou = NULL;
319         int err;
320         struct socket *sock = NULL;
321         struct sock *sk;
322
323         /* Open UDP socket */
324         err = udp_sock_create(net, &cfg->udp_config, &sock);
325         if (err < 0)
326                 goto error;
327
328         /* Allocate FOU port structure */
329         fou = kzalloc(sizeof(*fou), GFP_KERNEL);
330         if (!fou) {
331                 err = -ENOMEM;
332                 goto error;
333         }
334
335         sk = sock->sk;
336
337         fou->port = cfg->udp_config.local_udp_port;
338
339         /* Initial for fou type */
340         switch (cfg->type) {
341         case FOU_ENCAP_DIRECT:
342                 err = fou_encap_init(sk, fou, cfg);
343                 if (err)
344                         goto error;
345                 break;
346         case FOU_ENCAP_GUE:
347                 err = gue_encap_init(sk, fou, cfg);
348                 if (err)
349                         goto error;
350                 break;
351         default:
352                 err = -EINVAL;
353                 goto error;
354         }
355
356         udp_sk(sk)->encap_type = 1;
357         udp_encap_enable();
358
359         sk->sk_user_data = fou;
360         fou->sock = sock;
361
362         udp_set_convert_csum(sk, true);
363
364         sk->sk_allocation = GFP_ATOMIC;
365
366         if (cfg->udp_config.family == AF_INET) {
367                 err = udp_add_offload(&fou->udp_offloads);
368                 if (err)
369                         goto error;
370         }
371
372         err = fou_add_to_port_list(fou);
373         if (err)
374                 goto error;
375
376         if (sockp)
377                 *sockp = sock;
378
379         return 0;
380
381 error:
382         kfree(fou);
383         if (sock)
384                 sock_release(sock);
385
386         return err;
387 }
388
389 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
390 {
391         struct fou *fou;
392         u16 port = cfg->udp_config.local_udp_port;
393         int err = -EINVAL;
394
395         spin_lock(&fou_lock);
396         list_for_each_entry(fou, &fou_list, list) {
397                 if (fou->port == port) {
398                         udp_del_offload(&fou->udp_offloads);
399                         fou_release(fou);
400                         err = 0;
401                         break;
402                 }
403         }
404         spin_unlock(&fou_lock);
405
406         return err;
407 }
408
409 static struct genl_family fou_nl_family = {
410         .id             = GENL_ID_GENERATE,
411         .hdrsize        = 0,
412         .name           = FOU_GENL_NAME,
413         .version        = FOU_GENL_VERSION,
414         .maxattr        = FOU_ATTR_MAX,
415         .netnsok        = true,
416 };
417
418 static struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
419         [FOU_ATTR_PORT] = { .type = NLA_U16, },
420         [FOU_ATTR_AF] = { .type = NLA_U8, },
421         [FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
422         [FOU_ATTR_TYPE] = { .type = NLA_U8, },
423 };
424
425 static int parse_nl_config(struct genl_info *info,
426                            struct fou_cfg *cfg)
427 {
428         memset(cfg, 0, sizeof(*cfg));
429
430         cfg->udp_config.family = AF_INET;
431
432         if (info->attrs[FOU_ATTR_AF]) {
433                 u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
434
435                 if (family != AF_INET && family != AF_INET6)
436                         return -EINVAL;
437
438                 cfg->udp_config.family = family;
439         }
440
441         if (info->attrs[FOU_ATTR_PORT]) {
442                 u16 port = nla_get_u16(info->attrs[FOU_ATTR_PORT]);
443
444                 cfg->udp_config.local_udp_port = port;
445         }
446
447         if (info->attrs[FOU_ATTR_IPPROTO])
448                 cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
449
450         if (info->attrs[FOU_ATTR_TYPE])
451                 cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
452
453         return 0;
454 }
455
456 static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
457 {
458         struct fou_cfg cfg;
459         int err;
460
461         err = parse_nl_config(info, &cfg);
462         if (err)
463                 return err;
464
465         return fou_create(&init_net, &cfg, NULL);
466 }
467
468 static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
469 {
470         struct fou_cfg cfg;
471
472         parse_nl_config(info, &cfg);
473
474         return fou_destroy(&init_net, &cfg);
475 }
476
477 static const struct genl_ops fou_nl_ops[] = {
478         {
479                 .cmd = FOU_CMD_ADD,
480                 .doit = fou_nl_cmd_add_port,
481                 .policy = fou_nl_policy,
482                 .flags = GENL_ADMIN_PERM,
483         },
484         {
485                 .cmd = FOU_CMD_DEL,
486                 .doit = fou_nl_cmd_rm_port,
487                 .policy = fou_nl_policy,
488                 .flags = GENL_ADMIN_PERM,
489         },
490 };
491
492 static int __init fou_init(void)
493 {
494         int ret;
495
496         ret = genl_register_family_with_ops(&fou_nl_family,
497                                             fou_nl_ops);
498
499         return ret;
500 }
501
502 static void __exit fou_fini(void)
503 {
504         struct fou *fou, *next;
505
506         genl_unregister_family(&fou_nl_family);
507
508         /* Close all the FOU sockets */
509
510         spin_lock(&fou_lock);
511         list_for_each_entry_safe(fou, next, &fou_list, list)
512                 fou_release(fou);
513         spin_unlock(&fou_lock);
514 }
515
516 module_init(fou_init);
517 module_exit(fou_fini);
518 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
519 MODULE_LICENSE("GPL");