Merge tag 'gpio-v3.17-1' of git://git.kernel.org/pub/scm/linux/kernel/git/linusw...
[cascardo/linux.git] / net / netlink / diag.c
index 1af2962..de8c74a 100644 (file)
@@ -4,6 +4,7 @@
 #include <linux/netlink.h>
 #include <linux/sock_diag.h>
 #include <linux/netlink_diag.h>
+#include <linux/rhashtable.h>
 
 #include "af_netlink.h"
 
@@ -101,16 +102,20 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
                                int protocol, int s_num)
 {
        struct netlink_table *tbl = &nl_table[protocol];
-       struct nl_portid_hash *hash = &tbl->hash;
+       struct rhashtable *ht = &tbl->hash;
+       const struct bucket_table *htbl = rht_dereference(ht->tbl, ht);
        struct net *net = sock_net(skb->sk);
        struct netlink_diag_req *req;
+       struct netlink_sock *nlsk;
        struct sock *sk;
        int ret = 0, num = 0, i;
 
        req = nlmsg_data(cb->nlh);
 
-       for (i = 0; i <= hash->mask; i++) {
-               sk_for_each(sk, &hash->table[i]) {
+       for (i = 0; i < htbl->size; i++) {
+               rht_for_each_entry(nlsk, htbl->buckets[i], ht, node) {
+                       sk = (struct sock *)nlsk;
+
                        if (!net_eq(sock_net(sk), net))
                                continue;
                        if (num < s_num) {
@@ -165,6 +170,7 @@ static int netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
 
        req = nlmsg_data(cb->nlh);
 
+       mutex_lock(&nl_sk_hash_lock);
        read_lock(&nl_table_lock);
 
        if (req->sdiag_protocol == NDIAG_PROTO_ALL) {
@@ -178,6 +184,7 @@ static int netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
        } else {
                if (req->sdiag_protocol >= MAX_LINKS) {
                        read_unlock(&nl_table_lock);
+                       mutex_unlock(&nl_sk_hash_lock);
                        return -ENOENT;
                }
 
@@ -185,6 +192,7 @@ static int netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
        }
 
        read_unlock(&nl_table_lock);
+       mutex_unlock(&nl_sk_hash_lock);
 
        return skb->len;
 }