packet: tx timestamping on tpacket ring
[cascardo/linux.git] / net / packet / af_packet.c
index 1d6793d..ec8ea27 100644 (file)
@@ -158,10 +158,16 @@ struct packet_mreq_max {
        unsigned char   mr_address[MAX_ADDR_LEN];
 };
 
+union tpacket_uhdr {
+       struct tpacket_hdr  *h1;
+       struct tpacket2_hdr *h2;
+       struct tpacket3_hdr *h3;
+       void *raw;
+};
+
 static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
                int closing, int tx_ring);
 
-
 #define V3_ALIGNMENT   (8)
 
 #define BLK_HDR_LEN    (ALIGN(sizeof(struct tpacket_block_desc), V3_ALIGNMENT))
@@ -181,6 +187,8 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
 
 struct packet_sock;
 static int tpacket_snd(struct packet_sock *po, struct msghdr *msg);
+static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
+                      struct packet_type *pt, struct net_device *orig_dev);
 
 static void *packet_previous_frame(struct packet_sock *po,
                struct packet_ring_buffer *rb,
@@ -288,11 +296,7 @@ static inline __pure struct page *pgv_to_page(void *addr)
 
 static void __packet_set_status(struct packet_sock *po, void *frame, int status)
 {
-       union {
-               struct tpacket_hdr *h1;
-               struct tpacket2_hdr *h2;
-               void *raw;
-       } h;
+       union tpacket_uhdr h;
 
        h.raw = frame;
        switch (po->tp_version) {
@@ -315,11 +319,7 @@ static void __packet_set_status(struct packet_sock *po, void *frame, int status)
 
 static int __packet_get_status(struct packet_sock *po, void *frame)
 {
-       union {
-               struct tpacket_hdr *h1;
-               struct tpacket2_hdr *h2;
-               void *raw;
-       } h;
+       union tpacket_uhdr h;
 
        smp_rmb();
 
@@ -339,17 +339,44 @@ static int __packet_get_status(struct packet_sock *po, void *frame)
        }
 }
 
+static void __packet_set_timestamp(struct packet_sock *po, void *frame,
+                                  ktime_t tstamp)
+{
+       union tpacket_uhdr h;
+       struct timespec ts;
+
+       if (!ktime_to_timespec_cond(tstamp, &ts) ||
+           !sock_flag(&po->sk, SOCK_TIMESTAMPING_SOFTWARE))
+               return;
+
+       h.raw = frame;
+       switch (po->tp_version) {
+       case TPACKET_V1:
+               h.h1->tp_sec = ts.tv_sec;
+               h.h1->tp_usec = ts.tv_nsec / NSEC_PER_USEC;
+               break;
+       case TPACKET_V2:
+               h.h2->tp_sec = ts.tv_sec;
+               h.h2->tp_nsec = ts.tv_nsec;
+               break;
+       case TPACKET_V3:
+       default:
+               WARN(1, "TPACKET version not supported.\n");
+               BUG();
+       }
+
+       /* one flush is safe, as both fields always lie on the same cacheline */
+       flush_dcache_page(pgv_to_page(&h.h1->tp_sec));
+       smp_wmb();
+}
+
 static void *packet_lookup_frame(struct packet_sock *po,
                struct packet_ring_buffer *rb,
                unsigned int position,
                int status)
 {
        unsigned int pg_vec_pos, frame_offset;
-       union {
-               struct tpacket_hdr *h1;
-               struct tpacket2_hdr *h2;
-               void *raw;
-       } h;
+       union tpacket_uhdr h;
 
        pg_vec_pos = position / rb->frames_per_block;
        frame_offset = position % rb->frames_per_block;
@@ -973,11 +1000,11 @@ static void *packet_current_rx_frame(struct packet_sock *po,
 
 static void *prb_lookup_block(struct packet_sock *po,
                                     struct packet_ring_buffer *rb,
-                                    unsigned int previous,
+                                    unsigned int idx,
                                     int status)
 {
        struct tpacket_kbdq_core *pkc  = GET_PBDQC_FROM_RB(rb);
-       struct tpacket_block_desc *pbd = GET_PBLOCK_DESC(pkc, previous);
+       struct tpacket_block_desc *pbd = GET_PBLOCK_DESC(pkc, idx);
 
        if (status != BLOCK_STATUS(pbd))
                return NULL;
@@ -1041,6 +1068,29 @@ static void packet_increment_head(struct packet_ring_buffer *buff)
        buff->head = buff->head != buff->frame_max ? buff->head+1 : 0;
 }
 
+static bool packet_rcv_has_room(struct packet_sock *po, struct sk_buff *skb)
+{
+       struct sock *sk = &po->sk;
+       bool has_room;
+
+       if (po->prot_hook.func != tpacket_rcv)
+               return (atomic_read(&sk->sk_rmem_alloc) + skb->truesize)
+                       <= sk->sk_rcvbuf;
+
+       spin_lock(&sk->sk_receive_queue.lock);
+       if (po->tp_version == TPACKET_V3)
+               has_room = prb_lookup_block(po, &po->rx_ring,
+                                           po->rx_ring.prb_bdqc.kactive_blk_num,
+                                           TP_STATUS_KERNEL);
+       else
+               has_room = packet_lookup_frame(po, &po->rx_ring,
+                                              po->rx_ring.head,
+                                              TP_STATUS_KERNEL);
+       spin_unlock(&sk->sk_receive_queue.lock);
+
+       return has_room;
+}
+
 static void packet_sock_destruct(struct sock *sk)
 {
        skb_queue_purge(&sk->sk_error_queue);
@@ -1066,16 +1116,16 @@ static int fanout_rr_next(struct packet_fanout *f, unsigned int num)
        return x;
 }
 
-static struct sock *fanout_demux_hash(struct packet_fanout *f, struct sk_buff *skb, unsigned int num)
+static unsigned int fanout_demux_hash(struct packet_fanout *f,
+                                     struct sk_buff *skb,
+                                     unsigned int num)
 {
-       u32 idx, hash = skb->rxhash;
-
-       idx = ((u64)hash * num) >> 32;
-
-       return f->arr[idx];
+       return (((u64)skb->rxhash) * num) >> 32;
 }
 
-static struct sock *fanout_demux_lb(struct packet_fanout *f, struct sk_buff *skb, unsigned int num)
+static unsigned int fanout_demux_lb(struct packet_fanout *f,
+                                   struct sk_buff *skb,
+                                   unsigned int num)
 {
        int cur, old;
 
@@ -1083,14 +1133,40 @@ static struct sock *fanout_demux_lb(struct packet_fanout *f, struct sk_buff *skb
        while ((old = atomic_cmpxchg(&f->rr_cur, cur,
                                     fanout_rr_next(f, num))) != cur)
                cur = old;
-       return f->arr[cur];
+       return cur;
+}
+
+static unsigned int fanout_demux_cpu(struct packet_fanout *f,
+                                    struct sk_buff *skb,
+                                    unsigned int num)
+{
+       return smp_processor_id() % num;
 }
 
-static struct sock *fanout_demux_cpu(struct packet_fanout *f, struct sk_buff *skb, unsigned int num)
+static unsigned int fanout_demux_rollover(struct packet_fanout *f,
+                                         struct sk_buff *skb,
+                                         unsigned int idx, unsigned int skip,
+                                         unsigned int num)
 {
-       unsigned int cpu = smp_processor_id();
+       unsigned int i, j;
 
-       return f->arr[cpu % num];
+       i = j = min_t(int, f->next[idx], num - 1);
+       do {
+               if (i != skip && packet_rcv_has_room(pkt_sk(f->arr[i]), skb)) {
+                       if (i != j)
+                               f->next[idx] = i;
+                       return i;
+               }
+               if (++i == num)
+                       i = 0;
+       } while (i != j);
+
+       return idx;
+}
+
+static bool fanout_has_flag(struct packet_fanout *f, u16 flag)
+{
+       return f->flags & (flag >> 8);
 }
 
 static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev,
@@ -1099,7 +1175,7 @@ static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev,
        struct packet_fanout *f = pt->af_packet_priv;
        unsigned int num = f->num_members;
        struct packet_sock *po;
-       struct sock *sk;
+       unsigned int idx;
 
        if (!net_eq(dev_net(dev), read_pnet(&f->net)) ||
            !num) {
@@ -1110,23 +1186,31 @@ static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev,
        switch (f->type) {
        case PACKET_FANOUT_HASH:
        default:
-               if (f->defrag) {
+               if (fanout_has_flag(f, PACKET_FANOUT_FLAG_DEFRAG)) {
                        skb = ip_check_defrag(skb, IP_DEFRAG_AF_PACKET);
                        if (!skb)
                                return 0;
                }
                skb_get_rxhash(skb);
-               sk = fanout_demux_hash(f, skb, num);
+               idx = fanout_demux_hash(f, skb, num);
                break;
        case PACKET_FANOUT_LB:
-               sk = fanout_demux_lb(f, skb, num);
+               idx = fanout_demux_lb(f, skb, num);
                break;
        case PACKET_FANOUT_CPU:
-               sk = fanout_demux_cpu(f, skb, num);
+               idx = fanout_demux_cpu(f, skb, num);
+               break;
+       case PACKET_FANOUT_ROLLOVER:
+               idx = fanout_demux_rollover(f, skb, 0, (unsigned int) -1, num);
                break;
        }
 
-       po = pkt_sk(sk);
+       po = pkt_sk(f->arr[idx]);
+       if (fanout_has_flag(f, PACKET_FANOUT_FLAG_ROLLOVER) &&
+           unlikely(!packet_rcv_has_room(po, skb))) {
+               idx = fanout_demux_rollover(f, skb, idx, idx, num);
+               po = pkt_sk(f->arr[idx]);
+       }
 
        return po->prot_hook.func(skb, dev, &po->prot_hook, orig_dev);
 }
@@ -1175,10 +1259,13 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
        struct packet_sock *po = pkt_sk(sk);
        struct packet_fanout *f, *match;
        u8 type = type_flags & 0xff;
-       u8 defrag = (type_flags & PACKET_FANOUT_FLAG_DEFRAG) ? 1 : 0;
+       u8 flags = type_flags >> 8;
        int err;
 
        switch (type) {
+       case PACKET_FANOUT_ROLLOVER:
+               if (type_flags & PACKET_FANOUT_FLAG_ROLLOVER)
+                       return -EINVAL;
        case PACKET_FANOUT_HASH:
        case PACKET_FANOUT_LB:
        case PACKET_FANOUT_CPU:
@@ -1203,7 +1290,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
                }
        }
        err = -EINVAL;
-       if (match && match->defrag != defrag)
+       if (match && match->flags != flags)
                goto out;
        if (!match) {
                err = -ENOMEM;
@@ -1213,7 +1300,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
                write_pnet(&match->net, sock_net(sk));
                match->id = id;
                match->type = type;
-               match->defrag = defrag;
+               match->flags = flags;
                atomic_set(&match->rr_cur, 0);
                INIT_LIST_HEAD(&match->list);
                spin_lock_init(&match->lock);
@@ -1443,13 +1530,14 @@ retry:
        skb->dev = dev;
        skb->priority = sk->sk_priority;
        skb->mark = sk->sk_mark;
-       err = sock_tx_timestamp(sk, &skb_shinfo(skb)->tx_flags);
-       if (err < 0)
-               goto out_unlock;
+
+       sock_tx_timestamp(sk, &skb_shinfo(skb)->tx_flags);
 
        if (unlikely(extra_len == 4))
                skb->no_fcs = 1;
 
+       skb_probe_transport_header(skb, 0);
+
        dev_queue_xmit(skb);
        rcu_read_unlock();
        return len;
@@ -1600,27 +1688,40 @@ drop:
        return 0;
 }
 
+static void tpacket_get_timestamp(struct sk_buff *skb, struct timespec *ts,
+                                 unsigned int flags)
+{
+       struct skb_shared_hwtstamps *shhwtstamps = skb_hwtstamps(skb);
+
+       if (shhwtstamps) {
+               if ((flags & SOF_TIMESTAMPING_SYS_HARDWARE) &&
+                   ktime_to_timespec_cond(shhwtstamps->syststamp, ts))
+                       return;
+               if ((flags & SOF_TIMESTAMPING_RAW_HARDWARE) &&
+                   ktime_to_timespec_cond(shhwtstamps->hwtstamp, ts))
+                       return;
+       }
+
+       if (ktime_to_timespec_cond(skb->tstamp, ts))
+               return;
+
+       getnstimeofday(ts);
+}
+
 static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                       struct packet_type *pt, struct net_device *orig_dev)
 {
        struct sock *sk;
        struct packet_sock *po;
        struct sockaddr_ll *sll;
-       union {
-               struct tpacket_hdr *h1;
-               struct tpacket2_hdr *h2;
-               struct tpacket3_hdr *h3;
-               void *raw;
-       } h;
+       union tpacket_uhdr h;
        u8 *skb_head = skb->data;
        int skb_len = skb->len;
        unsigned int snaplen, res;
        unsigned long status = TP_STATUS_USER;
        unsigned short macoff, netoff, hdrlen;
        struct sk_buff *copy_skb = NULL;
-       struct timeval tv;
        struct timespec ts;
-       struct skb_shared_hwtstamps *shhwtstamps = skb_hwtstamps(skb);
 
        if (skb->pkt_type == PACKET_LOOPBACK)
                goto drop;
@@ -1703,6 +1804,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
        spin_unlock(&sk->sk_receive_queue.lock);
 
        skb_copy_bits(skb, 0, h.raw + macoff, snaplen);
+       tpacket_get_timestamp(skb, &ts, po->tp_tstamp);
 
        switch (po->tp_version) {
        case TPACKET_V1:
@@ -1710,18 +1812,8 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                h.h1->tp_snaplen = snaplen;
                h.h1->tp_mac = macoff;
                h.h1->tp_net = netoff;
-               if ((po->tp_tstamp & SOF_TIMESTAMPING_SYS_HARDWARE)
-                               && shhwtstamps->syststamp.tv64)
-                       tv = ktime_to_timeval(shhwtstamps->syststamp);
-               else if ((po->tp_tstamp & SOF_TIMESTAMPING_RAW_HARDWARE)
-                               && shhwtstamps->hwtstamp.tv64)
-                       tv = ktime_to_timeval(shhwtstamps->hwtstamp);
-               else if (skb->tstamp.tv64)
-                       tv = ktime_to_timeval(skb->tstamp);
-               else
-                       do_gettimeofday(&tv);
-               h.h1->tp_sec = tv.tv_sec;
-               h.h1->tp_usec = tv.tv_usec;
+               h.h1->tp_sec = ts.tv_sec;
+               h.h1->tp_usec = ts.tv_nsec / NSEC_PER_USEC;
                hdrlen = sizeof(*h.h1);
                break;
        case TPACKET_V2:
@@ -1729,16 +1821,6 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                h.h2->tp_snaplen = snaplen;
                h.h2->tp_mac = macoff;
                h.h2->tp_net = netoff;
-               if ((po->tp_tstamp & SOF_TIMESTAMPING_SYS_HARDWARE)
-                               && shhwtstamps->syststamp.tv64)
-                       ts = ktime_to_timespec(shhwtstamps->syststamp);
-               else if ((po->tp_tstamp & SOF_TIMESTAMPING_RAW_HARDWARE)
-                               && shhwtstamps->hwtstamp.tv64)
-                       ts = ktime_to_timespec(shhwtstamps->hwtstamp);
-               else if (skb->tstamp.tv64)
-                       ts = ktime_to_timespec(skb->tstamp);
-               else
-                       getnstimeofday(&ts);
                h.h2->tp_sec = ts.tv_sec;
                h.h2->tp_nsec = ts.tv_nsec;
                if (vlan_tx_tag_present(skb)) {
@@ -1759,16 +1841,6 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                h.h3->tp_snaplen = snaplen;
                h.h3->tp_mac = macoff;
                h.h3->tp_net = netoff;
-               if ((po->tp_tstamp & SOF_TIMESTAMPING_SYS_HARDWARE)
-                               && shhwtstamps->syststamp.tv64)
-                       ts = ktime_to_timespec(shhwtstamps->syststamp);
-               else if ((po->tp_tstamp & SOF_TIMESTAMPING_RAW_HARDWARE)
-                               && shhwtstamps->hwtstamp.tv64)
-                       ts = ktime_to_timespec(shhwtstamps->hwtstamp);
-               else if (skb->tstamp.tv64)
-                       ts = ktime_to_timespec(skb->tstamp);
-               else
-                       getnstimeofday(&ts);
                h.h3->tp_sec  = ts.tv_sec;
                h.h3->tp_nsec = ts.tv_nsec;
                hdrlen = sizeof(*h.h3);
@@ -1836,6 +1908,7 @@ static void tpacket_destruct_skb(struct sk_buff *skb)
                ph = skb_shinfo(skb)->destructor_arg;
                BUG_ON(atomic_read(&po->tx_ring.pending) == 0);
                atomic_dec(&po->tx_ring.pending);
+               __packet_set_timestamp(po, ph, skb->tstamp);
                __packet_set_status(po, ph, TP_STATUS_AVAILABLE);
        }
 
@@ -1846,11 +1919,7 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
                void *frame, struct net_device *dev, int size_max,
                __be16 proto, unsigned char *addr, int hlen)
 {
-       union {
-               struct tpacket_hdr *h1;
-               struct tpacket2_hdr *h2;
-               void *raw;
-       } ph;
+       union tpacket_uhdr ph;
        int to_write, offset, len, tp_len, nr_frags, len_max;
        struct socket *sock = po->sk.sk_socket;
        struct page *page;
@@ -1863,6 +1932,7 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
        skb->dev = dev;
        skb->priority = po->sk.sk_priority;
        skb->mark = po->sk.sk_mark;
+       sock_tx_timestamp(&po->sk, &skb_shinfo(skb)->tx_flags);
        skb_shinfo(skb)->destructor_arg = ph.raw;
 
        switch (po->tp_version) {
@@ -1880,6 +1950,7 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
 
        skb_reserve(skb, hlen);
        skb_reset_network_header(skb);
+       skb_probe_transport_header(skb, 0);
 
        if (po->tp_tx_has_off) {
                int off_min, off_max, off;
@@ -2247,9 +2318,8 @@ static int packet_snd(struct socket *sock,
        err = skb_copy_datagram_from_iovec(skb, offset, msg->msg_iov, 0, len);
        if (err)
                goto out_free;
-       err = sock_tx_timestamp(sk, &skb_shinfo(skb)->tx_flags);
-       if (err < 0)
-               goto out_free;
+
+       sock_tx_timestamp(sk, &skb_shinfo(skb)->tx_flags);
 
        if (!gso_type && (len > dev->mtu + reserve + extra_len)) {
                /* Earlier code assumed this would be a VLAN pkt,
@@ -2289,6 +2359,8 @@ static int packet_snd(struct socket *sock,
                len += vnet_hdr_len;
        }
 
+       skb_probe_transport_header(skb, reserve);
+
        if (unlikely(extra_len == 4))
                skb->no_fcs = 1;
 
@@ -3240,7 +3312,8 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
        case PACKET_FANOUT:
                val = (po->fanout ?
                       ((u32)po->fanout->id |
-                       ((u32)po->fanout->type << 16)) :
+                       ((u32)po->fanout->type << 16) |
+                       ((u32)po->fanout->flags << 24)) :
                       0);
                break;
        case PACKET_TX_HAS_OFF: