mac80211: add generic cipher scheme support
[cascardo/linux.git] / net / mac80211 / key.c
index 620677e..e568d98 100644 (file)
@@ -260,25 +260,29 @@ static void ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
        int idx;
        bool defunikey, defmultikey, defmgmtkey;
 
+       /* caller must provide at least one old/new */
+       if (WARN_ON(!new && !old))
+               return;
+
        if (new)
                list_add_tail(&new->list, &sdata->key_list);
 
-       if (sta && pairwise) {
-               rcu_assign_pointer(sta->ptk, new);
-       } else if (sta) {
-               if (old)
-                       idx = old->conf.keyidx;
-               else
-                       idx = new->conf.keyidx;
-               rcu_assign_pointer(sta->gtk[idx], new);
-       } else {
-               WARN_ON(new && old && new->conf.keyidx != old->conf.keyidx);
+       WARN_ON(new && old && new->conf.keyidx != old->conf.keyidx);
 
-               if (old)
-                       idx = old->conf.keyidx;
-               else
-                       idx = new->conf.keyidx;
+       if (old)
+               idx = old->conf.keyidx;
+       else
+               idx = new->conf.keyidx;
 
+       if (sta) {
+               if (pairwise) {
+                       rcu_assign_pointer(sta->ptk[idx], new);
+                       sta->ptk_idx = idx;
+               } else {
+                       rcu_assign_pointer(sta->gtk[idx], new);
+                       sta->gtk_idx = idx;
+               }
+       } else {
                defunikey = old &&
                        old == key_mtx_dereference(sdata->local,
                                                sdata->default_unicast_key);
@@ -312,9 +316,11 @@ static void ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
                list_del(&old->list);
 }
 
-struct ieee80211_key *ieee80211_key_alloc(u32 cipher, int idx, size_t key_len,
-                                         const u8 *key_data,
-                                         size_t seq_len, const u8 *seq)
+struct ieee80211_key *
+ieee80211_key_alloc(u32 cipher, int idx, size_t key_len,
+                   const u8 *key_data,
+                   size_t seq_len, const u8 *seq,
+                   const struct ieee80211_cipher_scheme *cs)
 {
        struct ieee80211_key *key;
        int i, j, err;
@@ -393,6 +399,18 @@ struct ieee80211_key *ieee80211_key_alloc(u32 cipher, int idx, size_t key_len,
                        return ERR_PTR(err);
                }
                break;
+       default:
+               if (cs) {
+                       size_t len = (seq_len > MAX_PN_LEN) ?
+                                               MAX_PN_LEN : seq_len;
+
+                       key->conf.iv_len = cs->hdr_len;
+                       key->conf.icv_len = cs->mic_len;
+                       for (i = 0; i < IEEE80211_NUM_TIDS + 1; i++)
+                               for (j = 0; j < len; j++)
+                                       key->u.gen.rx_pn[i][j] =
+                                                       seq[len - j - 1];
+               }
        }
        memcpy(key->conf.key, key_data, key_len);
        INIT_LIST_HEAD(&key->list);
@@ -475,7 +493,7 @@ int ieee80211_key_link(struct ieee80211_key *key,
        mutex_lock(&sdata->local->key_mtx);
 
        if (sta && pairwise)
-               old_key = key_mtx_dereference(sdata->local, sta->ptk);
+               old_key = key_mtx_dereference(sdata->local, sta->ptk[idx]);
        else if (sta)
                old_key = key_mtx_dereference(sdata->local, sta->gtk[idx]);
        else
@@ -625,8 +643,10 @@ void ieee80211_free_sta_keys(struct ieee80211_local *local,
                list_add(&key->list, &keys);
        }
 
-       key = key_mtx_dereference(local, sta->ptk);
-       if (key) {
+       for (i = 0; i < NUM_DEFAULT_KEYS; i++) {
+               key = key_mtx_dereference(local, sta->ptk[i]);
+               if (!key)
+                       continue;
                ieee80211_key_replace(key->sdata, key->sta,
                                key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
                                key, NULL);
@@ -877,9 +897,9 @@ ieee80211_gtk_rekey_add(struct ieee80211_vif *vif,
 
        key = ieee80211_key_alloc(keyconf->cipher, keyconf->keyidx,
                                  keyconf->keylen, keyconf->key,
-                                 0, NULL);
+                                 0, NULL, NULL);
        if (IS_ERR(key))
-               return ERR_PTR(PTR_ERR(key));
+               return ERR_CAST(key);
 
        if (sdata->u.mgd.mfp != IEEE80211_MFP_DISABLED)
                key->conf.flags |= IEEE80211_KEY_FLAG_RX_MGMT;