crypto: rsa - Generate fixed-length output
[cascardo/linux.git] / crypto / rsa.c
index 77d737f..4c280b6 100644 (file)
  */
 
 #include <linux/module.h>
+#include <linux/mpi.h>
 #include <crypto/internal/rsa.h>
 #include <crypto/internal/akcipher.h>
 #include <crypto/akcipher.h>
 #include <crypto/algapi.h>
 
+struct rsa_mpi_key {
+       MPI n;
+       MPI e;
+       MPI d;
+};
+
 /*
  * RSAEP function [RFC3447 sec 5.1.1]
  * c = m^e mod n;
  */
-static int _rsa_enc(const struct rsa_key *key, MPI c, MPI m)
+static int _rsa_enc(const struct rsa_mpi_key *key, MPI c, MPI m)
 {
        /* (1) Validate 0 <= m < n */
        if (mpi_cmp_ui(m, 0) < 0 || mpi_cmp(m, key->n) >= 0)
@@ -33,7 +40,7 @@ static int _rsa_enc(const struct rsa_key *key, MPI c, MPI m)
  * RSADP function [RFC3447 sec 5.1.2]
  * m = c^d mod n;
  */
-static int _rsa_dec(const struct rsa_key *key, MPI m, MPI c)
+static int _rsa_dec(const struct rsa_mpi_key *key, MPI m, MPI c)
 {
        /* (1) Validate 0 <= c < n */
        if (mpi_cmp_ui(c, 0) < 0 || mpi_cmp(c, key->n) >= 0)
@@ -47,7 +54,7 @@ static int _rsa_dec(const struct rsa_key *key, MPI m, MPI c)
  * RSASP1 function [RFC3447 sec 5.2.1]
  * s = m^d mod n
  */
-static int _rsa_sign(const struct rsa_key *key, MPI s, MPI m)
+static int _rsa_sign(const struct rsa_mpi_key *key, MPI s, MPI m)
 {
        /* (1) Validate 0 <= m < n */
        if (mpi_cmp_ui(m, 0) < 0 || mpi_cmp(m, key->n) >= 0)
@@ -61,7 +68,7 @@ static int _rsa_sign(const struct rsa_key *key, MPI s, MPI m)
  * RSAVP1 function [RFC3447 sec 5.2.2]
  * m = s^e mod n;
  */
-static int _rsa_verify(const struct rsa_key *key, MPI m, MPI s)
+static int _rsa_verify(const struct rsa_mpi_key *key, MPI m, MPI s)
 {
        /* (1) Validate 0 <= s < n */
        if (mpi_cmp_ui(s, 0) < 0 || mpi_cmp(s, key->n) >= 0)
@@ -71,7 +78,7 @@ static int _rsa_verify(const struct rsa_key *key, MPI m, MPI s)
        return mpi_powm(m, s, key->e, key->n);
 }
 
-static inline struct rsa_key *rsa_get_key(struct crypto_akcipher *tfm)
+static inline struct rsa_mpi_key *rsa_get_key(struct crypto_akcipher *tfm)
 {
        return akcipher_tfm_ctx(tfm);
 }
@@ -79,7 +86,7 @@ static inline struct rsa_key *rsa_get_key(struct crypto_akcipher *tfm)
 static int rsa_enc(struct akcipher_request *req)
 {
        struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
-       const struct rsa_key *pkey = rsa_get_key(tfm);
+       const struct rsa_mpi_key *pkey = rsa_get_key(tfm);
        MPI m, c = mpi_alloc(0);
        int ret = 0;
        int sign;
@@ -101,7 +108,7 @@ static int rsa_enc(struct akcipher_request *req)
        if (ret)
                goto err_free_m;
 
-       ret = mpi_write_to_sgl(c, req->dst, &req->dst_len, &sign);
+       ret = mpi_write_to_sgl(c, req->dst, req->dst_len, &sign);
        if (ret)
                goto err_free_m;
 
@@ -118,7 +125,7 @@ err_free_c:
 static int rsa_dec(struct akcipher_request *req)
 {
        struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
-       const struct rsa_key *pkey = rsa_get_key(tfm);
+       const struct rsa_mpi_key *pkey = rsa_get_key(tfm);
        MPI c, m = mpi_alloc(0);
        int ret = 0;
        int sign;
@@ -140,7 +147,7 @@ static int rsa_dec(struct akcipher_request *req)
        if (ret)
                goto err_free_c;
 
-       ret = mpi_write_to_sgl(m, req->dst, &req->dst_len, &sign);
+       ret = mpi_write_to_sgl(m, req->dst, req->dst_len, &sign);
        if (ret)
                goto err_free_c;
 
@@ -156,7 +163,7 @@ err_free_m:
 static int rsa_sign(struct akcipher_request *req)
 {
        struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
-       const struct rsa_key *pkey = rsa_get_key(tfm);
+       const struct rsa_mpi_key *pkey = rsa_get_key(tfm);
        MPI m, s = mpi_alloc(0);
        int ret = 0;
        int sign;
@@ -178,7 +185,7 @@ static int rsa_sign(struct akcipher_request *req)
        if (ret)
                goto err_free_m;
 
-       ret = mpi_write_to_sgl(s, req->dst, &req->dst_len, &sign);
+       ret = mpi_write_to_sgl(s, req->dst, req->dst_len, &sign);
        if (ret)
                goto err_free_m;
 
@@ -195,7 +202,7 @@ err_free_s:
 static int rsa_verify(struct akcipher_request *req)
 {
        struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
-       const struct rsa_key *pkey = rsa_get_key(tfm);
+       const struct rsa_mpi_key *pkey = rsa_get_key(tfm);
        MPI s, m = mpi_alloc(0);
        int ret = 0;
        int sign;
@@ -219,7 +226,7 @@ static int rsa_verify(struct akcipher_request *req)
        if (ret)
                goto err_free_s;
 
-       ret = mpi_write_to_sgl(m, req->dst, &req->dst_len, &sign);
+       ret = mpi_write_to_sgl(m, req->dst, req->dst_len, &sign);
        if (ret)
                goto err_free_s;
 
@@ -233,6 +240,16 @@ err_free_m:
        return ret;
 }
 
+static void rsa_free_mpi_key(struct rsa_mpi_key *key)
+{
+       mpi_free(key->d);
+       mpi_free(key->e);
+       mpi_free(key->n);
+       key->d = NULL;
+       key->e = NULL;
+       key->n = NULL;
+}
+
 static int rsa_check_key_length(unsigned int len)
 {
        switch (len) {
@@ -251,49 +268,87 @@ static int rsa_check_key_length(unsigned int len)
 static int rsa_set_pub_key(struct crypto_akcipher *tfm, const void *key,
                           unsigned int keylen)
 {
-       struct rsa_key *pkey = akcipher_tfm_ctx(tfm);
+       struct rsa_mpi_key *mpi_key = akcipher_tfm_ctx(tfm);
+       struct rsa_key raw_key = {0};
        int ret;
 
-       ret = rsa_parse_pub_key(pkey, key, keylen);
+       /* Free the old MPI key if any */
+       rsa_free_mpi_key(mpi_key);
+
+       ret = rsa_parse_pub_key(&raw_key, key, keylen);
        if (ret)
                return ret;
 
-       if (rsa_check_key_length(mpi_get_size(pkey->n) << 3)) {
-               rsa_free_key(pkey);
-               ret = -EINVAL;
+       mpi_key->e = mpi_read_raw_data(raw_key.e, raw_key.e_sz);
+       if (!mpi_key->e)
+               goto err;
+
+       mpi_key->n = mpi_read_raw_data(raw_key.n, raw_key.n_sz);
+       if (!mpi_key->n)
+               goto err;
+
+       if (rsa_check_key_length(mpi_get_size(mpi_key->n) << 3)) {
+               rsa_free_mpi_key(mpi_key);
+               return -EINVAL;
        }
-       return ret;
+
+       return 0;
+
+err:
+       rsa_free_mpi_key(mpi_key);
+       return -ENOMEM;
 }
 
 static int rsa_set_priv_key(struct crypto_akcipher *tfm, const void *key,
                            unsigned int keylen)
 {
-       struct rsa_key *pkey = akcipher_tfm_ctx(tfm);
+       struct rsa_mpi_key *mpi_key = akcipher_tfm_ctx(tfm);
+       struct rsa_key raw_key = {0};
        int ret;
 
-       ret = rsa_parse_priv_key(pkey, key, keylen);
+       /* Free the old MPI key if any */
+       rsa_free_mpi_key(mpi_key);
+
+       ret = rsa_parse_priv_key(&raw_key, key, keylen);
        if (ret)
                return ret;
 
-       if (rsa_check_key_length(mpi_get_size(pkey->n) << 3)) {
-               rsa_free_key(pkey);
-               ret = -EINVAL;
+       mpi_key->d = mpi_read_raw_data(raw_key.d, raw_key.d_sz);
+       if (!mpi_key->d)
+               goto err;
+
+       mpi_key->e = mpi_read_raw_data(raw_key.e, raw_key.e_sz);
+       if (!mpi_key->e)
+               goto err;
+
+       mpi_key->n = mpi_read_raw_data(raw_key.n, raw_key.n_sz);
+       if (!mpi_key->n)
+               goto err;
+
+       if (rsa_check_key_length(mpi_get_size(mpi_key->n) << 3)) {
+               rsa_free_mpi_key(mpi_key);
+               return -EINVAL;
        }
-       return ret;
+
+       return 0;
+
+err:
+       rsa_free_mpi_key(mpi_key);
+       return -ENOMEM;
 }
 
 static int rsa_max_size(struct crypto_akcipher *tfm)
 {
-       struct rsa_key *pkey = akcipher_tfm_ctx(tfm);
+       struct rsa_mpi_key *pkey = akcipher_tfm_ctx(tfm);
 
        return pkey->n ? mpi_get_size(pkey->n) : -EINVAL;
 }
 
 static void rsa_exit_tfm(struct crypto_akcipher *tfm)
 {
-       struct rsa_key *pkey = akcipher_tfm_ctx(tfm);
+       struct rsa_mpi_key *pkey = akcipher_tfm_ctx(tfm);
 
-       rsa_free_key(pkey);
+       rsa_free_mpi_key(pkey);
 }
 
 static struct akcipher_alg rsa = {
@@ -310,7 +365,7 @@ static struct akcipher_alg rsa = {
                .cra_driver_name = "rsa-generic",
                .cra_priority = 100,
                .cra_module = THIS_MODULE,
-               .cra_ctxsize = sizeof(struct rsa_key),
+               .cra_ctxsize = sizeof(struct rsa_mpi_key),
        },
 };