Instead of calling into kernel_neon_end() and kernel_neon_begin() (and
potentially into schedule()) from the assembler code when running in
task mode and a reschedule is pending, perform only the preempt count
check in assembler, but simply return early in this case, and let the C
code deal with the consequences.

This reverts commit 7edc86cb1c18b4c274672232117586ea2bef1d9a.

Signed-off-by: Ard Biesheuvel <a...@kernel.org>
---
 arch/arm64/crypto/sha3-ce-core.S | 81 ++++++++------------
 arch/arm64/crypto/sha3-ce-glue.c | 14 ++--
 2 files changed, 39 insertions(+), 56 deletions(-)

diff --git a/arch/arm64/crypto/sha3-ce-core.S b/arch/arm64/crypto/sha3-ce-core.S
index 1cfb768df350..6f5208414fe3 100644
--- a/arch/arm64/crypto/sha3-ce-core.S
+++ b/arch/arm64/crypto/sha3-ce-core.S
@@ -37,20 +37,13 @@
        .endm
 
        /*
-        * sha3_ce_transform(u64 *st, const u8 *data, int blocks, int dg_size)
+        * int sha3_ce_transform(u64 *st, const u8 *data, int blocks, int 
dg_size)
         */
        .text
 SYM_FUNC_START(sha3_ce_transform)
-       frame_push      4
-
-       mov     x19, x0
-       mov     x20, x1
-       mov     x21, x2
-       mov     x22, x3
-
-0:     /* load state */
-       add     x8, x19, #32
-       ld1     { v0.1d- v3.1d}, [x19]
+       /* load state */
+       add     x8, x0, #32
+       ld1     { v0.1d- v3.1d}, [x0]
        ld1     { v4.1d- v7.1d}, [x8], #32
        ld1     { v8.1d-v11.1d}, [x8], #32
        ld1     {v12.1d-v15.1d}, [x8], #32
@@ -58,13 +51,13 @@ SYM_FUNC_START(sha3_ce_transform)
        ld1     {v20.1d-v23.1d}, [x8], #32
        ld1     {v24.1d}, [x8]
 
-1:     sub     w21, w21, #1
+0:     sub     w2, w2, #1
        mov     w8, #24
        adr_l   x9, .Lsha3_rcon
 
        /* load input */
-       ld1     {v25.8b-v28.8b}, [x20], #32
-       ld1     {v29.8b-v31.8b}, [x20], #24
+       ld1     {v25.8b-v28.8b}, [x1], #32
+       ld1     {v29.8b-v31.8b}, [x1], #24
        eor     v0.8b, v0.8b, v25.8b
        eor     v1.8b, v1.8b, v26.8b
        eor     v2.8b, v2.8b, v27.8b
@@ -73,10 +66,10 @@ SYM_FUNC_START(sha3_ce_transform)
        eor     v5.8b, v5.8b, v30.8b
        eor     v6.8b, v6.8b, v31.8b
 
-       tbnz    x22, #6, 3f             // SHA3-512
+       tbnz    x3, #6, 2f              // SHA3-512
 
-       ld1     {v25.8b-v28.8b}, [x20], #32
-       ld1     {v29.8b-v30.8b}, [x20], #16
+       ld1     {v25.8b-v28.8b}, [x1], #32
+       ld1     {v29.8b-v30.8b}, [x1], #16
        eor      v7.8b,  v7.8b, v25.8b
        eor      v8.8b,  v8.8b, v26.8b
        eor      v9.8b,  v9.8b, v27.8b
@@ -84,34 +77,34 @@ SYM_FUNC_START(sha3_ce_transform)
        eor     v11.8b, v11.8b, v29.8b
        eor     v12.8b, v12.8b, v30.8b
 
-       tbnz    x22, #4, 2f             // SHA3-384 or SHA3-224
+       tbnz    x3, #4, 1f              // SHA3-384 or SHA3-224
 
        // SHA3-256
-       ld1     {v25.8b-v28.8b}, [x20], #32
+       ld1     {v25.8b-v28.8b}, [x1], #32
        eor     v13.8b, v13.8b, v25.8b
        eor     v14.8b, v14.8b, v26.8b
        eor     v15.8b, v15.8b, v27.8b
        eor     v16.8b, v16.8b, v28.8b
-       b       4f
+       b       3f
 
-2:     tbz     x22, #2, 4f             // bit 2 cleared? SHA-384
+1:     tbz     x3, #2, 3f              // bit 2 cleared? SHA-384
 
        // SHA3-224
-       ld1     {v25.8b-v28.8b}, [x20], #32
-       ld1     {v29.8b}, [x20], #8
+       ld1     {v25.8b-v28.8b}, [x1], #32
+       ld1     {v29.8b}, [x1], #8
        eor     v13.8b, v13.8b, v25.8b
        eor     v14.8b, v14.8b, v26.8b
        eor     v15.8b, v15.8b, v27.8b
        eor     v16.8b, v16.8b, v28.8b
        eor     v17.8b, v17.8b, v29.8b
-       b       4f
+       b       3f
 
        // SHA3-512
-3:     ld1     {v25.8b-v26.8b}, [x20], #16
+2:     ld1     {v25.8b-v26.8b}, [x1], #16
        eor      v7.8b,  v7.8b, v25.8b
        eor      v8.8b,  v8.8b, v26.8b
 
-4:     sub     w8, w8, #1
+3:     sub     w8, w8, #1
 
        eor3    v29.16b,  v4.16b,  v9.16b, v14.16b
        eor3    v26.16b,  v1.16b,  v6.16b, v11.16b
@@ -190,33 +183,19 @@ SYM_FUNC_START(sha3_ce_transform)
 
        eor      v0.16b,  v0.16b, v31.16b
 
-       cbnz    w8, 4b
-       cbz     w21, 5f
-
-       if_will_cond_yield_neon
-       add     x8, x19, #32
-       st1     { v0.1d- v3.1d}, [x19]
-       st1     { v4.1d- v7.1d}, [x8], #32
-       st1     { v8.1d-v11.1d}, [x8], #32
-       st1     {v12.1d-v15.1d}, [x8], #32
-       st1     {v16.1d-v19.1d}, [x8], #32
-       st1     {v20.1d-v23.1d}, [x8], #32
-       st1     {v24.1d}, [x8]
-       do_cond_yield_neon
-       b               0b
-       endif_yield_neon
-
-       b       1b
+       cbnz    w8, 3b
+       cond_yield 3f, x8
+       cbnz    w2, 0b
 
        /* save state */
-5:     st1     { v0.1d- v3.1d}, [x19], #32
-       st1     { v4.1d- v7.1d}, [x19], #32
-       st1     { v8.1d-v11.1d}, [x19], #32
-       st1     {v12.1d-v15.1d}, [x19], #32
-       st1     {v16.1d-v19.1d}, [x19], #32
-       st1     {v20.1d-v23.1d}, [x19], #32
-       st1     {v24.1d}, [x19]
-       frame_pop
+3:     st1     { v0.1d- v3.1d}, [x0], #32
+       st1     { v4.1d- v7.1d}, [x0], #32
+       st1     { v8.1d-v11.1d}, [x0], #32
+       st1     {v12.1d-v15.1d}, [x0], #32
+       st1     {v16.1d-v19.1d}, [x0], #32
+       st1     {v20.1d-v23.1d}, [x0], #32
+       st1     {v24.1d}, [x0]
+       mov     w0, w2
        ret
 SYM_FUNC_END(sha3_ce_transform)
 
diff --git a/arch/arm64/crypto/sha3-ce-glue.c b/arch/arm64/crypto/sha3-ce-glue.c
index 7288d3046354..8c65cecf560a 100644
--- a/arch/arm64/crypto/sha3-ce-glue.c
+++ b/arch/arm64/crypto/sha3-ce-glue.c
@@ -28,8 +28,8 @@ MODULE_ALIAS_CRYPTO("sha3-256");
 MODULE_ALIAS_CRYPTO("sha3-384");
 MODULE_ALIAS_CRYPTO("sha3-512");
 
-asmlinkage void sha3_ce_transform(u64 *st, const u8 *data, int blocks,
-                                 int md_len);
+asmlinkage int sha3_ce_transform(u64 *st, const u8 *data, int blocks,
+                                int md_len);
 
 static int sha3_update(struct shash_desc *desc, const u8 *data,
                       unsigned int len)
@@ -59,11 +59,15 @@ static int sha3_update(struct shash_desc *desc, const u8 
*data,
                blocks = len / sctx->rsiz;
                len %= sctx->rsiz;
 
-               if (blocks) {
+               while (blocks) {
+                       int rem;
+
                        kernel_neon_begin();
-                       sha3_ce_transform(sctx->st, data, blocks, digest_size);
+                       rem = sha3_ce_transform(sctx->st, data, blocks,
+                                               digest_size);
                        kernel_neon_end();
-                       data += blocks * sctx->rsiz;
+                       data += (blocks - rem) * sctx->rsiz;
+                       blocks = rem;
                }
        }
 
-- 
2.30.0

Reply via email to