netlink: update listeners directly when removing socket
[cascardo/linux.git] / net / netlink / af_netlink.c
index 0007b81..6a9fb7c 100644 (file)
@@ -114,14 +114,14 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);
 DEFINE_MUTEX(nl_sk_hash_lock);
 EXPORT_SYMBOL_GPL(nl_sk_hash_lock);
 
-static int lockdep_nl_sk_hash_is_held(void)
+#ifdef CONFIG_PROVE_LOCKING
+static int lockdep_nl_sk_hash_is_held(void *parent)
 {
-#ifdef CONFIG_LOCKDEP
        if (debug_locks)
                return lockdep_is_held(&nl_sk_hash_lock) || lockdep_is_held(&nl_table_lock);
-#endif
        return 1;
 }
+#endif
 
 static ATOMIC_NOTIFIER_HEAD(netlink_chain);
 
@@ -142,8 +142,7 @@ int netlink_add_tap(struct netlink_tap *nt)
        list_add_rcu(&nt->list, &netlink_tap_all);
        spin_unlock(&netlink_tap_lock);
 
-       if (nt->module)
-               __module_get(nt->module);
+       __module_get(nt->module);
 
        return 0;
 }
@@ -526,14 +525,14 @@ out:
        return err;
 }
 
-static void netlink_frame_flush_dcache(const struct nl_mmap_hdr *hdr)
+static void netlink_frame_flush_dcache(const struct nl_mmap_hdr *hdr, unsigned int nm_len)
 {
 #if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE == 1
        struct page *p_start, *p_end;
 
        /* First page is flushed through netlink_{get,set}_status */
        p_start = pgvec_to_page(hdr + PAGE_SIZE);
-       p_end   = pgvec_to_page((void *)hdr + NL_MMAP_HDRLEN + hdr->nm_len - 1);
+       p_end   = pgvec_to_page((void *)hdr + NL_MMAP_HDRLEN + nm_len - 1);
        while (p_start <= p_end) {
                flush_dcache_page(p_start);
                p_start++;
@@ -551,9 +550,9 @@ static enum nl_mmap_status netlink_get_status(const struct nl_mmap_hdr *hdr)
 static void netlink_set_status(struct nl_mmap_hdr *hdr,
                               enum nl_mmap_status status)
 {
+       smp_mb();
        hdr->nm_status = status;
        flush_dcache_page(pgvec_to_page(hdr));
-       smp_wmb();
 }
 
 static struct nl_mmap_hdr *
@@ -715,24 +714,16 @@ static int netlink_mmap_sendmsg(struct sock *sk, struct msghdr *msg,
        struct nl_mmap_hdr *hdr;
        struct sk_buff *skb;
        unsigned int maxlen;
-       bool excl = true;
        int err = 0, len = 0;
 
-       /* Netlink messages are validated by the receiver before processing.
-        * In order to avoid userspace changing the contents of the message
-        * after validation, the socket and the ring may only be used by a
-        * single process, otherwise we fall back to copying.
-        */
-       if (atomic_long_read(&sk->sk_socket->file->f_count) > 1 ||
-           atomic_read(&nlk->mapped) > 1)
-               excl = false;
-
        mutex_lock(&nlk->pg_vec_lock);
 
        ring   = &nlk->tx_ring;
        maxlen = ring->frame_size - NL_MMAP_HDRLEN;
 
        do {
+               unsigned int nm_len;
+
                hdr = netlink_current_frame(ring, NL_MMAP_STATUS_VALID);
                if (hdr == NULL) {
                        if (!(msg->msg_flags & MSG_DONTWAIT) &&
@@ -740,35 +731,23 @@ static int netlink_mmap_sendmsg(struct sock *sk, struct msghdr *msg,
                                schedule();
                        continue;
                }
-               if (hdr->nm_len > maxlen) {
+
+               nm_len = ACCESS_ONCE(hdr->nm_len);
+               if (nm_len > maxlen) {
                        err = -EINVAL;
                        goto out;
                }
 
-               netlink_frame_flush_dcache(hdr);
+               netlink_frame_flush_dcache(hdr, nm_len);
 
-               if (likely(dst_portid == 0 && dst_group == 0 && excl)) {
-                       skb = alloc_skb_head(GFP_KERNEL);
-                       if (skb == NULL) {
-                               err = -ENOBUFS;
-                               goto out;
-                       }
-                       sock_hold(sk);
-                       netlink_ring_setup_skb(skb, sk, ring, hdr);
-                       NETLINK_CB(skb).flags |= NETLINK_SKB_TX;
-                       __skb_put(skb, hdr->nm_len);
-                       netlink_set_status(hdr, NL_MMAP_STATUS_RESERVED);
-                       atomic_inc(&ring->pending);
-               } else {
-                       skb = alloc_skb(hdr->nm_len, GFP_KERNEL);
-                       if (skb == NULL) {
-                               err = -ENOBUFS;
-                               goto out;
-                       }
-                       __skb_put(skb, hdr->nm_len);
-                       memcpy(skb->data, (void *)hdr + NL_MMAP_HDRLEN, hdr->nm_len);
-                       netlink_set_status(hdr, NL_MMAP_STATUS_UNUSED);
+               skb = alloc_skb(nm_len, GFP_KERNEL);
+               if (skb == NULL) {
+                       err = -ENOBUFS;
+                       goto out;
                }
+               __skb_put(skb, nm_len);
+               memcpy(skb->data, (void *)hdr + NL_MMAP_HDRLEN, nm_len);
+               netlink_set_status(hdr, NL_MMAP_STATUS_UNUSED);
 
                netlink_increment_head(ring);
 
@@ -814,7 +793,7 @@ static void netlink_queue_mmaped_skb(struct sock *sk, struct sk_buff *skb)
        hdr->nm_pid     = NETLINK_CB(skb).creds.pid;
        hdr->nm_uid     = from_kuid(sk_user_ns(sk), NETLINK_CB(skb).creds.uid);
        hdr->nm_gid     = from_kgid(sk_user_ns(sk), NETLINK_CB(skb).creds.gid);
-       netlink_frame_flush_dcache(hdr);
+       netlink_frame_flush_dcache(hdr, hdr->nm_len);
        netlink_set_status(hdr, NL_MMAP_STATUS_VALID);
 
        NETLINK_CB(skb).flags |= NETLINK_SKB_DELIVERED;
@@ -1092,7 +1071,7 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
 
        nlk_sk(sk)->portid = portid;
        sock_hold(sk);
-       rhashtable_insert(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL);
+       rhashtable_insert(&table->hash, &nlk_sk(sk)->node);
        err = 0;
 err:
        mutex_unlock(&nl_sk_hash_lock);
@@ -1105,15 +1084,17 @@ static void netlink_remove(struct sock *sk)
 
        mutex_lock(&nl_sk_hash_lock);
        table = &nl_table[sk->sk_protocol];
-       if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL)) {
+       if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) {
                WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
                __sock_put(sk);
        }
        mutex_unlock(&nl_sk_hash_lock);
 
        netlink_table_grab();
-       if (nlk_sk(sk)->subscriptions)
+       if (nlk_sk(sk)->subscriptions) {
                __sk_del_bind_node(sk);
+               netlink_update_listeners(sk);
+       }
        netlink_table_ungrab();
 }
 
@@ -1247,8 +1228,8 @@ static int netlink_release(struct socket *sock)
 
        module_put(nlk->module);
 
-       netlink_table_grab();
        if (netlink_is_kernel(sk)) {
+               netlink_table_grab();
                BUG_ON(nl_table[sk->sk_protocol].registered == 0);
                if (--nl_table[sk->sk_protocol].registered == 0) {
                        struct listeners *old;
@@ -1262,10 +1243,8 @@ static int netlink_release(struct socket *sock)
                        nl_table[sk->sk_protocol].flags = 0;
                        nl_table[sk->sk_protocol].registered = 0;
                }
-       } else if (nlk->subscriptions) {
-               netlink_update_listeners(sk);
+               netlink_table_ungrab();
        }
-       netlink_table_ungrab();
 
        kfree(nlk->groups);
        nlk->groups = NULL;
@@ -1431,8 +1410,8 @@ static int netlink_realloc_groups(struct sock *sk)
        return err;
 }
 
-static void netlink_unbind(int group, long unsigned int groups,
-                          struct netlink_sock *nlk)
+static void netlink_undo_bind(int group, long unsigned int groups,
+                             struct netlink_sock *nlk)
 {
        int undo;
 
@@ -1482,7 +1461,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
                        err = nlk->netlink_bind(group);
                        if (!err)
                                continue;
-                       netlink_unbind(group, groups, nlk);
+                       netlink_undo_bind(group, groups, nlk);
                        return err;
                }
        }
@@ -1492,7 +1471,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
                        netlink_insert(sk, net, nladdr->nl_pid) :
                        netlink_autobind(sock);
                if (err) {
-                       netlink_unbind(nlk->ngroups, groups, nlk);
+                       netlink_undo_bind(nlk->ngroups, groups, nlk);
                        return err;
                }
        }
@@ -2306,7 +2285,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
        }
 
        if (netlink_tx_is_mmaped(sk) &&
-           msg->msg_iov->iov_base == NULL) {
+           msg->msg_iter.iov->iov_base == NULL) {
                err = netlink_mmap_sendmsg(sk, msg, dst_portid, dst_group,
                                           siocb);
                goto out;
@@ -2326,7 +2305,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
        NETLINK_CB(skb).flags   = netlink_skb_flags;
 
        err = -EFAULT;
-       if (memcpy_fromiovec(skb_put(skb, len), msg->msg_iov, len)) {
+       if (memcpy_from_msg(skb_put(skb, len), msg, len)) {
                kfree_skb(skb);
                goto out;
        }
@@ -2401,7 +2380,7 @@ static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock,
        }
 
        skb_reset_transport_header(data_skb);
-       err = skb_copy_datagram_iovec(data_skb, 0, msg->msg_iov, copied);
+       err = skb_copy_datagram_msg(data_skb, 0, msg, copied);
 
        if (msg->msg_name) {
                DECLARE_SOCKADDR(struct sockaddr_nl *, addr, msg->msg_name);
@@ -3130,11 +3109,13 @@ static int __init netlink_proto_init(void)
                .head_offset = offsetof(struct netlink_sock, node),
                .key_offset = offsetof(struct netlink_sock, portid),
                .key_len = sizeof(u32), /* portid */
-               .hashfn = arch_fast_hash,
+               .hashfn = jhash,
                .max_shift = 16, /* 64K */
                .grow_decision = rht_grow_above_75,
                .shrink_decision = rht_shrink_below_30,
+#ifdef CONFIG_PROVE_LOCKING
                .mutex_is_held = lockdep_nl_sk_hash_is_held,
+#endif
        };
 
        if (err != 0)