batman-adv: protect neighbor nodes with reference counters
[cascardo/linux.git] / net / batman-adv / originator.c
index 54863c9..b1b1773 100644 (file)
@@ -59,9 +59,18 @@ err:
        return 0;
 }
 
-struct neigh_node *
-create_neighbor(struct orig_node *orig_node, struct orig_node *orig_neigh_node,
-               uint8_t *neigh, struct batman_if *if_incoming)
+void neigh_node_free_ref(struct kref *refcount)
+{
+       struct neigh_node *neigh_node;
+
+       neigh_node = container_of(refcount, struct neigh_node, refcount);
+       kfree(neigh_node);
+}
+
+struct neigh_node *create_neighbor(struct orig_node *orig_node,
+                                  struct orig_node *orig_neigh_node,
+                                  uint8_t *neigh,
+                                  struct batman_if *if_incoming)
 {
        struct bat_priv *bat_priv = netdev_priv(if_incoming->soft_iface);
        struct neigh_node *neigh_node;
@@ -78,6 +87,7 @@ create_neighbor(struct orig_node *orig_node, struct orig_node *orig_neigh_node,
        memcpy(neigh_node->addr, neigh, ETH_ALEN);
        neigh_node->orig_node = orig_neigh_node;
        neigh_node->if_incoming = if_incoming;
+       kref_init(&neigh_node->refcount);
 
        list_add_tail(&neigh_node->list, &orig_node->neigh_list);
        return neigh_node;
@@ -95,7 +105,7 @@ static void free_orig_node(void *data, void *arg)
                neigh_node = list_entry(list_pos, struct neigh_node, list);
 
                list_del(list_pos);
-               kfree(neigh_node);
+               kref_put(&neigh_node->refcount, neigh_node_free_ref);
        }
 
        frag_list_free(&orig_node->frag_list);
@@ -216,7 +226,7 @@ static bool purge_orig_neighbors(struct bat_priv *bat_priv,
 
                        neigh_purged = true;
                        list_del(list_pos);
-                       kfree(neigh_node);
+                       kref_put(&neigh_node->refcount, neigh_node_free_ref);
                } else {
                        if ((!*best_neigh_node) ||
                            (neigh_node->tq_avg > (*best_neigh_node)->tq_avg))