sctp: label accepted/peeled off sockets
[cascardo/linux.git] / net / sctp / socket.c
index 897c01c..529ed35 100644 (file)
@@ -972,7 +972,7 @@ static int sctp_setsockopt_bindx(struct sock *sk,
                return -EFAULT;
 
        /* Alloc space for the address array in kernel memory.  */
-       kaddrs = kmalloc(addrs_size, GFP_KERNEL);
+       kaddrs = kmalloc(addrs_size, GFP_USER | __GFP_NOWARN);
        if (unlikely(!kaddrs))
                return -ENOMEM;
 
@@ -1301,8 +1301,9 @@ static int __sctp_setsockopt_connectx(struct sock *sk,
                                      int addrs_size,
                                      sctp_assoc_t *assoc_id)
 {
-       int err = 0;
        struct sockaddr *kaddrs;
+       gfp_t gfp = GFP_KERNEL;
+       int err = 0;
 
        pr_debug("%s: sk:%p addrs:%p addrs_size:%d\n",
                 __func__, sk, addrs, addrs_size);
@@ -1315,7 +1316,9 @@ static int __sctp_setsockopt_connectx(struct sock *sk,
                return -EFAULT;
 
        /* Alloc space for the address array in kernel memory.  */
-       kaddrs = kmalloc(addrs_size, GFP_KERNEL);
+       if (sk->sk_socket->file)
+               gfp = GFP_USER | __GFP_NOWARN;
+       kaddrs = kmalloc(addrs_size, gfp);
        if (unlikely(!kaddrs))
                return -ENOMEM;
 
@@ -1952,8 +1955,6 @@ static int sctp_sendmsg(struct sock *sk, struct msghdr *msg, size_t msg_len)
 
        /* Now send the (possibly) fragmented message. */
        list_for_each_entry(chunk, &datamsg->chunks, frag_list) {
-               sctp_chunk_hold(chunk);
-
                /* Do accounting for the write space.  */
                sctp_set_owner_w(chunk);
 
@@ -1966,15 +1967,13 @@ static int sctp_sendmsg(struct sock *sk, struct msghdr *msg, size_t msg_len)
         * breaks.
         */
        err = sctp_primitive_SEND(net, asoc, datamsg);
+       sctp_datamsg_put(datamsg);
        /* Did the lower layer accept the chunk? */
-       if (err) {
-               sctp_datamsg_free(datamsg);
+       if (err)
                goto out_free;
-       }
 
        pr_debug("%s: we sent primitively\n", __func__);
 
-       sctp_datamsg_put(datamsg);
        err = msg_len;
 
        if (unlikely(wait_connect)) {
@@ -4928,7 +4927,7 @@ static int sctp_getsockopt_local_addrs(struct sock *sk, int len,
        to = optval + offsetof(struct sctp_getaddrs, addrs);
        space_left = len - offsetof(struct sctp_getaddrs, addrs);
 
-       addrs = kmalloc(space_left, GFP_KERNEL);
+       addrs = kmalloc(space_left, GFP_USER | __GFP_NOWARN);
        if (!addrs)
                return -ENOMEM;
 
@@ -5777,7 +5776,7 @@ static int sctp_getsockopt_assoc_ids(struct sock *sk, int len,
 
        len = sizeof(struct sctp_assoc_ids) + sizeof(sctp_assoc_t) * num;
 
-       ids = kmalloc(len, GFP_KERNEL);
+       ids = kmalloc(len, GFP_USER | __GFP_NOWARN);
        if (unlikely(!ids))
                return -ENOMEM;
 
@@ -6458,7 +6457,7 @@ unsigned int sctp_poll(struct file *file, struct socket *sock, poll_table *wait)
        if (sctp_writeable(sk)) {
                mask |= POLLOUT | POLLWRNORM;
        } else {
-               set_bit(SOCK_ASYNC_NOSPACE, &sk->sk_socket->flags);
+               sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk);
                /*
                 * Since the socket is not locked, the buffer
                 * might be made available after the writeable check and
@@ -6801,26 +6800,30 @@ no_packet:
 static void __sctp_write_space(struct sctp_association *asoc)
 {
        struct sock *sk = asoc->base.sk;
-       struct socket *sock = sk->sk_socket;
 
-       if ((sctp_wspace(asoc) > 0) && sock) {
-               if (waitqueue_active(&asoc->wait))
-                       wake_up_interruptible(&asoc->wait);
+       if (sctp_wspace(asoc) <= 0)
+               return;
+
+       if (waitqueue_active(&asoc->wait))
+               wake_up_interruptible(&asoc->wait);
 
-               if (sctp_writeable(sk)) {
-                       wait_queue_head_t *wq = sk_sleep(sk);
+       if (sctp_writeable(sk)) {
+               struct socket_wq *wq;
 
-                       if (wq && waitqueue_active(wq))
-                               wake_up_interruptible(wq);
+               rcu_read_lock();
+               wq = rcu_dereference(sk->sk_wq);
+               if (wq) {
+                       if (waitqueue_active(&wq->wait))
+                               wake_up_interruptible(&wq->wait);
 
                        /* Note that we try to include the Async I/O support
                         * here by modeling from the current TCP/UDP code.
                         * We have not tested with it yet.
                         */
                        if (!(sk->sk_shutdown & SEND_SHUTDOWN))
-                               sock_wake_async(sock,
-                                               SOCK_WAKE_SPACE, POLL_OUT);
+                               sock_wake_async(wq, SOCK_WAKE_SPACE, POLL_OUT);
                }
+               rcu_read_unlock();
        }
 }
 
@@ -7163,6 +7166,7 @@ void sctp_copy_sock(struct sock *newsk, struct sock *sk,
        newsk->sk_type = sk->sk_type;
        newsk->sk_bound_dev_if = sk->sk_bound_dev_if;
        newsk->sk_flags = sk->sk_flags;
+       newsk->sk_tsflags = sk->sk_tsflags;
        newsk->sk_no_check_tx = sk->sk_no_check_tx;
        newsk->sk_no_check_rx = sk->sk_no_check_rx;
        newsk->sk_reuse = sk->sk_reuse;
@@ -7195,6 +7199,11 @@ void sctp_copy_sock(struct sock *newsk, struct sock *sk,
        newinet->mc_ttl = 1;
        newinet->mc_index = 0;
        newinet->mc_list = NULL;
+
+       if (newsk->sk_flags & SK_FLAGS_TIMESTAMP)
+               net_enable_timestamp();
+
+       security_sk_clone(sk, newsk);
 }
 
 static inline void sctp_copy_descendant(struct sock *sk_to,
@@ -7375,6 +7384,13 @@ struct proto sctp_prot = {
 
 #if IS_ENABLED(CONFIG_IPV6)
 
+#include <net/transp_v6.h>
+static void sctp_v6_destroy_sock(struct sock *sk)
+{
+       sctp_destroy_sock(sk);
+       inet6_destroy_sock(sk);
+}
+
 struct proto sctpv6_prot = {
        .name           = "SCTPv6",
        .owner          = THIS_MODULE,
@@ -7384,7 +7400,7 @@ struct proto sctpv6_prot = {
        .accept         = sctp_accept,
        .ioctl          = sctp_ioctl,
        .init           = sctp_init_sock,
-       .destroy        = sctp_destroy_sock,
+       .destroy        = sctp_v6_destroy_sock,
        .shutdown       = sctp_shutdown,
        .setsockopt     = sctp_setsockopt,
        .getsockopt     = sctp_getsockopt,