fib_trie: Make leaf and tnode more uniform
[cascardo/linux.git] / net / ipv4 / fib_trie.c
index d3dbb48..2b72207 100644 (file)
 
 typedef unsigned int t_key;
 
-#define T_TNODE 0
-#define T_LEAF  1
-#define NODE_TYPE_MASK 0x1UL
-#define NODE_TYPE(node) ((node)->parent & NODE_TYPE_MASK)
+#define IS_TNODE(n) ((n)->bits)
+#define IS_LEAF(n) (!(n)->bits)
 
-#define IS_TNODE(n) (!(n->parent & T_LEAF))
-#define IS_LEAF(n) (n->parent & T_LEAF)
+struct tnode {
+       t_key key;
+       unsigned char bits;             /* 2log(KEYLENGTH) bits needed */
+       unsigned char pos;              /* 2log(KEYLENGTH) bits needed */
+       struct tnode __rcu *parent;
+       union {
+               struct rcu_head rcu;
+               struct tnode *tnode_free;
+       };
+       unsigned int full_children;     /* KEYLENGTH bits needed */
+       unsigned int empty_children;    /* KEYLENGTH bits needed */
+       struct rt_trie_node __rcu *child[0];
+};
 
 struct rt_trie_node {
-       unsigned long parent;
        t_key key;
+       unsigned char bits;
+       unsigned char pos;
+       struct tnode __rcu *parent;
+       struct rcu_head rcu;
 };
 
 struct leaf {
-       unsigned long parent;
        t_key key;
-       struct hlist_head list;
+       unsigned char bits;
+       unsigned char pos;
+       struct tnode __rcu *parent;
        struct rcu_head rcu;
+       struct hlist_head list;
 };
 
 struct leaf_info {
@@ -115,20 +129,6 @@ struct leaf_info {
        struct rcu_head rcu;
 };
 
-struct tnode {
-       unsigned long parent;
-       t_key key;
-       unsigned char pos;              /* 2log(KEYLENGTH) bits needed */
-       unsigned char bits;             /* 2log(KEYLENGTH) bits needed */
-       unsigned int full_children;     /* KEYLENGTH bits needed */
-       unsigned int empty_children;    /* KEYLENGTH bits needed */
-       union {
-               struct rcu_head rcu;
-               struct tnode *tnode_free;
-       };
-       struct rt_trie_node __rcu *child[0];
-};
-
 #ifdef CONFIG_IP_FIB_TRIE_STATS
 struct trie_use_stats {
        unsigned int gets;
@@ -176,38 +176,27 @@ static const int sync_pages = 128;
 static struct kmem_cache *fn_alias_kmem __read_mostly;
 static struct kmem_cache *trie_leaf_kmem __read_mostly;
 
-/*
- * caller must hold RTNL
- */
-static inline struct tnode *node_parent(const struct rt_trie_node *node)
-{
-       unsigned long parent;
+/* caller must hold RTNL */
+#define node_parent(n) rtnl_dereference((n)->parent)
 
-       parent = rcu_dereference_index_check(node->parent, lockdep_rtnl_is_held());
+/* caller must hold RCU read lock or RTNL */
+#define node_parent_rcu(n) rcu_dereference_rtnl((n)->parent)
 
-       return (struct tnode *)(parent & ~NODE_TYPE_MASK);
-}
-
-/*
- * caller must hold RCU read lock or RTNL
- */
-static inline struct tnode *node_parent_rcu(const struct rt_trie_node *node)
+/* wrapper for rcu_assign_pointer */
+static inline void node_set_parent(struct rt_trie_node *node, struct tnode *ptr)
 {
-       unsigned long parent;
-
-       parent = rcu_dereference_index_check(node->parent, rcu_read_lock_held() ||
-                                                          lockdep_rtnl_is_held());
-
-       return (struct tnode *)(parent & ~NODE_TYPE_MASK);
+       if (node)
+               rcu_assign_pointer(node->parent, ptr);
 }
 
-/* Same as rcu_assign_pointer
- * but that macro() assumes that value is a pointer.
+#define NODE_INIT_PARENT(n, p) RCU_INIT_POINTER((n)->parent, p)
+
+/* This provides us with the number of children in this node, in the case of a
+ * leaf this will return 0 meaning none of the children are accessible.
  */
-static inline void node_set_parent(struct rt_trie_node *node, struct tnode *ptr)
+static inline int tnode_child_length(const struct tnode *tn)
 {
-       smp_wmb();
-       node->parent = (unsigned long)ptr | NODE_TYPE(node);
+       return (1ul << tn->bits) & ~(1ul);
 }
 
 /*
@@ -215,7 +204,7 @@ static inline void node_set_parent(struct rt_trie_node *node, struct tnode *ptr)
  */
 static inline struct rt_trie_node *tnode_get_child(const struct tnode *tn, unsigned int i)
 {
-       BUG_ON(i >= 1U << tn->bits);
+       BUG_ON(i >= tnode_child_length(tn));
 
        return rtnl_dereference(tn->child[i]);
 }
@@ -225,16 +214,11 @@ static inline struct rt_trie_node *tnode_get_child(const struct tnode *tn, unsig
  */
 static inline struct rt_trie_node *tnode_get_child_rcu(const struct tnode *tn, unsigned int i)
 {
-       BUG_ON(i >= 1U << tn->bits);
+       BUG_ON(i >= tnode_child_length(tn));
 
        return rcu_dereference_rtnl(tn->child[i]);
 }
 
-static inline int tnode_child_length(const struct tnode *tn)
-{
-       return 1 << tn->bits;
-}
-
 static inline t_key mask_pfx(t_key k, unsigned int l)
 {
        return (l == 0) ? 0 : k >> (KEYLENGTH-l) << (KEYLENGTH-l);
@@ -336,11 +320,6 @@ static inline int tkey_mismatch(t_key a, int offset, t_key b)
 
 */
 
-static inline void check_tnode(const struct tnode *tn)
-{
-       WARN_ON(tn && tn->pos+tn->bits > 32);
-}
-
 static const int halve_threshold = 25;
 static const int inflate_threshold = 50;
 static const int halve_threshold_root = 15;
@@ -426,11 +405,20 @@ static void tnode_free_flush(void)
        }
 }
 
-static struct leaf *leaf_new(void)
+static struct leaf *leaf_new(t_key key)
 {
        struct leaf *l = kmem_cache_alloc(trie_leaf_kmem, GFP_KERNEL);
        if (l) {
-               l->parent = T_LEAF;
+               l->parent = NULL;
+               /* set key and pos to reflect full key value
+                * any trailing zeros in the key should be ignored
+                * as the nodes are searched
+                */
+               l->key = key;
+               l->pos = KEYLENGTH;
+               /* set bits to 0 indicating we are not a tnode */
+               l->bits = 0;
+
                INIT_HLIST_HEAD(&l->list);
        }
        return l;
@@ -451,12 +439,16 @@ static struct tnode *tnode_new(t_key key, int pos, int bits)
 {
        size_t sz = sizeof(struct tnode) + (sizeof(struct rt_trie_node *) << bits);
        struct tnode *tn = tnode_alloc(sz);
+       unsigned int shift = pos + bits;
+
+       /* verify bits and pos their msb bits clear and values are valid */
+       BUG_ON(!bits || (shift > KEYLENGTH));
 
        if (tn) {
-               tn->parent = T_TNODE;
+               tn->parent = NULL;
                tn->pos = pos;
                tn->bits = bits;
-               tn->key = key;
+               tn->key = mask_pfx(key, pos);
                tn->full_children = 0;
                tn->empty_children = 1<<bits;
        }
@@ -473,10 +465,7 @@ static struct tnode *tnode_new(t_key key, int pos, int bits)
 
 static inline int tnode_full(const struct tnode *tn, const struct rt_trie_node *n)
 {
-       if (n == NULL || IS_LEAF(n))
-               return 0;
-
-       return ((struct tnode *) n)->pos == tn->pos + tn->bits;
+       return n && IS_TNODE(n) && (n->pos == (tn->pos + tn->bits));
 }
 
 static inline void put_child(struct tnode *tn, int i,
@@ -514,8 +503,7 @@ static void tnode_put_child_reorg(struct tnode *tn, int i, struct rt_trie_node *
        else if (!wasfull && isfull)
                tn->full_children++;
 
-       if (n)
-               node_set_parent(n, tn);
+       node_set_parent(n, tn);
 
        rcu_assign_pointer(tn->child[i], n);
 }
@@ -523,7 +511,7 @@ static void tnode_put_child_reorg(struct tnode *tn, int i, struct rt_trie_node *
 #define MAX_WORK 10
 static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
 {
-       int i;
+       struct rt_trie_node *n = NULL;
        struct tnode *old_tn;
        int inflate_threshold_use;
        int halve_threshold_use;
@@ -536,12 +524,11 @@ static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
                 tn, inflate_threshold, halve_threshold);
 
        /* No children */
-       if (tn->empty_children == tnode_child_length(tn)) {
-               tnode_free_safe(tn);
-               return NULL;
-       }
+       if (tn->empty_children > (tnode_child_length(tn) - 1))
+               goto no_children;
+
        /* One child */
-       if (tn->empty_children == tnode_child_length(tn) - 1)
+       if (tn->empty_children == (tnode_child_length(tn) - 1))
                goto one_child;
        /*
         * Double as long as the resulting node has a number of
@@ -607,11 +594,9 @@ static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
         *
         */
 
-       check_tnode(tn);
-
        /* Keep root node larger  */
 
-       if (!node_parent((struct rt_trie_node *)tn)) {
+       if (!node_parent(tn)) {
                inflate_threshold_use = inflate_threshold_root;
                halve_threshold_use = halve_threshold_root;
        } else {
@@ -637,8 +622,6 @@ static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
                }
        }
 
-       check_tnode(tn);
-
        /* Return if at least one inflate is run */
        if (max_work != MAX_WORK)
                return (struct rt_trie_node *) tn;
@@ -666,21 +649,16 @@ static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
 
 
        /* Only one child remains */
-       if (tn->empty_children == tnode_child_length(tn) - 1) {
+       if (tn->empty_children == (tnode_child_length(tn) - 1)) {
+               unsigned long i;
 one_child:
-               for (i = 0; i < tnode_child_length(tn); i++) {
-                       struct rt_trie_node *n;
-
-                       n = rtnl_dereference(tn->child[i]);
-                       if (!n)
-                               continue;
-
-                       /* compress one level */
-
-                       node_set_parent(n, NULL);
-                       tnode_free_safe(tn);
-                       return n;
-               }
+               for (i = tnode_child_length(tn); !n && i;)
+                       n = tnode_get_child(tn, --i);
+no_children:
+               /* compress one level */
+               node_set_parent(n, NULL);
+               tnode_free_safe(tn);
+               return n;
        }
        return (struct rt_trie_node *) tn;
 }
@@ -760,8 +738,7 @@ static struct tnode *inflate(struct trie *t, struct tnode *tn)
 
                /* A leaf or an internal node with skipped bits */
 
-               if (IS_LEAF(node) || ((struct tnode *) node)->pos >
-                  tn->pos + tn->bits - 1) {
+               if (IS_LEAF(node) || (node->pos > (tn->pos + tn->bits - 1))) {
                        put_child(tn,
                                tkey_extract_bits(node->key, oldtnode->pos, oldtnode->bits + 1),
                                node);
@@ -958,11 +935,9 @@ fib_find_node(struct trie *t, u32 key)
        pos = 0;
        n = rcu_dereference_rtnl(t->trie);
 
-       while (n != NULL &&  NODE_TYPE(n) == T_TNODE) {
+       while (n && IS_TNODE(n)) {
                tn = (struct tnode *) n;
 
-               check_tnode(tn);
-
                if (tkey_sub_equals(tn->key, pos, tn->pos-pos, key)) {
                        pos = tn->pos + tn->bits;
                        n = tnode_get_child_rcu(tn,
@@ -988,7 +963,7 @@ static void trie_rebalance(struct trie *t, struct tnode *tn)
 
        key = tn->key;
 
-       while (tn != NULL && (tp = node_parent((struct rt_trie_node *)tn)) != NULL) {
+       while (tn != NULL && (tp = node_parent(tn)) != NULL) {
                cindex = tkey_extract_bits(key, tp->pos, tp->bits);
                wasfull = tnode_full(tp, tnode_get_child(tp, cindex));
                tn = (struct tnode *)resize(t, tn);
@@ -996,7 +971,7 @@ static void trie_rebalance(struct trie *t, struct tnode *tn)
                tnode_put_child_reorg(tp, cindex,
                                      (struct rt_trie_node *)tn, wasfull);
 
-               tp = node_parent((struct rt_trie_node *) tn);
+               tp = node_parent(tn);
                if (!tp)
                        rcu_assign_pointer(t->trie, (struct rt_trie_node *)tn);
 
@@ -1048,11 +1023,9 @@ static struct list_head *fib_insert_node(struct trie *t, u32 key, int plen)
         * If it doesn't, we need to replace it with a T_TNODE.
         */
 
-       while (n != NULL &&  NODE_TYPE(n) == T_TNODE) {
+       while (n && IS_TNODE(n)) {
                tn = (struct tnode *) n;
 
-               check_tnode(tn);
-
                if (tkey_sub_equals(tn->key, pos, tn->pos-pos, key)) {
                        tp = tn;
                        pos = tn->pos + tn->bits;
@@ -1087,12 +1060,11 @@ static struct list_head *fib_insert_node(struct trie *t, u32 key, int plen)
                insert_leaf_info(&l->list, li);
                goto done;
        }
-       l = leaf_new();
+       l = leaf_new(key);
 
        if (!l)
                return NULL;
 
-       l->key = key;
        li = leaf_info_new(plen);
 
        if (!li) {
@@ -1569,7 +1541,7 @@ backtrace:
                if (chopped_off <= pn->bits) {
                        cindex &= ~(1 << (chopped_off-1));
                } else {
-                       struct tnode *parent = node_parent_rcu((struct rt_trie_node *) pn);
+                       struct tnode *parent = node_parent_rcu(pn);
                        if (!parent)
                                goto failed;
 
@@ -1597,7 +1569,7 @@ EXPORT_SYMBOL_GPL(fib_table_lookup);
  */
 static void trie_leaf_remove(struct trie *t, struct leaf *l)
 {
-       struct tnode *tp = node_parent((struct rt_trie_node *) l);
+       struct tnode *tp = node_parent(l);
 
        pr_debug("entering trie_leaf_remove(%p)\n", l);
 
@@ -2374,7 +2346,7 @@ static int fib_trie_seq_show(struct seq_file *seq, void *v)
 
        if (IS_TNODE(n)) {
                struct tnode *tn = (struct tnode *) n;
-               __be32 prf = htonl(mask_pfx(tn->key, tn->pos));
+               __be32 prf = htonl(tn->key);
 
                seq_indent(seq, iter->depth-1);
                seq_printf(seq, "  +-- %pI4/%d %d %d %d\n",