Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost
[cascardo/linux.git] / drivers / vhost / net.c
index 0965f86..5dc128a 100644 (file)
@@ -302,6 +302,32 @@ static bool vhost_can_busy_poll(struct vhost_dev *dev,
               !vhost_has_work(dev);
 }
 
+static void vhost_net_disable_vq(struct vhost_net *n,
+                                struct vhost_virtqueue *vq)
+{
+       struct vhost_net_virtqueue *nvq =
+               container_of(vq, struct vhost_net_virtqueue, vq);
+       struct vhost_poll *poll = n->poll + (nvq - n->vqs);
+       if (!vq->private_data)
+               return;
+       vhost_poll_stop(poll);
+}
+
+static int vhost_net_enable_vq(struct vhost_net *n,
+                               struct vhost_virtqueue *vq)
+{
+       struct vhost_net_virtqueue *nvq =
+               container_of(vq, struct vhost_net_virtqueue, vq);
+       struct vhost_poll *poll = n->poll + (nvq - n->vqs);
+       struct socket *sock;
+
+       sock = vq->private_data;
+       if (!sock)
+               return 0;
+
+       return vhost_poll_start(poll, sock->file);
+}
+
 static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
                                    struct vhost_virtqueue *vq,
                                    struct iovec iov[], unsigned int iov_size,
@@ -459,10 +485,14 @@ out:
 
 static int peek_head_len(struct sock *sk)
 {
+       struct socket *sock = sk->sk_socket;
        struct sk_buff *head;
        int len = 0;
        unsigned long flags;
 
+       if (sock->ops->peek_len)
+               return sock->ops->peek_len(sock);
+
        spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
        head = skb_peek(&sk->sk_receive_queue);
        if (likely(head)) {
@@ -475,6 +505,16 @@ static int peek_head_len(struct sock *sk)
        return len;
 }
 
+static int sk_has_rx_data(struct sock *sk)
+{
+       struct socket *sock = sk->sk_socket;
+
+       if (sock->ops->peek_len)
+               return sock->ops->peek_len(sock);
+
+       return skb_queue_empty(&sk->sk_receive_queue);
+}
+
 static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk)
 {
        struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
@@ -491,7 +531,7 @@ static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk)
                endtime = busy_clock() + vq->busyloop_timeout;
 
                while (vhost_can_busy_poll(&net->dev, endtime) &&
-                      skb_queue_empty(&sk->sk_receive_queue) &&
+                      !sk_has_rx_data(sk) &&
                       vhost_vq_avail_empty(&net->dev, vq))
                        cpu_relax_lowlatency();
 
@@ -621,6 +661,7 @@ static void handle_rx(struct vhost_net *net)
                goto out;
 
        vhost_disable_notify(&net->dev, vq);
+       vhost_net_disable_vq(net, vq);
 
        vhost_hlen = nvq->vhost_hlen;
        sock_hlen = nvq->sock_hlen;
@@ -637,7 +678,7 @@ static void handle_rx(struct vhost_net *net)
                                        likely(mergeable) ? UIO_MAXIOV : 1);
                /* On error, stop handling until the next kick. */
                if (unlikely(headcount < 0))
-                       break;
+                       goto out;
                /* On overrun, truncate and discard */
                if (unlikely(headcount > UIO_MAXIOV)) {
                        iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1);
@@ -656,7 +697,7 @@ static void handle_rx(struct vhost_net *net)
                        }
                        /* Nothing new?  Wait for eventfd to tell us
                         * they refilled. */
-                       break;
+                       goto out;
                }
                /* We don't need to be notified again. */
                iov_iter_init(&msg.msg_iter, READ, vq->iov, in, vhost_len);
@@ -684,7 +725,7 @@ static void handle_rx(struct vhost_net *net)
                                         &fixup) != sizeof(hdr)) {
                                vq_err(vq, "Unable to write vnet_hdr "
                                       "at addr %p\n", vq->iov->iov_base);
-                               break;
+                               goto out;
                        }
                } else {
                        /* Header came from socket; we'll need to patch
@@ -700,7 +741,7 @@ static void handle_rx(struct vhost_net *net)
                                 &fixup) != sizeof num_buffers) {
                        vq_err(vq, "Failed num_buffers write");
                        vhost_discard_vq_desc(vq, headcount);
-                       break;
+                       goto out;
                }
                vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
                                            headcount);
@@ -709,9 +750,10 @@ static void handle_rx(struct vhost_net *net)
                total_len += vhost_len;
                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
                        vhost_poll_queue(&vq->poll);
-                       break;
+                       goto out;
                }
        }
+       vhost_net_enable_vq(net, vq);
 out:
        mutex_unlock(&vq->mutex);
 }
@@ -790,32 +832,6 @@ static int vhost_net_open(struct inode *inode, struct file *f)
        return 0;
 }
 
-static void vhost_net_disable_vq(struct vhost_net *n,
-                                struct vhost_virtqueue *vq)
-{
-       struct vhost_net_virtqueue *nvq =
-               container_of(vq, struct vhost_net_virtqueue, vq);
-       struct vhost_poll *poll = n->poll + (nvq - n->vqs);
-       if (!vq->private_data)
-               return;
-       vhost_poll_stop(poll);
-}
-
-static int vhost_net_enable_vq(struct vhost_net *n,
-                               struct vhost_virtqueue *vq)
-{
-       struct vhost_net_virtqueue *nvq =
-               container_of(vq, struct vhost_net_virtqueue, vq);
-       struct vhost_poll *poll = n->poll + (nvq - n->vqs);
-       struct socket *sock;
-
-       sock = vq->private_data;
-       if (!sock)
-               return 0;
-
-       return vhost_poll_start(poll, sock->file);
-}
-
 static struct socket *vhost_net_stop_vq(struct vhost_net *n,
                                        struct vhost_virtqueue *vq)
 {