Hi, Protect the tdb hashes with a mutex.
ok? bluhm Index: netinet/ip_ipsp.c =================================================================== RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/ip_ipsp.c,v retrieving revision 1.246 diff -u -p -r1.246 ip_ipsp.c --- netinet/ip_ipsp.c 13 Oct 2021 14:36:31 -0000 1.246 +++ netinet/ip_ipsp.c 25 Oct 2021 14:00:18 -0000 @@ -84,7 +84,7 @@ void tdb_hashstats(void); do { } while (0) #endif -void tdb_rehash(void); +int tdb_rehash(void); void tdb_reaper(void *); void tdb_timeout(void *); void tdb_firstuse(void *); @@ -186,11 +186,12 @@ const struct xformsw *const xformswNXFOR #define TDB_HASHSIZE_INIT 32 -/* Protected by the NET_LOCK(). */ +/* Protected by the tdb_sadb_mtx. */ +struct mutex tdb_sadb_mtx = MUTEX_INITIALIZER(IPL_NET); static SIPHASH_KEY tdbkey; -static struct tdb **tdbh = NULL; -static struct tdb **tdbdst = NULL; -static struct tdb **tdbsrc = NULL; +static struct tdb **tdbh; +static struct tdb **tdbdst; +static struct tdb **tdbsrc; static u_int tdb_hashmask = TDB_HASHSIZE_INIT - 1; static int tdb_count; @@ -199,6 +200,14 @@ ipsp_init(void) { pool_init(&tdb_pool, sizeof(struct tdb), 0, IPL_SOFTNET, 0, "tdb", NULL); + + arc4random_buf(&tdbkey, sizeof(tdbkey)); + tdbh = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB, + M_WAITOK | M_ZERO); + tdbdst = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB, + M_WAITOK | M_ZERO); + tdbsrc = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB, + M_WAITOK | M_ZERO); } /* @@ -211,7 +220,7 @@ tdb_hash(u_int32_t spi, union sockaddr_u { SIPHASH_CTX ctx; - NET_ASSERT_LOCKED(); + MUTEX_ASSERT_LOCKED(&tdb_sadb_mtx); SipHash24_Init(&ctx, &tdbkey); SipHash24_Update(&ctx, &spi, sizeof(spi)); @@ -332,11 +341,7 @@ gettdb_dir(u_int rdomain, u_int32_t spi, u_int32_t hashval; struct tdb *tdbp; - NET_ASSERT_LOCKED(); - - if (tdbh == NULL) - return (struct tdb *) NULL; - + mtx_enter(&tdb_sadb_mtx); hashval = tdb_hash(spi, dst, proto); for (tdbp = tdbh[hashval]; tdbp != NULL; tdbp = tdbp->tdb_hnext) @@ -346,6 +351,7 @@ gettdb_dir(u_int rdomain, u_int32_t spi, !memcmp(&tdbp->tdb_dst, dst, dst->sa.sa_len)) break; + mtx_leave(&tdb_sadb_mtx); return tdbp; } @@ -362,11 +368,7 @@ gettdbbysrcdst_dir(u_int rdomain, u_int3 struct tdb *tdbp; union sockaddr_union su_null; - NET_ASSERT_LOCKED(); - - if (tdbsrc == NULL) - return (struct tdb *) NULL; - + mtx_enter(&tdb_sadb_mtx); hashval = tdb_hash(0, src, proto); for (tdbp = tdbsrc[hashval]; tdbp != NULL; tdbp = tdbp->tdb_snext) @@ -380,8 +382,10 @@ gettdbbysrcdst_dir(u_int rdomain, u_int3 !memcmp(&tdbp->tdb_src, src, src->sa.sa_len)) break; - if (tdbp != NULL) - return (tdbp); + if (tdbp != NULL) { + mtx_leave(&tdb_sadb_mtx); + return tdbp; + } memset(&su_null, 0, sizeof(su_null)); su_null.sa.sa_len = sizeof(struct sockaddr); @@ -398,7 +402,8 @@ gettdbbysrcdst_dir(u_int rdomain, u_int3 tdbp->tdb_src.sa.sa_family == AF_UNSPEC) break; - return (tdbp); + mtx_leave(&tdb_sadb_mtx); + return tdbp; } /* @@ -450,11 +455,7 @@ gettdbbydst(u_int rdomain, union sockadd u_int32_t hashval; struct tdb *tdbp; - NET_ASSERT_LOCKED(); - - if (tdbdst == NULL) - return (struct tdb *) NULL; - + mtx_enter(&tdb_sadb_mtx); hashval = tdb_hash(0, dst, sproto); for (tdbp = tdbdst[hashval]; tdbp != NULL; tdbp = tdbp->tdb_dnext) @@ -462,12 +463,13 @@ gettdbbydst(u_int rdomain, union sockadd (tdbp->tdb_rdomain == rdomain) && ((tdbp->tdb_flags & TDBF_INVALID) == 0) && (!memcmp(&tdbp->tdb_dst, dst, dst->sa.sa_len))) { - /* Do IDs match ? */ + /* Check whether IDs match */ if (!ipsp_aux_match(tdbp, ids, filter, filtermask)) continue; break; } + mtx_leave(&tdb_sadb_mtx); return tdbp; } @@ -483,11 +485,7 @@ gettdbbysrc(u_int rdomain, union sockadd u_int32_t hashval; struct tdb *tdbp; - NET_ASSERT_LOCKED(); - - if (tdbsrc == NULL) - return (struct tdb *) NULL; - + mtx_enter(&tdb_sadb_mtx); hashval = tdb_hash(0, src, sproto); for (tdbp = tdbsrc[hashval]; tdbp != NULL; tdbp = tdbp->tdb_snext) @@ -496,16 +494,16 @@ gettdbbysrc(u_int rdomain, union sockadd ((tdbp->tdb_flags & TDBF_INVALID) == 0) && (!memcmp(&tdbp->tdb_src, src, src->sa.sa_len))) { /* Check whether IDs match */ - if (!ipsp_aux_match(tdbp, ids, filter, - filtermask)) + if (!ipsp_aux_match(tdbp, ids, filter, filtermask)) continue; break; } + mtx_leave(&tdb_sadb_mtx); return tdbp; } -#if DDB +#ifdef DDB #define NBUCKETS 16 void @@ -542,12 +540,8 @@ tdb_walk(u_int rdomain, int (*walker)(st int i, rval = 0; struct tdb *tdbp, *next; - NET_ASSERT_LOCKED(); - - if (tdbh == NULL) - return ENOENT; - - for (i = 0; i <= tdb_hashmask; i++) + mtx_enter(&tdb_sadb_mtx); + for (i = 0; i <= tdb_hashmask; i++) { for (tdbp = tdbh[i]; rval == 0 && tdbp != NULL; tdbp = next) { next = tdbp->tdb_hnext; @@ -559,6 +553,8 @@ tdb_walk(u_int rdomain, int (*walker)(st else rval = walker(tdbp, (void *)arg, 0); } + } + mtx_leave(&tdb_sadb_mtx); return rval; } @@ -622,24 +618,34 @@ tdb_soft_firstuse(void *v) NET_UNLOCK(); } -void +int tdb_rehash(void) { struct tdb **new_tdbh, **new_tdbdst, **new_srcaddr, *tdbp, *tdbnp; - u_int i, old_hashmask = tdb_hashmask; + u_int i, old_hashmask; u_int32_t hashval; - NET_ASSERT_LOCKED(); + MUTEX_ASSERT_LOCKED(&tdb_sadb_mtx); + old_hashmask = tdb_hashmask; tdb_hashmask = (tdb_hashmask << 1) | 1; arc4random_buf(&tdbkey, sizeof(tdbkey)); new_tdbh = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB, - M_WAITOK | M_ZERO); + M_NOWAIT | M_ZERO); new_tdbdst = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB, - M_WAITOK | M_ZERO); + M_NOWAIT | M_ZERO); new_srcaddr = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB, - M_WAITOK | M_ZERO); + M_NOWAIT | M_ZERO); + if (new_tdbh == NULL || + new_tdbdst == NULL || + new_srcaddr == NULL) { + free(new_tdbh, M_TDB, 0); + free(new_tdbdst, M_TDB, 0); + free(new_srcaddr, M_TDB, 0); + return (ENOMEM); + } + for (i = 0; i <= old_hashmask; i++) { for (tdbp = tdbh[i]; tdbp != NULL; tdbp = tdbnp) { @@ -673,6 +679,8 @@ tdb_rehash(void) free(tdbsrc, M_TDB, 0); tdbsrc = new_srcaddr; + + return 0; } /* @@ -683,18 +691,7 @@ puttdb(struct tdb *tdbp) { u_int32_t hashval; - NET_ASSERT_LOCKED(); - - if (tdbh == NULL) { - arc4random_buf(&tdbkey, sizeof(tdbkey)); - tdbh = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), - M_TDB, M_WAITOK | M_ZERO); - tdbdst = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), - M_TDB, M_WAITOK | M_ZERO); - tdbsrc = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), - M_TDB, M_WAITOK | M_ZERO); - } - + mtx_enter(&tdb_sadb_mtx); hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst, tdbp->tdb_sproto); /* @@ -707,9 +704,9 @@ puttdb(struct tdb *tdbp) */ if (tdbh[hashval] != NULL && tdbh[hashval]->tdb_hnext != NULL && tdb_count * 10 > tdb_hashmask + 1) { - tdb_rehash(); - hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst, - tdbp->tdb_sproto); + if (tdb_rehash() == 0) + hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst, + tdbp->tdb_sproto); } tdbp->tdb_hnext = tdbh[hashval]; @@ -730,6 +727,7 @@ puttdb(struct tdb *tdbp) #endif /* IPSEC */ ipsec_last_added = getuptime(); + mtx_leave(&tdb_sadb_mtx); } void @@ -738,11 +736,7 @@ tdb_unlink(struct tdb *tdbp) struct tdb *tdbpp; u_int32_t hashval; - NET_ASSERT_LOCKED(); - - if (tdbh == NULL) - return; - + mtx_enter(&tdb_sadb_mtx); hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst, tdbp->tdb_sproto); if (tdbh[hashval] == tdbp) { @@ -799,6 +793,7 @@ tdb_unlink(struct tdb *tdbp) ipsecstat_inc(ipsec_prevtunnels); } #endif /* IPSEC */ + mtx_leave(&tdb_sadb_mtx); } void