Merge tag 'docs-for-linus' of git://git.lwn.net/linux
[cascardo/linux.git] / net / rds / bind.c
index bc6b93e..b22ea95 100644 (file)
 #include <linux/ratelimit.h>
 #include "rds.h"
 
-struct bind_bucket {
-       rwlock_t                lock;
-       struct hlist_head       head;
+static struct rhashtable bind_hash_table;
+
+static struct rhashtable_params ht_parms = {
+       .nelem_hint = 768,
+       .key_len = sizeof(u64),
+       .key_offset = offsetof(struct rds_sock, rs_bound_key),
+       .head_offset = offsetof(struct rds_sock, rs_bound_node),
+       .max_size = 16384,
+       .min_size = 1024,
 };
 
-#define BIND_HASH_SIZE 1024
-static struct bind_bucket bind_hash_table[BIND_HASH_SIZE];
-
-static struct bind_bucket *hash_to_bucket(__be32 addr, __be16 port)
-{
-       return bind_hash_table + (jhash_2words((u32)addr, (u32)port, 0) &
-                                 (BIND_HASH_SIZE - 1));
-}
-
-/* must hold either read or write lock (write lock for insert != NULL) */
-static struct rds_sock *rds_bind_lookup(struct bind_bucket *bucket,
-                                       __be32 addr, __be16 port,
-                                       struct rds_sock *insert)
-{
-       struct rds_sock *rs;
-       struct hlist_head *head = &bucket->head;
-       u64 cmp;
-       u64 needle = ((u64)be32_to_cpu(addr) << 32) | be16_to_cpu(port);
-
-       hlist_for_each_entry(rs, head, rs_bound_node) {
-               cmp = ((u64)be32_to_cpu(rs->rs_bound_addr) << 32) |
-                     be16_to_cpu(rs->rs_bound_port);
-
-               if (cmp == needle) {
-                       rds_sock_addref(rs);
-                       return rs;
-               }
-       }
-
-       if (insert) {
-               /*
-                * make sure our addr and port are set before
-                * we are added to the list.
-                */
-               insert->rs_bound_addr = addr;
-               insert->rs_bound_port = port;
-               rds_sock_addref(insert);
-
-               hlist_add_head(&insert->rs_bound_node, head);
-       }
-       return NULL;
-}
-
 /*
  * Return the rds_sock bound at the given local address.
  *
@@ -94,18 +57,14 @@ static struct rds_sock *rds_bind_lookup(struct bind_bucket *bucket,
  */
 struct rds_sock *rds_find_bound(__be32 addr, __be16 port)
 {
+       u64 key = ((u64)addr << 32) | port;
        struct rds_sock *rs;
-       unsigned long flags;
-       struct bind_bucket *bucket = hash_to_bucket(addr, port);
-
-       read_lock_irqsave(&bucket->lock, flags);
-       rs = rds_bind_lookup(bucket, addr, port, NULL);
-       read_unlock_irqrestore(&bucket->lock, flags);
 
-       if (rs && sock_flag(rds_rs_to_sk(rs), SOCK_DEAD)) {
-               rds_sock_put(rs);
+       rs = rhashtable_lookup_fast(&bind_hash_table, &key, ht_parms);
+       if (rs && !sock_flag(rds_rs_to_sk(rs), SOCK_DEAD))
+               rds_sock_addref(rs);
+       else
                rs = NULL;
-       }
 
        rdsdebug("returning rs %p for %pI4:%u\n", rs, &addr,
                ntohs(port));
@@ -116,10 +75,9 @@ struct rds_sock *rds_find_bound(__be32 addr, __be16 port)
 /* returns -ve errno or +ve port */
 static int rds_add_bound(struct rds_sock *rs, __be32 addr, __be16 *port)
 {
-       unsigned long flags;
        int ret = -EADDRINUSE;
        u16 rover, last;
-       struct bind_bucket *bucket;
+       u64 key;
 
        if (*port != 0) {
                rover = be16_to_cpu(*port);
@@ -130,22 +88,29 @@ static int rds_add_bound(struct rds_sock *rs, __be32 addr, __be16 *port)
        }
 
        do {
-               struct rds_sock *rrs;
                if (rover == 0)
                        rover++;
 
-               bucket = hash_to_bucket(addr, cpu_to_be16(rover));
-               write_lock_irqsave(&bucket->lock, flags);
-               rrs = rds_bind_lookup(bucket, addr, cpu_to_be16(rover), rs);
-               write_unlock_irqrestore(&bucket->lock, flags);
-               if (!rrs) {
+               key = ((u64)addr << 32) | cpu_to_be16(rover);
+               if (rhashtable_lookup_fast(&bind_hash_table, &key, ht_parms))
+                       continue;
+
+               rs->rs_bound_key = key;
+               rs->rs_bound_addr = addr;
+               rs->rs_bound_port = cpu_to_be16(rover);
+               rs->rs_bound_node.next = NULL;
+               rds_sock_addref(rs);
+               if (!rhashtable_insert_fast(&bind_hash_table,
+                                           &rs->rs_bound_node, ht_parms)) {
                        *port = rs->rs_bound_port;
                        ret = 0;
                        rdsdebug("rs %p binding to %pI4:%d\n",
                          rs, &addr, (int)ntohs(*port));
                        break;
                } else {
-                       rds_sock_put(rrs);
+                       rds_sock_put(rs);
+                       ret = -ENOMEM;
+                       break;
                }
        } while (rover++ != last);
 
@@ -154,23 +119,17 @@ static int rds_add_bound(struct rds_sock *rs, __be32 addr, __be16 *port)
 
 void rds_remove_bound(struct rds_sock *rs)
 {
-       unsigned long flags;
-       struct bind_bucket *bucket =
-               hash_to_bucket(rs->rs_bound_addr, rs->rs_bound_port);
-
-       write_lock_irqsave(&bucket->lock, flags);
 
-       if (rs->rs_bound_addr) {
-               rdsdebug("rs %p unbinding from %pI4:%d\n",
-                 rs, &rs->rs_bound_addr,
-                 ntohs(rs->rs_bound_port));
+       if (!rs->rs_bound_addr)
+               return;
 
-               hlist_del_init(&rs->rs_bound_node);
-               rds_sock_put(rs);
-               rs->rs_bound_addr = 0;
-       }
+       rdsdebug("rs %p unbinding from %pI4:%d\n",
+                rs, &rs->rs_bound_addr,
+                ntohs(rs->rs_bound_port));
 
-       write_unlock_irqrestore(&bucket->lock, flags);
+       rhashtable_remove_fast(&bind_hash_table, &rs->rs_bound_node, ht_parms);
+       rds_sock_put(rs);
+       rs->rs_bound_addr = 0;
 }
 
 int rds_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
@@ -196,7 +155,14 @@ int rds_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
                goto out;
 
        if (rs->rs_transport) { /* previously bound */
-               ret = 0;
+               trans = rs->rs_transport;
+               if (trans->laddr_check(sock_net(sock->sk),
+                                      sin->sin_addr.s_addr) != 0) {
+                       ret = -ENOPROTOOPT;
+                       rds_remove_bound(rs);
+               } else {
+                       ret = 0;
+               }
                goto out;
        }
        trans = rds_trans_get_preferred(sock_net(sock->sk),
@@ -217,10 +183,12 @@ out:
        return ret;
 }
 
-void rds_bind_lock_init(void)
+void rds_bind_lock_destroy(void)
 {
-       int i;
+       rhashtable_destroy(&bind_hash_table);
+}
 
-       for (i = 0; i < BIND_HASH_SIZE; i++)
-               rwlock_init(&bind_hash_table[i].lock);
+int rds_bind_lock_init(void)
+{
+       return rhashtable_init(&bind_hash_table, &ht_parms);
 }