there are links between the pcb/socket layer and pf as an optimisation, and links on mbufs between both sides of a forwarded connection. these links let pf skip an rb tree lookup for outgoing packets.
right now these links are between pf_state_key structs, which are the things that contain the actual addresses used by the connection, but you then have to iterate over a list in pf_state_keys to get to the pf_state structures. i dont understand why we dont just link the actual pf_state structs. my best guess is there wasnt enough machinery (ie, refcnts and mtxes) on a pf_state struct to make it safe, so the compromise was the pf_state keys. it still got to avoid the tree lookup. i wanted this to make it easier to look up information on pf states from the socket layer, but sashan@ said i should send it out. i do think it makes things a bit easier to understand. the most worrying bit is the change to pf_state_find(). thoughts? ok? Index: kern/uipc_mbuf.c =================================================================== RCS file: /cvs/src/sys/kern/uipc_mbuf.c,v retrieving revision 1.287 diff -u -p -r1.287 uipc_mbuf.c --- kern/uipc_mbuf.c 23 Jun 2023 04:36:49 -0000 1.287 +++ kern/uipc_mbuf.c 17 Aug 2023 01:31:04 -0000 @@ -308,7 +308,7 @@ m_clearhdr(struct mbuf *m) /* delete all mbuf tags to reset the state */ m_tag_delete_chain(m); #if NPF > 0 - pf_mbuf_unlink_state_key(m); + pf_mbuf_unlink_state(m); pf_mbuf_unlink_inpcb(m); #endif /* NPF > 0 */ @@ -440,7 +440,7 @@ m_free(struct mbuf *m) if (m->m_flags & M_PKTHDR) { m_tag_delete_chain(m); #if NPF > 0 - pf_mbuf_unlink_state_key(m); + pf_mbuf_unlink_state(m); pf_mbuf_unlink_inpcb(m); #endif /* NPF > 0 */ } @@ -1398,8 +1398,8 @@ m_dup_pkthdr(struct mbuf *to, struct mbu to->m_pkthdr = from->m_pkthdr; #if NPF > 0 - to->m_pkthdr.pf.statekey = NULL; - pf_mbuf_link_state_key(to, from->m_pkthdr.pf.statekey); + to->m_pkthdr.pf.st = NULL; + pf_mbuf_link_state(to, from->m_pkthdr.pf.st); to->m_pkthdr.pf.inp = NULL; pf_mbuf_link_inpcb(to, from->m_pkthdr.pf.inp); #endif /* NPF > 0 */ @@ -1526,8 +1526,8 @@ m_print(void *v, m->m_pkthdr.csum_flags, MCS_BITS); (*pr)("m_pkthdr.ether_vtag: %u\tm_ptkhdr.ph_rtableid: %u\n", m->m_pkthdr.ether_vtag, m->m_pkthdr.ph_rtableid); - (*pr)("m_pkthdr.pf.statekey: %p\tm_pkthdr.pf.inp %p\n", - m->m_pkthdr.pf.statekey, m->m_pkthdr.pf.inp); + (*pr)("m_pkthdr.pf.st: %p\tm_pkthdr.pf.inp %p\n", + m->m_pkthdr.pf.st, m->m_pkthdr.pf.inp); (*pr)("m_pkthdr.pf.qid: %u\tm_pkthdr.pf.tag: %u\n", m->m_pkthdr.pf.qid, m->m_pkthdr.pf.tag); (*pr)("m_pkthdr.pf.flags: %b\n", Index: net/if_mpw.c =================================================================== RCS file: /cvs/src/sys/net/if_mpw.c,v retrieving revision 1.63 diff -u -p -r1.63 if_mpw.c --- net/if_mpw.c 29 Aug 2022 07:51:45 -0000 1.63 +++ net/if_mpw.c 17 Aug 2023 01:31:04 -0000 @@ -620,7 +620,7 @@ mpw_input(struct mpw_softc *sc, struct m m->m_pkthdr.ph_rtableid = ifp->if_rdomain; /* packet has not been processed by PF yet. */ - KASSERT(m->m_pkthdr.pf.statekey == NULL); + KASSERT(m->m_pkthdr.pf.st == NULL); if_vinput(ifp, m); return; Index: net/if_tpmr.c =================================================================== RCS file: /cvs/src/sys/net/if_tpmr.c,v retrieving revision 1.33 diff -u -p -r1.33 if_tpmr.c --- net/if_tpmr.c 16 May 2023 14:32:54 -0000 1.33 +++ net/if_tpmr.c 17 Aug 2023 01:31:04 -0000 @@ -303,7 +303,7 @@ tpmr_pf(struct ifnet *ifp0, int dir, str return (NULL); if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) { - pf_mbuf_unlink_state_key(m); + pf_mbuf_unlink_state(m); pf_mbuf_unlink_inpcb(m); (*fam->ip_input)(ifp0, m); return (NULL); Index: net/if_veb.c =================================================================== RCS file: /cvs/src/sys/net/if_veb.c,v retrieving revision 1.31 diff -u -p -r1.31 if_veb.c --- net/if_veb.c 16 May 2023 14:32:54 -0000 1.31 +++ net/if_veb.c 17 Aug 2023 01:31:04 -0000 @@ -654,7 +654,7 @@ veb_pf(struct ifnet *ifp0, int dir, stru return (NULL); if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) { - pf_mbuf_unlink_state_key(m); + pf_mbuf_unlink_state(m); pf_mbuf_unlink_inpcb(m); (*fam->ip_input)(ifp0, m); return (NULL); Index: net/pf.c =================================================================== RCS file: /cvs/src/sys/net/pf.c,v retrieving revision 1.1184 diff -u -p -r1.1184 pf.c --- net/pf.c 31 Jul 2023 11:13:09 -0000 1.1184 +++ net/pf.c 17 Aug 2023 01:31:04 -0000 @@ -247,16 +247,17 @@ int pf_state_insert(struct pfi_kif *, struct pf_state_key **, struct pf_state_key **, struct pf_state *); +int pf_state_isvalid(struct pf_state *); int pf_state_key_isvalid(struct pf_state_key *); struct pf_state_key *pf_state_key_ref(struct pf_state_key *); void pf_state_key_unref(struct pf_state_key *); -void pf_state_key_link_reverse(struct pf_state_key *, - struct pf_state_key *); -void pf_state_key_unlink_reverse(struct pf_state_key *); -void pf_state_key_link_inpcb(struct pf_state_key *, +void pf_state_link_reverse(struct pf_state *, + struct pf_state *); +void pf_state_unlink_reverse(struct pf_state *); +void pf_state_link_inpcb(struct pf_state *, struct inpcb *); -void pf_state_key_unlink_inpcb(struct pf_state_key *); -void pf_inpcb_unlink_state_key(struct inpcb *); +void pf_state_unlink_inpcb(struct pf_state *); +void pf_inpcb_unlink_state(struct inpcb *); void pf_pktenqueue_delayed(void *); int32_t pf_state_expires(const struct pf_state *, uint8_t); @@ -852,8 +853,6 @@ pf_state_key_detach(struct pf_state *st, if (TAILQ_EMPTY(&sk->sk_states)) { RBT_REMOVE(pf_state_tree, &pf_statetbl, sk); sk->sk_removed = 1; - pf_state_key_unlink_reverse(sk); - pf_state_key_unlink_inpcb(sk); pf_state_key_unref(sk); } @@ -1115,13 +1114,41 @@ pf_compare_state_keys(struct pf_state_ke } } +static inline struct pf_state * +pf_find_state_lookup(struct pf_pdesc *pd, const struct pf_state_key_cmp *key) +{ + struct pf_state_key *sk; + struct pf_state_item *si; + struct pf_state *st; + uint8_t dir = pd->dir; + + sk = RBT_FIND(pf_state_tree, &pf_statetbl, (struct pf_state_key *)key); + if (sk == NULL) + return (NULL); + + /* list is sorted, if-bound states before floating ones */ + TAILQ_FOREACH(si, &sk->sk_states, si_entry) { + st = si->si_st; + if (st->timeout == PFTM_PURGE) + continue; + if (st->kif != pfi_all && st->kif != pd->kif) + continue; + + if (st->key[dir == PF_IN ? PF_SK_WIRE : PF_SK_STACK] == sk) + return (st); + } + + return (NULL); +} + int pf_find_state(struct pf_pdesc *pd, struct pf_state_key_cmp *key, struct pf_state **stp) { - struct pf_state_key *sk, *pkt_sk, *inp_sk; - struct pf_state_item *si; struct pf_state *st = NULL; + struct pf_state *strev = NULL; + struct inpcb *inp = NULL; + int rv = PF_DROP; pf_status.fcounters[FCNT_STATE_SEARCH]++; if (pf_status.debug >= LOG_DEBUG) { @@ -1131,80 +1158,67 @@ pf_find_state(struct pf_pdesc *pd, struc addlog("\n"); } - inp_sk = NULL; - pkt_sk = NULL; - sk = NULL; if (pd->dir == PF_OUT) { + /* take the references */ + strev = pd->m->m_pkthdr.pf.st; + inp = pd->m->m_pkthdr.pf.inp; + /* first if block deals with outbound forwarded packet */ - pkt_sk = pd->m->m_pkthdr.pf.statekey; + if (strev != NULL) { + pd->m->m_pkthdr.pf.st = NULL; + KASSERT(inp == NULL); - if (!pf_state_key_isvalid(pkt_sk)) { - pf_mbuf_unlink_state_key(pd->m); - pkt_sk = NULL; - } + if (pf_state_isvalid(strev)) { + st = strev->reverse; + if (st != NULL && pf_state_isvalid(st)) + goto match; + } - if (pkt_sk && pf_state_key_isvalid(pkt_sk->sk_reverse)) - sk = pkt_sk->sk_reverse; + /* this handles st not being valid too */ + pf_state_unlink_reverse(strev); - if (pkt_sk == NULL) { + } else if (inp != NULL) { /* here we deal with local outbound packet */ - if (pd->m->m_pkthdr.pf.inp != NULL) { - inp_sk = pd->m->m_pkthdr.pf.inp->inp_pf_sk; - if (pf_state_key_isvalid(inp_sk)) - sk = inp_sk; - else - pf_inpcb_unlink_state_key( - pd->m->m_pkthdr.pf.inp); + pd->m->m_pkthdr.pf.inp = NULL; + + st = inp->inp_pf_st; + if (st != NULL) { + if (pf_state_isvalid(st)) + goto match; + + pf_inpcb_unlink_state(inp); } } } - if (sk == NULL) { - if ((sk = RBT_FIND(pf_state_tree, &pf_statetbl, - (struct pf_state_key *)key)) == NULL) - return (PF_DROP); - if (pd->dir == PF_OUT && pkt_sk && - pf_compare_state_keys(pkt_sk, sk, pd->kif, pd->dir) == 0) - pf_state_key_link_reverse(sk, pkt_sk); - else if (pd->dir == PF_OUT && pd->m->m_pkthdr.pf.inp && - !pd->m->m_pkthdr.pf.inp->inp_pf_sk && !sk->sk_inp) - pf_state_key_link_inpcb(sk, pd->m->m_pkthdr.pf.inp); - } - - /* remove firewall data from outbound packet */ - if (pd->dir == PF_OUT) - pf_pkt_addr_changed(pd->m); + st = pf_find_state_lookup(pd, key); + if (st == NULL || ISSET(st->state_flags, PFSTATE_INP_UNLINKED)) + goto drop; - /* list is sorted, if-bound states before floating ones */ - TAILQ_FOREACH(si, &sk->sk_states, si_entry) { - struct pf_state *sist = si->si_st; - if (sist->timeout != PFTM_PURGE && - (sist->kif == pfi_all || sist->kif == pd->kif) && - ((sist->key[PF_SK_WIRE]->af == sist->key[PF_SK_STACK]->af && - sk == (pd->dir == PF_IN ? sist->key[PF_SK_WIRE] : - sist->key[PF_SK_STACK])) || - (sist->key[PF_SK_WIRE]->af != sist->key[PF_SK_STACK]->af - && pd->dir == PF_IN && (sk == sist->key[PF_SK_STACK] || - sk == sist->key[PF_SK_WIRE])))) { - st = sist; - break; - } + if (pd->dir == PF_OUT) { + if (strev != NULL) + pf_state_link_reverse(st, strev); + else if (inp != NULL) + pf_state_link_inpcb(st, inp); } - if (st == NULL) - return (PF_DROP); - if (ISSET(st->state_flags, PFSTATE_INP_UNLINKED)) - return (PF_DROP); - +match: if (st->rule.ptr->pktrate.limit && pd->dir == st->direction) { pf_add_threshold(&st->rule.ptr->pktrate); if (pf_check_threshold(&st->rule.ptr->pktrate)) - return (PF_DROP); + goto drop; } *stp = st; + rv = PF_MATCH; + +drop: + if (strev != NULL) + pf_state_unref(strev); + else if (inp != NULL) + in_pcbunref(inp); - return (PF_MATCH); + return (rv); } struct pf_state * @@ -1763,6 +1777,9 @@ pf_remove_state(struct pf_state *st) st->timeout = PFTM_UNLINKED; + pf_state_unlink_reverse(st); + pf_state_unlink_inpcb(st); + /* handle load balancing related tasks */ pf_postprocess_addr(st); @@ -1792,38 +1809,32 @@ pf_remove_state(struct pf_state *st) } void -pf_remove_divert_state(struct pf_state_key *sk) +pf_remove_divert_state(struct pf_state *st) { - struct pf_state_item *si; - PF_ASSERT_UNLOCKED(); PF_LOCK(); PF_STATE_ENTER_WRITE(); - TAILQ_FOREACH(si, &sk->sk_states, si_entry) { - struct pf_state *sist = si->si_st; - if (sk == sist->key[PF_SK_STACK] && sist->rule.ptr && - (sist->rule.ptr->divert.type == PF_DIVERT_TO || - sist->rule.ptr->divert.type == PF_DIVERT_REPLY)) { - if (sist->key[PF_SK_STACK]->proto == IPPROTO_TCP && - sist->key[PF_SK_WIRE] != sist->key[PF_SK_STACK]) { - /* - * If the local address is translated, keep - * the state for "tcp.closed" seconds to - * prevent its source port from being reused. - */ - if (sist->src.state < TCPS_FIN_WAIT_2 || - sist->dst.state < TCPS_FIN_WAIT_2) { - pf_set_protostate(sist, PF_PEER_BOTH, - TCPS_TIME_WAIT); - sist->timeout = PFTM_TCP_CLOSED; - sist->expire = getuptime(); - } - sist->state_flags |= PFSTATE_INP_UNLINKED; - } else - pf_remove_state(sist); - break; - } + if (st->rule.ptr && + (st->rule.ptr->divert.type == PF_DIVERT_TO || + st->rule.ptr->divert.type == PF_DIVERT_REPLY)) { + if (st->key[PF_SK_STACK]->proto == IPPROTO_TCP && + st->key[PF_SK_WIRE] != st->key[PF_SK_STACK]) { + /* + * If the local address is translated, keep + * the state for "tcp.closed" seconds to + * prevent its source port from being reused. + */ + if (st->src.state < TCPS_FIN_WAIT_2 || + st->dst.state < TCPS_FIN_WAIT_2) { + pf_set_protostate(st, PF_PEER_BOTH, + TCPS_TIME_WAIT); + st->timeout = PFTM_TCP_CLOSED; + st->expire = getuptime(); + } + st->state_flags |= PFSTATE_INP_UNLINKED; + } else + pf_remove_state(st); } PF_STATE_EXIT_WRITE(); PF_UNLOCK(); @@ -7836,17 +7847,22 @@ done: if (action == PF_PASS && qid) pd.m->m_pkthdr.pf.qid = qid; - if (pd.dir == PF_IN && st && st->key[PF_SK_STACK]) - pf_mbuf_link_state_key(pd.m, st->key[PF_SK_STACK]); - if (pd.dir == PF_OUT && - pd.m->m_pkthdr.pf.inp && !pd.m->m_pkthdr.pf.inp->inp_pf_sk && - st && st->key[PF_SK_STACK] && !st->key[PF_SK_STACK]->sk_inp) - pf_state_key_link_inpcb(st->key[PF_SK_STACK], - pd.m->m_pkthdr.pf.inp); - - if (st != NULL && !ISSET(pd.m->m_pkthdr.csum_flags, M_FLOWID)) { - pd.m->m_pkthdr.ph_flowid = st->key[PF_SK_WIRE]->hash; - SET(pd.m->m_pkthdr.csum_flags, M_FLOWID); + if (st != NULL) { + struct mbuf *m = pd.m; + struct inpcb *inp = m->m_pkthdr.pf.inp; + + if (pd.dir == PF_IN) { + KASSERT(inp == NULL); + pf_mbuf_link_state(m, st); + } else { + if (inp != NULL && inp->inp_pf_st == NULL) + pf_state_link_inpcb(st, inp); + } + + if (!ISSET(m->m_pkthdr.csum_flags, M_FLOWID)) { + m->m_pkthdr.ph_flowid = st->key[PF_SK_WIRE]->hash; + SET(m->m_pkthdr.csum_flags, M_FLOWID); + } } /* @@ -8004,14 +8020,14 @@ done: int pf_ouraddr(struct mbuf *m) { - struct pf_state_key *sk; + struct pf_state *st; if (m->m_pkthdr.pf.flags & PF_TAG_DIVERTED) return (1); - sk = m->m_pkthdr.pf.statekey; - if (sk != NULL) { - if (sk->sk_inp != NULL) + st = m->m_pkthdr.pf.st; + if (st != NULL) { + if (st->inp != NULL) return (1); } @@ -8025,7 +8041,7 @@ pf_ouraddr(struct mbuf *m) void pf_pkt_addr_changed(struct mbuf *m) { - pf_mbuf_unlink_state_key(m); + pf_mbuf_unlink_state(m); pf_mbuf_unlink_inpcb(m); } @@ -8033,71 +8049,56 @@ struct inpcb * pf_inp_lookup(struct mbuf *m) { struct inpcb *inp = NULL; - struct pf_state_key *sk = m->m_pkthdr.pf.statekey; + struct pf_state *st; - if (!pf_state_key_isvalid(sk)) - pf_mbuf_unlink_state_key(m); - else - inp = m->m_pkthdr.pf.statekey->sk_inp; + st = m->m_pkthdr.pf.st; + if (st == NULL) + return (NULL); + if (!pf_state_isvalid(st)) { + pf_mbuf_unlink_state(m); + return (NULL); + } + + inp = st->inp; + if (inp == NULL) + return (NULL); - if (inp && inp->inp_pf_sk) - KASSERT(m->m_pkthdr.pf.statekey == inp->inp_pf_sk); + KASSERT(inp->inp_pf_st == NULL || inp->inp_pf_st == st); - in_pcbref(inp); - return (inp); + return (in_pcbref(inp)); } +/* + * This is called from the IP stack after it's found an inpcb for + * an mbuf so it can link the pf_state to that pcb. + */ void pf_inp_link(struct mbuf *m, struct inpcb *inp) { - struct pf_state_key *sk = m->m_pkthdr.pf.statekey; + struct pf_state *st; - if (!pf_state_key_isvalid(sk)) { - pf_mbuf_unlink_state_key(m); + st = m->m_pkthdr.pf.st; + if (st == NULL) return; - } /* * we don't need to grab PF-lock here. At worst case we link inp to * state, which might be just being marked as deleted by another * thread. */ - if (inp && !sk->sk_inp && !inp->inp_pf_sk) - pf_state_key_link_inpcb(sk, inp); + if (pf_state_isvalid(st)) { + if (st->inp == NULL && inp->inp_pf_st == NULL) + pf_state_link_inpcb(st, inp); + } /* The statekey has finished finding the inp, it is no longer needed. */ - pf_mbuf_unlink_state_key(m); + pf_mbuf_unlink_state(m); } void pf_inp_unlink(struct inpcb *inp) { - pf_inpcb_unlink_state_key(inp); -} - -void -pf_state_key_link_reverse(struct pf_state_key *sk, struct pf_state_key *skrev) -{ - struct pf_state_key *old_reverse; - - old_reverse = atomic_cas_ptr(&sk->sk_reverse, NULL, skrev); - if (old_reverse != NULL) - KASSERT(old_reverse == skrev); - else { - pf_state_key_ref(skrev); - - /* - * NOTE: if sk == skrev, then KASSERT() below holds true, we - * still want to grab a reference in such case, because - * pf_state_key_unlink_reverse() does not check whether keys - * are identical or not. - */ - old_reverse = atomic_cas_ptr(&skrev->sk_reverse, NULL, sk); - if (old_reverse != NULL) - KASSERT(old_reverse == sk); - - pf_state_key_ref(sk); - } + pf_inpcb_unlink_state(inp); } #if NPFLOG > 0 @@ -8132,10 +8133,6 @@ pf_state_key_unref(struct pf_state_key * if (PF_REF_RELE(sk->sk_refcnt)) { /* state key must be removed from tree */ KASSERT(!pf_state_key_isvalid(sk)); - /* state key must be unlinked from reverse key */ - KASSERT(sk->sk_reverse == NULL); - /* state key must be unlinked from socket */ - KASSERT(sk->sk_inp == NULL); pool_put(&pf_state_key_pl, sk); } } @@ -8146,21 +8143,28 @@ pf_state_key_isvalid(struct pf_state_key return ((sk != NULL) && (sk->sk_removed == 0)); } +int +pf_state_isvalid(struct pf_state *st) +{ + return (st->timeout < PFTM_MAX); +} + void -pf_mbuf_link_state_key(struct mbuf *m, struct pf_state_key *sk) +pf_mbuf_link_state(struct mbuf *m, struct pf_state *st) { - KASSERT(m->m_pkthdr.pf.statekey == NULL); - m->m_pkthdr.pf.statekey = pf_state_key_ref(sk); + KASSERT(m->m_pkthdr.pf.st == NULL); + m->m_pkthdr.pf.st = pf_state_ref(st); } void -pf_mbuf_unlink_state_key(struct mbuf *m) +pf_mbuf_unlink_state(struct mbuf *m) { - struct pf_state_key *sk = m->m_pkthdr.pf.statekey; + struct pf_state *st; - if (sk != NULL) { - m->m_pkthdr.pf.statekey = NULL; - pf_state_key_unref(sk); + st = m->m_pkthdr.pf.st; + if (st != NULL) { + m->m_pkthdr.pf.st = NULL; + pf_state_unref(st); } } @@ -8174,64 +8178,107 @@ pf_mbuf_link_inpcb(struct mbuf *m, struc void pf_mbuf_unlink_inpcb(struct mbuf *m) { - struct inpcb *inp = m->m_pkthdr.pf.inp; + struct inpcb *inp; + inp = m->m_pkthdr.pf.inp; if (inp != NULL) { m->m_pkthdr.pf.inp = NULL; in_pcbunref(inp); } } +/* assumes caller has an exclusive lock around inp */ void -pf_state_key_link_inpcb(struct pf_state_key *sk, struct inpcb *inp) +pf_state_link_inpcb(struct pf_state *st, struct inpcb *inp) { - KASSERT(sk->sk_inp == NULL); - sk->sk_inp = in_pcbref(inp); - KASSERT(inp->inp_pf_sk == NULL); - inp->inp_pf_sk = pf_state_key_ref(sk); + KASSERT(inp->inp_pf_st == NULL); + inp->inp_pf_st = pf_state_ref(st); + + mtx_enter(&st->mtx); + KASSERT(st->inp == NULL); + st->inp = in_pcbref(inp); + mtx_leave(&st->mtx); } +/* assumes caller has an exclusive lock around inp */ void -pf_inpcb_unlink_state_key(struct inpcb *inp) +pf_inpcb_unlink_state(struct inpcb *inp) { - struct pf_state_key *sk = inp->inp_pf_sk; + struct pf_state *st; - if (sk != NULL) { - KASSERT(sk->sk_inp == inp); - sk->sk_inp = NULL; - inp->inp_pf_sk = NULL; - pf_state_key_unref(sk); + st = inp->inp_pf_st; + if (st != NULL) { + inp->inp_pf_st = NULL; + + mtx_enter(&st->mtx); + KASSERT(st->inp == inp); + st->inp = NULL; + mtx_leave(&st->mtx); in_pcbunref(inp); + + pf_state_unref(st); } } void -pf_state_key_unlink_inpcb(struct pf_state_key *sk) +pf_state_unlink_inpcb(struct pf_state *st) { - struct inpcb *inp = sk->sk_inp; + struct inpcb *inp; + + mtx_enter(&st->mtx); + inp = st->inp; + if (inp != NULL) + st->inp = NULL; + mtx_leave(&st->mtx); + /* XXX wtf lock? */ if (inp != NULL) { - KASSERT(inp->inp_pf_sk == sk); - sk->sk_inp = NULL; - inp->inp_pf_sk = NULL; - pf_state_key_unref(sk); + KASSERT(inp->inp_pf_st == st); + inp->inp_pf_st = NULL; + + pf_state_unref(st); in_pcbunref(inp); } } void -pf_state_key_unlink_reverse(struct pf_state_key *sk) +pf_state_link_reverse(struct pf_state *st, struct pf_state *strev) { - struct pf_state_key *skrev = sk->sk_reverse; + mtx_enter(&st->mtx); + if (st->reverse == NULL) + st->reverse = pf_state_ref(strev); + mtx_leave(&st->mtx); - /* Note that sk and skrev may be equal, then we unref twice. */ - if (skrev != NULL) { - KASSERT(skrev->sk_reverse == sk); - sk->sk_reverse = NULL; - skrev->sk_reverse = NULL; - pf_state_key_unref(skrev); - pf_state_key_unref(sk); - } + mtx_enter(&strev->mtx); + if (strev->reverse == NULL) + strev->reverse = pf_state_ref(st); + mtx_leave(&strev->mtx); +} + +void +pf_state_unlink_reverse(struct pf_state *st) +{ + struct pf_state *strev; + + mtx_enter(&st->mtx); + strev = st->reverse; + if (strev != NULL) + st->reverse = NULL; /* take over strev reference */ + mtx_leave(&st->mtx); + + if (strev == NULL) + return; + + mtx_enter(&strev->mtx); + if (strev->reverse == st) + strev->reverse = NULL; + else + st = NULL; + mtx_leave(&strev->mtx); + + pf_state_unref(strev); /* drop the reference we just inherited */ + if (st != NULL) + pf_state_unref(st); /* drop the reference strev had */ } struct pf_state * @@ -8257,6 +8304,11 @@ pf_state_unref(struct pf_state *st) pf_state_key_unref(st->key[PF_SK_WIRE]); pf_state_key_unref(st->key[PF_SK_STACK]); + + /* state must be unlinked from reverse */ + KASSERT(st->reverse == NULL); + /* state must be unlinked from socket */ + KASSERT(st->inp == NULL); pool_put(&pf_state_pl, st); } Index: net/pfvar.h =================================================================== RCS file: /cvs/src/sys/net/pfvar.h,v retrieving revision 1.533 diff -u -p -r1.533 pfvar.h --- net/pfvar.h 6 Jul 2023 04:55:05 -0000 1.533 +++ net/pfvar.h 17 Aug 2023 01:31:04 -0000 @@ -1606,7 +1606,7 @@ extern void pf_calc_skip_steps(struct extern void pf_purge_expired_src_nodes(void); extern void pf_purge_expired_rules(void); extern void pf_remove_state(struct pf_state *); -extern void pf_remove_divert_state(struct pf_state_key *); +extern void pf_remove_divert_state(struct pf_state *); extern void pf_free_state(struct pf_state *); int pf_insert_src_node(struct pf_src_node **, struct pf_rule *, enum pf_sn_types, @@ -1860,9 +1860,8 @@ int pf_map_addr(sa_family_t, struct p struct pf_pool *, enum pf_sn_types); int pf_postprocess_addr(struct pf_state *); -void pf_mbuf_link_state_key(struct mbuf *, - struct pf_state_key *); -void pf_mbuf_unlink_state_key(struct mbuf *); +void pf_mbuf_link_state(struct mbuf *, struct pf_state *); +void pf_mbuf_unlink_state(struct mbuf *); void pf_mbuf_link_inpcb(struct mbuf *, struct inpcb *); void pf_mbuf_unlink_inpcb(struct mbuf *); Index: net/pfvar_priv.h =================================================================== RCS file: /cvs/src/sys/net/pfvar_priv.h,v retrieving revision 1.34 diff -u -p -r1.34 pfvar_priv.h --- net/pfvar_priv.h 6 Jul 2023 04:55:05 -0000 1.34 +++ net/pfvar_priv.h 17 Aug 2023 01:31:04 -0000 @@ -69,8 +69,6 @@ struct pf_state_key { RB_ENTRY(pf_state_key) sk_entry; struct pf_statelisthead sk_states; - struct pf_state_key *sk_reverse; - struct inpcb *sk_inp; pf_refcnt_t sk_refcnt; u_int8_t sk_removed; }; @@ -115,6 +113,8 @@ struct pf_state { struct pf_sn_head src_nodes; /* [I] */ struct pf_state_key *key[2]; /* [I] stack and wire */ struct pfi_kif *kif; /* [I] */ + struct pf_state *reverse; /* [M] */ + struct inpcb *inp; /* [M] */ struct mutex mtx; pf_refcnt_t refcnt; u_int64_t packets[2]; Index: netinet/in_pcb.c =================================================================== RCS file: /cvs/src/sys/netinet/in_pcb.c,v retrieving revision 1.277 diff -u -p -r1.277 in_pcb.c --- netinet/in_pcb.c 24 Jun 2023 20:54:46 -0000 1.277 +++ netinet/in_pcb.c 17 Aug 2023 01:31:04 -0000 @@ -538,8 +538,8 @@ void in_pcbdisconnect(struct inpcb *inp) { #if NPF > 0 - if (inp->inp_pf_sk) { - pf_remove_divert_state(inp->inp_pf_sk); + if (inp->inp_pf_st) { + pf_remove_divert_state(inp->inp_pf_st); /* pf_remove_divert_state() may have detached the state */ pf_inp_unlink(inp); } @@ -588,8 +588,8 @@ in_pcbdetach(struct inpcb *inp) #endif ip_freemoptions(inp->inp_moptions); #if NPF > 0 - if (inp->inp_pf_sk) { - pf_remove_divert_state(inp->inp_pf_sk); + if (inp->inp_pf_st) { + pf_remove_divert_state(inp->inp_pf_st); /* pf_remove_divert_state() may have detached the state */ pf_inp_unlink(inp); } Index: netinet/in_pcb.h =================================================================== RCS file: /cvs/src/sys/netinet/in_pcb.h,v retrieving revision 1.136 diff -u -p -r1.136 in_pcb.h --- netinet/in_pcb.h 24 Jun 2023 20:54:46 -0000 1.136 +++ netinet/in_pcb.h 17 Aug 2023 01:31:04 -0000 @@ -84,7 +84,7 @@ * p inpcb_mtx pcb mutex */ -struct pf_state_key; +struct pf_state; union inpaddru { struct in6_addr iau_addr6; @@ -155,7 +155,7 @@ struct inpcb { #define inp_csumoffset inp_cksum6 #endif struct icmp6_filter *inp_icmp6filt; - struct pf_state_key *inp_pf_sk; + struct pf_state *inp_pf_st; struct mbuf *(*inp_upcall)(void *, struct mbuf *, struct ip *, struct ip6_hdr *, void *, int); void *inp_upcall_arg; Index: sys/mbuf.h =================================================================== RCS file: /cvs/src/sys/sys/mbuf.h,v retrieving revision 1.261 diff -u -p -r1.261 mbuf.h --- sys/mbuf.h 16 Jul 2023 03:01:31 -0000 1.261 +++ sys/mbuf.h 17 Aug 2023 01:31:04 -0000 @@ -92,11 +92,11 @@ struct m_hdr { }; /* pf stuff */ -struct pf_state_key; +struct pf_state; struct inpcb; struct pkthdr_pf { - struct pf_state_key *statekey; /* pf stackside statekey */ + struct pf_state *st; /* pf state */ struct inpcb *inp; /* connected pcb for outgoing packet */ u_int32_t qid; /* queue id */ u_int16_t tag; /* tag id */ @@ -327,7 +327,7 @@ u_int mextfree_register(void (*)(caddr_t (to)->m_pkthdr = (from)->m_pkthdr; \ (from)->m_flags &= ~M_PKTHDR; \ SLIST_INIT(&(from)->m_pkthdr.ph_tags); \ - (from)->m_pkthdr.pf.statekey = NULL; \ + (from)->m_pkthdr.pf.st = NULL; \ } while (/* CONSTCOND */ 0) /*