The RSA output must be processed as per the EMSA-PSS-VERIFY operation
from RFC8017, which forms the core of the PSS signature verification.

Implement the verification callback, which operates on the RSA output
buffer.

Reference: https://tools.ietf.org/html/rfc8017#section-9.1.2
Signed-off-by: Varad Gautam <varad.gau...@suse.com>
---
v2: Allow mgf_hash_algo to be different from digest hash algorithm.

 crypto/rsa-psspad.c | 114 +++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 113 insertions(+), 1 deletion(-)

diff --git a/crypto/rsa-psspad.c b/crypto/rsa-psspad.c
index 4ba4d69f6ce17..87e90479a4fa7 100644
--- a/crypto/rsa-psspad.c
+++ b/crypto/rsa-psspad.c
@@ -8,6 +8,7 @@
 
 #include <crypto/hash.h>
 #include <crypto/internal/akcipher.h>
+#include <crypto/internal/rsa.h>
 #include <crypto/internal/rsa-common.h>
 #include <crypto/public_key.h>
 
@@ -107,7 +108,118 @@ static int pkcs1_mgf1(u8 *seed, unsigned int seed_len,
 
 static int psspad_verify_complete(struct akcipher_request *req, int err)
 {
-       return -EOPNOTSUPP;
+       struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
+       struct rsapad_tfm_ctx *ctx = akcipher_tfm_ctx(tfm);
+       struct rsapad_akciper_req_ctx *req_ctx = akcipher_request_ctx(req);
+       struct akcipher_instance *inst = akcipher_alg_instance(tfm);
+       struct rsapad_inst_ctx *ictx = akcipher_instance_ctx(inst);
+       const struct rsa_asn1_template *digest_info = ictx->digest_info;
+       struct crypto_shash *hash_tfm = NULL;
+       struct shash_desc *desc = NULL;
+       struct rsa_mpi_key *pkey = akcipher_tfm_ctx(ctx->child);
+
+       u8 *em, *h, *salt, *maskeddb;
+       unsigned int em_len, em_bits, h_len, s_len, maskeddb_len;
+       u8 *m_hash, *db_mask, *db, *h_;
+       static u8 zeroes[8] = { 0 };
+       unsigned int pos;
+
+       if (err)
+               goto out;
+
+       err = -EINVAL;
+       if (!digest_info)
+               goto out;
+
+       em = req_ctx->out_buf;
+       em_len = ctx->key_size;
+       em_bits = mpi_get_nbits(pkey->n) - 1;
+       if ((em_bits & 0x7) == 0) {
+               em_len--;
+               em++;
+       }
+
+       h_len = req->dst_len;
+       s_len = ictx->salt_len;
+
+       if (em_len < h_len + s_len + 2)
+               goto out;
+
+       if (em[em_len - 1] != 0xbc)
+               goto out;
+
+       maskeddb = em;
+       maskeddb_len = em_len - h_len - 1;
+       h = em + maskeddb_len;
+
+       if (em[0] & ~((u8) 0xff >> (8 * em_len - em_bits)))
+               goto out;
+
+       db_mask = kzalloc(maskeddb_len, GFP_KERNEL);
+       if (!db_mask) {
+               err = -ENOMEM;
+               goto out;
+       }
+
+       err = psspad_setup_shash(&hash_tfm, &desc, ictx->mgf_hash_algo);
+       if (err < 0)
+               goto out_db_mask;
+
+       err = pkcs1_mgf1(h, h_len, desc, db_mask, maskeddb_len);
+       if (err < 0)
+               goto out_shash;
+
+       for (pos = 0; pos < maskeddb_len; pos++)
+               maskeddb[pos] ^= db_mask[pos];
+       db = maskeddb;
+
+       db[0] &= ((u8) 0xff >> (8 * em_len - em_bits));
+
+       err = -EINVAL;
+       for (pos = 0; pos < em_len - h_len - s_len - 2; pos++) {
+               if (db[pos] != 0)
+                       goto out_shash;
+       }
+       if (db[pos] != 0x01)
+               goto out_shash;
+
+       salt = db + (maskeddb_len - s_len);
+
+       m_hash = req_ctx->out_buf + ctx->key_size;
+       sg_pcopy_to_buffer(req->src,
+                          sg_nents_for_len(req->src, req->src_len + 
req->dst_len),
+                          m_hash,
+                          req->dst_len, ctx->key_size);
+
+       if (strcmp(ictx->mgf_hash_algo, digest_info->name) != 0) {
+               psspad_free_shash(hash_tfm, desc);
+               err = psspad_setup_shash(&hash_tfm, &desc, digest_info->name);
+               if (err < 0)
+                       goto out_db_mask;
+       }
+
+       err = crypto_shash_init(desc);
+       if (!err)
+               err = crypto_shash_update(desc, zeroes, 8);
+       if (!err)
+               err = crypto_shash_update(desc, m_hash, h_len);
+       if (!err)
+               err = crypto_shash_finup(desc, salt, s_len, m_hash);
+       if (err < 0)
+               goto out_shash;
+
+       h_ = m_hash;
+
+       if (memcmp(h_, h, h_len) != 0)
+               err = -EKEYREJECTED;
+
+out_shash:
+       psspad_free_shash(hash_tfm, desc);
+out_db_mask:
+       kfree(db_mask);
+out:
+       kfree_sensitive(req_ctx->out_buf);
+       return err;
 }
 
 static void psspad_verify_complete_cb(struct crypto_async_request 
*child_async_req,
-- 
2.30.2

Reply via email to