Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[cascardo/linux.git] / net / xfrm / xfrm_policy.c
index 45f9cf9..fd69866 100644 (file)
@@ -49,6 +49,7 @@ static struct xfrm_policy_afinfo __rcu *xfrm_policy_afinfo[NPROTO]
                                                __read_mostly;
 
 static struct kmem_cache *xfrm_dst_cache __read_mostly;
+static __read_mostly seqcount_t xfrm_policy_hash_generation;
 
 static void xfrm_init_pmtu(struct dst_entry *dst);
 static int stale_bundle(struct dst_entry *dst);
@@ -59,6 +60,11 @@ static void __xfrm_policy_link(struct xfrm_policy *pol, int dir);
 static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol,
                                                int dir);
 
+static inline bool xfrm_pol_hold_rcu(struct xfrm_policy *policy)
+{
+       return atomic_inc_not_zero(&policy->refcnt);
+}
+
 static inline bool
 __xfrm4_selector_match(const struct xfrm_selector *sel, const struct flowi *fl)
 {
@@ -385,9 +391,11 @@ static struct hlist_head *policy_hash_bysel(struct net *net,
        __get_hash_thresh(net, family, dir, &dbits, &sbits);
        hash = __sel_hash(sel, family, hmask, dbits, sbits);
 
-       return (hash == hmask + 1 ?
-               &net->xfrm.policy_inexact[dir] :
-               net->xfrm.policy_bydst[dir].table + hash);
+       if (hash == hmask + 1)
+               return &net->xfrm.policy_inexact[dir];
+
+       return rcu_dereference_check(net->xfrm.policy_bydst[dir].table,
+                    lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash;
 }
 
 static struct hlist_head *policy_hash_direct(struct net *net,
@@ -403,7 +411,8 @@ static struct hlist_head *policy_hash_direct(struct net *net,
        __get_hash_thresh(net, family, dir, &dbits, &sbits);
        hash = __addr_hash(daddr, saddr, family, hmask, dbits, sbits);
 
-       return net->xfrm.policy_bydst[dir].table + hash;
+       return rcu_dereference_check(net->xfrm.policy_bydst[dir].table,
+                    lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash;
 }
 
 static void xfrm_dst_hash_transfer(struct net *net,
@@ -426,14 +435,14 @@ redo:
                h = __addr_hash(&pol->selector.daddr, &pol->selector.saddr,
                                pol->family, nhashmask, dbits, sbits);
                if (!entry0) {
-                       hlist_del(&pol->bydst);
-                       hlist_add_head(&pol->bydst, ndsttable+h);
+                       hlist_del_rcu(&pol->bydst);
+                       hlist_add_head_rcu(&pol->bydst, ndsttable + h);
                        h0 = h;
                } else {
                        if (h != h0)
                                continue;
-                       hlist_del(&pol->bydst);
-                       hlist_add_behind(&pol->bydst, entry0);
+                       hlist_del_rcu(&pol->bydst);
+                       hlist_add_behind_rcu(&pol->bydst, entry0);
                }
                entry0 = &pol->bydst;
        }
@@ -468,22 +477,32 @@ static void xfrm_bydst_resize(struct net *net, int dir)
        unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
        unsigned int nhashmask = xfrm_new_hash_mask(hmask);
        unsigned int nsize = (nhashmask + 1) * sizeof(struct hlist_head);
-       struct hlist_head *odst = net->xfrm.policy_bydst[dir].table;
        struct hlist_head *ndst = xfrm_hash_alloc(nsize);
+       struct hlist_head *odst;
        int i;
 
        if (!ndst)
                return;
 
-       write_lock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
+       write_seqcount_begin(&xfrm_policy_hash_generation);
+
+       odst = rcu_dereference_protected(net->xfrm.policy_bydst[dir].table,
+                               lockdep_is_held(&net->xfrm.xfrm_policy_lock));
+
+       odst = rcu_dereference_protected(net->xfrm.policy_bydst[dir].table,
+                               lockdep_is_held(&net->xfrm.xfrm_policy_lock));
 
        for (i = hmask; i >= 0; i--)
                xfrm_dst_hash_transfer(net, odst + i, ndst, nhashmask, dir);
 
-       net->xfrm.policy_bydst[dir].table = ndst;
+       rcu_assign_pointer(net->xfrm.policy_bydst[dir].table, ndst);
        net->xfrm.policy_bydst[dir].hmask = nhashmask;
 
-       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       write_seqcount_end(&xfrm_policy_hash_generation);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
+
+       synchronize_rcu();
 
        xfrm_hash_free(odst, (hmask + 1) * sizeof(struct hlist_head));
 }
@@ -500,7 +519,7 @@ static void xfrm_byidx_resize(struct net *net, int total)
        if (!nidx)
                return;
 
-       write_lock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
 
        for (i = hmask; i >= 0; i--)
                xfrm_idx_hash_transfer(oidx + i, nidx, nhashmask);
@@ -508,7 +527,7 @@ static void xfrm_byidx_resize(struct net *net, int total)
        net->xfrm.policy_byidx = nidx;
        net->xfrm.policy_idx_hmask = nhashmask;
 
-       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
 
        xfrm_hash_free(oidx, (hmask + 1) * sizeof(struct hlist_head));
 }
@@ -541,7 +560,6 @@ static inline int xfrm_byidx_should_resize(struct net *net, int total)
 
 void xfrm_spd_getinfo(struct net *net, struct xfrmk_spdinfo *si)
 {
-       read_lock_bh(&net->xfrm.xfrm_policy_lock);
        si->incnt = net->xfrm.policy_count[XFRM_POLICY_IN];
        si->outcnt = net->xfrm.policy_count[XFRM_POLICY_OUT];
        si->fwdcnt = net->xfrm.policy_count[XFRM_POLICY_FWD];
@@ -550,7 +568,6 @@ void xfrm_spd_getinfo(struct net *net, struct xfrmk_spdinfo *si)
        si->fwdscnt = net->xfrm.policy_count[XFRM_POLICY_FWD+XFRM_POLICY_MAX];
        si->spdhcnt = net->xfrm.policy_idx_hmask;
        si->spdhmcnt = xfrm_policy_hashmax;
-       read_unlock_bh(&net->xfrm.xfrm_policy_lock);
 }
 EXPORT_SYMBOL(xfrm_spd_getinfo);
 
@@ -600,7 +617,7 @@ static void xfrm_hash_rebuild(struct work_struct *work)
                rbits6 = net->xfrm.policy_hthresh.rbits6;
        } while (read_seqretry(&net->xfrm.policy_hthresh.lock, seq));
 
-       write_lock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
 
        /* reset the bydst and inexact table in all directions */
        for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
@@ -646,7 +663,7 @@ static void xfrm_hash_rebuild(struct work_struct *work)
                        hlist_add_head(&policy->bydst, chain);
        }
 
-       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
 
        mutex_unlock(&hash_resize_mutex);
 }
@@ -757,7 +774,7 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl)
        struct hlist_head *chain;
        struct hlist_node *newpos;
 
-       write_lock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
        chain = policy_hash_bysel(net, &policy->selector, policy->family, dir);
        delpol = NULL;
        newpos = NULL;
@@ -768,7 +785,7 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl)
                    xfrm_sec_ctx_match(pol->security, policy->security) &&
                    !WARN_ON(delpol)) {
                        if (excl) {
-                               write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+                               spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
                                return -EEXIST;
                        }
                        delpol = pol;
@@ -804,7 +821,7 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl)
        policy->curlft.use_time = 0;
        if (!mod_timer(&policy->timer, jiffies + HZ))
                xfrm_pol_hold(policy);
-       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
 
        if (delpol)
                xfrm_policy_kill(delpol);
@@ -824,7 +841,7 @@ struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u32 mark, u8 type,
        struct hlist_head *chain;
 
        *err = 0;
-       write_lock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
        chain = policy_hash_bysel(net, sel, sel->family, dir);
        ret = NULL;
        hlist_for_each_entry(pol, chain, bydst) {
@@ -837,7 +854,7 @@ struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u32 mark, u8 type,
                                *err = security_xfrm_policy_delete(
                                                                pol->security);
                                if (*err) {
-                                       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+                                       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
                                        return pol;
                                }
                                __xfrm_policy_unlink(pol, dir);
@@ -846,7 +863,7 @@ struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u32 mark, u8 type,
                        break;
                }
        }
-       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
 
        if (ret && delete)
                xfrm_policy_kill(ret);
@@ -865,7 +882,7 @@ struct xfrm_policy *xfrm_policy_byid(struct net *net, u32 mark, u8 type,
                return NULL;
 
        *err = 0;
-       write_lock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
        chain = net->xfrm.policy_byidx + idx_hash(net, id);
        ret = NULL;
        hlist_for_each_entry(pol, chain, byidx) {
@@ -876,7 +893,7 @@ struct xfrm_policy *xfrm_policy_byid(struct net *net, u32 mark, u8 type,
                                *err = security_xfrm_policy_delete(
                                                                pol->security);
                                if (*err) {
-                                       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+                                       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
                                        return pol;
                                }
                                __xfrm_policy_unlink(pol, dir);
@@ -885,7 +902,7 @@ struct xfrm_policy *xfrm_policy_byid(struct net *net, u32 mark, u8 type,
                        break;
                }
        }
-       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
 
        if (ret && delete)
                xfrm_policy_kill(ret);
@@ -943,7 +960,7 @@ int xfrm_policy_flush(struct net *net, u8 type, bool task_valid)
 {
        int dir, err = 0, cnt = 0;
 
-       write_lock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
 
        err = xfrm_policy_flush_secctx_check(net, type, task_valid);
        if (err)
@@ -959,14 +976,14 @@ int xfrm_policy_flush(struct net *net, u8 type, bool task_valid)
                        if (pol->type != type)
                                continue;
                        __xfrm_policy_unlink(pol, dir);
-                       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+                       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
                        cnt++;
 
                        xfrm_audit_policy_delete(pol, 1, task_valid);
 
                        xfrm_policy_kill(pol);
 
-                       write_lock_bh(&net->xfrm.xfrm_policy_lock);
+                       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
                        goto again1;
                }
 
@@ -978,13 +995,13 @@ int xfrm_policy_flush(struct net *net, u8 type, bool task_valid)
                                if (pol->type != type)
                                        continue;
                                __xfrm_policy_unlink(pol, dir);
-                               write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+                               spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
                                cnt++;
 
                                xfrm_audit_policy_delete(pol, 1, task_valid);
                                xfrm_policy_kill(pol);
 
-                               write_lock_bh(&net->xfrm.xfrm_policy_lock);
+                               spin_lock_bh(&net->xfrm.xfrm_policy_lock);
                                goto again2;
                        }
                }
@@ -993,7 +1010,7 @@ int xfrm_policy_flush(struct net *net, u8 type, bool task_valid)
        if (!cnt)
                err = -ESRCH;
 out:
-       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
        return err;
 }
 EXPORT_SYMBOL(xfrm_policy_flush);
@@ -1013,7 +1030,7 @@ int xfrm_policy_walk(struct net *net, struct xfrm_policy_walk *walk,
        if (list_empty(&walk->walk.all) && walk->seq != 0)
                return 0;
 
-       write_lock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
        if (list_empty(&walk->walk.all))
                x = list_first_entry(&net->xfrm.policy_all, struct xfrm_policy_walk_entry, all);
        else
@@ -1041,7 +1058,7 @@ int xfrm_policy_walk(struct net *net, struct xfrm_policy_walk *walk,
        }
        list_del_init(&walk->walk.all);
 out:
-       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
        return error;
 }
 EXPORT_SYMBOL(xfrm_policy_walk);
@@ -1060,9 +1077,9 @@ void xfrm_policy_walk_done(struct xfrm_policy_walk *walk, struct net *net)
        if (list_empty(&walk->walk.all))
                return;
 
-       write_lock_bh(&net->xfrm.xfrm_policy_lock); /*FIXME where is net? */
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock); /*FIXME where is net? */
        list_del(&walk->walk.all);
-       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
 }
 EXPORT_SYMBOL(xfrm_policy_walk_done);
 
@@ -1100,17 +1117,24 @@ static struct xfrm_policy *xfrm_policy_lookup_bytype(struct net *net, u8 type,
        struct xfrm_policy *pol, *ret;
        const xfrm_address_t *daddr, *saddr;
        struct hlist_head *chain;
-       u32 priority = ~0U;
+       unsigned int sequence;
+       u32 priority;
 
        daddr = xfrm_flowi_daddr(fl, family);
        saddr = xfrm_flowi_saddr(fl, family);
        if (unlikely(!daddr || !saddr))
                return NULL;
 
-       read_lock_bh(&net->xfrm.xfrm_policy_lock);
-       chain = policy_hash_direct(net, daddr, saddr, family, dir);
+       rcu_read_lock();
+ retry:
+       do {
+               sequence = read_seqcount_begin(&xfrm_policy_hash_generation);
+               chain = policy_hash_direct(net, daddr, saddr, family, dir);
+       } while (read_seqcount_retry(&xfrm_policy_hash_generation, sequence));
+
+       priority = ~0U;
        ret = NULL;
-       hlist_for_each_entry(pol, chain, bydst) {
+       hlist_for_each_entry_rcu(pol, chain, bydst) {
                err = xfrm_policy_match(pol, fl, type, family, dir);
                if (err) {
                        if (err == -ESRCH)
@@ -1126,7 +1150,7 @@ static struct xfrm_policy *xfrm_policy_lookup_bytype(struct net *net, u8 type,
                }
        }
        chain = &net->xfrm.policy_inexact[dir];
-       hlist_for_each_entry(pol, chain, bydst) {
+       hlist_for_each_entry_rcu(pol, chain, bydst) {
                if ((pol->priority >= priority) && ret)
                        break;
 
@@ -1144,9 +1168,13 @@ static struct xfrm_policy *xfrm_policy_lookup_bytype(struct net *net, u8 type,
                }
        }
 
-       xfrm_pol_hold(ret);
+       if (read_seqcount_retry(&xfrm_policy_hash_generation, sequence))
+               goto retry;
+
+       if (ret && !xfrm_pol_hold_rcu(ret))
+               goto retry;
 fail:
-       read_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       rcu_read_unlock();
 
        return ret;
 }
@@ -1223,10 +1251,9 @@ static struct xfrm_policy *xfrm_sk_policy_lookup(const struct sock *sk, int dir,
                                                 const struct flowi *fl)
 {
        struct xfrm_policy *pol;
-       struct net *net = sock_net(sk);
 
        rcu_read_lock();
-       read_lock_bh(&net->xfrm.xfrm_policy_lock);
+ again:
        pol = rcu_dereference(sk->sk_policy[dir]);
        if (pol != NULL) {
                bool match = xfrm_selector_match(&pol->selector, fl,
@@ -1241,8 +1268,8 @@ static struct xfrm_policy *xfrm_sk_policy_lookup(const struct sock *sk, int dir,
                        err = security_xfrm_policy_lookup(pol->security,
                                                      fl->flowi_secid,
                                                      policy_to_flow_dir(dir));
-                       if (!err)
-                               xfrm_pol_hold(pol);
+                       if (!err && !xfrm_pol_hold_rcu(pol))
+                               goto again;
                        else if (err == -ESRCH)
                                pol = NULL;
                        else
@@ -1251,7 +1278,6 @@ static struct xfrm_policy *xfrm_sk_policy_lookup(const struct sock *sk, int dir,
                        pol = NULL;
        }
 out:
-       read_unlock_bh(&net->xfrm.xfrm_policy_lock);
        rcu_read_unlock();
        return pol;
 }
@@ -1275,7 +1301,7 @@ static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol,
 
        /* Socket policies are not hashed. */
        if (!hlist_unhashed(&pol->bydst)) {
-               hlist_del(&pol->bydst);
+               hlist_del_rcu(&pol->bydst);
                hlist_del(&pol->byidx);
        }
 
@@ -1299,9 +1325,9 @@ int xfrm_policy_delete(struct xfrm_policy *pol, int dir)
 {
        struct net *net = xp_net(pol);
 
-       write_lock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
        pol = __xfrm_policy_unlink(pol, dir);
-       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
        if (pol) {
                xfrm_policy_kill(pol);
                return 0;
@@ -1320,7 +1346,7 @@ int xfrm_sk_policy_insert(struct sock *sk, int dir, struct xfrm_policy *pol)
                return -EINVAL;
 #endif
 
-       write_lock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
        old_pol = rcu_dereference_protected(sk->sk_policy[dir],
                                lockdep_is_held(&net->xfrm.xfrm_policy_lock));
        if (pol) {
@@ -1338,7 +1364,7 @@ int xfrm_sk_policy_insert(struct sock *sk, int dir, struct xfrm_policy *pol)
                 */
                xfrm_sk_policy_unlink(old_pol, dir);
        }
-       write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
 
        if (old_pol) {
                xfrm_policy_kill(old_pol);
@@ -1368,9 +1394,9 @@ static struct xfrm_policy *clone_policy(const struct xfrm_policy *old, int dir)
                newp->type = old->type;
                memcpy(newp->xfrm_vec, old->xfrm_vec,
                       newp->xfrm_nr*sizeof(struct xfrm_tmpl));
-               write_lock_bh(&net->xfrm.xfrm_policy_lock);
+               spin_lock_bh(&net->xfrm.xfrm_policy_lock);
                xfrm_sk_policy_link(newp, dir);
-               write_unlock_bh(&net->xfrm.xfrm_policy_lock);
+               spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
                xfrm_pol_put(newp);
        }
        return newp;
@@ -3052,7 +3078,7 @@ static int __net_init xfrm_net_init(struct net *net)
 
        /* Initialize the per-net locks here */
        spin_lock_init(&net->xfrm.xfrm_state_lock);
-       rwlock_init(&net->xfrm.xfrm_policy_lock);
+       spin_lock_init(&net->xfrm.xfrm_policy_lock);
        mutex_init(&net->xfrm.xfrm_cfg_mutex);
 
        return 0;
@@ -3086,6 +3112,7 @@ static struct pernet_operations __net_initdata xfrm_net_ops = {
 void __init xfrm_init(void)
 {
        register_pernet_subsys(&xfrm_net_ops);
+       seqcount_init(&xfrm_policy_hash_generation);
        xfrm_input_init();
 }
 
@@ -3183,7 +3210,7 @@ static struct xfrm_policy *xfrm_migrate_policy_find(const struct xfrm_selector *
        struct hlist_head *chain;
        u32 priority = ~0U;
 
-       read_lock_bh(&net->xfrm.xfrm_policy_lock); /*FIXME*/
+       spin_lock_bh(&net->xfrm.xfrm_policy_lock);
        chain = policy_hash_direct(net, &sel->daddr, &sel->saddr, sel->family, dir);
        hlist_for_each_entry(pol, chain, bydst) {
                if (xfrm_migrate_selector_match(sel, &pol->selector) &&
@@ -3207,7 +3234,7 @@ static struct xfrm_policy *xfrm_migrate_policy_find(const struct xfrm_selector *
 
        xfrm_pol_hold(ret);
 
-       read_unlock_bh(&net->xfrm.xfrm_policy_lock);
+       spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
 
        return ret;
 }