packet: rollover only to socket with headroom
[cascardo/linux.git] / net / packet / af_packet.c
index f8db706..ffa6720 100644 (file)
@@ -216,10 +216,16 @@ static void prb_fill_vlan_info(struct tpacket_kbdq_core *,
 static void packet_flush_mclist(struct sock *sk);
 
 struct packet_skb_cb {
-       unsigned int origlen;
        union {
                struct sockaddr_pkt pkt;
-               struct sockaddr_ll ll;
+               union {
+                       /* Trick: alias skb original length with
+                        * ll.sll_family and ll.protocol in order
+                        * to save room.
+                        */
+                       unsigned int origlen;
+                       struct sockaddr_ll ll;
+               };
        } sa;
 };
 
@@ -1228,27 +1234,68 @@ static void packet_free_pending(struct packet_sock *po)
        free_percpu(po->tx_ring.pending_refcnt);
 }
 
-static bool packet_rcv_has_room(struct packet_sock *po, struct sk_buff *skb)
+#define ROOM_POW_OFF   2
+#define ROOM_NONE      0x0
+#define ROOM_LOW       0x1
+#define ROOM_NORMAL    0x2
+
+static bool __tpacket_has_room(struct packet_sock *po, int pow_off)
 {
-       struct sock *sk = &po->sk;
-       bool has_room;
+       int idx, len;
+
+       len = po->rx_ring.frame_max + 1;
+       idx = po->rx_ring.head;
+       if (pow_off)
+               idx += len >> pow_off;
+       if (idx >= len)
+               idx -= len;
+       return packet_lookup_frame(po, &po->rx_ring, idx, TP_STATUS_KERNEL);
+}
 
-       if (po->prot_hook.func != tpacket_rcv)
-               return (atomic_read(&sk->sk_rmem_alloc) + skb->truesize)
-                       <= sk->sk_rcvbuf;
+static bool __tpacket_v3_has_room(struct packet_sock *po, int pow_off)
+{
+       int idx, len;
+
+       len = po->rx_ring.prb_bdqc.knum_blocks;
+       idx = po->rx_ring.prb_bdqc.kactive_blk_num;
+       if (pow_off)
+               idx += len >> pow_off;
+       if (idx >= len)
+               idx -= len;
+       return prb_lookup_block(po, &po->rx_ring, idx, TP_STATUS_KERNEL);
+}
+
+static int packet_rcv_has_room(struct packet_sock *po, struct sk_buff *skb)
+{
+       struct sock *sk = &po->sk;
+       int ret = ROOM_NONE;
+
+       if (po->prot_hook.func != tpacket_rcv) {
+               int avail = sk->sk_rcvbuf - atomic_read(&sk->sk_rmem_alloc)
+                                         - skb->truesize;
+               if (avail > (sk->sk_rcvbuf >> ROOM_POW_OFF))
+                       return ROOM_NORMAL;
+               else if (avail > 0)
+                       return ROOM_LOW;
+               else
+                       return ROOM_NONE;
+       }
 
        spin_lock(&sk->sk_receive_queue.lock);
-       if (po->tp_version == TPACKET_V3)
-               has_room = prb_lookup_block(po, &po->rx_ring,
-                                           po->rx_ring.prb_bdqc.kactive_blk_num,
-                                           TP_STATUS_KERNEL);
-       else
-               has_room = packet_lookup_frame(po, &po->rx_ring,
-                                              po->rx_ring.head,
-                                              TP_STATUS_KERNEL);
+       if (po->tp_version == TPACKET_V3) {
+               if (__tpacket_v3_has_room(po, ROOM_POW_OFF))
+                       ret = ROOM_NORMAL;
+               else if (__tpacket_v3_has_room(po, 0))
+                       ret = ROOM_LOW;
+       } else {
+               if (__tpacket_has_room(po, ROOM_POW_OFF))
+                       ret = ROOM_NORMAL;
+               else if (__tpacket_has_room(po, 0))
+                       ret = ROOM_LOW;
+       }
        spin_unlock(&sk->sk_receive_queue.lock);
 
-       return has_room;
+       return ret;
 }
 
 static void packet_sock_destruct(struct sock *sk)
@@ -1312,18 +1359,25 @@ static unsigned int fanout_demux_rnd(struct packet_fanout *f,
 
 static unsigned int fanout_demux_rollover(struct packet_fanout *f,
                                          struct sk_buff *skb,
-                                         unsigned int idx, unsigned int skip,
+                                         unsigned int idx, bool try_self,
                                          unsigned int num)
 {
+       struct packet_sock *po;
        unsigned int i, j;
 
-       i = j = min_t(int, f->next[idx], num - 1);
+       po = pkt_sk(f->arr[idx]);
+       if (try_self && packet_rcv_has_room(po, skb) != ROOM_NONE)
+               return idx;
+
+       i = j = min_t(int, po->rollover->sock, num - 1);
        do {
-               if (i != skip && packet_rcv_has_room(pkt_sk(f->arr[i]), skb)) {
+               if (i != idx &&
+                   packet_rcv_has_room(pkt_sk(f->arr[i]), skb) == ROOM_NORMAL) {
                        if (i != j)
-                               f->next[idx] = i;
+                               po->rollover->sock = i;
                        return i;
                }
+
                if (++i == num)
                        i = 0;
        } while (i != j);
@@ -1380,17 +1434,14 @@ static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev,
                idx = fanout_demux_qm(f, skb, num);
                break;
        case PACKET_FANOUT_ROLLOVER:
-               idx = fanout_demux_rollover(f, skb, 0, (unsigned int) -1, num);
+               idx = fanout_demux_rollover(f, skb, 0, false, num);
                break;
        }
 
-       po = pkt_sk(f->arr[idx]);
-       if (fanout_has_flag(f, PACKET_FANOUT_FLAG_ROLLOVER) &&
-           unlikely(!packet_rcv_has_room(po, skb))) {
-               idx = fanout_demux_rollover(f, skb, idx, idx, num);
-               po = pkt_sk(f->arr[idx]);
-       }
+       if (fanout_has_flag(f, PACKET_FANOUT_FLAG_ROLLOVER))
+               idx = fanout_demux_rollover(f, skb, idx, true, num);
 
+       po = pkt_sk(f->arr[idx]);
        return po->prot_hook.func(skb, dev, &po->prot_hook, orig_dev);
 }
 
@@ -1461,6 +1512,12 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
        if (po->fanout)
                return -EALREADY;
 
+       if (type_flags & PACKET_FANOUT_FLAG_ROLLOVER) {
+               po->rollover = kzalloc(sizeof(*po->rollover), GFP_KERNEL);
+               if (!po->rollover)
+                       return -ENOMEM;
+       }
+
        mutex_lock(&fanout_mutex);
        match = NULL;
        list_for_each_entry(f, &fanout_list, list) {
@@ -1509,6 +1566,10 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
        }
 out:
        mutex_unlock(&fanout_mutex);
+       if (err) {
+               kfree(po->rollover);
+               po->rollover = NULL;
+       }
        return err;
 }
 
@@ -1530,6 +1591,8 @@ static void fanout_release(struct sock *sk)
                kfree(f);
        }
        mutex_unlock(&fanout_mutex);
+
+       kfree(po->rollover);
 }
 
 static const struct proto_ops packet_ops;
@@ -1608,8 +1671,8 @@ oom:
  *     protocol layers and you must therefore supply it with a complete frame
  */
 
-static int packet_sendmsg_spkt(struct kiocb *iocb, struct socket *sock,
-                              struct msghdr *msg, size_t len)
+static int packet_sendmsg_spkt(struct socket *sock, struct msghdr *msg,
+                              size_t len)
 {
        struct sock *sk = sock->sk;
        DECLARE_SOCKADDR(struct sockaddr_pkt *, saddr, msg->msg_name);
@@ -1818,13 +1881,10 @@ static int packet_rcv(struct sk_buff *skb, struct net_device *dev,
                skb = nskb;
        }
 
-       BUILD_BUG_ON(sizeof(*PACKET_SKB_CB(skb)) + MAX_ADDR_LEN - 8 >
-                    sizeof(skb->cb));
+       sock_skb_cb_check_size(sizeof(*PACKET_SKB_CB(skb)) + MAX_ADDR_LEN - 8);
 
        sll = &PACKET_SKB_CB(skb)->sa.ll;
-       sll->sll_family = AF_PACKET;
        sll->sll_hatype = dev->type;
-       sll->sll_protocol = skb->protocol;
        sll->sll_pkttype = skb->pkt_type;
        if (unlikely(po->origdev))
                sll->sll_ifindex = orig_dev->ifindex;
@@ -1833,7 +1893,10 @@ static int packet_rcv(struct sk_buff *skb, struct net_device *dev,
 
        sll->sll_halen = dev_parse_header(skb, sll->sll_addr);
 
-       PACKET_SKB_CB(skb)->origlen = skb->len;
+       /* sll->sll_family and sll->sll_protocol are set in packet_recvmsg().
+        * Use their space for storing the original skb length.
+        */
+       PACKET_SKB_CB(skb)->sa.origlen = skb->len;
 
        if (pskb_trim(skb, snaplen))
                goto drop_n_acct;
@@ -1847,7 +1910,7 @@ static int packet_rcv(struct sk_buff *skb, struct net_device *dev,
 
        spin_lock(&sk->sk_receive_queue.lock);
        po->stats.stats1.tp_packets++;
-       skb->dropcount = atomic_read(&sk->sk_drops);
+       sock_skb_set_dropcount(sk, skb);
        __skb_queue_tail(&sk->sk_receive_queue, skb);
        spin_unlock(&sk->sk_receive_queue.lock);
        sk->sk_data_ready(sk);
@@ -1910,14 +1973,19 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                }
        }
 
-       if (skb->ip_summed == CHECKSUM_PARTIAL)
-               status |= TP_STATUS_CSUMNOTREADY;
-
        snaplen = skb->len;
 
        res = run_filter(skb, sk, snaplen);
        if (!res)
                goto drop_n_restore;
+
+       if (skb->ip_summed == CHECKSUM_PARTIAL)
+               status |= TP_STATUS_CSUMNOTREADY;
+       else if (skb->pkt_type != PACKET_OUTGOING &&
+                (skb->ip_summed == CHECKSUM_COMPLETE ||
+                 skb_csum_unnecessary(skb)))
+               status |= TP_STATUS_CSUM_VALID;
+
        if (snaplen > res)
                snaplen = res;
 
@@ -2300,11 +2368,14 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
                tlen = dev->needed_tailroom;
                skb = sock_alloc_send_skb(&po->sk,
                                hlen + tlen + sizeof(struct sockaddr_ll),
-                               0, &err);
+                               !need_wait, &err);
 
-               if (unlikely(skb == NULL))
+               if (unlikely(skb == NULL)) {
+                       /* we assume the socket was initially writeable ... */
+                       if (likely(len_sum > 0))
+                               err = len_sum;
                        goto out_status;
-
+               }
                tp_len = tpacket_fill_skb(po, skb, ph, dev, size_max, proto,
                                          addr, hlen);
                if (tp_len > dev->mtu + dev->hard_header_len) {
@@ -2603,8 +2674,7 @@ out:
        return err;
 }
 
-static int packet_sendmsg(struct kiocb *iocb, struct socket *sock,
-               struct msghdr *msg, size_t len)
+static int packet_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 {
        struct sock *sk = sock->sk;
        struct packet_sock *po = pkt_sk(sk);
@@ -2822,7 +2892,7 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,
        sock->state = SS_UNCONNECTED;
 
        err = -ENOBUFS;
-       sk = sk_alloc(net, PF_PACKET, GFP_KERNEL, &packet_proto);
+       sk = sk_alloc(net, PF_PACKET, GFP_KERNEL, &packet_proto, kern);
        if (sk == NULL)
                goto out;
 
@@ -2852,6 +2922,7 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,
 
        spin_lock_init(&po->bind_lock);
        mutex_init(&po->pg_vec_lock);
+       po->rollover = NULL;
        po->prot_hook.func = packet_rcv;
 
        if (sock->type == SOCK_PACKET)
@@ -2884,13 +2955,14 @@ out:
  *     If necessary we block.
  */
 
-static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,
-                         struct msghdr *msg, size_t len, int flags)
+static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
+                         int flags)
 {
        struct sock *sk = sock->sk;
        struct sk_buff *skb;
        int copied, err;
        int vnet_hdr_len = 0;
+       unsigned int origlen = 0;
 
        err = -EINVAL;
        if (flags & ~(MSG_PEEK|MSG_DONTWAIT|MSG_TRUNC|MSG_CMSG_COMPAT|MSG_ERRQUEUE))
@@ -2990,6 +3062,15 @@ static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,
        if (err)
                goto out_free;
 
+       if (sock->type != SOCK_PACKET) {
+               struct sockaddr_ll *sll = &PACKET_SKB_CB(skb)->sa.ll;
+
+               /* Original length was stored in sockaddr_ll fields */
+               origlen = PACKET_SKB_CB(skb)->sa.origlen;
+               sll->sll_family = AF_PACKET;
+               sll->sll_protocol = skb->protocol;
+       }
+
        sock_recv_ts_and_drops(msg, sk, skb);
 
        if (msg->msg_name) {
@@ -3001,6 +3082,7 @@ static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,
                        msg->msg_namelen = sizeof(struct sockaddr_pkt);
                } else {
                        struct sockaddr_ll *sll = &PACKET_SKB_CB(skb)->sa.ll;
+
                        msg->msg_namelen = sll->sll_halen +
                                offsetof(struct sockaddr_ll, sll_addr);
                }
@@ -3014,7 +3096,12 @@ static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,
                aux.tp_status = TP_STATUS_USER;
                if (skb->ip_summed == CHECKSUM_PARTIAL)
                        aux.tp_status |= TP_STATUS_CSUMNOTREADY;
-               aux.tp_len = PACKET_SKB_CB(skb)->origlen;
+               else if (skb->pkt_type != PACKET_OUTGOING &&
+                        (skb->ip_summed == CHECKSUM_COMPLETE ||
+                         skb_csum_unnecessary(skb)))
+                       aux.tp_status |= TP_STATUS_CSUM_VALID;
+
+               aux.tp_len = origlen;
                aux.tp_snaplen = skb->len;
                aux.tp_mac = 0;
                aux.tp_net = skb_network_offset(skb);