With AES algo, generic API aes_check_keylen() is used to check length
of key.

Signed-off-by: Bibo Mao <[email protected]>
---
 .../virtio/virtio_crypto_skcipher_algs.c      | 20 ++++++++-----------
 1 file changed, 8 insertions(+), 12 deletions(-)

diff --git a/drivers/crypto/virtio/virtio_crypto_skcipher_algs.c 
b/drivers/crypto/virtio/virtio_crypto_skcipher_algs.c
index 8a139de3d064..682d192a4ed7 100644
--- a/drivers/crypto/virtio/virtio_crypto_skcipher_algs.c
+++ b/drivers/crypto/virtio/virtio_crypto_skcipher_algs.c
@@ -94,18 +94,16 @@ static u64 virtio_crypto_alg_sg_nents_length(struct 
scatterlist *sg)
 }
 
 static int
-virtio_crypto_alg_validate_key(int key_len, uint32_t *alg)
+virtio_crypto_alg_validate_key(int key_len, uint32_t alg)
 {
-       switch (key_len) {
-       case AES_KEYSIZE_128:
-       case AES_KEYSIZE_192:
-       case AES_KEYSIZE_256:
-               *alg = VIRTIO_CRYPTO_CIPHER_AES_CBC;
-               break;
+       switch (alg) {
+       case VIRTIO_CRYPTO_CIPHER_AES_ECB:
+       case VIRTIO_CRYPTO_CIPHER_AES_CBC:
+       case VIRTIO_CRYPTO_CIPHER_AES_CTR:
+               return aes_check_keylen(key_len);
        default:
                return -EINVAL;
        }
-       return 0;
 }
 
 static int virtio_crypto_alg_skcipher_init_session(
@@ -248,7 +246,6 @@ static int virtio_crypto_alg_skcipher_init_sessions(
                struct virtio_crypto_skcipher_ctx *ctx,
                const uint8_t *key, unsigned int keylen)
 {
-       uint32_t alg;
        int ret;
        struct virtio_crypto *vcrypto = ctx->vcrypto;
 
@@ -257,7 +254,7 @@ static int virtio_crypto_alg_skcipher_init_sessions(
                return -EINVAL;
        }
 
-       if (virtio_crypto_alg_validate_key(keylen, &alg))
+       if (virtio_crypto_alg_validate_key(keylen, ctx->alg->algonum))
                return -EINVAL;
 
        /* Create encryption session */
@@ -279,10 +276,9 @@ static int virtio_crypto_skcipher_setkey(struct 
crypto_skcipher *tfm,
                                         unsigned int keylen)
 {
        struct virtio_crypto_skcipher_ctx *ctx = crypto_skcipher_ctx(tfm);
-       uint32_t alg;
        int ret;
 
-       ret = virtio_crypto_alg_validate_key(keylen, &alg);
+       ret = virtio_crypto_alg_validate_key(keylen, ctx->alg->algonum);
        if (ret)
                return ret;
 
-- 
2.39.3


Reply via email to