packet: parse tpacket header before skb alloc
[cascardo/linux.git] / net / packet / af_packet.c
index 992396a..89377bf 100644 (file)
@@ -1960,6 +1960,64 @@ static unsigned int run_filter(struct sk_buff *skb,
        return res;
 }
 
+static int __packet_rcv_vnet(const struct sk_buff *skb,
+                            struct virtio_net_hdr *vnet_hdr)
+{
+       *vnet_hdr = (const struct virtio_net_hdr) { 0 };
+
+       if (skb_is_gso(skb)) {
+               struct skb_shared_info *sinfo = skb_shinfo(skb);
+
+               /* This is a hint as to how much should be linear. */
+               vnet_hdr->hdr_len =
+                       __cpu_to_virtio16(vio_le(), skb_headlen(skb));
+               vnet_hdr->gso_size =
+                       __cpu_to_virtio16(vio_le(), sinfo->gso_size);
+
+               if (sinfo->gso_type & SKB_GSO_TCPV4)
+                       vnet_hdr->gso_type = VIRTIO_NET_HDR_GSO_TCPV4;
+               else if (sinfo->gso_type & SKB_GSO_TCPV6)
+                       vnet_hdr->gso_type = VIRTIO_NET_HDR_GSO_TCPV6;
+               else if (sinfo->gso_type & SKB_GSO_UDP)
+                       vnet_hdr->gso_type = VIRTIO_NET_HDR_GSO_UDP;
+               else if (sinfo->gso_type & SKB_GSO_FCOE)
+                       return -EINVAL;
+               else
+                       BUG();
+
+               if (sinfo->gso_type & SKB_GSO_TCP_ECN)
+                       vnet_hdr->gso_type |= VIRTIO_NET_HDR_GSO_ECN;
+       } else
+               vnet_hdr->gso_type = VIRTIO_NET_HDR_GSO_NONE;
+
+       if (skb->ip_summed == CHECKSUM_PARTIAL) {
+               vnet_hdr->flags = VIRTIO_NET_HDR_F_NEEDS_CSUM;
+               vnet_hdr->csum_start = __cpu_to_virtio16(vio_le(),
+                                 skb_checksum_start_offset(skb));
+               vnet_hdr->csum_offset = __cpu_to_virtio16(vio_le(),
+                                                skb->csum_offset);
+       } else if (skb->ip_summed == CHECKSUM_UNNECESSARY) {
+               vnet_hdr->flags = VIRTIO_NET_HDR_F_DATA_VALID;
+       } /* else everything is zero */
+
+       return 0;
+}
+
+static int packet_rcv_vnet(struct msghdr *msg, const struct sk_buff *skb,
+                          size_t *len)
+{
+       struct virtio_net_hdr vnet_hdr;
+
+       if (*len < sizeof(vnet_hdr))
+               return -EINVAL;
+       *len -= sizeof(vnet_hdr);
+
+       if (__packet_rcv_vnet(skb, &vnet_hdr))
+               return -EINVAL;
+
+       return memcpy_to_msg(msg, (void *)&vnet_hdr, sizeof(vnet_hdr));
+}
+
 /*
  * This function makes lazy skb cloning in hope that most of packets
  * are discarded by BPF.
@@ -2148,7 +2206,9 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                unsigned int maclen = skb_network_offset(skb);
                netoff = TPACKET_ALIGN(po->tp_hdrlen +
                                       (maclen < 16 ? 16 : maclen)) +
-                       po->tp_reserve;
+                                      po->tp_reserve;
+               if (po->has_vnet_hdr)
+                       netoff += sizeof(struct virtio_net_hdr);
                macoff = netoff - maclen;
        }
        if (po->tp_version <= TPACKET_V2) {
@@ -2185,7 +2245,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
        h.raw = packet_current_rx_frame(po, skb,
                                        TP_STATUS_KERNEL, (macoff+snaplen));
        if (!h.raw)
-               goto ring_is_full;
+               goto drop_n_account;
        if (po->tp_version <= TPACKET_V2) {
                packet_increment_rx_head(po, &po->rx_ring);
        /*
@@ -2204,6 +2264,14 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
        }
        spin_unlock(&sk->sk_receive_queue.lock);
 
+       if (po->has_vnet_hdr) {
+               if (__packet_rcv_vnet(skb, h.raw + macoff -
+                                          sizeof(struct virtio_net_hdr))) {
+                       spin_lock(&sk->sk_receive_queue.lock);
+                       goto drop_n_account;
+               }
+       }
+
        skb_copy_bits(skb, 0, h.raw + macoff, snaplen);
 
        if (!(ts_status = tpacket_get_timestamp(skb, &ts, po->tp_tstamp)))
@@ -2299,7 +2367,7 @@ drop:
        kfree_skb(skb);
        return 0;
 
-ring_is_full:
+drop_n_account:
        po->stats.stats1.tp_drops++;
        spin_unlock(&sk->sk_receive_queue.lock);
 
@@ -2347,15 +2415,92 @@ static void tpacket_set_protocol(const struct net_device *dev,
        }
 }
 
+static int __packet_snd_vnet_parse(struct virtio_net_hdr *vnet_hdr, size_t len)
+{
+       unsigned short gso_type = 0;
+
+       if ((vnet_hdr->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
+           (__virtio16_to_cpu(vio_le(), vnet_hdr->csum_start) +
+            __virtio16_to_cpu(vio_le(), vnet_hdr->csum_offset) + 2 >
+             __virtio16_to_cpu(vio_le(), vnet_hdr->hdr_len)))
+               vnet_hdr->hdr_len = __cpu_to_virtio16(vio_le(),
+                        __virtio16_to_cpu(vio_le(), vnet_hdr->csum_start) +
+                       __virtio16_to_cpu(vio_le(), vnet_hdr->csum_offset) + 2);
+
+       if (__virtio16_to_cpu(vio_le(), vnet_hdr->hdr_len) > len)
+               return -EINVAL;
+
+       if (vnet_hdr->gso_type != VIRTIO_NET_HDR_GSO_NONE) {
+               switch (vnet_hdr->gso_type & ~VIRTIO_NET_HDR_GSO_ECN) {
+               case VIRTIO_NET_HDR_GSO_TCPV4:
+                       gso_type = SKB_GSO_TCPV4;
+                       break;
+               case VIRTIO_NET_HDR_GSO_TCPV6:
+                       gso_type = SKB_GSO_TCPV6;
+                       break;
+               case VIRTIO_NET_HDR_GSO_UDP:
+                       gso_type = SKB_GSO_UDP;
+                       break;
+               default:
+                       return -EINVAL;
+               }
+
+               if (vnet_hdr->gso_type & VIRTIO_NET_HDR_GSO_ECN)
+                       gso_type |= SKB_GSO_TCP_ECN;
+
+               if (vnet_hdr->gso_size == 0)
+                       return -EINVAL;
+       }
+
+       vnet_hdr->gso_type = gso_type;  /* changes type, temporary storage */
+       return 0;
+}
+
+static int packet_snd_vnet_parse(struct msghdr *msg, size_t *len,
+                                struct virtio_net_hdr *vnet_hdr)
+{
+       int n;
+
+       if (*len < sizeof(*vnet_hdr))
+               return -EINVAL;
+       *len -= sizeof(*vnet_hdr);
+
+       n = copy_from_iter(vnet_hdr, sizeof(*vnet_hdr), &msg->msg_iter);
+       if (n != sizeof(*vnet_hdr))
+               return -EFAULT;
+
+       return __packet_snd_vnet_parse(vnet_hdr, *len);
+}
+
+static int packet_snd_vnet_gso(struct sk_buff *skb,
+                              struct virtio_net_hdr *vnet_hdr)
+{
+       if (vnet_hdr->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) {
+               u16 s = __virtio16_to_cpu(vio_le(), vnet_hdr->csum_start);
+               u16 o = __virtio16_to_cpu(vio_le(), vnet_hdr->csum_offset);
+
+               if (!skb_partial_csum_set(skb, s, o))
+                       return -EINVAL;
+       }
+
+       skb_shinfo(skb)->gso_size =
+               __virtio16_to_cpu(vio_le(), vnet_hdr->gso_size);
+       skb_shinfo(skb)->gso_type = vnet_hdr->gso_type;
+
+       /* Header must be checked, and gso_segs computed. */
+       skb_shinfo(skb)->gso_type |= SKB_GSO_DODGY;
+       skb_shinfo(skb)->gso_segs = 0;
+       return 0;
+}
+
 static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
-               void *frame, struct net_device *dev, int size_max,
+               void *frame, struct net_device *dev, void *data, int tp_len,
                __be16 proto, unsigned char *addr, int hlen)
 {
        union tpacket_uhdr ph;
-       int to_write, offset, len, tp_len, nr_frags, len_max;
+       int to_write, offset, len, nr_frags, len_max;
        struct socket *sock = po->sk.sk_socket;
        struct page *page;
-       void *data;
        int err;
 
        ph.raw = frame;
@@ -2367,51 +2512,9 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
        sock_tx_timestamp(&po->sk, &skb_shinfo(skb)->tx_flags);
        skb_shinfo(skb)->destructor_arg = ph.raw;
 
-       switch (po->tp_version) {
-       case TPACKET_V2:
-               tp_len = ph.h2->tp_len;
-               break;
-       default:
-               tp_len = ph.h1->tp_len;
-               break;
-       }
-       if (unlikely(tp_len > size_max)) {
-               pr_err("packet size is too long (%d > %d)\n", tp_len, size_max);
-               return -EMSGSIZE;
-       }
-
        skb_reserve(skb, hlen);
        skb_reset_network_header(skb);
 
-       if (unlikely(po->tp_tx_has_off)) {
-               int off_min, off_max, off;
-               off_min = po->tp_hdrlen - sizeof(struct sockaddr_ll);
-               off_max = po->tx_ring.frame_size - tp_len;
-               if (sock->type == SOCK_DGRAM) {
-                       switch (po->tp_version) {
-                       case TPACKET_V2:
-                               off = ph.h2->tp_net;
-                               break;
-                       default:
-                               off = ph.h1->tp_net;
-                               break;
-                       }
-               } else {
-                       switch (po->tp_version) {
-                       case TPACKET_V2:
-                               off = ph.h2->tp_mac;
-                               break;
-                       default:
-                               off = ph.h1->tp_mac;
-                               break;
-                       }
-               }
-               if (unlikely((off < off_min) || (off_max < off)))
-                       return -EINVAL;
-               data = ph.raw + off;
-       } else {
-               data = ph.raw + po->tp_hdrlen - sizeof(struct sockaddr_ll);
-       }
        to_write = tp_len;
 
        if (sock->type == SOCK_DGRAM) {
@@ -2469,6 +2572,61 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
        return tp_len;
 }
 
+static int tpacket_parse_header(struct packet_sock *po, void *frame,
+                               int size_max, void **data)
+{
+       union tpacket_uhdr ph;
+       int tp_len, off;
+
+       ph.raw = frame;
+
+       switch (po->tp_version) {
+       case TPACKET_V2:
+               tp_len = ph.h2->tp_len;
+               break;
+       default:
+               tp_len = ph.h1->tp_len;
+               break;
+       }
+       if (unlikely(tp_len > size_max)) {
+               pr_err("packet size is too long (%d > %d)\n", tp_len, size_max);
+               return -EMSGSIZE;
+       }
+
+       if (unlikely(po->tp_tx_has_off)) {
+               int off_min, off_max;
+
+               off_min = po->tp_hdrlen - sizeof(struct sockaddr_ll);
+               off_max = po->tx_ring.frame_size - tp_len;
+               if (po->sk.sk_type == SOCK_DGRAM) {
+                       switch (po->tp_version) {
+                       case TPACKET_V2:
+                               off = ph.h2->tp_net;
+                               break;
+                       default:
+                               off = ph.h1->tp_net;
+                               break;
+                       }
+               } else {
+                       switch (po->tp_version) {
+                       case TPACKET_V2:
+                               off = ph.h2->tp_mac;
+                               break;
+                       default:
+                               off = ph.h1->tp_mac;
+                               break;
+                       }
+               }
+               if (unlikely((off < off_min) || (off_max < off)))
+                       return -EINVAL;
+       } else {
+               off = po->tp_hdrlen - sizeof(struct sockaddr_ll);
+       }
+
+       *data = frame + off;
+       return tp_len;
+}
+
 static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
 {
        struct sk_buff *skb;
@@ -2480,6 +2638,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
        bool need_wait = !(msg->msg_flags & MSG_DONTWAIT);
        int tp_len, size_max;
        unsigned char *addr;
+       void *data;
        int len_sum = 0;
        int status = TP_STATUS_AVAILABLE;
        int hlen, tlen;
@@ -2527,6 +2686,11 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
                        continue;
                }
 
+               skb = NULL;
+               tp_len = tpacket_parse_header(po, ph, size_max, &data);
+               if (tp_len < 0)
+                       goto tpacket_error;
+
                status = TP_STATUS_SEND_REQUEST;
                hlen = LL_RESERVED_SPACE(dev);
                tlen = dev->needed_tailroom;
@@ -2540,7 +2704,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
                                err = len_sum;
                        goto out_status;
                }
-               tp_len = tpacket_fill_skb(po, skb, ph, dev, size_max, proto,
+               tp_len = tpacket_fill_skb(po, skb, ph, dev, data, tp_len, proto,
                                          addr, hlen);
                if (likely(tp_len >= 0) &&
                    tp_len > dev->mtu + reserve &&
@@ -2548,6 +2712,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
                        tp_len = -EMSGSIZE;
 
                if (unlikely(tp_len < 0)) {
+tpacket_error:
                        if (po->tp_loss) {
                                __packet_set_status(po, ph,
                                                TP_STATUS_AVAILABLE);
@@ -2643,12 +2808,9 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
        struct sockcm_cookie sockc;
        struct virtio_net_hdr vnet_hdr = { 0 };
        int offset = 0;
-       int vnet_hdr_len;
        struct packet_sock *po = pkt_sk(sk);
-       unsigned short gso_type = 0;
        int hlen, tlen;
        int extra_len = 0;
-       ssize_t n;
 
        /*
         *      Get and verify the address.
@@ -2686,53 +2848,9 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
        if (sock->type == SOCK_RAW)
                reserve = dev->hard_header_len;
        if (po->has_vnet_hdr) {
-               vnet_hdr_len = sizeof(vnet_hdr);
-
-               err = -EINVAL;
-               if (len < vnet_hdr_len)
-                       goto out_unlock;
-
-               len -= vnet_hdr_len;
-
-               err = -EFAULT;
-               n = copy_from_iter(&vnet_hdr, vnet_hdr_len, &msg->msg_iter);
-               if (n != vnet_hdr_len)
-                       goto out_unlock;
-
-               if ((vnet_hdr.flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
-                   (__virtio16_to_cpu(vio_le(), vnet_hdr.csum_start) +
-                    __virtio16_to_cpu(vio_le(), vnet_hdr.csum_offset) + 2 >
-                     __virtio16_to_cpu(vio_le(), vnet_hdr.hdr_len)))
-                       vnet_hdr.hdr_len = __cpu_to_virtio16(vio_le(),
-                                __virtio16_to_cpu(vio_le(), vnet_hdr.csum_start) +
-                               __virtio16_to_cpu(vio_le(), vnet_hdr.csum_offset) + 2);
-
-               err = -EINVAL;
-               if (__virtio16_to_cpu(vio_le(), vnet_hdr.hdr_len) > len)
+               err = packet_snd_vnet_parse(msg, &len, &vnet_hdr);
+               if (err)
                        goto out_unlock;
-
-               if (vnet_hdr.gso_type != VIRTIO_NET_HDR_GSO_NONE) {
-                       switch (vnet_hdr.gso_type & ~VIRTIO_NET_HDR_GSO_ECN) {
-                       case VIRTIO_NET_HDR_GSO_TCPV4:
-                               gso_type = SKB_GSO_TCPV4;
-                               break;
-                       case VIRTIO_NET_HDR_GSO_TCPV6:
-                               gso_type = SKB_GSO_TCPV6;
-                               break;
-                       case VIRTIO_NET_HDR_GSO_UDP:
-                               gso_type = SKB_GSO_UDP;
-                               break;
-                       default:
-                               goto out_unlock;
-                       }
-
-                       if (vnet_hdr.gso_type & VIRTIO_NET_HDR_GSO_ECN)
-                               gso_type |= SKB_GSO_TCP_ECN;
-
-                       if (vnet_hdr.gso_size == 0)
-                               goto out_unlock;
-
-               }
        }
 
        if (unlikely(sock_flag(sk, SOCK_NOFCS))) {
@@ -2744,7 +2862,8 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
        }
 
        err = -EMSGSIZE;
-       if (!gso_type && (len > dev->mtu + reserve + VLAN_HLEN + extra_len))
+       if (!vnet_hdr.gso_type &&
+           (len > dev->mtu + reserve + VLAN_HLEN + extra_len))
                goto out_unlock;
 
        err = -ENOBUFS;
@@ -2775,7 +2894,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
 
        sock_tx_timestamp(sk, &skb_shinfo(skb)->tx_flags);
 
-       if (!gso_type && (len > dev->mtu + reserve + extra_len) &&
+       if (!vnet_hdr.gso_type && (len > dev->mtu + reserve + extra_len) &&
            !packet_extra_vlan_len_allowed(dev, skb)) {
                err = -EMSGSIZE;
                goto out_free;
@@ -2789,24 +2908,10 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
        packet_pick_tx_queue(dev, skb);
 
        if (po->has_vnet_hdr) {
-               if (vnet_hdr.flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) {
-                       u16 s = __virtio16_to_cpu(vio_le(), vnet_hdr.csum_start);
-                       u16 o = __virtio16_to_cpu(vio_le(), vnet_hdr.csum_offset);
-                       if (!skb_partial_csum_set(skb, s, o)) {
-                               err = -EINVAL;
-                               goto out_free;
-                       }
-               }
-
-               skb_shinfo(skb)->gso_size =
-                       __virtio16_to_cpu(vio_le(), vnet_hdr.gso_size);
-               skb_shinfo(skb)->gso_type = gso_type;
-
-               /* Header must be checked, and gso_segs computed. */
-               skb_shinfo(skb)->gso_type |= SKB_GSO_DODGY;
-               skb_shinfo(skb)->gso_segs = 0;
-
-               len += vnet_hdr_len;
+               err = packet_snd_vnet_gso(skb, &vnet_hdr);
+               if (err)
+                       goto out_free;
+               len += sizeof(vnet_hdr);
        }
 
        skb_probe_transport_header(skb, reserve);
@@ -3177,51 +3282,10 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
                packet_rcv_has_room(pkt_sk(sk), NULL);
 
        if (pkt_sk(sk)->has_vnet_hdr) {
-               struct virtio_net_hdr vnet_hdr = { 0 };
-
-               err = -EINVAL;
-               vnet_hdr_len = sizeof(vnet_hdr);
-               if (len < vnet_hdr_len)
-                       goto out_free;
-
-               len -= vnet_hdr_len;
-
-               if (skb_is_gso(skb)) {
-                       struct skb_shared_info *sinfo = skb_shinfo(skb);
-
-                       /* This is a hint as to how much should be linear. */
-                       vnet_hdr.hdr_len =
-                               __cpu_to_virtio16(vio_le(), skb_headlen(skb));
-                       vnet_hdr.gso_size =
-                               __cpu_to_virtio16(vio_le(), sinfo->gso_size);
-                       if (sinfo->gso_type & SKB_GSO_TCPV4)
-                               vnet_hdr.gso_type = VIRTIO_NET_HDR_GSO_TCPV4;
-                       else if (sinfo->gso_type & SKB_GSO_TCPV6)
-                               vnet_hdr.gso_type = VIRTIO_NET_HDR_GSO_TCPV6;
-                       else if (sinfo->gso_type & SKB_GSO_UDP)
-                               vnet_hdr.gso_type = VIRTIO_NET_HDR_GSO_UDP;
-                       else if (sinfo->gso_type & SKB_GSO_FCOE)
-                               goto out_free;
-                       else
-                               BUG();
-                       if (sinfo->gso_type & SKB_GSO_TCP_ECN)
-                               vnet_hdr.gso_type |= VIRTIO_NET_HDR_GSO_ECN;
-               } else
-                       vnet_hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE;
-
-               if (skb->ip_summed == CHECKSUM_PARTIAL) {
-                       vnet_hdr.flags = VIRTIO_NET_HDR_F_NEEDS_CSUM;
-                       vnet_hdr.csum_start = __cpu_to_virtio16(vio_le(),
-                                         skb_checksum_start_offset(skb));
-                       vnet_hdr.csum_offset = __cpu_to_virtio16(vio_le(),
-                                                        skb->csum_offset);
-               } else if (skb->ip_summed == CHECKSUM_UNNECESSARY) {
-                       vnet_hdr.flags = VIRTIO_NET_HDR_F_DATA_VALID;
-               } /* else everything is zero */
-
-               err = memcpy_to_msg(msg, (void *)&vnet_hdr, vnet_hdr_len);
-               if (err < 0)
+               err = packet_rcv_vnet(msg, skb, &len);
+               if (err)
                        goto out_free;
+               vnet_hdr_len = sizeof(struct virtio_net_hdr);
        }
 
        /* You lose any data beyond the buffer you gave. If it worries
@@ -3552,7 +3616,8 @@ packet_setsockopt(struct socket *sock, int level, int optname, char __user *optv
                }
                if (optlen < len)
                        return -EINVAL;
-               if (pkt_sk(sk)->has_vnet_hdr)
+               if (pkt_sk(sk)->has_vnet_hdr &&
+                   optname == PACKET_TX_RING)
                        return -EINVAL;
                if (copy_from_user(&req_u.req, optval, len))
                        return -EFAULT;