cris: buggered copy_from_user/copy_to_user/clear_user
[cascardo/linux.git] / crypto / rsa_helper.c
index d226f48..4df6451 100644 (file)
@@ -22,20 +22,29 @@ int rsa_get_n(void *context, size_t hdrlen, unsigned char tag,
              const void *value, size_t vlen)
 {
        struct rsa_key *key = context;
+       const u8 *ptr = value;
+       size_t n_sz = vlen;
 
-       key->n = mpi_read_raw_data(value, vlen);
-
-       if (!key->n)
-               return -ENOMEM;
-
-       /* In FIPS mode only allow key size 2K & 3K */
-       if (fips_enabled && (mpi_get_size(key->n) != 256 &&
-                            mpi_get_size(key->n) != 384)) {
-               pr_err("RSA: key size not allowed in FIPS mode\n");
-               mpi_free(key->n);
-               key->n = NULL;
+       /* invalid key provided */
+       if (!value || !vlen)
                return -EINVAL;
+
+       if (fips_enabled) {
+               while (!*ptr && n_sz) {
+                       ptr++;
+                       n_sz--;
+               }
+
+               /* In FIPS mode only allow key size 2K & 3K */
+               if (n_sz != 256 && n_sz != 384) {
+                       pr_err("RSA: key size not allowed in FIPS mode\n");
+                       return -EINVAL;
+               }
        }
+
+       key->n = value;
+       key->n_sz = vlen;
+
        return 0;
 }
 
@@ -44,10 +53,12 @@ int rsa_get_e(void *context, size_t hdrlen, unsigned char tag,
 {
        struct rsa_key *key = context;
 
-       key->e = mpi_read_raw_data(value, vlen);
+       /* invalid key provided */
+       if (!value || !key->n_sz || !vlen || vlen > key->n_sz)
+               return -EINVAL;
 
-       if (!key->e)
-               return -ENOMEM;
+       key->e = value;
+       key->e_sz = vlen;
 
        return 0;
 }
@@ -57,46 +68,95 @@ int rsa_get_d(void *context, size_t hdrlen, unsigned char tag,
 {
        struct rsa_key *key = context;
 
-       key->d = mpi_read_raw_data(value, vlen);
+       /* invalid key provided */
+       if (!value || !key->n_sz || !vlen || vlen > key->n_sz)
+               return -EINVAL;
 
-       if (!key->d)
-               return -ENOMEM;
+       key->d = value;
+       key->d_sz = vlen;
 
-       /* In FIPS mode only allow key size 2K & 3K */
-       if (fips_enabled && (mpi_get_size(key->d) != 256 &&
-                            mpi_get_size(key->d) != 384)) {
-               pr_err("RSA: key size not allowed in FIPS mode\n");
-               mpi_free(key->d);
-               key->d = NULL;
+       return 0;
+}
+
+int rsa_get_p(void *context, size_t hdrlen, unsigned char tag,
+             const void *value, size_t vlen)
+{
+       struct rsa_key *key = context;
+
+       /* invalid key provided */
+       if (!value || !vlen || vlen > key->n_sz)
                return -EINVAL;
-       }
+
+       key->p = value;
+       key->p_sz = vlen;
+
        return 0;
 }
 
-static void free_mpis(struct rsa_key *key)
+int rsa_get_q(void *context, size_t hdrlen, unsigned char tag,
+             const void *value, size_t vlen)
 {
-       mpi_free(key->n);
-       mpi_free(key->e);
-       mpi_free(key->d);
-       key->n = NULL;
-       key->e = NULL;
-       key->d = NULL;
+       struct rsa_key *key = context;
+
+       /* invalid key provided */
+       if (!value || !vlen || vlen > key->n_sz)
+               return -EINVAL;
+
+       key->q = value;
+       key->q_sz = vlen;
+
+       return 0;
 }
 
-/**
- * rsa_free_key() - frees rsa key allocated by rsa_parse_key()
- *
- * @rsa_key:   struct rsa_key key representation
- */
-void rsa_free_key(struct rsa_key *key)
+int rsa_get_dp(void *context, size_t hdrlen, unsigned char tag,
+              const void *value, size_t vlen)
+{
+       struct rsa_key *key = context;
+
+       /* invalid key provided */
+       if (!value || !vlen || vlen > key->n_sz)
+               return -EINVAL;
+
+       key->dp = value;
+       key->dp_sz = vlen;
+
+       return 0;
+}
+
+int rsa_get_dq(void *context, size_t hdrlen, unsigned char tag,
+              const void *value, size_t vlen)
 {
-       free_mpis(key);
+       struct rsa_key *key = context;
+
+       /* invalid key provided */
+       if (!value || !vlen || vlen > key->n_sz)
+               return -EINVAL;
+
+       key->dq = value;
+       key->dq_sz = vlen;
+
+       return 0;
+}
+
+int rsa_get_qinv(void *context, size_t hdrlen, unsigned char tag,
+                const void *value, size_t vlen)
+{
+       struct rsa_key *key = context;
+
+       /* invalid key provided */
+       if (!value || !vlen || vlen > key->n_sz)
+               return -EINVAL;
+
+       key->qinv = value;
+       key->qinv_sz = vlen;
+
+       return 0;
 }
-EXPORT_SYMBOL_GPL(rsa_free_key);
 
 /**
- * rsa_parse_pub_key() - extracts an rsa public key from BER encoded buffer
- *                      and stores it in the provided struct rsa_key
+ * rsa_parse_pub_key() - decodes the BER encoded buffer and stores in the
+ *                       provided struct rsa_key, pointers to the raw key as is,
+ *                       so that the caller can copy it or MPI parse it, etc.
  *
  * @rsa_key:   struct rsa_key key representation
  * @key:       key in BER format
@@ -107,23 +167,15 @@ EXPORT_SYMBOL_GPL(rsa_free_key);
 int rsa_parse_pub_key(struct rsa_key *rsa_key, const void *key,
                      unsigned int key_len)
 {
-       int ret;
-
-       free_mpis(rsa_key);
-       ret = asn1_ber_decoder(&rsapubkey_decoder, rsa_key, key, key_len);
-       if (ret < 0)
-               goto error;
-
-       return 0;
-error:
-       free_mpis(rsa_key);
-       return ret;
+       return asn1_ber_decoder(&rsapubkey_decoder, rsa_key, key, key_len);
 }
 EXPORT_SYMBOL_GPL(rsa_parse_pub_key);
 
 /**
- * rsa_parse_pub_key() - extracts an rsa private key from BER encoded buffer
- *                      and stores it in the provided struct rsa_key
+ * rsa_parse_priv_key() - decodes the BER encoded buffer and stores in the
+ *                        provided struct rsa_key, pointers to the raw key
+ *                        as is, so that the caller can copy it or MPI parse it,
+ *                        etc.
  *
  * @rsa_key:   struct rsa_key key representation
  * @key:       key in BER format
@@ -134,16 +186,6 @@ EXPORT_SYMBOL_GPL(rsa_parse_pub_key);
 int rsa_parse_priv_key(struct rsa_key *rsa_key, const void *key,
                       unsigned int key_len)
 {
-       int ret;
-
-       free_mpis(rsa_key);
-       ret = asn1_ber_decoder(&rsaprivkey_decoder, rsa_key, key, key_len);
-       if (ret < 0)
-               goto error;
-
-       return 0;
-error:
-       free_mpis(rsa_key);
-       return ret;
+       return asn1_ber_decoder(&rsaprivkey_decoder, rsa_key, key, key_len);
 }
 EXPORT_SYMBOL_GPL(rsa_parse_priv_key);