fib_trie: Optimize fib_table_lookup to avoid wasting time on loops/variables
[cascardo/linux.git] / net / ipv4 / fib_trie.c
index 18bcaf2..3fe4dd9 100644 (file)
 
 typedef unsigned int t_key;
 
-#define T_TNODE 0
-#define T_LEAF  1
-#define NODE_TYPE_MASK 0x1UL
-#define NODE_TYPE(node) ((node)->parent & NODE_TYPE_MASK)
+#define IS_TNODE(n) ((n)->bits)
+#define IS_LEAF(n) (!(n)->bits)
 
-#define IS_TNODE(n) (!(n->parent & T_LEAF))
-#define IS_LEAF(n) (n->parent & T_LEAF)
+#define get_shift(_kv) (KEYLENGTH - (_kv)->pos - (_kv)->bits)
+#define get_index(_key, _kv) (((_key) ^ (_kv)->key) >> get_shift(_kv))
 
-struct rt_trie_node {
-       unsigned long parent;
-       t_key key;
-};
-
-struct leaf {
-       unsigned long parent;
+struct tnode {
        t_key key;
-       struct hlist_head list;
+       unsigned char bits;             /* 2log(KEYLENGTH) bits needed */
+       unsigned char pos;              /* 2log(KEYLENGTH) bits needed */
+       struct tnode __rcu *parent;
        struct rcu_head rcu;
+       union {
+               /* The fields in this struct are valid if bits > 0 (TNODE) */
+               struct {
+                       unsigned int full_children;  /* KEYLENGTH bits needed */
+                       unsigned int empty_children; /* KEYLENGTH bits needed */
+                       struct tnode __rcu *child[0];
+               };
+               /* This list pointer if valid if bits == 0 (LEAF) */
+               struct hlist_head list;
+       };
 };
 
 struct leaf_info {
@@ -115,20 +119,6 @@ struct leaf_info {
        struct rcu_head rcu;
 };
 
-struct tnode {
-       unsigned long parent;
-       t_key key;
-       unsigned char pos;              /* 2log(KEYLENGTH) bits needed */
-       unsigned char bits;             /* 2log(KEYLENGTH) bits needed */
-       unsigned int full_children;     /* KEYLENGTH bits needed */
-       unsigned int empty_children;    /* KEYLENGTH bits needed */
-       union {
-               struct rcu_head rcu;
-               struct tnode *tnode_free;
-       };
-       struct rt_trie_node __rcu *child[0];
-};
-
 #ifdef CONFIG_IP_FIB_TRIE_STATS
 struct trie_use_stats {
        unsigned int gets;
@@ -151,19 +141,19 @@ struct trie_stat {
 };
 
 struct trie {
-       struct rt_trie_node __rcu *trie;
+       struct tnode __rcu *trie;
 #ifdef CONFIG_IP_FIB_TRIE_STATS
-       struct trie_use_stats stats;
+       struct trie_use_stats __percpu *stats;
 #endif
 };
 
-static void tnode_put_child_reorg(struct tnode *tn, int i, struct rt_trie_node *n,
+static void tnode_put_child_reorg(struct tnode *tn, int i, struct tnode *n,
                                  int wasfull);
-static struct rt_trie_node *resize(struct trie *t, struct tnode *tn);
+static struct tnode *resize(struct trie *t, struct tnode *tn);
 static struct tnode *inflate(struct trie *t, struct tnode *tn);
 static struct tnode *halve(struct trie *t, struct tnode *tn);
 /* tnodes to free after resize(); protected by RTNL */
-static struct tnode *tnode_free_head;
+static struct callback_head *tnode_free_head;
 static size_t tnode_free_size;
 
 /*
@@ -176,46 +166,35 @@ static const int sync_pages = 128;
 static struct kmem_cache *fn_alias_kmem __read_mostly;
 static struct kmem_cache *trie_leaf_kmem __read_mostly;
 
-/*
- * caller must hold RTNL
- */
-static inline struct tnode *node_parent(const struct rt_trie_node *node)
-{
-       unsigned long parent;
+/* caller must hold RTNL */
+#define node_parent(n) rtnl_dereference((n)->parent)
 
-       parent = rcu_dereference_index_check(node->parent, lockdep_rtnl_is_held());
+/* caller must hold RCU read lock or RTNL */
+#define node_parent_rcu(n) rcu_dereference_rtnl((n)->parent)
 
-       return (struct tnode *)(parent & ~NODE_TYPE_MASK);
-}
-
-/*
- * caller must hold RCU read lock or RTNL
- */
-static inline struct tnode *node_parent_rcu(const struct rt_trie_node *node)
+/* wrapper for rcu_assign_pointer */
+static inline void node_set_parent(struct tnode *n, struct tnode *tp)
 {
-       unsigned long parent;
-
-       parent = rcu_dereference_index_check(node->parent, rcu_read_lock_held() ||
-                                                          lockdep_rtnl_is_held());
-
-       return (struct tnode *)(parent & ~NODE_TYPE_MASK);
+       if (n)
+               rcu_assign_pointer(n->parent, tp);
 }
 
-/* Same as rcu_assign_pointer
- * but that macro() assumes that value is a pointer.
+#define NODE_INIT_PARENT(n, p) RCU_INIT_POINTER((n)->parent, p)
+
+/* This provides us with the number of children in this node, in the case of a
+ * leaf this will return 0 meaning none of the children are accessible.
  */
-static inline void node_set_parent(struct rt_trie_node *node, struct tnode *ptr)
+static inline int tnode_child_length(const struct tnode *tn)
 {
-       smp_wmb();
-       node->parent = (unsigned long)ptr | NODE_TYPE(node);
+       return (1ul << tn->bits) & ~(1ul);
 }
 
 /*
  * caller must hold RTNL
  */
-static inline struct rt_trie_node *tnode_get_child(const struct tnode *tn, unsigned int i)
+static inline struct tnode *tnode_get_child(const struct tnode *tn, unsigned int i)
 {
-       BUG_ON(i >= 1U << tn->bits);
+       BUG_ON(i >= tnode_child_length(tn));
 
        return rtnl_dereference(tn->child[i]);
 }
@@ -223,18 +202,13 @@ static inline struct rt_trie_node *tnode_get_child(const struct tnode *tn, unsig
 /*
  * caller must hold RCU read lock or RTNL
  */
-static inline struct rt_trie_node *tnode_get_child_rcu(const struct tnode *tn, unsigned int i)
+static inline struct tnode *tnode_get_child_rcu(const struct tnode *tn, unsigned int i)
 {
-       BUG_ON(i >= 1U << tn->bits);
+       BUG_ON(i >= tnode_child_length(tn));
 
        return rcu_dereference_rtnl(tn->child[i]);
 }
 
-static inline int tnode_child_length(const struct tnode *tn)
-{
-       return 1 << tn->bits;
-}
-
 static inline t_key mask_pfx(t_key k, unsigned int l)
 {
        return (l == 0) ? 0 : k >> (KEYLENGTH-l) << (KEYLENGTH-l);
@@ -336,11 +310,6 @@ static inline int tkey_mismatch(t_key a, int offset, t_key b)
 
 */
 
-static inline void check_tnode(const struct tnode *tn)
-{
-       WARN_ON(tn && tn->pos+tn->bits > 32);
-}
-
 static const int halve_threshold = 25;
 static const int inflate_threshold = 50;
 static const int halve_threshold_root = 15;
@@ -357,17 +326,23 @@ static inline void alias_free_mem_rcu(struct fib_alias *fa)
        call_rcu(&fa->rcu, __alias_free_mem);
 }
 
-static void __leaf_free_rcu(struct rcu_head *head)
-{
-       struct leaf *l = container_of(head, struct leaf, rcu);
-       kmem_cache_free(trie_leaf_kmem, l);
-}
+#define TNODE_KMALLOC_MAX \
+       ilog2((PAGE_SIZE - sizeof(struct tnode)) / sizeof(struct tnode *))
 
-static inline void free_leaf(struct leaf *l)
+static void __node_free_rcu(struct rcu_head *head)
 {
-       call_rcu(&l->rcu, __leaf_free_rcu);
+       struct tnode *n = container_of(head, struct tnode, rcu);
+
+       if (IS_LEAF(n))
+               kmem_cache_free(trie_leaf_kmem, n);
+       else if (n->bits <= TNODE_KMALLOC_MAX)
+               kfree(n);
+       else
+               vfree(n);
 }
 
+#define node_free(n) call_rcu(&n->rcu, __node_free_rcu)
+
 static inline void free_leaf_info(struct leaf_info *leaf)
 {
        kfree_rcu(leaf, rcu);
@@ -381,43 +356,24 @@ static struct tnode *tnode_alloc(size_t size)
                return vzalloc(size);
 }
 
-static void __tnode_free_rcu(struct rcu_head *head)
-{
-       struct tnode *tn = container_of(head, struct tnode, rcu);
-       size_t size = sizeof(struct tnode) +
-                     (sizeof(struct rt_trie_node *) << tn->bits);
-
-       if (size <= PAGE_SIZE)
-               kfree(tn);
-       else
-               vfree(tn);
-}
-
-static inline void tnode_free(struct tnode *tn)
-{
-       if (IS_LEAF(tn))
-               free_leaf((struct leaf *) tn);
-       else
-               call_rcu(&tn->rcu, __tnode_free_rcu);
-}
-
 static void tnode_free_safe(struct tnode *tn)
 {
        BUG_ON(IS_LEAF(tn));
-       tn->tnode_free = tnode_free_head;
-       tnode_free_head = tn;
-       tnode_free_size += sizeof(struct tnode) +
-                          (sizeof(struct rt_trie_node *) << tn->bits);
+       tn->rcu.next = tnode_free_head;
+       tnode_free_head = &tn->rcu;
 }
 
 static void tnode_free_flush(void)
 {
-       struct tnode *tn;
+       struct callback_head *head;
+
+       while ((head = tnode_free_head)) {
+               struct tnode *tn = container_of(head, struct tnode, rcu);
+
+               tnode_free_head = head->next;
+               tnode_free_size += offsetof(struct tnode, child[1 << tn->bits]);
 
-       while ((tn = tnode_free_head)) {
-               tnode_free_head = tn->tnode_free;
-               tn->tnode_free = NULL;
-               tnode_free(tn);
+               node_free(tn);
        }
 
        if (tnode_free_size >= PAGE_SIZE * sync_pages) {
@@ -426,11 +382,20 @@ static void tnode_free_flush(void)
        }
 }
 
-static struct leaf *leaf_new(void)
+static struct tnode *leaf_new(t_key key)
 {
-       struct leaf *l = kmem_cache_alloc(trie_leaf_kmem, GFP_KERNEL);
+       struct tnode *l = kmem_cache_alloc(trie_leaf_kmem, GFP_KERNEL);
        if (l) {
-               l->parent = T_LEAF;
+               l->parent = NULL;
+               /* set key and pos to reflect full key value
+                * any trailing zeros in the key should be ignored
+                * as the nodes are searched
+                */
+               l->key = key;
+               l->pos = KEYLENGTH;
+               /* set bits to 0 indicating we are not a tnode */
+               l->bits = 0;
+
                INIT_HLIST_HEAD(&l->list);
        }
        return l;
@@ -449,20 +414,24 @@ static struct leaf_info *leaf_info_new(int plen)
 
 static struct tnode *tnode_new(t_key key, int pos, int bits)
 {
-       size_t sz = sizeof(struct tnode) + (sizeof(struct rt_trie_node *) << bits);
+       size_t sz = offsetof(struct tnode, child[1 << bits]);
        struct tnode *tn = tnode_alloc(sz);
+       unsigned int shift = pos + bits;
+
+       /* verify bits and pos their msb bits clear and values are valid */
+       BUG_ON(!bits || (shift > KEYLENGTH));
 
        if (tn) {
-               tn->parent = T_TNODE;
+               tn->parent = NULL;
                tn->pos = pos;
                tn->bits = bits;
-               tn->key = key;
+               tn->key = mask_pfx(key, pos);
                tn->full_children = 0;
                tn->empty_children = 1<<bits;
        }
 
        pr_debug("AT %p s=%zu %zu\n", tn, sizeof(struct tnode),
-                sizeof(struct rt_trie_node *) << bits);
+                sizeof(struct tnode *) << bits);
        return tn;
 }
 
@@ -471,16 +440,13 @@ static struct tnode *tnode_new(t_key key, int pos, int bits)
  * and no bits are skipped. See discussion in dyntree paper p. 6
  */
 
-static inline int tnode_full(const struct tnode *tn, const struct rt_trie_node *n)
+static inline int tnode_full(const struct tnode *tn, const struct tnode *n)
 {
-       if (n == NULL || IS_LEAF(n))
-               return 0;
-
-       return ((struct tnode *) n)->pos == tn->pos + tn->bits;
+       return n && IS_TNODE(n) && (n->pos == (tn->pos + tn->bits));
 }
 
 static inline void put_child(struct tnode *tn, int i,
-                            struct rt_trie_node *n)
+                            struct tnode *n)
 {
        tnode_put_child_reorg(tn, i, n, -1);
 }
@@ -490,10 +456,10 @@ static inline void put_child(struct tnode *tn, int i,
   * Update the value of full_children and empty_children.
   */
 
-static void tnode_put_child_reorg(struct tnode *tn, int i, struct rt_trie_node *n,
+static void tnode_put_child_reorg(struct tnode *tn, int i, struct tnode *n,
                                  int wasfull)
 {
-       struct rt_trie_node *chi = rtnl_dereference(tn->child[i]);
+       struct tnode *chi = rtnl_dereference(tn->child[i]);
        int isfull;
 
        BUG_ON(i >= 1<<tn->bits);
@@ -514,17 +480,15 @@ static void tnode_put_child_reorg(struct tnode *tn, int i, struct rt_trie_node *
        else if (!wasfull && isfull)
                tn->full_children++;
 
-       if (n)
-               node_set_parent(n, tn);
+       node_set_parent(n, tn);
 
        rcu_assign_pointer(tn->child[i], n);
 }
 
 #define MAX_WORK 10
-static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
+static struct tnode *resize(struct trie *t, struct tnode *tn)
 {
-       int i;
-       struct tnode *old_tn;
+       struct tnode *old_tn, *n = NULL;
        int inflate_threshold_use;
        int halve_threshold_use;
        int max_work;
@@ -536,12 +500,11 @@ static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
                 tn, inflate_threshold, halve_threshold);
 
        /* No children */
-       if (tn->empty_children == tnode_child_length(tn)) {
-               tnode_free_safe(tn);
-               return NULL;
-       }
+       if (tn->empty_children > (tnode_child_length(tn) - 1))
+               goto no_children;
+
        /* One child */
-       if (tn->empty_children == tnode_child_length(tn) - 1)
+       if (tn->empty_children == (tnode_child_length(tn) - 1))
                goto one_child;
        /*
         * Double as long as the resulting node has a number of
@@ -607,11 +570,9 @@ static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
         *
         */
 
-       check_tnode(tn);
-
        /* Keep root node larger  */
 
-       if (!node_parent((struct rt_trie_node *)tn)) {
+       if (!node_parent(tn)) {
                inflate_threshold_use = inflate_threshold_root;
                halve_threshold_use = halve_threshold_root;
        } else {
@@ -631,17 +592,15 @@ static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
                if (IS_ERR(tn)) {
                        tn = old_tn;
 #ifdef CONFIG_IP_FIB_TRIE_STATS
-                       t->stats.resize_node_skipped++;
+                       this_cpu_inc(t->stats->resize_node_skipped);
 #endif
                        break;
                }
        }
 
-       check_tnode(tn);
-
        /* Return if at least one inflate is run */
        if (max_work != MAX_WORK)
-               return (struct rt_trie_node *) tn;
+               return tn;
 
        /*
         * Halve as long as the number of empty children in this
@@ -658,7 +617,7 @@ static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
                if (IS_ERR(tn)) {
                        tn = old_tn;
 #ifdef CONFIG_IP_FIB_TRIE_STATS
-                       t->stats.resize_node_skipped++;
+                       this_cpu_inc(t->stats->resize_node_skipped);
 #endif
                        break;
                }
@@ -666,43 +625,38 @@ static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
 
 
        /* Only one child remains */
-       if (tn->empty_children == tnode_child_length(tn) - 1) {
+       if (tn->empty_children == (tnode_child_length(tn) - 1)) {
+               unsigned long i;
 one_child:
-               for (i = 0; i < tnode_child_length(tn); i++) {
-                       struct rt_trie_node *n;
-
-                       n = rtnl_dereference(tn->child[i]);
-                       if (!n)
-                               continue;
-
-                       /* compress one level */
-
-                       node_set_parent(n, NULL);
-                       tnode_free_safe(tn);
-                       return n;
-               }
+               for (i = tnode_child_length(tn); !n && i;)
+                       n = tnode_get_child(tn, --i);
+no_children:
+               /* compress one level */
+               node_set_parent(n, NULL);
+               tnode_free_safe(tn);
+               return n;
        }
-       return (struct rt_trie_node *) tn;
+       return tn;
 }
 
 
 static void tnode_clean_free(struct tnode *tn)
 {
-       int i;
        struct tnode *tofree;
+       int i;
 
        for (i = 0; i < tnode_child_length(tn); i++) {
-               tofree = (struct tnode *)rtnl_dereference(tn->child[i]);
+               tofree = rtnl_dereference(tn->child[i]);
                if (tofree)
-                       tnode_free(tofree);
+                       node_free(tofree);
        }
-       tnode_free(tn);
+       node_free(tn);
 }
 
-static struct tnode *inflate(struct trie *t, struct tnode *tn)
+static struct tnode *inflate(struct trie *t, struct tnode *oldtnode)
 {
-       struct tnode *oldtnode = tn;
-       int olen = tnode_child_length(tn);
+       int olen = tnode_child_length(oldtnode);
+       struct tnode *tn;
        int i;
 
        pr_debug("In inflate\n");
@@ -722,11 +676,8 @@ static struct tnode *inflate(struct trie *t, struct tnode *tn)
        for (i = 0; i < olen; i++) {
                struct tnode *inode;
 
-               inode = (struct tnode *) tnode_get_child(oldtnode, i);
-               if (inode &&
-                   IS_TNODE(inode) &&
-                   inode->pos == oldtnode->pos + oldtnode->bits &&
-                   inode->bits > 1) {
+               inode = tnode_get_child(oldtnode, i);
+               if (tnode_full(oldtnode, inode) && inode->bits > 1) {
                        struct tnode *left, *right;
                        t_key m = ~0U << (KEYLENGTH - 1) >> inode->pos;
 
@@ -739,38 +690,33 @@ static struct tnode *inflate(struct trie *t, struct tnode *tn)
                                          inode->bits - 1);
 
                        if (!right) {
-                               tnode_free(left);
+                               node_free(left);
                                goto nomem;
                        }
 
-                       put_child(tn, 2*i, (struct rt_trie_node *) left);
-                       put_child(tn, 2*i+1, (struct rt_trie_node *) right);
+                       put_child(tn, 2*i, left);
+                       put_child(tn, 2*i+1, right);
                }
        }
 
        for (i = 0; i < olen; i++) {
-               struct tnode *inode;
-               struct rt_trie_node *node = tnode_get_child(oldtnode, i);
+               struct tnode *inode = tnode_get_child(oldtnode, i);
                struct tnode *left, *right;
                int size, j;
 
                /* An empty child */
-               if (node == NULL)
+               if (inode == NULL)
                        continue;
 
                /* A leaf or an internal node with skipped bits */
-
-               if (IS_LEAF(node) || ((struct tnode *) node)->pos >
-                  tn->pos + tn->bits - 1) {
+               if (!tnode_full(oldtnode, inode)) {
                        put_child(tn,
-                               tkey_extract_bits(node->key, oldtnode->pos, oldtnode->bits + 1),
-                               node);
+                               tkey_extract_bits(inode->key, tn->pos, tn->bits),
+                               inode);
                        continue;
                }
 
                /* An internal node with two children */
-               inode = (struct tnode *) node;
-
                if (inode->bits == 1) {
                        put_child(tn, 2*i, rtnl_dereference(inode->child[0]));
                        put_child(tn, 2*i+1, rtnl_dereference(inode->child[1]));
@@ -802,12 +748,12 @@ static struct tnode *inflate(struct trie *t, struct tnode *tn)
                 *   bit to zero.
                 */
 
-               left = (struct tnode *) tnode_get_child(tn, 2*i);
+               left = tnode_get_child(tn, 2*i);
                put_child(tn, 2*i, NULL);
 
                BUG_ON(!left);
 
-               right = (struct tnode *) tnode_get_child(tn, 2*i+1);
+               right = tnode_get_child(tn, 2*i+1);
                put_child(tn, 2*i+1, NULL);
 
                BUG_ON(!right);
@@ -829,12 +775,11 @@ nomem:
        return ERR_PTR(-ENOMEM);
 }
 
-static struct tnode *halve(struct trie *t, struct tnode *tn)
+static struct tnode *halve(struct trie *t, struct tnode *oldtnode)
 {
-       struct tnode *oldtnode = tn;
-       struct rt_trie_node *left, *right;
+       int olen = tnode_child_length(oldtnode);
+       struct tnode *tn, *left, *right;
        int i;
-       int olen = tnode_child_length(tn);
 
        pr_debug("In halve\n");
 
@@ -863,7 +808,7 @@ static struct tnode *halve(struct trie *t, struct tnode *tn)
                        if (!newn)
                                goto nomem;
 
-                       put_child(tn, i/2, (struct rt_trie_node *)newn);
+                       put_child(tn, i/2, newn);
                }
 
        }
@@ -888,7 +833,7 @@ static struct tnode *halve(struct trie *t, struct tnode *tn)
                }
 
                /* Two nonempty children */
-               newBinNode = (struct tnode *) tnode_get_child(tn, i/2);
+               newBinNode = tnode_get_child(tn, i/2);
                put_child(tn, i/2, NULL);
                put_child(newBinNode, 0, left);
                put_child(newBinNode, 1, right);
@@ -904,7 +849,7 @@ nomem:
 /* readside must use rcu_read_lock currently dump routines
  via get_fa_head and dump */
 
-static struct leaf_info *find_leaf_info(struct leaf *l, int plen)
+static struct leaf_info *find_leaf_info(struct tnode *l, int plen)
 {
        struct hlist_head *head = &l->list;
        struct leaf_info *li;
@@ -916,7 +861,7 @@ static struct leaf_info *find_leaf_info(struct leaf *l, int plen)
        return NULL;
 }
 
-static inline struct list_head *get_fa_head(struct leaf *l, int plen)
+static inline struct list_head *get_fa_head(struct tnode *l, int plen)
 {
        struct leaf_info *li = find_leaf_info(l, plen);
 
@@ -948,34 +893,25 @@ static void insert_leaf_info(struct hlist_head *head, struct leaf_info *new)
 
 /* rcu_read_lock needs to be hold by caller from readside */
 
-static struct leaf *
-fib_find_node(struct trie *t, u32 key)
+static struct tnode *fib_find_node(struct trie *t, u32 key)
 {
-       int pos;
-       struct tnode *tn;
-       struct rt_trie_node *n;
-
-       pos = 0;
-       n = rcu_dereference_rtnl(t->trie);
-
-       while (n != NULL &&  NODE_TYPE(n) == T_TNODE) {
-               tn = (struct tnode *) n;
+       struct tnode *n = rcu_dereference_rtnl(t->trie);
+       int pos = 0;
 
-               check_tnode(tn);
-
-               if (tkey_sub_equals(tn->key, pos, tn->pos-pos, key)) {
-                       pos = tn->pos + tn->bits;
-                       n = tnode_get_child_rcu(tn,
+       while (n && IS_TNODE(n)) {
+               if (tkey_sub_equals(n->key, pos, n->pos-pos, key)) {
+                       pos = n->pos + n->bits;
+                       n = tnode_get_child_rcu(n,
                                                tkey_extract_bits(key,
-                                                                 tn->pos,
-                                                                 tn->bits));
+                                                                 n->pos,
+                                                                 n->bits));
                } else
                        break;
        }
        /* Case we have found a leaf. Compare prefixes */
 
        if (n != NULL && IS_LEAF(n) && tkey_equals(key, n->key))
-               return (struct leaf *)n;
+               return n;
 
        return NULL;
 }
@@ -988,17 +924,16 @@ static void trie_rebalance(struct trie *t, struct tnode *tn)
 
        key = tn->key;
 
-       while (tn != NULL && (tp = node_parent((struct rt_trie_node *)tn)) != NULL) {
+       while (tn != NULL && (tp = node_parent(tn)) != NULL) {
                cindex = tkey_extract_bits(key, tp->pos, tp->bits);
                wasfull = tnode_full(tp, tnode_get_child(tp, cindex));
-               tn = (struct tnode *)resize(t, tn);
+               tn = resize(t, tn);
 
-               tnode_put_child_reorg(tp, cindex,
-                                     (struct rt_trie_node *)tn, wasfull);
+               tnode_put_child_reorg(tp, cindex, tn, wasfull);
 
-               tp = node_parent((struct rt_trie_node *) tn);
+               tp = node_parent(tn);
                if (!tp)
-                       rcu_assign_pointer(t->trie, (struct rt_trie_node *)tn);
+                       rcu_assign_pointer(t->trie, tn);
 
                tnode_free_flush();
                if (!tp)
@@ -1008,9 +943,9 @@ static void trie_rebalance(struct trie *t, struct tnode *tn)
 
        /* Handle last (top) tnode */
        if (IS_TNODE(tn))
-               tn = (struct tnode *)resize(t, tn);
+               tn = resize(t, tn);
 
-       rcu_assign_pointer(t->trie, (struct rt_trie_node *)tn);
+       rcu_assign_pointer(t->trie, tn);
        tnode_free_flush();
 }
 
@@ -1020,8 +955,8 @@ static struct list_head *fib_insert_node(struct trie *t, u32 key, int plen)
 {
        int pos, newpos;
        struct tnode *tp = NULL, *tn = NULL;
-       struct rt_trie_node *n;
-       struct leaf *l;
+       struct tnode *n;
+       struct tnode *l;
        int missbit;
        struct list_head *fa_head = NULL;
        struct leaf_info *li;
@@ -1048,20 +983,16 @@ static struct list_head *fib_insert_node(struct trie *t, u32 key, int plen)
         * If it doesn't, we need to replace it with a T_TNODE.
         */
 
-       while (n != NULL &&  NODE_TYPE(n) == T_TNODE) {
-               tn = (struct tnode *) n;
-
-               check_tnode(tn);
-
-               if (tkey_sub_equals(tn->key, pos, tn->pos-pos, key)) {
-                       tp = tn;
-                       pos = tn->pos + tn->bits;
-                       n = tnode_get_child(tn,
+       while (n && IS_TNODE(n)) {
+               if (tkey_sub_equals(n->key, pos, n->pos-pos, key)) {
+                       tp = n;
+                       pos = n->pos + n->bits;
+                       n = tnode_get_child(n,
                                            tkey_extract_bits(key,
-                                                             tn->pos,
-                                                             tn->bits));
+                                                             n->pos,
+                                                             n->bits));
 
-                       BUG_ON(n && node_parent(n) != tn);
+                       BUG_ON(n && node_parent(n) != tp);
                } else
                        break;
        }
@@ -1077,26 +1008,24 @@ static struct list_head *fib_insert_node(struct trie *t, u32 key, int plen)
        /* Case 1: n is a leaf. Compare prefixes */
 
        if (n != NULL && IS_LEAF(n) && tkey_equals(key, n->key)) {
-               l = (struct leaf *) n;
                li = leaf_info_new(plen);
 
                if (!li)
                        return NULL;
 
                fa_head = &li->falh;
-               insert_leaf_info(&l->list, li);
+               insert_leaf_info(&n->list, li);
                goto done;
        }
-       l = leaf_new();
+       l = leaf_new(key);
 
        if (!l)
                return NULL;
 
-       l->key = key;
        li = leaf_info_new(plen);
 
        if (!li) {
-               free_leaf(l);
+               node_free(l);
                return NULL;
        }
 
@@ -1106,10 +1035,10 @@ static struct list_head *fib_insert_node(struct trie *t, u32 key, int plen)
        if (t->trie && n == NULL) {
                /* Case 2: n is NULL, and will just insert a new leaf */
 
-               node_set_parent((struct rt_trie_node *)l, tp);
+               node_set_parent(l, tp);
 
                cindex = tkey_extract_bits(key, tp->pos, tp->bits);
-               put_child(tp, cindex, (struct rt_trie_node *)l);
+               put_child(tp, cindex, l);
        } else {
                /* Case 3: n is a LEAF or a TNODE and the key doesn't match. */
                /*
@@ -1128,21 +1057,21 @@ static struct list_head *fib_insert_node(struct trie *t, u32 key, int plen)
 
                if (!tn) {
                        free_leaf_info(li);
-                       free_leaf(l);
+                       node_free(l);
                        return NULL;
                }
 
-               node_set_parent((struct rt_trie_node *)tn, tp);
+               node_set_parent(tn, tp);
 
                missbit = tkey_extract_bits(key, newpos, 1);
-               put_child(tn, missbit, (struct rt_trie_node *)l);
+               put_child(tn, missbit, l);
                put_child(tn, 1-missbit, n);
 
                if (tp) {
                        cindex = tkey_extract_bits(key, tp->pos, tp->bits);
-                       put_child(tp, cindex, (struct rt_trie_node *)tn);
+                       put_child(tp, cindex, tn);
                } else {
-                       rcu_assign_pointer(t->trie, (struct rt_trie_node *)tn);
+                       rcu_assign_pointer(t->trie, tn);
                }
 
                tp = tn;
@@ -1172,7 +1101,7 @@ int fib_table_insert(struct fib_table *tb, struct fib_config *cfg)
        u8 tos = cfg->fc_tos;
        u32 key, mask;
        int err;
-       struct leaf *l;
+       struct tnode *l;
 
        if (plen > 32)
                return -EINVAL;
@@ -1330,7 +1259,7 @@ err:
 }
 
 /* should be called with rcu_read_lock */
-static int check_leaf(struct fib_table *tb, struct trie *t, struct leaf *l,
+static int check_leaf(struct fib_table *tb, struct trie *t, struct tnode *l,
                      t_key key,  const struct flowi4 *flp,
                      struct fib_result *res, int fib_flags)
 {
@@ -1355,9 +1284,9 @@ static int check_leaf(struct fib_table *tb, struct trie *t, struct leaf *l,
                                continue;
                        fib_alias_accessed(fa);
                        err = fib_props[fa->fa_type].error;
-                       if (err) {
+                       if (unlikely(err < 0)) {
 #ifdef CONFIG_IP_FIB_TRIE_STATS
-                               t->stats.semantic_match_passed++;
+                               this_cpu_inc(t->stats->semantic_match_passed);
 #endif
                                return err;
                        }
@@ -1372,12 +1301,12 @@ static int check_leaf(struct fib_table *tb, struct trie *t, struct leaf *l,
                                        continue;
 
 #ifdef CONFIG_IP_FIB_TRIE_STATS
-                               t->stats.semantic_match_passed++;
+                               this_cpu_inc(t->stats->semantic_match_passed);
 #endif
                                res->prefixlen = li->plen;
                                res->nh_sel = nhsel;
                                res->type = fa->fa_type;
-                               res->scope = fa->fa_info->fib_scope;
+                               res->scope = fi->fib_scope;
                                res->fi = fi;
                                res->table = tb;
                                res->fa_head = &li->falh;
@@ -1388,27 +1317,31 @@ static int check_leaf(struct fib_table *tb, struct trie *t, struct leaf *l,
                }
 
 #ifdef CONFIG_IP_FIB_TRIE_STATS
-               t->stats.semantic_match_miss++;
+               this_cpu_inc(t->stats->semantic_match_miss);
 #endif
        }
 
        return 1;
 }
 
+static inline t_key prefix_mismatch(t_key key, struct tnode *n)
+{
+       t_key prefix = n->key;
+
+       return (key ^ prefix) & (prefix | -prefix);
+}
+
 int fib_table_lookup(struct fib_table *tb, const struct flowi4 *flp,
                     struct fib_result *res, int fib_flags)
 {
-       struct trie *t = (struct trie *) tb->tb_data;
-       int ret;
-       struct rt_trie_node *n;
-       struct tnode *pn;
-       unsigned int pos, bits;
-       t_key key = ntohl(flp->daddr);
-       unsigned int chopped_off;
-       t_key cindex = 0;
-       unsigned int current_prefix_length = KEYLENGTH;
-       struct tnode *cn;
-       t_key pref_mismatch;
+       struct trie *t = (struct trie *)tb->tb_data;
+#ifdef CONFIG_IP_FIB_TRIE_STATS
+       struct trie_use_stats __percpu *stats = t->stats;
+#endif
+       const t_key key = ntohl(flp->daddr);
+       struct tnode *n, *pn;
+       t_key cindex;
+       int ret = 1;
 
        rcu_read_lock();
 
@@ -1417,173 +1350,105 @@ int fib_table_lookup(struct fib_table *tb, const struct flowi4 *flp,
                goto failed;
 
 #ifdef CONFIG_IP_FIB_TRIE_STATS
-       t->stats.gets++;
+       this_cpu_inc(stats->gets);
 #endif
 
-       /* Just a leaf? */
-       if (IS_LEAF(n)) {
-               ret = check_leaf(tb, t, (struct leaf *)n, key, flp, res, fib_flags);
-               goto found;
-       }
-
-       pn = (struct tnode *) n;
-       chopped_off = 0;
-
-       while (pn) {
-               pos = pn->pos;
-               bits = pn->bits;
-
-               if (!chopped_off)
-                       cindex = tkey_extract_bits(mask_pfx(key, current_prefix_length),
-                                                  pos, bits);
-
-               n = tnode_get_child_rcu(pn, cindex);
-
-               if (n == NULL) {
-#ifdef CONFIG_IP_FIB_TRIE_STATS
-                       t->stats.null_node_hit++;
-#endif
-                       goto backtrace;
-               }
+       pn = n;
+       cindex = 0;
+
+       /* Step 1: Travel to the longest prefix match in the trie */
+       for (;;) {
+               unsigned long index = get_index(key, n);
+
+               /* This bit of code is a bit tricky but it combines multiple
+                * checks into a single check.  The prefix consists of the
+                * prefix plus zeros for the "bits" in the prefix. The index
+                * is the difference between the key and this value.  From
+                * this we can actually derive several pieces of data.
+                *   if !(index >> bits)
+                *     we know the value is child index
+                *   else
+                *     we have a mismatch in skip bits and failed
+                */
+               if (index >> n->bits)
+                       break;
 
-               if (IS_LEAF(n)) {
-                       ret = check_leaf(tb, t, (struct leaf *)n, key, flp, res, fib_flags);
-                       if (ret > 0)
-                               goto backtrace;
+               /* we have found a leaf. Prefixes have already been compared */
+               if (IS_LEAF(n))
                        goto found;
-               }
-
-               cn = (struct tnode *)n;
-
-               /*
-                * It's a tnode, and we can do some extra checks here if we
-                * like, to avoid descending into a dead-end branch.
-                * This tnode is in the parent's child array at index
-                * key[p_pos..p_pos+p_bits] but potentially with some bits
-                * chopped off, so in reality the index may be just a
-                * subprefix, padded with zero at the end.
-                * We can also take a look at any skipped bits in this
-                * tnode - everything up to p_pos is supposed to be ok,
-                * and the non-chopped bits of the index (se previous
-                * paragraph) are also guaranteed ok, but the rest is
-                * considered unknown.
-                *
-                * The skipped bits are key[pos+bits..cn->pos].
-                */
 
-               /* If current_prefix_length < pos+bits, we are already doing
-                * actual prefix  matching, which means everything from
-                * pos+(bits-chopped_off) onward must be zero along some
-                * branch of this subtree - otherwise there is *no* valid
-                * prefix present. Here we can only check the skipped
-                * bits. Remember, since we have already indexed into the
-                * parent's child array, we know that the bits we chopped of
-                * *are* zero.
+               /* only record pn and cindex if we are going to be chopping
+                * bits later.  Otherwise we are just wasting cycles.
                 */
-
-               /* NOTA BENE: Checking only skipped bits
-                  for the new node here */
-
-               if (current_prefix_length < pos+bits) {
-                       if (tkey_extract_bits(cn->key, current_prefix_length,
-                                               cn->pos - current_prefix_length)
-                           || !(cn->child[0]))
-                               goto backtrace;
+               if (index) {
+                       pn = n;
+                       cindex = index;
                }
 
-               /*
-                * If chopped_off=0, the index is fully validated and we
-                * only need to look at the skipped bits for this, the new,
-                * tnode. What we actually want to do is to find out if
-                * these skipped bits match our key perfectly, or if we will
-                * have to count on finding a matching prefix further down,
-                * because if we do, we would like to have some way of
-                * verifying the existence of such a prefix at this point.
-                */
+               n = rcu_dereference(n->child[index]);
+               if (unlikely(!n))
+                       goto backtrace;
+       }
 
-               /* The only thing we can do at this point is to verify that
-                * any such matching prefix can indeed be a prefix to our
-                * key, and if the bits in the node we are inspecting that
-                * do not match our key are not ZERO, this cannot be true.
-                * Thus, find out where there is a mismatch (before cn->pos)
-                * and verify that all the mismatching bits are zero in the
-                * new tnode's key.
-                */
+       /* Step 2: Sort out leaves and begin backtracing for longest prefix */
+       for (;;) {
+               /* record the pointer where our next node pointer is stored */
+               struct tnode __rcu **cptr = n->child;
 
-               /*
-                * Note: We aren't very concerned about the piece of
-                * the key that precede pn->pos+pn->bits, since these
-                * have already been checked. The bits after cn->pos
-                * aren't checked since these are by definition
-                * "unknown" at this point. Thus, what we want to see
-                * is if we are about to enter the "prefix matching"
-                * state, and in that case verify that the skipped
-                * bits that will prevail throughout this subtree are
-                * zero, as they have to be if we are to find a
-                * matching prefix.
+               /* This test verifies that none of the bits that differ
+                * between the key and the prefix exist in the region of
+                * the lsb and higher in the prefix.
                 */
+               if (unlikely(prefix_mismatch(key, n)))
+                       goto backtrace;
 
-               pref_mismatch = mask_pfx(cn->key ^ key, cn->pos);
+               /* exit out and process leaf */
+               if (unlikely(IS_LEAF(n)))
+                       break;
 
-               /*
-                * In short: If skipped bits in this node do not match
-                * the search key, enter the "prefix matching"
-                * state.directly.
+               /* Don't bother recording parent info.  Since we are in
+                * prefix match mode we will have to come back to wherever
+                * we started this traversal anyway
                 */
-               if (pref_mismatch) {
-                       /* fls(x) = __fls(x) + 1 */
-                       int mp = KEYLENGTH - __fls(pref_mismatch) - 1;
-
-                       if (tkey_extract_bits(cn->key, mp, cn->pos - mp) != 0)
-                               goto backtrace;
-
-                       if (current_prefix_length >= cn->pos)
-                               current_prefix_length = mp;
-               }
-
-               pn = (struct tnode *)n; /* Descend */
-               chopped_off = 0;
-               continue;
 
+               while ((n = rcu_dereference(*cptr)) == NULL) {
 backtrace:
-               chopped_off++;
-
-               /* As zero don't change the child key (cindex) */
-               while ((chopped_off <= pn->bits)
-                      && !(cindex & (1<<(chopped_off-1))))
-                       chopped_off++;
-
-               /* Decrease current_... with bits chopped off */
-               if (current_prefix_length > pn->pos + pn->bits - chopped_off)
-                       current_prefix_length = pn->pos + pn->bits
-                               - chopped_off;
-
-               /*
-                * Either we do the actual chop off according or if we have
-                * chopped off all bits in this tnode walk up to our parent.
-                */
-
-               if (chopped_off <= pn->bits) {
-                       cindex &= ~(1 << (chopped_off-1));
-               } else {
-                       struct tnode *parent = node_parent_rcu((struct rt_trie_node *) pn);
-                       if (!parent)
-                               goto failed;
-
-                       /* Get Child's index */
-                       cindex = tkey_extract_bits(pn->key, parent->pos, parent->bits);
-                       pn = parent;
-                       chopped_off = 0;
-
 #ifdef CONFIG_IP_FIB_TRIE_STATS
-                       t->stats.backtrack++;
+                       if (!n)
+                               this_cpu_inc(stats->null_node_hit);
+#endif
+                       /* If we are at cindex 0 there are no more bits for
+                        * us to strip at this level so we must ascend back
+                        * up one level to see if there are any more bits to
+                        * be stripped there.
+                        */
+                       while (!cindex) {
+                               t_key pkey = pn->key;
+
+                               pn = node_parent_rcu(pn);
+                               if (unlikely(!pn))
+                                       goto failed;
+#ifdef CONFIG_IP_FIB_TRIE_STATS
+                               this_cpu_inc(stats->backtrack);
 #endif
-                       goto backtrace;
+                               /* Get Child's index */
+                               cindex = get_index(pkey, pn);
+                       }
+
+                       /* strip the least significant bit from the cindex */
+                       cindex &= cindex - 1;
+
+                       /* grab pointer for next child node */
+                       cptr = &pn->child[cindex];
                }
        }
-failed:
-       ret = 1;
+
 found:
+       /* Step 3: Process the leaf, if that fails fall back to backtracing */
+       ret = check_leaf(tb, t, n, key, flp, res, fib_flags);
+       if (unlikely(ret > 0))
+               goto backtrace;
+failed:
        rcu_read_unlock();
        return ret;
 }
@@ -1592,9 +1457,9 @@ EXPORT_SYMBOL_GPL(fib_table_lookup);
 /*
  * Remove the leaf and return parent.
  */
-static void trie_leaf_remove(struct trie *t, struct leaf *l)
+static void trie_leaf_remove(struct trie *t, struct tnode *l)
 {
-       struct tnode *tp = node_parent((struct rt_trie_node *) l);
+       struct tnode *tp = node_parent(l);
 
        pr_debug("entering trie_leaf_remove(%p)\n", l);
 
@@ -1605,7 +1470,7 @@ static void trie_leaf_remove(struct trie *t, struct leaf *l)
        } else
                RCU_INIT_POINTER(t->trie, NULL);
 
-       free_leaf(l);
+       node_free(l);
 }
 
 /*
@@ -1619,7 +1484,7 @@ int fib_table_delete(struct fib_table *tb, struct fib_config *cfg)
        u8 tos = cfg->fc_tos;
        struct fib_alias *fa, *fa_to_delete;
        struct list_head *fa_head;
-       struct leaf *l;
+       struct tnode *l;
        struct leaf_info *li;
 
        if (plen > 32)
@@ -1717,7 +1582,7 @@ static int trie_flush_list(struct list_head *head)
        return found;
 }
 
-static int trie_flush_leaf(struct leaf *l)
+static int trie_flush_leaf(struct tnode *l)
 {
        int found = 0;
        struct hlist_head *lih = &l->list;
@@ -1739,7 +1604,7 @@ static int trie_flush_leaf(struct leaf *l)
  * Scan for the next right leaf starting at node p->child[idx]
  * Since we have back pointer, no recursion necessary.
  */
-static struct leaf *leaf_walk_rcu(struct tnode *p, struct rt_trie_node *c)
+static struct tnode *leaf_walk_rcu(struct tnode *p, struct tnode *c)
 {
        do {
                t_key idx;
@@ -1755,47 +1620,46 @@ static struct leaf *leaf_walk_rcu(struct tnode *p, struct rt_trie_node *c)
                                continue;
 
                        if (IS_LEAF(c))
-                               return (struct leaf *) c;
+                               return c;
 
                        /* Rescan start scanning in new node */
-                       p = (struct tnode *) c;
+                       p = c;
                        idx = 0;
                }
 
                /* Node empty, walk back up to parent */
-               c = (struct rt_trie_node *) p;
+               c = p;
        } while ((p = node_parent_rcu(c)) != NULL);
 
        return NULL; /* Root of trie */
 }
 
-static struct leaf *trie_firstleaf(struct trie *t)
+static struct tnode *trie_firstleaf(struct trie *t)
 {
-       struct tnode *n = (struct tnode *)rcu_dereference_rtnl(t->trie);
+       struct tnode *n = rcu_dereference_rtnl(t->trie);
 
        if (!n)
                return NULL;
 
        if (IS_LEAF(n))          /* trie is just a leaf */
-               return (struct leaf *) n;
+               return n;
 
        return leaf_walk_rcu(n, NULL);
 }
 
-static struct leaf *trie_nextleaf(struct leaf *l)
+static struct tnode *trie_nextleaf(struct tnode *l)
 {
-       struct rt_trie_node *c = (struct rt_trie_node *) l;
-       struct tnode *p = node_parent_rcu(c);
+       struct tnode *p = node_parent_rcu(l);
 
        if (!p)
                return NULL;    /* trie with just one leaf */
 
-       return leaf_walk_rcu(p, c);
+       return leaf_walk_rcu(p, l);
 }
 
-static struct leaf *trie_leafindex(struct trie *t, int index)
+static struct tnode *trie_leafindex(struct trie *t, int index)
 {
-       struct leaf *l = trie_firstleaf(t);
+       struct tnode *l = trie_firstleaf(t);
 
        while (l && index-- > 0)
                l = trie_nextleaf(l);
@@ -1810,7 +1674,7 @@ static struct leaf *trie_leafindex(struct trie *t, int index)
 int fib_table_flush(struct fib_table *tb)
 {
        struct trie *t = (struct trie *) tb->tb_data;
-       struct leaf *l, *ll = NULL;
+       struct tnode *l, *ll = NULL;
        int found = 0;
 
        for (l = trie_firstleaf(t); l; l = trie_nextleaf(l)) {
@@ -1830,6 +1694,11 @@ int fib_table_flush(struct fib_table *tb)
 
 void fib_free_table(struct fib_table *tb)
 {
+#ifdef CONFIG_IP_FIB_TRIE_STATS
+       struct trie *t = (struct trie *)tb->tb_data;
+
+       free_percpu(t->stats);
+#endif /* CONFIG_IP_FIB_TRIE_STATS */
        kfree(tb);
 }
 
@@ -1870,7 +1739,7 @@ static int fn_trie_dump_fa(t_key key, int plen, struct list_head *fah,
        return skb->len;
 }
 
-static int fn_trie_dump_leaf(struct leaf *l, struct fib_table *tb,
+static int fn_trie_dump_leaf(struct tnode *l, struct fib_table *tb,
                        struct sk_buff *skb, struct netlink_callback *cb)
 {
        struct leaf_info *li;
@@ -1906,7 +1775,7 @@ static int fn_trie_dump_leaf(struct leaf *l, struct fib_table *tb,
 int fib_table_dump(struct fib_table *tb, struct sk_buff *skb,
                   struct netlink_callback *cb)
 {
-       struct leaf *l;
+       struct tnode *l;
        struct trie *t = (struct trie *) tb->tb_data;
        t_key key = cb->args[2];
        int count = cb->args[3];
@@ -1952,7 +1821,7 @@ void __init fib_trie_init(void)
                                          0, SLAB_PANIC, NULL);
 
        trie_leaf_kmem = kmem_cache_create("ip_fib_trie",
-                                          max(sizeof(struct leaf),
+                                          max(sizeof(struct tnode),
                                               sizeof(struct leaf_info)),
                                           0, SLAB_PANIC, NULL);
 }
@@ -1973,7 +1842,14 @@ struct fib_table *fib_trie_table(u32 id)
        tb->tb_num_default = 0;
 
        t = (struct trie *) tb->tb_data;
-       memset(t, 0, sizeof(*t));
+       RCU_INIT_POINTER(t->trie, NULL);
+#ifdef CONFIG_IP_FIB_TRIE_STATS
+       t->stats = alloc_percpu(struct trie_use_stats);
+       if (!t->stats) {
+               kfree(tb);
+               tb = NULL;
+       }
+#endif
 
        return tb;
 }
@@ -1988,7 +1864,7 @@ struct fib_trie_iter {
        unsigned int depth;
 };
 
-static struct rt_trie_node *fib_trie_get_next(struct fib_trie_iter *iter)
+static struct tnode *fib_trie_get_next(struct fib_trie_iter *iter)
 {
        struct tnode *tn = iter->tnode;
        unsigned int cindex = iter->index;
@@ -2002,7 +1878,7 @@ static struct rt_trie_node *fib_trie_get_next(struct fib_trie_iter *iter)
                 iter->tnode, iter->index, iter->depth);
 rescan:
        while (cindex < (1<<tn->bits)) {
-               struct rt_trie_node *n = tnode_get_child_rcu(tn, cindex);
+               struct tnode *n = tnode_get_child_rcu(tn, cindex);
 
                if (n) {
                        if (IS_LEAF(n)) {
@@ -2010,7 +1886,7 @@ rescan:
                                iter->index = cindex + 1;
                        } else {
                                /* push down one level */
-                               iter->tnode = (struct tnode *) n;
+                               iter->tnode = n;
                                iter->index = 0;
                                ++iter->depth;
                        }
@@ -2021,7 +1897,7 @@ rescan:
        }
 
        /* Current node exhausted, pop back up */
-       p = node_parent_rcu((struct rt_trie_node *)tn);
+       p = node_parent_rcu(tn);
        if (p) {
                cindex = tkey_extract_bits(tn->key, p->pos, p->bits)+1;
                tn = p;
@@ -2033,10 +1909,10 @@ rescan:
        return NULL;
 }
 
-static struct rt_trie_node *fib_trie_get_first(struct fib_trie_iter *iter,
+static struct tnode *fib_trie_get_first(struct fib_trie_iter *iter,
                                       struct trie *t)
 {
-       struct rt_trie_node *n;
+       struct tnode *n;
 
        if (!t)
                return NULL;
@@ -2046,7 +1922,7 @@ static struct rt_trie_node *fib_trie_get_first(struct fib_trie_iter *iter,
                return NULL;
 
        if (IS_TNODE(n)) {
-               iter->tnode = (struct tnode *) n;
+               iter->tnode = n;
                iter->index = 0;
                iter->depth = 1;
        } else {
@@ -2060,7 +1936,7 @@ static struct rt_trie_node *fib_trie_get_first(struct fib_trie_iter *iter,
 
 static void trie_collect_stats(struct trie *t, struct trie_stat *s)
 {
-       struct rt_trie_node *n;
+       struct tnode *n;
        struct fib_trie_iter iter;
 
        memset(s, 0, sizeof(*s));
@@ -2068,7 +1944,6 @@ static void trie_collect_stats(struct trie *t, struct trie_stat *s)
        rcu_read_lock();
        for (n = fib_trie_get_first(&iter, t); n; n = fib_trie_get_next(&iter)) {
                if (IS_LEAF(n)) {
-                       struct leaf *l = (struct leaf *)n;
                        struct leaf_info *li;
 
                        s->leaves++;
@@ -2076,18 +1951,17 @@ static void trie_collect_stats(struct trie *t, struct trie_stat *s)
                        if (iter.depth > s->maxdepth)
                                s->maxdepth = iter.depth;
 
-                       hlist_for_each_entry_rcu(li, &l->list, hlist)
+                       hlist_for_each_entry_rcu(li, &n->list, hlist)
                                ++s->prefixes;
                } else {
-                       const struct tnode *tn = (const struct tnode *) n;
                        int i;
 
                        s->tnodes++;
-                       if (tn->bits < MAX_STAT_DEPTH)
-                               s->nodesizes[tn->bits]++;
+                       if (n->bits < MAX_STAT_DEPTH)
+                               s->nodesizes[n->bits]++;
 
-                       for (i = 0; i < (1<<tn->bits); i++)
-                               if (!tn->child[i])
+                       for (i = 0; i < tnode_child_length(n); i++)
+                               if (!rcu_access_pointer(n->child[i]))
                                        s->nullpointers++;
                }
        }
@@ -2111,7 +1985,7 @@ static void trie_show_stats(struct seq_file *seq, struct trie_stat *stat)
        seq_printf(seq, "\tMax depth:      %u\n", stat->maxdepth);
 
        seq_printf(seq, "\tLeaves:         %u\n", stat->leaves);
-       bytes = sizeof(struct leaf) * stat->leaves;
+       bytes = sizeof(struct tnode) * stat->leaves;
 
        seq_printf(seq, "\tPrefixes:       %u\n", stat->prefixes);
        bytes += sizeof(struct leaf_info) * stat->prefixes;
@@ -2132,25 +2006,38 @@ static void trie_show_stats(struct seq_file *seq, struct trie_stat *stat)
        seq_putc(seq, '\n');
        seq_printf(seq, "\tPointers: %u\n", pointers);
 
-       bytes += sizeof(struct rt_trie_node *) * pointers;
+       bytes += sizeof(struct tnode *) * pointers;
        seq_printf(seq, "Null ptrs: %u\n", stat->nullpointers);
        seq_printf(seq, "Total size: %u  kB\n", (bytes + 1023) / 1024);
 }
 
 #ifdef CONFIG_IP_FIB_TRIE_STATS
 static void trie_show_usage(struct seq_file *seq,
-                           const struct trie_use_stats *stats)
+                           const struct trie_use_stats __percpu *stats)
 {
+       struct trie_use_stats s = { 0 };
+       int cpu;
+
+       /* loop through all of the CPUs and gather up the stats */
+       for_each_possible_cpu(cpu) {
+               const struct trie_use_stats *pcpu = per_cpu_ptr(stats, cpu);
+
+               s.gets += pcpu->gets;
+               s.backtrack += pcpu->backtrack;
+               s.semantic_match_passed += pcpu->semantic_match_passed;
+               s.semantic_match_miss += pcpu->semantic_match_miss;
+               s.null_node_hit += pcpu->null_node_hit;
+               s.resize_node_skipped += pcpu->resize_node_skipped;
+       }
+
        seq_printf(seq, "\nCounters:\n---------\n");
-       seq_printf(seq, "gets = %u\n", stats->gets);
-       seq_printf(seq, "backtracks = %u\n", stats->backtrack);
+       seq_printf(seq, "gets = %u\n", s.gets);
+       seq_printf(seq, "backtracks = %u\n", s.backtrack);
        seq_printf(seq, "semantic match passed = %u\n",
-                  stats->semantic_match_passed);
-       seq_printf(seq, "semantic match miss = %u\n",
-                  stats->semantic_match_miss);
-       seq_printf(seq, "null node hit= %u\n", stats->null_node_hit);
-       seq_printf(seq, "skipped node resize = %u\n\n",
-                  stats->resize_node_skipped);
+                  s.semantic_match_passed);
+       seq_printf(seq, "semantic match miss = %u\n", s.semantic_match_miss);
+       seq_printf(seq, "null node hit= %u\n", s.null_node_hit);
+       seq_printf(seq, "skipped node resize = %u\n\n", s.resize_node_skipped);
 }
 #endif /*  CONFIG_IP_FIB_TRIE_STATS */
 
@@ -2173,7 +2060,7 @@ static int fib_triestat_seq_show(struct seq_file *seq, void *v)
        seq_printf(seq,
                   "Basic info: size of leaf:"
                   " %Zd bytes, size of tnode: %Zd bytes.\n",
-                  sizeof(struct leaf), sizeof(struct tnode));
+                  sizeof(struct tnode), sizeof(struct tnode));
 
        for (h = 0; h < FIB_TABLE_HASHSZ; h++) {
                struct hlist_head *head = &net->ipv4.fib_table_hash[h];
@@ -2191,7 +2078,7 @@ static int fib_triestat_seq_show(struct seq_file *seq, void *v)
                        trie_collect_stats(t, &stat);
                        trie_show_stats(seq, &stat);
 #ifdef CONFIG_IP_FIB_TRIE_STATS
-                       trie_show_usage(seq, &t->stats);
+                       trie_show_usage(seq, t->stats);
 #endif
                }
        }
@@ -2212,7 +2099,7 @@ static const struct file_operations fib_triestat_fops = {
        .release = single_release_net,
 };
 
-static struct rt_trie_node *fib_trie_get_idx(struct seq_file *seq, loff_t pos)
+static struct tnode *fib_trie_get_idx(struct seq_file *seq, loff_t pos)
 {
        struct fib_trie_iter *iter = seq->private;
        struct net *net = seq_file_net(seq);
@@ -2224,7 +2111,7 @@ static struct rt_trie_node *fib_trie_get_idx(struct seq_file *seq, loff_t pos)
                struct fib_table *tb;
 
                hlist_for_each_entry_rcu(tb, head, tb_hlist) {
-                       struct rt_trie_node *n;
+                       struct tnode *n;
 
                        for (n = fib_trie_get_first(iter,
                                                    (struct trie *) tb->tb_data);
@@ -2253,7 +2140,7 @@ static void *fib_trie_seq_next(struct seq_file *seq, void *v, loff_t *pos)
        struct fib_table *tb = iter->tb;
        struct hlist_node *tb_node;
        unsigned int h;
-       struct rt_trie_node *n;
+       struct tnode *n;
 
        ++*pos;
        /* next node in same table */
@@ -2339,29 +2226,26 @@ static inline const char *rtn_type(char *buf, size_t len, unsigned int t)
 static int fib_trie_seq_show(struct seq_file *seq, void *v)
 {
        const struct fib_trie_iter *iter = seq->private;
-       struct rt_trie_node *n = v;
+       struct tnode *n = v;
 
        if (!node_parent_rcu(n))
                fib_table_print(seq, iter->tb);
 
        if (IS_TNODE(n)) {
-               struct tnode *tn = (struct tnode *) n;
-               __be32 prf = htonl(mask_pfx(tn->key, tn->pos));
+               __be32 prf = htonl(n->key);
 
-               seq_indent(seq, iter->depth-1);
+               seq_indent(seq, iter->depth - 1);
                seq_printf(seq, "  +-- %pI4/%d %d %d %d\n",
-                          &prf, tn->pos, tn->bits, tn->full_children,
-                          tn->empty_children);
-
+                          &prf, n->pos, n->bits, n->full_children,
+                          n->empty_children);
        } else {
-               struct leaf *l = (struct leaf *) n;
                struct leaf_info *li;
-               __be32 val = htonl(l->key);
+               __be32 val = htonl(n->key);
 
                seq_indent(seq, iter->depth);
                seq_printf(seq, "  |-- %pI4\n", &val);
 
-               hlist_for_each_entry_rcu(li, &l->list, hlist) {
+               hlist_for_each_entry_rcu(li, &n->list, hlist) {
                        struct fib_alias *fa;
 
                        list_for_each_entry_rcu(fa, &li->falh, fa_list) {
@@ -2411,9 +2295,9 @@ struct fib_route_iter {
        t_key   key;
 };
 
-static struct leaf *fib_route_get_idx(struct fib_route_iter *iter, loff_t pos)
+static struct tnode *fib_route_get_idx(struct fib_route_iter *iter, loff_t pos)
 {
-       struct leaf *l = NULL;
+       struct tnode *l = NULL;
        struct trie *t = iter->main_trie;
 
        /* use cache location of last found key */
@@ -2458,7 +2342,7 @@ static void *fib_route_seq_start(struct seq_file *seq, loff_t *pos)
 static void *fib_route_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 {
        struct fib_route_iter *iter = seq->private;
-       struct leaf *l = v;
+       struct tnode *l = v;
 
        ++*pos;
        if (v == SEQ_START_TOKEN) {
@@ -2504,7 +2388,7 @@ static unsigned int fib_flag_trans(int type, __be32 mask, const struct fib_info
  */
 static int fib_route_seq_show(struct seq_file *seq, void *v)
 {
-       struct leaf *l = v;
+       struct tnode *l = v;
        struct leaf_info *li;
 
        if (v == SEQ_START_TOKEN) {