ip: fix error queue empty skb handling
[cascardo/linux.git] / net / ipv4 / ip_sockglue.c
index 6b85adb..5cd9927 100644 (file)
@@ -37,6 +37,7 @@
 #include <net/route.h>
 #include <net/xfrm.h>
 #include <net/compat.h>
+#include <net/checksum.h>
 #if IS_ENABLED(CONFIG_IPV6)
 #include <net/transp_v6.h>
 #endif
 #include <linux/errqueue.h>
 #include <asm/uaccess.h>
 
-#define IP_CMSG_PKTINFO                1
-#define IP_CMSG_TTL            2
-#define IP_CMSG_TOS            4
-#define IP_CMSG_RECVOPTS       8
-#define IP_CMSG_RETOPTS                16
-#define IP_CMSG_PASSSEC                32
-#define IP_CMSG_ORIGDSTADDR     64
-
 /*
  *     SOL_IP control messages.
  */
@@ -104,6 +97,20 @@ static void ip_cmsg_recv_retopts(struct msghdr *msg, struct sk_buff *skb)
        put_cmsg(msg, SOL_IP, IP_RETOPTS, opt->optlen, opt->__data);
 }
 
+static void ip_cmsg_recv_checksum(struct msghdr *msg, struct sk_buff *skb,
+                                 int offset)
+{
+       __wsum csum = skb->csum;
+
+       if (skb->ip_summed != CHECKSUM_COMPLETE)
+               return;
+
+       if (offset != 0)
+               csum = csum_sub(csum, csum_partial(skb->data, offset, 0));
+
+       put_cmsg(msg, SOL_IP, IP_CHECKSUM, sizeof(__wsum), &csum);
+}
+
 static void ip_cmsg_recv_security(struct msghdr *msg, struct sk_buff *skb)
 {
        char *secdata;
@@ -144,47 +151,73 @@ static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb)
        put_cmsg(msg, SOL_IP, IP_ORIGDSTADDR, sizeof(sin), &sin);
 }
 
-void ip_cmsg_recv(struct msghdr *msg, struct sk_buff *skb)
+void ip_cmsg_recv_offset(struct msghdr *msg, struct sk_buff *skb,
+                        int offset)
 {
        struct inet_sock *inet = inet_sk(skb->sk);
        unsigned int flags = inet->cmsg_flags;
 
        /* Ordered by supposed usage frequency */
-       if (flags & 1)
+       if (flags & IP_CMSG_PKTINFO) {
                ip_cmsg_recv_pktinfo(msg, skb);
-       if ((flags >>= 1) == 0)
-               return;
 
-       if (flags & 1)
+               flags &= ~IP_CMSG_PKTINFO;
+               if (!flags)
+                       return;
+       }
+
+       if (flags & IP_CMSG_TTL) {
                ip_cmsg_recv_ttl(msg, skb);
-       if ((flags >>= 1) == 0)
-               return;
 
-       if (flags & 1)
+               flags &= ~IP_CMSG_TTL;
+               if (!flags)
+                       return;
+       }
+
+       if (flags & IP_CMSG_TOS) {
                ip_cmsg_recv_tos(msg, skb);
-       if ((flags >>= 1) == 0)
-               return;
 
-       if (flags & 1)
+               flags &= ~IP_CMSG_TOS;
+               if (!flags)
+                       return;
+       }
+
+       if (flags & IP_CMSG_RECVOPTS) {
                ip_cmsg_recv_opts(msg, skb);
-       if ((flags >>= 1) == 0)
-               return;
 
-       if (flags & 1)
+               flags &= ~IP_CMSG_RECVOPTS;
+               if (!flags)
+                       return;
+       }
+
+       if (flags & IP_CMSG_RETOPTS) {
                ip_cmsg_recv_retopts(msg, skb);
-       if ((flags >>= 1) == 0)
-               return;
 
-       if (flags & 1)
+               flags &= ~IP_CMSG_RETOPTS;
+               if (!flags)
+                       return;
+       }
+
+       if (flags & IP_CMSG_PASSSEC) {
                ip_cmsg_recv_security(msg, skb);
 
-       if ((flags >>= 1) == 0)
-               return;
-       if (flags & 1)
+               flags &= ~IP_CMSG_PASSSEC;
+               if (!flags)
+                       return;
+       }
+
+       if (flags & IP_CMSG_ORIGDSTADDR) {
                ip_cmsg_recv_dstaddr(msg, skb);
 
+               flags &= ~IP_CMSG_ORIGDSTADDR;
+               if (!flags)
+                       return;
+       }
+
+       if (flags & IP_CMSG_CHECKSUM)
+               ip_cmsg_recv_checksum(msg, skb, offset);
 }
-EXPORT_SYMBOL(ip_cmsg_recv);
+EXPORT_SYMBOL(ip_cmsg_recv_offset);
 
 int ip_cmsg_send(struct net *net, struct msghdr *msg, struct ipcm_cookie *ipc,
                 bool allow_ipv6)
@@ -399,17 +432,32 @@ void ip_local_error(struct sock *sk, int err, __be32 daddr, __be16 port, u32 inf
                kfree_skb(skb);
 }
 
-static bool ipv4_pktinfo_prepare_errqueue(const struct sock *sk,
-                                         const struct sk_buff *skb,
-                                         int ee_origin)
+/* IPv4 supports cmsg on all imcp errors and some timestamps
+ *
+ * Timestamp code paths do not initialize the fields expected by cmsg:
+ * the PKTINFO fields in skb->cb[]. Fill those in here.
+ */
+static bool ipv4_datagram_support_cmsg(const struct sock *sk,
+                                      struct sk_buff *skb,
+                                      int ee_origin)
 {
-       struct in_pktinfo *info = PKTINFO_SKB_CB(skb);
+       struct in_pktinfo *info;
+
+       if (ee_origin == SO_EE_ORIGIN_ICMP)
+               return true;
+
+       if (ee_origin == SO_EE_ORIGIN_LOCAL)
+               return false;
 
-       if ((ee_origin != SO_EE_ORIGIN_TIMESTAMPING) ||
-           (!(sk->sk_tsflags & SOF_TIMESTAMPING_OPT_CMSG)) ||
+       /* Support IP_PKTINFO on tstamp packets if requested, to correlate
+        * timestamp with egress dev. Not possible for packets without dev
+        * or without payload (SOF_TIMESTAMPING_OPT_TSONLY).
+        */
+       if ((!(sk->sk_tsflags & SOF_TIMESTAMPING_OPT_CMSG)) ||
            (!skb->dev))
                return false;
 
+       info = PKTINFO_SKB_CB(skb);
        info->ipi_spec_dst.s_addr = ip_hdr(skb)->saddr;
        info->ipi_ifindex = skb->dev->ifindex;
        return true;
@@ -450,7 +498,7 @@ int ip_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
 
        serr = SKB_EXT_ERR(skb);
 
-       if (sin) {
+       if (sin && serr->port) {
                sin->sin_family = AF_INET;
                sin->sin_addr.s_addr = *(__be32 *)(skb_network_header(skb) +
                                                   serr->addr_offset);
@@ -463,8 +511,7 @@ int ip_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
        sin = &errhdr.offender;
        memset(sin, 0, sizeof(*sin));
 
-       if (serr->ee.ee_origin == SO_EE_ORIGIN_ICMP ||
-           ipv4_pktinfo_prepare_errqueue(sk, skb, serr->ee.ee_origin)) {
+       if (ipv4_datagram_support_cmsg(sk, skb, serr->ee.ee_origin)) {
                sin->sin_family = AF_INET;
                sin->sin_addr.s_addr = ip_hdr(skb)->saddr;
                if (inet_sk(sk)->cmsg_flags)
@@ -518,6 +565,7 @@ static int do_ip_setsockopt(struct sock *sk, int level,
        case IP_MULTICAST_ALL:
        case IP_MULTICAST_LOOP:
        case IP_RECVORIGDSTADDR:
+       case IP_CHECKSUM:
                if (optlen >= sizeof(int)) {
                        if (get_user(val, (int __user *) optval))
                                return -EFAULT;
@@ -615,6 +663,19 @@ static int do_ip_setsockopt(struct sock *sk, int level,
                else
                        inet->cmsg_flags &= ~IP_CMSG_ORIGDSTADDR;
                break;
+       case IP_CHECKSUM:
+               if (val) {
+                       if (!(inet->cmsg_flags & IP_CMSG_CHECKSUM)) {
+                               inet_inc_convert_csum(sk);
+                               inet->cmsg_flags |= IP_CMSG_CHECKSUM;
+                       }
+               } else {
+                       if (inet->cmsg_flags & IP_CMSG_CHECKSUM) {
+                               inet_dec_convert_csum(sk);
+                               inet->cmsg_flags &= ~IP_CMSG_CHECKSUM;
+                       }
+               }
+               break;
        case IP_TOS:    /* This sets both TOS and Precedence */
                if (sk->sk_type == SOCK_STREAM) {
                        val &= ~INET_ECN_MASK;
@@ -1218,6 +1279,9 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
        case IP_RECVORIGDSTADDR:
                val = (inet->cmsg_flags & IP_CMSG_ORIGDSTADDR) != 0;
                break;
+       case IP_CHECKSUM:
+               val = (inet->cmsg_flags & IP_CMSG_CHECKSUM) != 0;
+               break;
        case IP_TOS:
                val = inet->tos;
                break;