Signed-off-by: Ard Biesheuvel <ard.biesheu...@linaro.org>
---
 arch/x86/crypto/des3_ede_glue.c | 39 ++++++++++----------
 1 file changed, 20 insertions(+), 19 deletions(-)

diff --git a/arch/x86/crypto/des3_ede_glue.c b/arch/x86/crypto/des3_ede_glue.c
index 571966e5c542..9c743246f5ad 100644
--- a/arch/x86/crypto/des3_ede_glue.c
+++ b/arch/x86/crypto/des3_ede_glue.c
@@ -21,7 +21,7 @@
  */
 
 #include <crypto/algapi.h>
-#include <crypto/internal/des.h>
+#include <crypto/des.h>
 #include <crypto/internal/skcipher.h>
 #include <linux/crypto.h>
 #include <linux/init.h>
@@ -29,8 +29,8 @@
 #include <linux/types.h>
 
 struct des3_ede_x86_ctx {
-       u32 enc_expkey[DES3_EDE_EXPKEY_WORDS];
-       u32 dec_expkey[DES3_EDE_EXPKEY_WORDS];
+       struct des3_ede_ctx enc;
+       struct des3_ede_ctx dec;
 };
 
 /* regular block cipher functions */
@@ -44,7 +44,7 @@ asmlinkage void des3_ede_x86_64_crypt_blk_3way(const u32 
*expkey, u8 *dst,
 static inline void des3_ede_enc_blk(struct des3_ede_x86_ctx *ctx, u8 *dst,
                                    const u8 *src)
 {
-       u32 *enc_ctx = ctx->enc_expkey;
+       u32 *enc_ctx = ctx->enc.expkey;
 
        des3_ede_x86_64_crypt_blk(enc_ctx, dst, src);
 }
@@ -52,7 +52,7 @@ static inline void des3_ede_enc_blk(struct des3_ede_x86_ctx 
*ctx, u8 *dst,
 static inline void des3_ede_dec_blk(struct des3_ede_x86_ctx *ctx, u8 *dst,
                                    const u8 *src)
 {
-       u32 *dec_ctx = ctx->dec_expkey;
+       u32 *dec_ctx = ctx->dec.expkey;
 
        des3_ede_x86_64_crypt_blk(dec_ctx, dst, src);
 }
@@ -60,7 +60,7 @@ static inline void des3_ede_dec_blk(struct des3_ede_x86_ctx 
*ctx, u8 *dst,
 static inline void des3_ede_enc_blk_3way(struct des3_ede_x86_ctx *ctx, u8 *dst,
                                         const u8 *src)
 {
-       u32 *enc_ctx = ctx->enc_expkey;
+       u32 *enc_ctx = ctx->enc.expkey;
 
        des3_ede_x86_64_crypt_blk_3way(enc_ctx, dst, src);
 }
@@ -68,7 +68,7 @@ static inline void des3_ede_enc_blk_3way(struct 
des3_ede_x86_ctx *ctx, u8 *dst,
 static inline void des3_ede_dec_blk_3way(struct des3_ede_x86_ctx *ctx, u8 *dst,
                                         const u8 *src)
 {
-       u32 *dec_ctx = ctx->dec_expkey;
+       u32 *dec_ctx = ctx->dec.expkey;
 
        des3_ede_x86_64_crypt_blk_3way(dec_ctx, dst, src);
 }
@@ -132,7 +132,7 @@ static int ecb_encrypt(struct skcipher_request *req)
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct des3_ede_x86_ctx *ctx = crypto_skcipher_ctx(tfm);
 
-       return ecb_crypt(req, ctx->enc_expkey);
+       return ecb_crypt(req, ctx->enc.expkey);
 }
 
 static int ecb_decrypt(struct skcipher_request *req)
@@ -140,7 +140,7 @@ static int ecb_decrypt(struct skcipher_request *req)
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct des3_ede_x86_ctx *ctx = crypto_skcipher_ctx(tfm);
 
-       return ecb_crypt(req, ctx->dec_expkey);
+       return ecb_crypt(req, ctx->dec.expkey);
 }
 
 static unsigned int __cbc_encrypt(struct des3_ede_x86_ctx *ctx,
@@ -358,24 +358,25 @@ static int des3_ede_x86_setkey(struct crypto_tfm *tfm, 
const u8 *key,
        u32 i, j, tmp;
        int err;
 
-       err = des3_ede_verify_key(tfm, key, keylen);
-       if (unlikely(err))
-               return err;
+       err = des3_ede_expand_key(&ctx->enc, key, keylen);
+       if (err == -ENOKEY &&
+           !(crypto_tfm_get_flags(tfm) & CRYPTO_TFM_REQ_FORBID_WEAK_KEYS))
+               err = 0;
 
-       /* Generate encryption context using generic implementation. */
-       err = __des3_ede_setkey(ctx->enc_expkey, &tfm->crt_flags, key, keylen);
-       if (err < 0)
+       if (err) {
+               memzero_explicit(ctx, sizeof(*ctx));
                return err;
+       }
 
        /* Fix encryption context for this implementation and form decryption
         * context. */
        j = DES3_EDE_EXPKEY_WORDS - 2;
        for (i = 0; i < DES3_EDE_EXPKEY_WORDS; i += 2, j -= 2) {
-               tmp = ror32(ctx->enc_expkey[i + 1], 4);
-               ctx->enc_expkey[i + 1] = tmp;
+               tmp = ror32(ctx->enc.expkey[i + 1], 4);
+               ctx->enc.expkey[i + 1] = tmp;
 
-               ctx->dec_expkey[j + 0] = ctx->enc_expkey[i + 0];
-               ctx->dec_expkey[j + 1] = tmp;
+               ctx->dec.expkey[j + 0] = ctx->enc.expkey[i + 0];
+               ctx->dec.expkey[j + 1] = tmp;
        }
 
        return 0;
-- 
2.20.1

Reply via email to