Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[cascardo/linux.git] / net / kcm / kcmsock.c
index 4116932..b7f869a 100644 (file)
@@ -1,3 +1,13 @@
+/*
+ * Kernel Connection Multiplexor
+ *
+ * Copyright (c) 2016 Tom Herbert <tom@herbertland.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 2
+ * as published by the Free Software Foundation.
+ */
+
 #include <linux/bpf.h>
 #include <linux/errno.h>
 #include <linux/errqueue.h>
@@ -17,7 +27,6 @@
 #include <net/kcm.h>
 #include <net/netns/generic.h>
 #include <net/sock.h>
-#include <net/tcp.h>
 #include <uapi/linux/kcm.h>
 
 unsigned int kcm_net_id;
@@ -36,38 +45,12 @@ static inline struct kcm_tx_msg *kcm_tx_msg(struct sk_buff *skb)
        return (struct kcm_tx_msg *)skb->cb;
 }
 
-static inline struct kcm_rx_msg *kcm_rx_msg(struct sk_buff *skb)
-{
-       return (struct kcm_rx_msg *)((void *)skb->cb +
-                                    offsetof(struct qdisc_skb_cb, data));
-}
-
 static void report_csk_error(struct sock *csk, int err)
 {
        csk->sk_err = EPIPE;
        csk->sk_error_report(csk);
 }
 
-/* Callback lock held */
-static void kcm_abort_rx_psock(struct kcm_psock *psock, int err,
-                              struct sk_buff *skb)
-{
-       struct sock *csk = psock->sk;
-
-       /* Unrecoverable error in receive */
-
-       del_timer(&psock->rx_msg_timer);
-
-       if (psock->rx_stopped)
-               return;
-
-       psock->rx_stopped = 1;
-       KCM_STATS_INCR(psock->stats.rx_aborts);
-
-       /* Report an error on the lower socket */
-       report_csk_error(csk, err);
-}
-
 static void kcm_abort_tx_psock(struct kcm_psock *psock, int err,
                               bool wakeup_kcm)
 {
@@ -110,12 +93,13 @@ static void kcm_abort_tx_psock(struct kcm_psock *psock, int err,
 static void kcm_update_rx_mux_stats(struct kcm_mux *mux,
                                    struct kcm_psock *psock)
 {
-       KCM_STATS_ADD(mux->stats.rx_bytes,
-                     psock->stats.rx_bytes - psock->saved_rx_bytes);
+       STRP_STATS_ADD(mux->stats.rx_bytes,
+                      psock->strp.stats.rx_bytes -
+                      psock->saved_rx_bytes);
        mux->stats.rx_msgs +=
-               psock->stats.rx_msgs - psock->saved_rx_msgs;
-       psock->saved_rx_msgs = psock->stats.rx_msgs;
-       psock->saved_rx_bytes = psock->stats.rx_bytes;
+               psock->strp.stats.rx_msgs - psock->saved_rx_msgs;
+       psock->saved_rx_msgs = psock->strp.stats.rx_msgs;
+       psock->saved_rx_bytes = psock->strp.stats.rx_bytes;
 }
 
 static void kcm_update_tx_mux_stats(struct kcm_mux *mux,
@@ -168,11 +152,11 @@ static void kcm_rcv_ready(struct kcm_sock *kcm)
                 */
                list_del(&psock->psock_ready_list);
                psock->ready_rx_msg = NULL;
-
                /* Commit clearing of ready_rx_msg for queuing work */
                smp_mb();
 
-               queue_work(kcm_wq, &psock->rx_work);
+               strp_unpause(&psock->strp);
+               strp_check_rcv(&psock->strp);
        }
 
        /* Buffer limit is okay now, add to ready list */
@@ -286,6 +270,7 @@ static struct kcm_sock *reserve_rx_kcm(struct kcm_psock *psock,
 
        if (list_empty(&mux->kcm_rx_waiters)) {
                psock->ready_rx_msg = head;
+               strp_pause(&psock->strp);
                list_add_tail(&psock->psock_ready_list,
                              &mux->psocks_ready);
                spin_unlock_bh(&mux->rx_lock);
@@ -354,346 +339,60 @@ static void unreserve_rx_kcm(struct kcm_psock *psock,
        spin_unlock_bh(&mux->rx_lock);
 }
 
-static void kcm_start_rx_timer(struct kcm_psock *psock)
-{
-       if (psock->sk->sk_rcvtimeo)
-               mod_timer(&psock->rx_msg_timer, psock->sk->sk_rcvtimeo);
-}
-
-/* Macro to invoke filter function. */
-#define KCM_RUN_FILTER(prog, ctx) \
-       (*prog->bpf_func)(ctx, prog->insnsi)
-
-/* Lower socket lock held */
-static int kcm_tcp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
-                       unsigned int orig_offset, size_t orig_len)
-{
-       struct kcm_psock *psock = (struct kcm_psock *)desc->arg.data;
-       struct kcm_rx_msg *rxm;
-       struct kcm_sock *kcm;
-       struct sk_buff *head, *skb;
-       size_t eaten = 0, cand_len;
-       ssize_t extra;
-       int err;
-       bool cloned_orig = false;
-
-       if (psock->ready_rx_msg)
-               return 0;
-
-       head = psock->rx_skb_head;
-       if (head) {
-               /* Message already in progress */
-
-               rxm = kcm_rx_msg(head);
-               if (unlikely(rxm->early_eaten)) {
-                       /* Already some number of bytes on the receive sock
-                        * data saved in rx_skb_head, just indicate they
-                        * are consumed.
-                        */
-                       eaten = orig_len <= rxm->early_eaten ?
-                               orig_len : rxm->early_eaten;
-                       rxm->early_eaten -= eaten;
-
-                       return eaten;
-               }
-
-               if (unlikely(orig_offset)) {
-                       /* Getting data with a non-zero offset when a message is
-                        * in progress is not expected. If it does happen, we
-                        * need to clone and pull since we can't deal with
-                        * offsets in the skbs for a message expect in the head.
-                        */
-                       orig_skb = skb_clone(orig_skb, GFP_ATOMIC);
-                       if (!orig_skb) {
-                               KCM_STATS_INCR(psock->stats.rx_mem_fail);
-                               desc->error = -ENOMEM;
-                               return 0;
-                       }
-                       if (!pskb_pull(orig_skb, orig_offset)) {
-                               KCM_STATS_INCR(psock->stats.rx_mem_fail);
-                               kfree_skb(orig_skb);
-                               desc->error = -ENOMEM;
-                               return 0;
-                       }
-                       cloned_orig = true;
-                       orig_offset = 0;
-               }
-
-               if (!psock->rx_skb_nextp) {
-                       /* We are going to append to the frags_list of head.
-                        * Need to unshare the frag_list.
-                        */
-                       err = skb_unclone(head, GFP_ATOMIC);
-                       if (err) {
-                               KCM_STATS_INCR(psock->stats.rx_mem_fail);
-                               desc->error = err;
-                               return 0;
-                       }
-
-                       if (unlikely(skb_shinfo(head)->frag_list)) {
-                               /* We can't append to an sk_buff that already
-                                * has a frag_list. We create a new head, point
-                                * the frag_list of that to the old head, and
-                                * then are able to use the old head->next for
-                                * appending to the message.
-                                */
-                               if (WARN_ON(head->next)) {
-                                       desc->error = -EINVAL;
-                                       return 0;
-                               }
-
-                               skb = alloc_skb(0, GFP_ATOMIC);
-                               if (!skb) {
-                                       KCM_STATS_INCR(psock->stats.rx_mem_fail);
-                                       desc->error = -ENOMEM;
-                                       return 0;
-                               }
-                               skb->len = head->len;
-                               skb->data_len = head->len;
-                               skb->truesize = head->truesize;
-                               *kcm_rx_msg(skb) = *kcm_rx_msg(head);
-                               psock->rx_skb_nextp = &head->next;
-                               skb_shinfo(skb)->frag_list = head;
-                               psock->rx_skb_head = skb;
-                               head = skb;
-                       } else {
-                               psock->rx_skb_nextp =
-                                   &skb_shinfo(head)->frag_list;
-                       }
-               }
-       }
-
-       while (eaten < orig_len) {
-               /* Always clone since we will consume something */
-               skb = skb_clone(orig_skb, GFP_ATOMIC);
-               if (!skb) {
-                       KCM_STATS_INCR(psock->stats.rx_mem_fail);
-                       desc->error = -ENOMEM;
-                       break;
-               }
-
-               cand_len = orig_len - eaten;
-
-               head = psock->rx_skb_head;
-               if (!head) {
-                       head = skb;
-                       psock->rx_skb_head = head;
-                       /* Will set rx_skb_nextp on next packet if needed */
-                       psock->rx_skb_nextp = NULL;
-                       rxm = kcm_rx_msg(head);
-                       memset(rxm, 0, sizeof(*rxm));
-                       rxm->offset = orig_offset + eaten;
-               } else {
-                       /* Unclone since we may be appending to an skb that we
-                        * already share a frag_list with.
-                        */
-                       err = skb_unclone(skb, GFP_ATOMIC);
-                       if (err) {
-                               KCM_STATS_INCR(psock->stats.rx_mem_fail);
-                               desc->error = err;
-                               break;
-                       }
-
-                       rxm = kcm_rx_msg(head);
-                       *psock->rx_skb_nextp = skb;
-                       psock->rx_skb_nextp = &skb->next;
-                       head->data_len += skb->len;
-                       head->len += skb->len;
-                       head->truesize += skb->truesize;
-               }
-
-               if (!rxm->full_len) {
-                       ssize_t len;
-
-                       len = KCM_RUN_FILTER(psock->bpf_prog, head);
-
-                       if (!len) {
-                               /* Need more header to determine length */
-                               if (!rxm->accum_len) {
-                                       /* Start RX timer for new message */
-                                       kcm_start_rx_timer(psock);
-                               }
-                               rxm->accum_len += cand_len;
-                               eaten += cand_len;
-                               KCM_STATS_INCR(psock->stats.rx_need_more_hdr);
-                               WARN_ON(eaten != orig_len);
-                               break;
-                       } else if (len > psock->sk->sk_rcvbuf) {
-                               /* Message length exceeds maximum allowed */
-                               KCM_STATS_INCR(psock->stats.rx_msg_too_big);
-                               desc->error = -EMSGSIZE;
-                               psock->rx_skb_head = NULL;
-                               kcm_abort_rx_psock(psock, EMSGSIZE, head);
-                               break;
-                       } else if (len <= (ssize_t)head->len -
-                                         skb->len - rxm->offset) {
-                               /* Length must be into new skb (and also
-                                * greater than zero)
-                                */
-                               KCM_STATS_INCR(psock->stats.rx_bad_hdr_len);
-                               desc->error = -EPROTO;
-                               psock->rx_skb_head = NULL;
-                               kcm_abort_rx_psock(psock, EPROTO, head);
-                               break;
-                       }
-
-                       rxm->full_len = len;
-               }
-
-               extra = (ssize_t)(rxm->accum_len + cand_len) - rxm->full_len;
-
-               if (extra < 0) {
-                       /* Message not complete yet. */
-                       if (rxm->full_len - rxm->accum_len >
-                           tcp_inq(psock->sk)) {
-                               /* Don't have the whole messages in the socket
-                                * buffer. Set psock->rx_need_bytes to wait for
-                                * the rest of the message. Also, set "early
-                                * eaten" since we've already buffered the skb
-                                * but don't consume yet per tcp_read_sock.
-                                */
-
-                               if (!rxm->accum_len) {
-                                       /* Start RX timer for new message */
-                                       kcm_start_rx_timer(psock);
-                               }
-
-                               psock->rx_need_bytes = rxm->full_len -
-                                                      rxm->accum_len;
-                               rxm->accum_len += cand_len;
-                               rxm->early_eaten = cand_len;
-                               KCM_STATS_ADD(psock->stats.rx_bytes, cand_len);
-                               desc->count = 0; /* Stop reading socket */
-                               break;
-                       }
-                       rxm->accum_len += cand_len;
-                       eaten += cand_len;
-                       WARN_ON(eaten != orig_len);
-                       break;
-               }
-
-               /* Positive extra indicates ore bytes than needed for the
-                * message
-                */
-
-               WARN_ON(extra > cand_len);
-
-               eaten += (cand_len - extra);
-
-               /* Hurray, we have a new message! */
-               del_timer(&psock->rx_msg_timer);
-               psock->rx_skb_head = NULL;
-               KCM_STATS_INCR(psock->stats.rx_msgs);
-
-try_queue:
-               kcm = reserve_rx_kcm(psock, head);
-               if (!kcm) {
-                       /* Unable to reserve a KCM, message is held in psock. */
-                       break;
-               }
-
-               if (kcm_queue_rcv_skb(&kcm->sk, head)) {
-                       /* Should mean socket buffer full */
-                       unreserve_rx_kcm(psock, false);
-                       goto try_queue;
-               }
-       }
-
-       if (cloned_orig)
-               kfree_skb(orig_skb);
-
-       KCM_STATS_ADD(psock->stats.rx_bytes, eaten);
-
-       return eaten;
-}
-
-/* Called with lock held on lower socket */
-static int psock_tcp_read_sock(struct kcm_psock *psock)
-{
-       read_descriptor_t desc;
-
-       desc.arg.data = psock;
-       desc.error = 0;
-       desc.count = 1; /* give more than one skb per call */
-
-       /* sk should be locked here, so okay to do tcp_read_sock */
-       tcp_read_sock(psock->sk, &desc, kcm_tcp_recv);
-
-       unreserve_rx_kcm(psock, true);
-
-       return desc.error;
-}
-
 /* Lower sock lock held */
-static void psock_tcp_data_ready(struct sock *sk)
+static void psock_data_ready(struct sock *sk)
 {
        struct kcm_psock *psock;
 
        read_lock_bh(&sk->sk_callback_lock);
 
        psock = (struct kcm_psock *)sk->sk_user_data;
-       if (unlikely(!psock || psock->rx_stopped))
-               goto out;
+       if (likely(psock))
+               strp_data_ready(&psock->strp);
 
-       if (psock->ready_rx_msg)
-               goto out;
-
-       if (psock->rx_need_bytes) {
-               if (tcp_inq(sk) >= psock->rx_need_bytes)
-                       psock->rx_need_bytes = 0;
-               else
-                       goto out;
-       }
-
-       if (psock_tcp_read_sock(psock) == -ENOMEM)
-               queue_delayed_work(kcm_wq, &psock->rx_delayed_work, 0);
-
-out:
        read_unlock_bh(&sk->sk_callback_lock);
 }
 
-static void do_psock_rx_work(struct kcm_psock *psock)
+/* Called with lower sock held */
+static void kcm_rcv_strparser(struct strparser *strp, struct sk_buff *skb)
 {
-       read_descriptor_t rd_desc;
-       struct sock *csk = psock->sk;
-
-       /* We need the read lock to synchronize with psock_tcp_data_ready. We
-        * need the socket lock for calling tcp_read_sock.
-        */
-       lock_sock(csk);
-       read_lock_bh(&csk->sk_callback_lock);
-
-       if (unlikely(csk->sk_user_data != psock))
-               goto out;
-
-       if (unlikely(psock->rx_stopped))
-               goto out;
-
-       if (psock->ready_rx_msg)
-               goto out;
-
-       rd_desc.arg.data = psock;
+       struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp);
+       struct kcm_sock *kcm;
 
-       if (psock_tcp_read_sock(psock) == -ENOMEM)
-               queue_delayed_work(kcm_wq, &psock->rx_delayed_work, 0);
+try_queue:
+       kcm = reserve_rx_kcm(psock, skb);
+       if (!kcm) {
+                /* Unable to reserve a KCM, message is held in psock and strp
+                 * is paused.
+                 */
+               return;
+       }
 
-out:
-       read_unlock_bh(&csk->sk_callback_lock);
-       release_sock(csk);
+       if (kcm_queue_rcv_skb(&kcm->sk, skb)) {
+               /* Should mean socket buffer full */
+               unreserve_rx_kcm(psock, false);
+               goto try_queue;
+       }
 }
 
-static void psock_rx_work(struct work_struct *w)
+static int kcm_parse_func_strparser(struct strparser *strp, struct sk_buff *skb)
 {
-       do_psock_rx_work(container_of(w, struct kcm_psock, rx_work));
+       struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp);
+       struct bpf_prog *prog = psock->bpf_prog;
+
+       return (*prog->bpf_func)(skb, prog->insnsi);
 }
 
-static void psock_rx_delayed_work(struct work_struct *w)
+static int kcm_read_sock_done(struct strparser *strp, int err)
 {
-       do_psock_rx_work(container_of(w, struct kcm_psock,
-                                     rx_delayed_work.work));
+       struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp);
+
+       unreserve_rx_kcm(psock, true);
+
+       return err;
 }
 
-static void psock_tcp_state_change(struct sock *sk)
+static void psock_state_change(struct sock *sk)
 {
        /* TCP only does a POLLIN for a half close. Do a POLLHUP here
         * since application will normally not poll with POLLIN
@@ -703,7 +402,7 @@ static void psock_tcp_state_change(struct sock *sk)
        report_csk_error(sk, EPIPE);
 }
 
-static void psock_tcp_write_space(struct sock *sk)
+static void psock_write_space(struct sock *sk)
 {
        struct kcm_psock *psock;
        struct kcm_mux *mux;
@@ -714,14 +413,13 @@ static void psock_tcp_write_space(struct sock *sk)
        psock = (struct kcm_psock *)sk->sk_user_data;
        if (unlikely(!psock))
                goto out;
-
        mux = psock->mux;
 
        spin_lock_bh(&mux->lock);
 
        /* Check if the socket is reserved so someone is waiting for sending. */
        kcm = psock->tx_kcm;
-       if (kcm)
+       if (kcm && !unlikely(kcm->tx_stopped))
                queue_work(kcm_wq, &kcm->tx_work);
 
        spin_unlock_bh(&mux->lock);
@@ -1412,7 +1110,7 @@ static int kcm_recvmsg(struct socket *sock, struct msghdr *msg,
        struct kcm_sock *kcm = kcm_sk(sk);
        int err = 0;
        long timeo;
-       struct kcm_rx_msg *rxm;
+       struct strp_rx_msg *rxm;
        int copied = 0;
        struct sk_buff *skb;
 
@@ -1426,7 +1124,7 @@ static int kcm_recvmsg(struct socket *sock, struct msghdr *msg,
 
        /* Okay, have a message on the receive queue */
 
-       rxm = kcm_rx_msg(skb);
+       rxm = strp_rx_msg(skb);
 
        if (len > rxm->full_len)
                len = rxm->full_len;
@@ -1482,7 +1180,7 @@ static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos,
        struct sock *sk = sock->sk;
        struct kcm_sock *kcm = kcm_sk(sk);
        long timeo;
-       struct kcm_rx_msg *rxm;
+       struct strp_rx_msg *rxm;
        int err = 0;
        ssize_t copied;
        struct sk_buff *skb;
@@ -1499,7 +1197,7 @@ static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos,
 
        /* Okay, have a message on the receive queue */
 
-       rxm = kcm_rx_msg(skb);
+       rxm = strp_rx_msg(skb);
 
        if (len > rxm->full_len)
                len = rxm->full_len;
@@ -1675,15 +1373,6 @@ static void init_kcm_sock(struct kcm_sock *kcm, struct kcm_mux *mux)
        spin_unlock_bh(&mux->rx_lock);
 }
 
-static void kcm_rx_msg_timeout(unsigned long arg)
-{
-       struct kcm_psock *psock = (struct kcm_psock *)arg;
-
-       /* Message assembly timed out */
-       KCM_STATS_INCR(psock->stats.rx_msg_timeouts);
-       kcm_abort_rx_psock(psock, ETIMEDOUT, NULL);
-}
-
 static int kcm_attach(struct socket *sock, struct socket *csock,
                      struct bpf_prog *prog)
 {
@@ -1693,19 +1382,13 @@ static int kcm_attach(struct socket *sock, struct socket *csock,
        struct kcm_psock *psock = NULL, *tpsock;
        struct list_head *head;
        int index = 0;
-
-       if (csock->ops->family != PF_INET &&
-           csock->ops->family != PF_INET6)
-               return -EINVAL;
+       struct strp_callbacks cb;
+       int err;
 
        csk = csock->sk;
        if (!csk)
                return -EINVAL;
 
-       /* Only support TCP for now */
-       if (csk->sk_protocol != IPPROTO_TCP)
-               return -EINVAL;
-
        psock = kmem_cache_zalloc(kcm_psockp, GFP_KERNEL);
        if (!psock)
                return -ENOMEM;
@@ -1714,11 +1397,16 @@ static int kcm_attach(struct socket *sock, struct socket *csock,
        psock->sk = csk;
        psock->bpf_prog = prog;
 
-       setup_timer(&psock->rx_msg_timer, kcm_rx_msg_timeout,
-                   (unsigned long)psock);
+       cb.rcv_msg = kcm_rcv_strparser;
+       cb.abort_parser = NULL;
+       cb.parse_msg = kcm_parse_func_strparser;
+       cb.read_sock_done = kcm_read_sock_done;
 
-       INIT_WORK(&psock->rx_work, psock_rx_work);
-       INIT_DELAYED_WORK(&psock->rx_delayed_work, psock_rx_delayed_work);
+       err = strp_init(&psock->strp, csk, &cb);
+       if (err) {
+               kmem_cache_free(kcm_psockp, psock);
+               return err;
+       }
 
        sock_hold(csk);
 
@@ -1727,9 +1415,9 @@ static int kcm_attach(struct socket *sock, struct socket *csock,
        psock->save_write_space = csk->sk_write_space;
        psock->save_state_change = csk->sk_state_change;
        csk->sk_user_data = psock;
-       csk->sk_data_ready = psock_tcp_data_ready;
-       csk->sk_write_space = psock_tcp_write_space;
-       csk->sk_state_change = psock_tcp_state_change;
+       csk->sk_data_ready = psock_data_ready;
+       csk->sk_write_space = psock_write_space;
+       csk->sk_state_change = psock_state_change;
        write_unlock_bh(&csk->sk_callback_lock);
 
        /* Finished initialization, now add the psock to the MUX. */
@@ -1751,7 +1439,7 @@ static int kcm_attach(struct socket *sock, struct socket *csock,
        spin_unlock_bh(&mux->lock);
 
        /* Schedule RX work in case there are already bytes queued */
-       queue_work(kcm_wq, &psock->rx_work);
+       strp_check_rcv(&psock->strp);
 
        return 0;
 }
@@ -1791,6 +1479,8 @@ static void kcm_unattach(struct kcm_psock *psock)
        struct sock *csk = psock->sk;
        struct kcm_mux *mux = psock->mux;
 
+       lock_sock(csk);
+
        /* Stop getting callbacks from TCP socket. After this there should
         * be no way to reserve a kcm for this psock.
         */
@@ -1799,7 +1489,7 @@ static void kcm_unattach(struct kcm_psock *psock)
        csk->sk_data_ready = psock->save_data_ready;
        csk->sk_write_space = psock->save_write_space;
        csk->sk_state_change = psock->save_state_change;
-       psock->rx_stopped = 1;
+       strp_stop(&psock->strp);
 
        if (WARN_ON(psock->rx_kcm)) {
                write_unlock_bh(&csk->sk_callback_lock);
@@ -1822,18 +1512,17 @@ static void kcm_unattach(struct kcm_psock *psock)
 
        write_unlock_bh(&csk->sk_callback_lock);
 
-       del_timer_sync(&psock->rx_msg_timer);
-       cancel_work_sync(&psock->rx_work);
-       cancel_delayed_work_sync(&psock->rx_delayed_work);
+       /* Call strp_done without sock lock */
+       release_sock(csk);
+       strp_done(&psock->strp);
+       lock_sock(csk);
 
        bpf_prog_put(psock->bpf_prog);
 
-       kfree_skb(psock->rx_skb_head);
-       psock->rx_skb_head = NULL;
-
        spin_lock_bh(&mux->lock);
 
        aggregate_psock_stats(&psock->stats, &mux->aggregate_psock_stats);
+       save_strp_stats(&psock->strp, &mux->aggregate_strp_stats);
 
        KCM_STATS_INCR(mux->stats.psock_unattach);
 
@@ -1876,6 +1565,8 @@ no_reserved:
                fput(csk->sk_socket->file);
                kmem_cache_free(kcm_psockp, psock);
        }
+
+       release_sock(csk);
 }
 
 static int kcm_unattach_ioctl(struct socket *sock, struct kcm_unattach *info)
@@ -1916,6 +1607,7 @@ static int kcm_unattach_ioctl(struct socket *sock, struct kcm_unattach *info)
 
                spin_unlock_bh(&mux->lock);
 
+               /* Lower socket lock should already be held */
                kcm_unattach(psock);
 
                err = 0;
@@ -2073,6 +1765,8 @@ static void release_mux(struct kcm_mux *mux)
        aggregate_mux_stats(&mux->stats, &knet->aggregate_mux_stats);
        aggregate_psock_stats(&mux->aggregate_psock_stats,
                              &knet->aggregate_psock_stats);
+       aggregate_strp_stats(&mux->aggregate_strp_stats,
+                            &knet->aggregate_strp_stats);
        list_del_rcu(&mux->kcm_mux_list);
        knet->count--;
        mutex_unlock(&knet->mutex);
@@ -2152,6 +1846,13 @@ static int kcm_release(struct socket *sock)
         * it will just return.
         */
        __skb_queue_purge(&sk->sk_write_queue);
+
+       /* Set tx_stopped. This is checked when psock is bound to a kcm and we
+        * get a writespace callback. This prevents further work being queued
+        * from the callback (unbinding the psock occurs after canceling work.
+        */
+       kcm->tx_stopped = 1;
+
        release_sock(sk);
 
        spin_lock_bh(&mux->lock);