Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[cascardo/linux.git] / net / netlink / af_netlink.c
index 7a186e7..d479b32 100644 (file)
@@ -96,6 +96,14 @@ static DECLARE_WAIT_QUEUE_HEAD(nl_table_wait);
 static int netlink_dump(struct sock *sk);
 static void netlink_skb_destructor(struct sk_buff *skb);
 
+/* nl_table locking explained:
+ * Lookup and traversal are protected with nl_sk_hash_lock or nl_table_lock
+ * combined with an RCU read-side lock. Insertion and removal are protected
+ * with nl_sk_hash_lock while using RCU list modification primitives and may
+ * run in parallel to nl_table_lock protected lookups. Destruction of the
+ * Netlink socket may only occur *after* nl_table_lock has been acquired
+ * either during or after the socket has been removed from the list.
+ */
 DEFINE_RWLOCK(nl_table_lock);
 EXPORT_SYMBOL_GPL(nl_table_lock);
 static atomic_t nl_table_users = ATOMIC_INIT(0);
@@ -106,14 +114,14 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);
 DEFINE_MUTEX(nl_sk_hash_lock);
 EXPORT_SYMBOL_GPL(nl_sk_hash_lock);
 
-static int lockdep_nl_sk_hash_is_held(void)
+#ifdef CONFIG_PROVE_LOCKING
+static int lockdep_nl_sk_hash_is_held(void *parent)
 {
-#ifdef CONFIG_LOCKDEP
-       return (debug_locks) ? lockdep_is_held(&nl_sk_hash_lock) : 1;
-#else
+       if (debug_locks)
+               return lockdep_is_held(&nl_sk_hash_lock) || lockdep_is_held(&nl_table_lock);
        return 1;
-#endif
 }
+#endif
 
 static ATOMIC_NOTIFIER_HEAD(netlink_chain);
 
@@ -1028,11 +1036,13 @@ static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
        struct netlink_table *table = &nl_table[protocol];
        struct sock *sk;
 
+       read_lock(&nl_table_lock);
        rcu_read_lock();
        sk = __netlink_lookup(table, portid, net);
        if (sk)
                sock_hold(sk);
        rcu_read_unlock();
+       read_unlock(&nl_table_lock);
 
        return sk;
 }
@@ -1082,7 +1092,7 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
 
        nlk_sk(sk)->portid = portid;
        sock_hold(sk);
-       rhashtable_insert(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL);
+       rhashtable_insert(&table->hash, &nlk_sk(sk)->node);
        err = 0;
 err:
        mutex_unlock(&nl_sk_hash_lock);
@@ -1095,7 +1105,7 @@ static void netlink_remove(struct sock *sk)
 
        mutex_lock(&nl_sk_hash_lock);
        table = &nl_table[sk->sk_protocol];
-       if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL)) {
+       if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) {
                WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
                __sock_put(sk);
        }
@@ -1257,9 +1267,6 @@ static int netlink_release(struct socket *sock)
        }
        netlink_table_ungrab();
 
-       /* Wait for readers to complete */
-       synchronize_net();
-
        kfree(nlk->groups);
        nlk->groups = NULL;
 
@@ -1281,6 +1288,7 @@ static int netlink_autobind(struct socket *sock)
 
 retry:
        cond_resched();
+       netlink_table_grab();
        rcu_read_lock();
        if (__netlink_lookup(table, portid, net)) {
                /* Bind collision, search negative portid values. */
@@ -1288,9 +1296,11 @@ retry:
                if (rover > -4097)
                        rover = -4097;
                rcu_read_unlock();
+               netlink_table_ungrab();
                goto retry;
        }
        rcu_read_unlock();
+       netlink_table_ungrab();
 
        err = netlink_insert(sk, net, portid);
        if (err == -EADDRINUSE)
@@ -1430,7 +1440,7 @@ static void netlink_unbind(int group, long unsigned int groups,
                return;
 
        for (undo = 0; undo < group; undo++)
-               if (test_bit(group, &groups))
+               if (test_bit(undo, &groups))
                        nlk->netlink_unbind(undo);
 }
 
@@ -1482,7 +1492,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
                        netlink_insert(sk, net, nladdr->nl_pid) :
                        netlink_autobind(sock);
                if (err) {
-                       netlink_unbind(nlk->ngroups - 1, groups, nlk);
+                       netlink_unbind(nlk->ngroups, groups, nlk);
                        return err;
                }
        }
@@ -2391,7 +2401,7 @@ static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock,
        }
 
        skb_reset_transport_header(data_skb);
-       err = skb_copy_datagram_iovec(data_skb, 0, msg->msg_iov, copied);
+       err = skb_copy_datagram_msg(data_skb, 0, msg, copied);
 
        if (msg->msg_name) {
                DECLARE_SOCKADDR(struct sockaddr_nl *, addr, msg->msg_name);
@@ -2499,6 +2509,7 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module,
                nl_table[unit].module = module;
                if (cfg) {
                        nl_table[unit].bind = cfg->bind;
+                       nl_table[unit].unbind = cfg->unbind;
                        nl_table[unit].flags = cfg->flags;
                        if (cfg->compare)
                                nl_table[unit].compare = cfg->compare;
@@ -2921,14 +2932,16 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
 }
 
 static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
-       __acquires(RCU)
+       __acquires(nl_table_lock) __acquires(RCU)
 {
+       read_lock(&nl_table_lock);
        rcu_read_lock();
        return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;
 }
 
 static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 {
+       struct rhashtable *ht;
        struct netlink_sock *nlk;
        struct nl_seq_iter *iter;
        struct net *net;
@@ -2943,19 +2956,19 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
        iter = seq->private;
        nlk = v;
 
-       rht_for_each_entry_rcu(nlk, nlk->node.next, node)
+       i = iter->link;
+       ht = &nl_table[i].hash;
+       rht_for_each_entry(nlk, nlk->node.next, ht, node)
                if (net_eq(sock_net((struct sock *)nlk), net))
                        return nlk;
 
-       i = iter->link;
        j = iter->hash_idx + 1;
 
        do {
-               struct rhashtable *ht = &nl_table[i].hash;
                const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
 
                for (; j < tbl->size; j++) {
-                       rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
+                       rht_for_each_entry(nlk, tbl->buckets[j], ht, node) {
                                if (net_eq(sock_net((struct sock *)nlk), net)) {
                                        iter->link = i;
                                        iter->hash_idx = j;
@@ -2971,9 +2984,10 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 }
 
 static void netlink_seq_stop(struct seq_file *seq, void *v)
-       __releases(RCU)
+       __releases(RCU) __releases(nl_table_lock)
 {
        rcu_read_unlock();
+       read_unlock(&nl_table_lock);
 }
 
 
@@ -3120,7 +3134,9 @@ static int __init netlink_proto_init(void)
                .max_shift = 16, /* 64K */
                .grow_decision = rht_grow_above_75,
                .shrink_decision = rht_shrink_below_30,
+#ifdef CONFIG_PROVE_LOCKING
                .mutex_is_held = lockdep_nl_sk_hash_is_held,
+#endif
        };
 
        if (err != 0)