zram: check comp algorithm availability earlier
[cascardo/linux.git] / crypto / testmgr.c
index 915a9ef..975e1ea 100644 (file)
@@ -30,6 +30,7 @@
 #include <linux/string.h>
 #include <crypto/rng.h>
 #include <crypto/drbg.h>
+#include <crypto/akcipher.h>
 
 #include "internal.h"
 
@@ -116,6 +117,11 @@ struct drbg_test_suite {
        unsigned int count;
 };
 
+struct akcipher_test_suite {
+       struct akcipher_testvec *vecs;
+       unsigned int count;
+};
+
 struct alg_test_desc {
        const char *alg;
        int (*test)(const struct alg_test_desc *desc, const char *driver,
@@ -130,6 +136,7 @@ struct alg_test_desc {
                struct hash_test_suite hash;
                struct cprng_test_suite cprng;
                struct drbg_test_suite drbg;
+               struct akcipher_test_suite akcipher;
        } suite;
 };
 
@@ -1825,6 +1832,147 @@ static int alg_test_drbg(const struct alg_test_desc *desc, const char *driver,
 
 }
 
+static int do_test_rsa(struct crypto_akcipher *tfm,
+                      struct akcipher_testvec *vecs)
+{
+       struct akcipher_request *req;
+       void *outbuf_enc = NULL;
+       void *outbuf_dec = NULL;
+       struct tcrypt_result result;
+       unsigned int out_len_max, out_len = 0;
+       int err = -ENOMEM;
+
+       req = akcipher_request_alloc(tfm, GFP_KERNEL);
+       if (!req)
+               return err;
+
+       init_completion(&result.completion);
+       err = crypto_akcipher_setkey(tfm, vecs->key, vecs->key_len);
+       if (err)
+               goto free_req;
+
+       akcipher_request_set_crypt(req, vecs->m, outbuf_enc, vecs->m_size,
+                                  out_len);
+       /* expect this to fail, and update the required buf len */
+       crypto_akcipher_encrypt(req);
+       out_len = req->dst_len;
+       if (!out_len) {
+               err = -EINVAL;
+               goto free_req;
+       }
+
+       out_len_max = out_len;
+       err = -ENOMEM;
+       outbuf_enc = kzalloc(out_len_max, GFP_KERNEL);
+       if (!outbuf_enc)
+               goto free_req;
+
+       akcipher_request_set_crypt(req, vecs->m, outbuf_enc, vecs->m_size,
+                                  out_len);
+       akcipher_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG,
+                                     tcrypt_complete, &result);
+
+       /* Run RSA encrypt - c = m^e mod n;*/
+       err = wait_async_op(&result, crypto_akcipher_encrypt(req));
+       if (err) {
+               pr_err("alg: rsa: encrypt test failed. err %d\n", err);
+               goto free_all;
+       }
+       if (out_len != vecs->c_size) {
+               pr_err("alg: rsa: encrypt test failed. Invalid output len\n");
+               err = -EINVAL;
+               goto free_all;
+       }
+       /* verify that encrypted message is equal to expected */
+       if (memcmp(vecs->c, outbuf_enc, vecs->c_size)) {
+               pr_err("alg: rsa: encrypt test failed. Invalid output\n");
+               err = -EINVAL;
+               goto free_all;
+       }
+       /* Don't invoke decrypt for vectors with public key */
+       if (vecs->public_key_vec) {
+               err = 0;
+               goto free_all;
+       }
+       outbuf_dec = kzalloc(out_len_max, GFP_KERNEL);
+       if (!outbuf_dec) {
+               err = -ENOMEM;
+               goto free_all;
+       }
+       init_completion(&result.completion);
+       akcipher_request_set_crypt(req, outbuf_enc, outbuf_dec, vecs->c_size,
+                                  out_len);
+
+       /* Run RSA decrypt - m = c^d mod n;*/
+       err = wait_async_op(&result, crypto_akcipher_decrypt(req));
+       if (err) {
+               pr_err("alg: rsa: decrypt test failed. err %d\n", err);
+               goto free_all;
+       }
+       out_len = req->dst_len;
+       if (out_len != vecs->m_size) {
+               pr_err("alg: rsa: decrypt test failed. Invalid output len\n");
+               err = -EINVAL;
+               goto free_all;
+       }
+       /* verify that decrypted message is equal to the original msg */
+       if (memcmp(vecs->m, outbuf_dec, vecs->m_size)) {
+               pr_err("alg: rsa: decrypt test failed. Invalid output\n");
+               err = -EINVAL;
+       }
+free_all:
+       kfree(outbuf_dec);
+       kfree(outbuf_enc);
+free_req:
+       akcipher_request_free(req);
+       return err;
+}
+
+static int test_rsa(struct crypto_akcipher *tfm, struct akcipher_testvec *vecs,
+                   unsigned int tcount)
+{
+       int ret, i;
+
+       for (i = 0; i < tcount; i++) {
+               ret = do_test_rsa(tfm, vecs++);
+               if (ret) {
+                       pr_err("alg: rsa: test failed on vector %d, err=%d\n",
+                              i + 1, ret);
+                       return ret;
+               }
+       }
+       return 0;
+}
+
+static int test_akcipher(struct crypto_akcipher *tfm, const char *alg,
+                        struct akcipher_testvec *vecs, unsigned int tcount)
+{
+       if (strncmp(alg, "rsa", 3) == 0)
+               return test_rsa(tfm, vecs, tcount);
+
+       return 0;
+}
+
+static int alg_test_akcipher(const struct alg_test_desc *desc,
+                            const char *driver, u32 type, u32 mask)
+{
+       struct crypto_akcipher *tfm;
+       int err = 0;
+
+       tfm = crypto_alloc_akcipher(driver, type | CRYPTO_ALG_INTERNAL, mask);
+       if (IS_ERR(tfm)) {
+               pr_err("alg: akcipher: Failed to load tfm for %s: %ld\n",
+                      driver, PTR_ERR(tfm));
+               return PTR_ERR(tfm);
+       }
+       if (desc->suite.akcipher.vecs)
+               err = test_akcipher(tfm, desc->alg, desc->suite.akcipher.vecs,
+                                   desc->suite.akcipher.count);
+
+       crypto_free_akcipher(tfm);
+       return err;
+}
+
 static int alg_test_null(const struct alg_test_desc *desc,
                             const char *driver, u32 type, u32 mask)
 {
@@ -3401,6 +3549,21 @@ static const struct alg_test_desc alg_test_descs[] = {
                                },
                        }
                }
+       }, {
+               .alg = "rfc7539esp(chacha20,poly1305)",
+               .test = alg_test_aead,
+               .suite = {
+                       .aead = {
+                               .enc = {
+                                       .vecs = rfc7539esp_enc_tv_template,
+                                       .count = RFC7539ESP_ENC_TEST_VECTORS
+                               },
+                               .dec = {
+                                       .vecs = rfc7539esp_dec_tv_template,
+                                       .count = RFC7539ESP_DEC_TEST_VECTORS
+                               },
+                       }
+               }
        }, {
                .alg = "rmd128",
                .test = alg_test_hash,
@@ -3437,6 +3600,16 @@ static const struct alg_test_desc alg_test_descs[] = {
                                .count = RMD320_TEST_VECTORS
                        }
                }
+       }, {
+               .alg = "rsa",
+               .test = alg_test_akcipher,
+               .fips_allowed = 1,
+               .suite = {
+                       .akcipher = {
+                               .vecs = rsa_tv_template,
+                               .count = RSA_TEST_VECTORS
+                       }
+               }
        }, {
                .alg = "salsa20",
                .test = alg_test_skcipher,