The goal of this patch is to protect the JIT against an attacker with a
write-in-memory primitive. The JIT allocates a buffer which will eventually
be marked +x, so we need to make sure that what was written to this buffer
is what was intended.

We acheive this by building a hash of the instruction buffer as
instructions are emittted and then comparing that to a hash at the end of
the JIT compile after the buffer has been marked read-only.

Signed-off-by: Tycho Andersen <ty...@docker.com>
CC: Daniel Borkmann <dan...@iogearbox.net>
CC: Alexei Starovoitov <a...@kernel.org>
CC: Kees Cook <keesc...@chromium.org>
CC: Mickaël Salaün <m...@digikod.net>
---
 arch/x86/Kconfig            |  11 ++++
 arch/x86/net/bpf_jit_comp.c | 147 ++++++++++++++++++++++++++++++++++++++++----
 2 files changed, 147 insertions(+), 11 deletions(-)

diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig
index cc98d5a..7b2db2c 100644
--- a/arch/x86/Kconfig
+++ b/arch/x86/Kconfig
@@ -2789,6 +2789,17 @@ config X86_DMA_REMAP
 
 source "net/Kconfig"
 
+config EBPF_JIT_HASH_OUTPUT
+       def_bool y
+       depends on HAVE_EBPF_JIT
+       depends on BPF_JIT
+       select CRYPTO_SHA256
+       ---help---
+         Enables a double check of the JIT's output after it is marked 
read-only to
+         ensure that it matches what the JIT generated.
+
+         Note, only applies when /proc/sys/net/core/bpf_jit_harden > 0.
+
 source "drivers/Kconfig"
 
 source "drivers/firmware/Kconfig"
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 32322ce..be1271e 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -13,9 +13,15 @@
 #include <linux/if_vlan.h>
 #include <asm/cacheflush.h>
 #include <linux/bpf.h>
+#include <linux/crypto.h>
+#include <crypto/hash.h>
 
 int bpf_jit_enable __read_mostly;
 
+#ifdef CONFIG_EBPF_JIT_HASH_OUTPUT
+struct crypto_shash *tfm __read_mostly;
+#endif
+
 /*
  * assembly code in arch/x86/net/bpf_jit.S
  */
@@ -25,7 +31,8 @@ extern u8 sk_load_byte_positive_offset[];
 extern u8 sk_load_word_negative_offset[], sk_load_half_negative_offset[];
 extern u8 sk_load_byte_negative_offset[];
 
-static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
+static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len,
+                    struct shash_desc *hash)
 {
        if (len == 1)
                *ptr = bytes;
@@ -35,11 +42,15 @@ static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
                *(u32 *)ptr = bytes;
                barrier();
        }
+
+       if (IS_ENABLED(CONFIG_EBPF_JIT_HASH_OUTPUT) && hash)
+               crypto_shash_update(hash, (u8 *) &bytes, len);
+
        return ptr + len;
 }
 
 #define EMIT(bytes, len) \
-       do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
+       do { prog = emit_code(prog, bytes, len, hash); cnt += len; } while (0)
 
 #define EMIT1(b1)              EMIT(b1, 1)
 #define EMIT2(b1, b2)          EMIT((b1) + ((b2) << 8), 2)
@@ -206,7 +217,7 @@ struct jit_context {
 /* emit x64 prologue code for BPF program and check it's size.
  * bpf_tail_call helper will skip it while jumping into another program
  */
-static void emit_prologue(u8 **pprog)
+static void emit_prologue(u8 **pprog, struct shash_desc *hash)
 {
        u8 *prog = *pprog;
        int cnt = 0;
@@ -264,7 +275,7 @@ static void emit_prologue(u8 **pprog)
  *   goto *(prog->bpf_func + prologue_size);
  * out:
  */
-static void emit_bpf_tail_call(u8 **pprog)
+static void emit_bpf_tail_call(u8 **pprog, struct shash_desc *hash)
 {
        u8 *prog = *pprog;
        int label1, label2, label3;
@@ -328,7 +339,7 @@ static void emit_bpf_tail_call(u8 **pprog)
 }
 
 
-static void emit_load_skb_data_hlen(u8 **pprog)
+static void emit_load_skb_data_hlen(u8 **pprog, struct shash_desc *hash)
 {
        u8 *prog = *pprog;
        int cnt = 0;
@@ -348,7 +359,8 @@ static void emit_load_skb_data_hlen(u8 **pprog)
 }
 
 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
-                 int oldproglen, struct jit_context *ctx)
+                 int oldproglen, struct jit_context *ctx,
+                 struct shash_desc *hash)
 {
        struct bpf_insn *insn = bpf_prog->insnsi;
        int insn_cnt = bpf_prog->len;
@@ -360,10 +372,10 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, 
u8 *image,
        int proglen = 0;
        u8 *prog = temp;
 
-       emit_prologue(&prog);
+       emit_prologue(&prog, hash);
 
        if (seen_ld_abs)
-               emit_load_skb_data_hlen(&prog);
+               emit_load_skb_data_hlen(&prog, hash);
 
        for (i = 0; i < insn_cnt; i++, insn++) {
                const s32 imm32 = insn->imm;
@@ -875,7 +887,7 @@ xadd:                       if (is_imm8(insn->off))
                        if (seen_ld_abs) {
                                if (reload_skb_data) {
                                        EMIT1(0x5F); /* pop %rdi */
-                                       emit_load_skb_data_hlen(&prog);
+                                       emit_load_skb_data_hlen(&prog, hash);
                                } else {
                                        EMIT2(0x41, 0x59); /* pop %r9 */
                                        EMIT2(0x41, 0x5A); /* pop %r10 */
@@ -884,7 +896,7 @@ xadd:                       if (is_imm8(insn->off))
                        break;
 
                case BPF_JMP | BPF_CALL | BPF_X:
-                       emit_bpf_tail_call(&prog);
+                       emit_bpf_tail_call(&prog, hash);
                        break;
 
                        /* cond jump */
@@ -1085,6 +1097,106 @@ xadd:                   if (is_imm8(insn->off))
        return proglen;
 }
 
+#ifdef CONFIG_EBPF_JIT_HASH_OUTPUT
+static struct shash_desc *bpf_alloc_hash_desc(void)
+{
+       struct shash_desc *hash;
+       int sz = sizeof(struct shash_desc) + crypto_shash_descsize(tfm);
+
+       hash = kzalloc(sz, GFP_KERNEL);
+       if (hash)
+               hash->tfm = tfm;
+       return hash;
+}
+
+static int init_hash(struct shash_desc **hash, u32 *nonce)
+{
+       if (!bpf_jit_harden)
+               return 0;
+
+       *nonce = get_random_int();
+
+       if (!tfm) {
+               tfm = crypto_alloc_shash("sha256", 0, 0);
+               if (IS_ERR(tfm))
+                       return PTR_ERR(tfm);
+       }
+
+       if (!*hash) {
+               *hash = bpf_alloc_hash_desc();
+               if (!*hash)
+                       return -ENOMEM;
+       }
+
+       if (crypto_shash_init(*hash) < 0)
+               return -1;
+
+       return crypto_shash_update(*hash, (u8 *) nonce, sizeof(*nonce));
+}
+
+static bool check_jit_hash(u8 *buf, u32 len, struct shash_desc *out_d,
+                          u32 nonce)
+{
+       struct shash_desc *check_d;
+       void *out, *check;
+       unsigned int sz;
+       bool match = false;
+
+       if (!out_d)
+               return 0;
+
+       BUG_ON(out_d->tfm != tfm);
+
+       sz = crypto_shash_digestsize(out_d->tfm);
+       out = kzalloc(2 * sz, GFP_KERNEL);
+       if (!out)
+               return false;
+
+       if (crypto_shash_final(out_d, out) < 0) {
+               kfree(out);
+               return false;
+       }
+
+       check_d = bpf_alloc_hash_desc();
+       if (!check_d) {
+               kfree(out);
+               return false;
+       }
+
+       if (crypto_shash_init(check_d) < 0)
+               goto out;
+
+       if (crypto_shash_update(check_d, (u8 *) &nonce, sizeof(nonce)) < 0)
+               goto out;
+
+       if (crypto_shash_update(check_d, buf, len) < 0)
+               goto out;
+
+       check = out + sz;
+       if (crypto_shash_final(check_d, check) < 0)
+               goto out;
+
+       if (!memcmp(out, check, sz))
+               match = true;
+
+out:
+       kfree(out);
+       kfree(check_d);
+       return match;
+}
+#else
+static int init_hash(struct shash_desc **hash, u32 *nonce)
+{
+       return 0;
+}
+
+static bool check_jit_hash(u8 *buf, u32 len, struct shash_desc *out_d,
+                          u32 nonce)
+{
+       return true;
+}
+#endif
+
 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 {
        struct bpf_binary_header *header = NULL;
@@ -1096,6 +1208,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog 
*prog)
        int *addrs;
        int pass;
        int i;
+       struct shash_desc *hash = NULL;
+       u32 nonce;
 
        if (!bpf_jit_enable)
                return orig_prog;
@@ -1132,7 +1246,15 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog 
*prog)
         * pass to emit the final image
         */
        for (pass = 0; pass < 10 || image; pass++) {
-               proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
+               if (init_hash(&hash, &nonce) < 0) {
+                       image = NULL;
+                       if (header)
+                               bpf_jit_binary_free(header);
+                       prog = orig_prog;
+                       goto out_addrs;
+               }
+
+               proglen = do_jit(prog, addrs, image, oldproglen, &ctx, hash);
                if (proglen <= 0) {
                        image = NULL;
                        if (header)
@@ -1166,6 +1288,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog 
*prog)
        if (image) {
                bpf_flush_icache(header, image + proglen);
                bpf_jit_binary_lock_ro(header);
+               if (!check_jit_hash(image, proglen, hash, nonce))
+                       BUG();
                prog->bpf_func = (void *)image;
                prog->jited = 1;
        } else {
@@ -1174,6 +1298,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog 
*prog)
 
 out_addrs:
        kfree(addrs);
+       kfree(hash);
 out:
        if (tmp_blinded)
                bpf_jit_prog_release_other(prog, prog == orig_prog ?
-- 
2.9.3

Reply via email to