Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[cascardo/linux.git] / net / ipv4 / fou.c
index 606c520..3dfe982 100644 (file)
@@ -38,21 +38,17 @@ static inline struct fou *fou_from_sock(struct sock *sk)
        return sk->sk_user_data;
 }
 
-static int fou_udp_encap_recv_deliver(struct sk_buff *skb,
-                                     u8 protocol, size_t len)
+static void fou_recv_pull(struct sk_buff *skb, size_t len)
 {
        struct iphdr *iph = ip_hdr(skb);
 
        /* Remove 'len' bytes from the packet (UDP header and
-        * FOU header if present), modify the protocol to the one
-        * we found, and then call rcv_encap.
+        * FOU header if present).
         */
        iph->tot_len = htons(ntohs(iph->tot_len) - len);
        __skb_pull(skb, len);
        skb_postpull_rcsum(skb, udp_hdr(skb), len);
        skb_reset_transport_header(skb);
-
-       return -protocol;
 }
 
 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
@@ -62,16 +58,78 @@ static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
        if (!fou)
                return 1;
 
-       return fou_udp_encap_recv_deliver(skb, fou->protocol,
-                                         sizeof(struct udphdr));
+       fou_recv_pull(skb, sizeof(struct udphdr));
+
+       return -fou->protocol;
+}
+
+static struct guehdr *gue_remcsum(struct sk_buff *skb, struct guehdr *guehdr,
+                                 void *data, int hdrlen, u8 ipproto)
+{
+       __be16 *pd = data;
+       u16 start = ntohs(pd[0]);
+       u16 offset = ntohs(pd[1]);
+       u16 poffset = 0;
+       u16 plen;
+       __wsum csum, delta;
+       __sum16 *psum;
+
+       if (skb->remcsum_offload) {
+               /* Already processed in GRO path */
+               skb->remcsum_offload = 0;
+               return guehdr;
+       }
+
+       if (start > skb->len - hdrlen ||
+           offset > skb->len - hdrlen - sizeof(u16))
+               return NULL;
+
+       if (unlikely(skb->ip_summed != CHECKSUM_COMPLETE))
+               __skb_checksum_complete(skb);
+
+       plen = hdrlen + offset + sizeof(u16);
+       if (!pskb_may_pull(skb, plen))
+               return NULL;
+       guehdr = (struct guehdr *)&udp_hdr(skb)[1];
+
+       if (ipproto == IPPROTO_IP && sizeof(struct iphdr) < plen) {
+               struct iphdr *ip = (struct iphdr *)(skb->data + hdrlen);
+
+               /* If next header happens to be IP we can skip that for the
+                * checksum calculation since the IP header checksum is zero
+                * if correct.
+                */
+               poffset = ip->ihl * 4;
+       }
+
+       csum = csum_sub(skb->csum, skb_checksum(skb, poffset + hdrlen,
+                                               start - poffset - hdrlen, 0));
+
+       /* Set derived checksum in packet */
+       psum = (__sum16 *)(skb->data + hdrlen + offset);
+       delta = csum_sub(csum_fold(csum), *psum);
+       *psum = csum_fold(csum);
+
+       /* Adjust skb->csum since we changed the packet */
+       skb->csum = csum_add(skb->csum, delta);
+
+       return guehdr;
+}
+
+static int gue_control_message(struct sk_buff *skb, struct guehdr *guehdr)
+{
+       /* No support yet */
+       kfree_skb(skb);
+       return 0;
 }
 
 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
 {
        struct fou *fou = fou_from_sock(sk);
-       size_t len;
+       size_t len, optlen, hdrlen;
        struct guehdr *guehdr;
-       struct udphdr *uh;
+       void *data;
+       u16 doffset = 0;
 
        if (!fou)
                return 1;
@@ -80,25 +138,61 @@ static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
        if (!pskb_may_pull(skb, len))
                goto drop;
 
-       uh = udp_hdr(skb);
-       guehdr = (struct guehdr *)&uh[1];
+       guehdr = (struct guehdr *)&udp_hdr(skb)[1];
+
+       optlen = guehdr->hlen << 2;
+       len += optlen;
 
-       len += guehdr->hlen << 2;
        if (!pskb_may_pull(skb, len))
                goto drop;
 
-       uh = udp_hdr(skb);
-       guehdr = (struct guehdr *)&uh[1];
+       /* guehdr may change after pull */
+       guehdr = (struct guehdr *)&udp_hdr(skb)[1];
 
-       if (guehdr->version != 0)
-               goto drop;
+       hdrlen = sizeof(struct guehdr) + optlen;
 
-       if (guehdr->flags) {
-               /* No support yet */
+       if (guehdr->version != 0 || validate_gue_flags(guehdr, optlen))
                goto drop;
+
+       hdrlen = sizeof(struct guehdr) + optlen;
+
+       ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
+
+       /* Pull UDP header now, skb->data points to guehdr */
+       __skb_pull(skb, sizeof(struct udphdr));
+
+       /* Pull csum through the guehdr now . This can be used if
+        * there is a remote checksum offload.
+        */
+       skb_postpull_rcsum(skb, udp_hdr(skb), len);
+
+       data = &guehdr[1];
+
+       if (guehdr->flags & GUE_FLAG_PRIV) {
+               __be32 flags = *(__be32 *)(data + doffset);
+
+               doffset += GUE_LEN_PRIV;
+
+               if (flags & GUE_PFLAG_REMCSUM) {
+                       guehdr = gue_remcsum(skb, guehdr, data + doffset,
+                                            hdrlen, guehdr->proto_ctype);
+                       if (!guehdr)
+                               goto drop;
+
+                       data = &guehdr[1];
+
+                       doffset += GUE_PLEN_REMCSUM;
+               }
        }
 
-       return fou_udp_encap_recv_deliver(skb, guehdr->next_hdr, len);
+       if (unlikely(guehdr->control))
+               return gue_control_message(skb, guehdr);
+
+       __skb_pull(skb, hdrlen);
+       skb_reset_transport_header(skb);
+
+       return -guehdr->proto_ctype;
+
 drop:
        kfree_skb(skb);
        return 0;
@@ -149,6 +243,66 @@ out_unlock:
        return err;
 }
 
+static struct guehdr *gue_gro_remcsum(struct sk_buff *skb, unsigned int off,
+                                     struct guehdr *guehdr, void *data,
+                                     size_t hdrlen, u8 ipproto)
+{
+       __be16 *pd = data;
+       u16 start = ntohs(pd[0]);
+       u16 offset = ntohs(pd[1]);
+       u16 poffset = 0;
+       u16 plen;
+       void *ptr;
+       __wsum csum, delta;
+       __sum16 *psum;
+
+       if (skb->remcsum_offload)
+               return guehdr;
+
+       if (start > skb_gro_len(skb) - hdrlen ||
+           offset > skb_gro_len(skb) - hdrlen - sizeof(u16) ||
+           !NAPI_GRO_CB(skb)->csum_valid || skb->remcsum_offload)
+               return NULL;
+
+       plen = hdrlen + offset + sizeof(u16);
+
+       /* Pull checksum that will be written */
+       if (skb_gro_header_hard(skb, off + plen)) {
+               guehdr = skb_gro_header_slow(skb, off + plen, off);
+               if (!guehdr)
+                       return NULL;
+       }
+
+       ptr = (void *)guehdr + hdrlen;
+
+       if (ipproto == IPPROTO_IP &&
+           (hdrlen + sizeof(struct iphdr) < plen)) {
+               struct iphdr *ip = (struct iphdr *)(ptr + hdrlen);
+
+               /* If next header happens to be IP we can skip
+                * that for the checksum calculation since the
+                * IP header checksum is zero if correct.
+                */
+               poffset = ip->ihl * 4;
+       }
+
+       csum = csum_sub(NAPI_GRO_CB(skb)->csum,
+                       csum_partial(ptr + poffset, start - poffset, 0));
+
+       /* Set derived checksum in packet */
+       psum = (__sum16 *)(ptr + offset);
+       delta = csum_sub(csum_fold(csum), *psum);
+       *psum = csum_fold(csum);
+
+       /* Adjust skb->csum since we changed the packet */
+       skb->csum = csum_add(skb->csum, delta);
+       NAPI_GRO_CB(skb)->csum = csum_add(NAPI_GRO_CB(skb)->csum, delta);
+
+       skb->remcsum_offload = 1;
+
+       return guehdr;
+}
+
 static struct sk_buff **gue_gro_receive(struct sk_buff **head,
                                        struct sk_buff *skb)
 {
@@ -156,38 +310,64 @@ static struct sk_buff **gue_gro_receive(struct sk_buff **head,
        const struct net_offload *ops;
        struct sk_buff **pp = NULL;
        struct sk_buff *p;
-       u8 proto;
        struct guehdr *guehdr;
-       unsigned int hlen, guehlen;
-       unsigned int off;
+       size_t len, optlen, hdrlen, off;
+       void *data;
+       u16 doffset = 0;
        int flush = 1;
 
        off = skb_gro_offset(skb);
-       hlen = off + sizeof(*guehdr);
+       len = off + sizeof(*guehdr);
+
        guehdr = skb_gro_header_fast(skb, off);
-       if (skb_gro_header_hard(skb, hlen)) {
-               guehdr = skb_gro_header_slow(skb, hlen, off);
+       if (skb_gro_header_hard(skb, len)) {
+               guehdr = skb_gro_header_slow(skb, len, off);
                if (unlikely(!guehdr))
                        goto out;
        }
 
-       proto = guehdr->next_hdr;
+       optlen = guehdr->hlen << 2;
+       len += optlen;
 
-       rcu_read_lock();
-       offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
-       ops = rcu_dereference(offloads[proto]);
-       if (WARN_ON(!ops || !ops->callbacks.gro_receive))
-               goto out_unlock;
+       if (skb_gro_header_hard(skb, len)) {
+               guehdr = skb_gro_header_slow(skb, len, off);
+               if (unlikely(!guehdr))
+                       goto out;
+       }
 
-       guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
+       if (unlikely(guehdr->control) || guehdr->version != 0 ||
+           validate_gue_flags(guehdr, optlen))
+               goto out;
 
-       hlen = off + guehlen;
-       if (skb_gro_header_hard(skb, hlen)) {
-               guehdr = skb_gro_header_slow(skb, hlen, off);
-               if (unlikely(!guehdr))
-                       goto out_unlock;
+       hdrlen = sizeof(*guehdr) + optlen;
+
+       /* Adjust NAPI_GRO_CB(skb)->csum to account for guehdr,
+        * this is needed if there is a remote checkcsum offload.
+        */
+       skb_gro_postpull_rcsum(skb, guehdr, hdrlen);
+
+       data = &guehdr[1];
+
+       if (guehdr->flags & GUE_FLAG_PRIV) {
+               __be32 flags = *(__be32 *)(data + doffset);
+
+               doffset += GUE_LEN_PRIV;
+
+               if (flags & GUE_PFLAG_REMCSUM) {
+                       guehdr = gue_gro_remcsum(skb, off, guehdr,
+                                                data + doffset, hdrlen,
+                                                guehdr->proto_ctype);
+                       if (!guehdr)
+                               goto out;
+
+                       data = &guehdr[1];
+
+                       doffset += GUE_PLEN_REMCSUM;
+               }
        }
 
+       skb_gro_pull(skb, hdrlen);
+
        flush = 0;
 
        for (p = *head; p; p = p->next) {
@@ -199,7 +379,7 @@ static struct sk_buff **gue_gro_receive(struct sk_buff **head,
                guehdr2 = (struct guehdr *)(p->data + off);
 
                /* Compare base GUE header to be equal (covers
-                * hlen, version, next_hdr, and flags.
+                * hlen, version, proto_ctype, and flags.
                 */
                if (guehdr->word != guehdr2->word) {
                        NAPI_GRO_CB(p)->same_flow = 0;
@@ -214,10 +394,11 @@ static struct sk_buff **gue_gro_receive(struct sk_buff **head,
                }
        }
 
-       skb_gro_pull(skb, guehlen);
-
-       /* Adjusted NAPI_GRO_CB(skb)->csum after skb_gro_pull()*/
-       skb_gro_postpull_rcsum(skb, guehdr, guehlen);
+       rcu_read_lock();
+       offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
+       ops = rcu_dereference(offloads[guehdr->proto_ctype]);
+       if (WARN_ON(!ops || !ops->callbacks.gro_receive))
+               goto out_unlock;
 
        pp = ops->callbacks.gro_receive(head, skb);
 
@@ -238,7 +419,7 @@ static int gue_gro_complete(struct sk_buff *skb, int nhoff)
        u8 proto;
        int err = -ENOENT;
 
-       proto = guehdr->next_hdr;
+       proto = guehdr->proto_ctype;
 
        guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
 
@@ -489,6 +670,200 @@ static const struct genl_ops fou_nl_ops[] = {
        },
 };
 
+size_t fou_encap_hlen(struct ip_tunnel_encap *e)
+{
+       return sizeof(struct udphdr);
+}
+EXPORT_SYMBOL(fou_encap_hlen);
+
+size_t gue_encap_hlen(struct ip_tunnel_encap *e)
+{
+       size_t len;
+       bool need_priv = false;
+
+       len = sizeof(struct udphdr) + sizeof(struct guehdr);
+
+       if (e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) {
+               len += GUE_PLEN_REMCSUM;
+               need_priv = true;
+       }
+
+       len += need_priv ? GUE_LEN_PRIV : 0;
+
+       return len;
+}
+EXPORT_SYMBOL(gue_encap_hlen);
+
+static void fou_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e,
+                         struct flowi4 *fl4, u8 *protocol, __be16 sport)
+{
+       struct udphdr *uh;
+
+       skb_push(skb, sizeof(struct udphdr));
+       skb_reset_transport_header(skb);
+
+       uh = udp_hdr(skb);
+
+       uh->dest = e->dport;
+       uh->source = sport;
+       uh->len = htons(skb->len);
+       uh->check = 0;
+       udp_set_csum(!(e->flags & TUNNEL_ENCAP_FLAG_CSUM), skb,
+                    fl4->saddr, fl4->daddr, skb->len);
+
+       *protocol = IPPROTO_UDP;
+}
+
+int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
+                    u8 *protocol, struct flowi4 *fl4)
+{
+       bool csum = !!(e->flags & TUNNEL_ENCAP_FLAG_CSUM);
+       int type = csum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
+       __be16 sport;
+
+       skb = iptunnel_handle_offloads(skb, csum, type);
+
+       if (IS_ERR(skb))
+               return PTR_ERR(skb);
+
+       sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
+                                              skb, 0, 0, false);
+       fou_build_udp(skb, e, fl4, protocol, sport);
+
+       return 0;
+}
+EXPORT_SYMBOL(fou_build_header);
+
+int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
+                    u8 *protocol, struct flowi4 *fl4)
+{
+       bool csum = !!(e->flags & TUNNEL_ENCAP_FLAG_CSUM);
+       int type = csum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
+       struct guehdr *guehdr;
+       size_t hdrlen, optlen = 0;
+       __be16 sport;
+       void *data;
+       bool need_priv = false;
+
+       if ((e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) &&
+           skb->ip_summed == CHECKSUM_PARTIAL) {
+               csum = false;
+               optlen += GUE_PLEN_REMCSUM;
+               type |= SKB_GSO_TUNNEL_REMCSUM;
+               need_priv = true;
+       }
+
+       optlen += need_priv ? GUE_LEN_PRIV : 0;
+
+       skb = iptunnel_handle_offloads(skb, csum, type);
+
+       if (IS_ERR(skb))
+               return PTR_ERR(skb);
+
+       /* Get source port (based on flow hash) before skb_push */
+       sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
+                                              skb, 0, 0, false);
+
+       hdrlen = sizeof(struct guehdr) + optlen;
+
+       skb_push(skb, hdrlen);
+
+       guehdr = (struct guehdr *)skb->data;
+
+       guehdr->control = 0;
+       guehdr->version = 0;
+       guehdr->hlen = optlen >> 2;
+       guehdr->flags = 0;
+       guehdr->proto_ctype = *protocol;
+
+       data = &guehdr[1];
+
+       if (need_priv) {
+               __be32 *flags = data;
+
+               guehdr->flags |= GUE_FLAG_PRIV;
+               *flags = 0;
+               data += GUE_LEN_PRIV;
+
+               if (type & SKB_GSO_TUNNEL_REMCSUM) {
+                       u16 csum_start = skb_checksum_start_offset(skb);
+                       __be16 *pd = data;
+
+                       if (csum_start < hdrlen)
+                               return -EINVAL;
+
+                       csum_start -= hdrlen;
+                       pd[0] = htons(csum_start);
+                       pd[1] = htons(csum_start + skb->csum_offset);
+
+                       if (!skb_is_gso(skb)) {
+                               skb->ip_summed = CHECKSUM_NONE;
+                               skb->encapsulation = 0;
+                       }
+
+                       *flags |= GUE_PFLAG_REMCSUM;
+                       data += GUE_PLEN_REMCSUM;
+               }
+
+       }
+
+       fou_build_udp(skb, e, fl4, protocol, sport);
+
+       return 0;
+}
+EXPORT_SYMBOL(gue_build_header);
+
+#ifdef CONFIG_NET_FOU_IP_TUNNELS
+
+static const struct ip_tunnel_encap_ops __read_mostly fou_iptun_ops = {
+       .encap_hlen = fou_encap_hlen,
+       .build_header = fou_build_header,
+};
+
+static const struct ip_tunnel_encap_ops __read_mostly gue_iptun_ops = {
+       .encap_hlen = gue_encap_hlen,
+       .build_header = gue_build_header,
+};
+
+static int ip_tunnel_encap_add_fou_ops(void)
+{
+       int ret;
+
+       ret = ip_tunnel_encap_add_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
+       if (ret < 0) {
+               pr_err("can't add fou ops\n");
+               return ret;
+       }
+
+       ret = ip_tunnel_encap_add_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
+       if (ret < 0) {
+               pr_err("can't add gue ops\n");
+               ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
+               return ret;
+       }
+
+       return 0;
+}
+
+static void ip_tunnel_encap_del_fou_ops(void)
+{
+       ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
+       ip_tunnel_encap_del_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
+}
+
+#else
+
+static int ip_tunnel_encap_add_fou_ops(void)
+{
+       return 0;
+}
+
+static void ip_tunnel_encap_del_fou_ops(void)
+{
+}
+
+#endif
+
 static int __init fou_init(void)
 {
        int ret;
@@ -496,6 +871,14 @@ static int __init fou_init(void)
        ret = genl_register_family_with_ops(&fou_nl_family,
                                            fou_nl_ops);
 
+       if (ret < 0)
+               goto exit;
+
+       ret = ip_tunnel_encap_add_fou_ops();
+       if (ret < 0)
+               genl_unregister_family(&fou_nl_family);
+
+exit:
        return ret;
 }
 
@@ -503,6 +886,8 @@ static void __exit fou_fini(void)
 {
        struct fou *fou, *next;
 
+       ip_tunnel_encap_del_fou_ops();
+
        genl_unregister_family(&fou_nl_family);
 
        /* Close all the FOU sockets */