On Mon, Oct 07, 2019 at 06:45:48PM +0200, Ard Biesheuvel wrote:
> +static int chacha_stream_xor(struct skcipher_request *req,
> +                          const struct chacha_ctx *ctx, const u8 *iv)
> +{
> +     struct skcipher_walk walk;
> +     u32 state[16];
> +     int err;
> +
> +     err = skcipher_walk_virt(&walk, req, false);
> +
> +     chacha_init_generic(state, ctx->key, iv);
> +
> +     while (walk.nbytes > 0) {
> +             unsigned int nbytes = walk.nbytes;
> +
> +             if (nbytes < walk.total)
> +                     nbytes = round_down(nbytes, walk.stride);
> +
> +             chacha_doarm(walk.dst.virt.addr, walk.src.virt.addr,
> +                          nbytes, state, ctx->nrounds);
> +             state[12] += DIV_ROUND_UP(nbytes, CHACHA_BLOCK_SIZE);
> +             err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
> +     }
> +
> +     return err;
> +}
> +
> +static int chacha_arm(struct skcipher_request *req)
> +{
> +     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> +     struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
> +
> +     return chacha_stream_xor(req, ctx, req->iv);
> +}
> +
> +static int chacha_neon_stream_xor(struct skcipher_request *req,
> +                               const struct chacha_ctx *ctx, const u8 *iv)
> +{
> +     struct skcipher_walk walk;
> +     u32 state[16];
> +     bool do_neon;
> +     int err;
> +
> +     err = skcipher_walk_virt(&walk, req, false);
> +
> +     chacha_init_generic(state, ctx->key, iv);
> +
> +     do_neon = (req->cryptlen > CHACHA_BLOCK_SIZE) && crypto_simd_usable();
> +     while (walk.nbytes > 0) {
> +             unsigned int nbytes = walk.nbytes;
> +
> +             if (nbytes < walk.total)
> +                     nbytes = round_down(nbytes, walk.stride);
> +
> +             if (!do_neon) {
> +                     chacha_doarm(walk.dst.virt.addr, walk.src.virt.addr,
> +                                  nbytes, state, ctx->nrounds);
> +                     state[12] += DIV_ROUND_UP(nbytes, CHACHA_BLOCK_SIZE);
> +             } else {
> +                     kernel_neon_begin();
> +                     chacha_doneon(state, walk.dst.virt.addr,
> +                                   walk.src.virt.addr, nbytes, ctx->nrounds);
> +                     kernel_neon_end();
> +             }
> +             err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
> +     }
> +
> +     return err;
> +}
> +
> +static int chacha_neon(struct skcipher_request *req)
> +{
> +     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> +     struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
> +
> +     return chacha_neon_stream_xor(req, ctx, req->iv);
> +}
> +
> +static int xchacha_arm(struct skcipher_request *req)
> +{
> +     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> +     struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
> +     struct chacha_ctx subctx;
> +     u32 state[16];
> +     u8 real_iv[16];
> +
> +     chacha_init_generic(state, ctx->key, req->iv);
> +
> +     hchacha_block_arm(state, subctx.key, ctx->nrounds);
> +     subctx.nrounds = ctx->nrounds;
> +
> +     memcpy(&real_iv[0], req->iv + 24, 8);
> +     memcpy(&real_iv[8], req->iv + 16, 8);
> +     return chacha_stream_xor(req, &subctx, real_iv);
> +}
> +
> +static int xchacha_neon(struct skcipher_request *req)
> +{
> +     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> +     struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
> +     struct chacha_ctx subctx;
> +     u32 state[16];
> +     u8 real_iv[16];
> +
> +     chacha_init_generic(state, ctx->key, req->iv);
> +
> +     if (!crypto_simd_usable()) {
> +             hchacha_block_arm(state, subctx.key, ctx->nrounds);
> +     } else {
> +             kernel_neon_begin();
> +             hchacha_block_neon(state, subctx.key, ctx->nrounds);
> +             kernel_neon_end();
> +     }
> +     subctx.nrounds = ctx->nrounds;
> +
> +     memcpy(&real_iv[0], req->iv + 24, 8);
> +     memcpy(&real_iv[8], req->iv + 16, 8);
> +     return chacha_neon_stream_xor(req, &subctx, real_iv);
> +}

There is some code duplication here: two implementations of stream_xor, and two
implementations of xchacha (hchacha + stream_xor).  How about doing something
like the following?

diff --git a/arch/arm/crypto/chacha-glue.c b/arch/arm/crypto/chacha-glue.c
index dae69a63b640..1952cbda2168 100644
--- a/arch/arm/crypto/chacha-glue.c
+++ b/arch/arm/crypto/chacha-glue.c
@@ -32,6 +32,11 @@ asmlinkage void chacha_doarm(u8 *dst, const u8 *src, 
unsigned int bytes,
 
 static __ro_after_init DEFINE_STATIC_KEY_FALSE(use_neon);
 
+static inline bool neon_usable(void)
+{
+       return static_branch_likely(&use_neon) && crypto_simd_usable();
+}
+
 static void chacha_doneon(u32 *state, u8 *dst, const u8 *src,
                          unsigned int bytes, int nrounds)
 {
@@ -95,7 +100,8 @@ void chacha_crypt(u32 *state, u8 *dst, const u8 *src, 
unsigned int bytes,
 EXPORT_SYMBOL(chacha_crypt);
 
 static int chacha_stream_xor(struct skcipher_request *req,
-                            const struct chacha_ctx *ctx, const u8 *iv)
+                            const struct chacha_ctx *ctx, const u8 *iv,
+                            bool neon)
 {
        struct skcipher_walk walk;
        u32 state[16];
@@ -105,49 +111,14 @@ static int chacha_stream_xor(struct skcipher_request *req,
 
        chacha_init_generic(state, ctx->key, iv);
 
+       neon &= (req->cryptlen > CHACHA_BLOCK_SIZE);
        while (walk.nbytes > 0) {
                unsigned int nbytes = walk.nbytes;
 
                if (nbytes < walk.total)
                        nbytes = round_down(nbytes, walk.stride);
 
-               chacha_doarm(walk.dst.virt.addr, walk.src.virt.addr,
-                            nbytes, state, ctx->nrounds);
-               state[12] += DIV_ROUND_UP(nbytes, CHACHA_BLOCK_SIZE);
-               err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
-       }
-
-       return err;
-}
-
-static int chacha_arm(struct skcipher_request *req)
-{
-       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-       struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
-
-       return chacha_stream_xor(req, ctx, req->iv);
-}
-
-static int chacha_neon_stream_xor(struct skcipher_request *req,
-                                 const struct chacha_ctx *ctx, const u8 *iv)
-{
-       struct skcipher_walk walk;
-       u32 state[16];
-       bool do_neon;
-       int err;
-
-       err = skcipher_walk_virt(&walk, req, false);
-
-       chacha_init_generic(state, ctx->key, iv);
-
-       do_neon = (req->cryptlen > CHACHA_BLOCK_SIZE) && crypto_simd_usable();
-       while (walk.nbytes > 0) {
-               unsigned int nbytes = walk.nbytes;
-
-               if (nbytes < walk.total)
-                       nbytes = round_down(nbytes, walk.stride);
-
-               if (!static_branch_likely(&use_neon) || !do_neon) {
+               if (!neon) {
                        chacha_doarm(walk.dst.virt.addr, walk.src.virt.addr,
                                     nbytes, state, ctx->nrounds);
                        state[12] += DIV_ROUND_UP(nbytes, CHACHA_BLOCK_SIZE);
@@ -163,33 +134,25 @@ static int chacha_neon_stream_xor(struct skcipher_request 
*req,
        return err;
 }
 
-static int chacha_neon(struct skcipher_request *req)
+static int do_chacha(struct skcipher_request *req, bool neon)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
 
-       return chacha_neon_stream_xor(req, ctx, req->iv);
+       return chacha_stream_xor(req, ctx, req->iv, neon);
 }
 
-static int xchacha_arm(struct skcipher_request *req)
+static int chacha_arm(struct skcipher_request *req)
 {
-       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-       struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
-       struct chacha_ctx subctx;
-       u32 state[16];
-       u8 real_iv[16];
-
-       chacha_init_generic(state, ctx->key, req->iv);
-
-       hchacha_block_arm(state, subctx.key, ctx->nrounds);
-       subctx.nrounds = ctx->nrounds;
+       return do_chacha(req, false);
+}
 
-       memcpy(&real_iv[0], req->iv + 24, 8);
-       memcpy(&real_iv[8], req->iv + 16, 8);
-       return chacha_stream_xor(req, &subctx, real_iv);
+static int chacha_neon(struct skcipher_request *req)
+{
+       return do_chacha(req, neon_usable());
 }
 
-static int xchacha_neon(struct skcipher_request *req)
+static int do_xchacha(struct skcipher_request *req, bool neon)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
@@ -199,7 +162,7 @@ static int xchacha_neon(struct skcipher_request *req)
 
        chacha_init_generic(state, ctx->key, req->iv);
 
-       if (!static_branch_likely(&use_neon) || !crypto_simd_usable()) {
+       if (!neon) {
                hchacha_block_arm(state, subctx.key, ctx->nrounds);
        } else {
                kernel_neon_begin();
@@ -210,7 +173,17 @@ static int xchacha_neon(struct skcipher_request *req)
 
        memcpy(&real_iv[0], req->iv + 24, 8);
        memcpy(&real_iv[8], req->iv + 16, 8);
-       return chacha_neon_stream_xor(req, &subctx, real_iv);
+       return chacha_stream_xor(req, &subctx, real_iv, neon);
+}
+
+static int xchacha_arm(struct skcipher_request *req)
+{
+       return do_xchacha(req, false);
+}
+
+static int xchacha_neon(struct skcipher_request *req)
+{
+       return do_xchacha(req, neon_usable());
 }
 
 static int chacha20_setkey(struct crypto_skcipher *tfm, const u8 *key,

Reply via email to