there are links between the pcb/socket layer and pf as an optimisation,
and links on mbufs between both sides of a forwarded connection.
these links let pf skip an rb tree lookup for outgoing packets.

right now these links are between pf_state_key structs, which are the
things that contain the actual addresses used by the connection, but you
then have to iterate over a list in pf_state_keys to get to the pf_state
structures.

i dont understand why we dont just link the actual pf_state structs.
my best guess is there wasnt enough machinery (ie, refcnts and
mtxes) on a pf_state struct to make it safe, so the compromise was
the pf_state keys. it still got to avoid the tree lookup.

i wanted this to make it easier to look up information on pf states from
the socket layer, but sashan@ said i should send it out. i do think it
makes things a bit easier to understand.

the most worrying bit is the change to pf_state_find().

thoughts? ok?

Index: kern/uipc_mbuf.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_mbuf.c,v
retrieving revision 1.287
diff -u -p -r1.287 uipc_mbuf.c
--- kern/uipc_mbuf.c    23 Jun 2023 04:36:49 -0000      1.287
+++ kern/uipc_mbuf.c    17 Aug 2023 01:31:04 -0000
@@ -308,7 +308,7 @@ m_clearhdr(struct mbuf *m)
        /* delete all mbuf tags to reset the state */
        m_tag_delete_chain(m);
 #if NPF > 0
-       pf_mbuf_unlink_state_key(m);
+       pf_mbuf_unlink_state(m);
        pf_mbuf_unlink_inpcb(m);
 #endif /* NPF > 0 */
 
@@ -440,7 +440,7 @@ m_free(struct mbuf *m)
        if (m->m_flags & M_PKTHDR) {
                m_tag_delete_chain(m);
 #if NPF > 0
-               pf_mbuf_unlink_state_key(m);
+               pf_mbuf_unlink_state(m);
                pf_mbuf_unlink_inpcb(m);
 #endif /* NPF > 0 */
        }
@@ -1398,8 +1398,8 @@ m_dup_pkthdr(struct mbuf *to, struct mbu
        to->m_pkthdr = from->m_pkthdr;
 
 #if NPF > 0
-       to->m_pkthdr.pf.statekey = NULL;
-       pf_mbuf_link_state_key(to, from->m_pkthdr.pf.statekey);
+       to->m_pkthdr.pf.st = NULL;
+       pf_mbuf_link_state(to, from->m_pkthdr.pf.st);
        to->m_pkthdr.pf.inp = NULL;
        pf_mbuf_link_inpcb(to, from->m_pkthdr.pf.inp);
 #endif /* NPF > 0 */
@@ -1526,8 +1526,8 @@ m_print(void *v,
                    m->m_pkthdr.csum_flags, MCS_BITS);
                (*pr)("m_pkthdr.ether_vtag: %u\tm_ptkhdr.ph_rtableid: %u\n",
                    m->m_pkthdr.ether_vtag, m->m_pkthdr.ph_rtableid);
-               (*pr)("m_pkthdr.pf.statekey: %p\tm_pkthdr.pf.inp %p\n",
-                   m->m_pkthdr.pf.statekey, m->m_pkthdr.pf.inp);
+               (*pr)("m_pkthdr.pf.st: %p\tm_pkthdr.pf.inp %p\n",
+                   m->m_pkthdr.pf.st, m->m_pkthdr.pf.inp);
                (*pr)("m_pkthdr.pf.qid: %u\tm_pkthdr.pf.tag: %u\n",
                    m->m_pkthdr.pf.qid, m->m_pkthdr.pf.tag);
                (*pr)("m_pkthdr.pf.flags: %b\n",
Index: net/if_mpw.c
===================================================================
RCS file: /cvs/src/sys/net/if_mpw.c,v
retrieving revision 1.63
diff -u -p -r1.63 if_mpw.c
--- net/if_mpw.c        29 Aug 2022 07:51:45 -0000      1.63
+++ net/if_mpw.c        17 Aug 2023 01:31:04 -0000
@@ -620,7 +620,7 @@ mpw_input(struct mpw_softc *sc, struct m
        m->m_pkthdr.ph_rtableid = ifp->if_rdomain;
 
        /* packet has not been processed by PF yet. */
-       KASSERT(m->m_pkthdr.pf.statekey == NULL);
+       KASSERT(m->m_pkthdr.pf.st == NULL);
 
        if_vinput(ifp, m);
        return;
Index: net/if_tpmr.c
===================================================================
RCS file: /cvs/src/sys/net/if_tpmr.c,v
retrieving revision 1.33
diff -u -p -r1.33 if_tpmr.c
--- net/if_tpmr.c       16 May 2023 14:32:54 -0000      1.33
+++ net/if_tpmr.c       17 Aug 2023 01:31:04 -0000
@@ -303,7 +303,7 @@ tpmr_pf(struct ifnet *ifp0, int dir, str
                return (NULL);
 
        if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) {
-               pf_mbuf_unlink_state_key(m);
+               pf_mbuf_unlink_state(m);
                pf_mbuf_unlink_inpcb(m);
                (*fam->ip_input)(ifp0, m);
                return (NULL);
Index: net/if_veb.c
===================================================================
RCS file: /cvs/src/sys/net/if_veb.c,v
retrieving revision 1.31
diff -u -p -r1.31 if_veb.c
--- net/if_veb.c        16 May 2023 14:32:54 -0000      1.31
+++ net/if_veb.c        17 Aug 2023 01:31:04 -0000
@@ -654,7 +654,7 @@ veb_pf(struct ifnet *ifp0, int dir, stru
                return (NULL);
 
        if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) {
-               pf_mbuf_unlink_state_key(m);
+               pf_mbuf_unlink_state(m);
                pf_mbuf_unlink_inpcb(m);
                (*fam->ip_input)(ifp0, m);
                return (NULL);
Index: net/pf.c
===================================================================
RCS file: /cvs/src/sys/net/pf.c,v
retrieving revision 1.1184
diff -u -p -r1.1184 pf.c
--- net/pf.c    31 Jul 2023 11:13:09 -0000      1.1184
+++ net/pf.c    17 Aug 2023 01:31:04 -0000
@@ -247,16 +247,17 @@ int                        pf_state_insert(struct pfi_kif 
*,
                            struct pf_state_key **, struct pf_state_key **,
                            struct pf_state *);
 
+int                     pf_state_isvalid(struct pf_state *);
 int                     pf_state_key_isvalid(struct pf_state_key *);
 struct pf_state_key    *pf_state_key_ref(struct pf_state_key *);
 void                    pf_state_key_unref(struct pf_state_key *);
-void                    pf_state_key_link_reverse(struct pf_state_key *,
-                           struct pf_state_key *);
-void                    pf_state_key_unlink_reverse(struct pf_state_key *);
-void                    pf_state_key_link_inpcb(struct pf_state_key *,
+void                    pf_state_link_reverse(struct pf_state *,
+                           struct pf_state *);
+void                    pf_state_unlink_reverse(struct pf_state *);
+void                    pf_state_link_inpcb(struct pf_state *,
                            struct inpcb *);
-void                    pf_state_key_unlink_inpcb(struct pf_state_key *);
-void                    pf_inpcb_unlink_state_key(struct inpcb *);
+void                    pf_state_unlink_inpcb(struct pf_state *);
+void                    pf_inpcb_unlink_state(struct inpcb *);
 void                    pf_pktenqueue_delayed(void *);
 int32_t                         pf_state_expires(const struct pf_state *, 
uint8_t);
 
@@ -852,8 +853,6 @@ pf_state_key_detach(struct pf_state *st,
        if (TAILQ_EMPTY(&sk->sk_states)) {
                RBT_REMOVE(pf_state_tree, &pf_statetbl, sk);
                sk->sk_removed = 1;
-               pf_state_key_unlink_reverse(sk);
-               pf_state_key_unlink_inpcb(sk);
                pf_state_key_unref(sk);
        }
 
@@ -1115,13 +1114,41 @@ pf_compare_state_keys(struct pf_state_ke
        }
 }
 
+static inline struct pf_state *
+pf_find_state_lookup(struct pf_pdesc *pd, const struct pf_state_key_cmp *key)
+{
+       struct pf_state_key     *sk;
+       struct pf_state_item    *si;
+       struct pf_state         *st;
+       uint8_t                  dir = pd->dir;
+
+       sk = RBT_FIND(pf_state_tree, &pf_statetbl, (struct pf_state_key *)key);
+       if (sk == NULL)
+               return (NULL);
+
+       /* list is sorted, if-bound states before floating ones */
+       TAILQ_FOREACH(si, &sk->sk_states, si_entry) {
+               st = si->si_st;
+               if (st->timeout == PFTM_PURGE)
+                       continue;
+               if (st->kif != pfi_all && st->kif != pd->kif)
+                       continue;
+
+               if (st->key[dir == PF_IN ? PF_SK_WIRE : PF_SK_STACK] == sk)
+                       return (st);
+       }
+
+       return (NULL);
+}
+
 int
 pf_find_state(struct pf_pdesc *pd, struct pf_state_key_cmp *key,
     struct pf_state **stp)
 {
-       struct pf_state_key     *sk, *pkt_sk, *inp_sk;
-       struct pf_state_item    *si;
        struct pf_state         *st = NULL;
+       struct pf_state         *strev = NULL;
+       struct inpcb            *inp = NULL;
+       int                      rv = PF_DROP;
 
        pf_status.fcounters[FCNT_STATE_SEARCH]++;
        if (pf_status.debug >= LOG_DEBUG) {
@@ -1131,80 +1158,67 @@ pf_find_state(struct pf_pdesc *pd, struc
                addlog("\n");
        }
 
-       inp_sk = NULL;
-       pkt_sk = NULL;
-       sk = NULL;
        if (pd->dir == PF_OUT) {
+               /* take the references */
+               strev = pd->m->m_pkthdr.pf.st;
+               inp = pd->m->m_pkthdr.pf.inp;
+
                /* first if block deals with outbound forwarded packet */
-               pkt_sk = pd->m->m_pkthdr.pf.statekey;
+               if (strev != NULL) {
+                       pd->m->m_pkthdr.pf.st = NULL;
+                       KASSERT(inp == NULL);
 
-               if (!pf_state_key_isvalid(pkt_sk)) {
-                       pf_mbuf_unlink_state_key(pd->m);
-                       pkt_sk = NULL;
-               }
+                       if (pf_state_isvalid(strev)) {
+                               st = strev->reverse;
+                               if (st != NULL && pf_state_isvalid(st))
+                                       goto match;
+                       }
 
-               if (pkt_sk && pf_state_key_isvalid(pkt_sk->sk_reverse))
-                       sk = pkt_sk->sk_reverse;
+                       /* this handles st not being valid too */
+                       pf_state_unlink_reverse(strev);
 
-               if (pkt_sk == NULL) {
+               } else if (inp != NULL) {
                        /* here we deal with local outbound packet */
-                       if (pd->m->m_pkthdr.pf.inp != NULL) {
-                               inp_sk = pd->m->m_pkthdr.pf.inp->inp_pf_sk;
-                               if (pf_state_key_isvalid(inp_sk))
-                                       sk = inp_sk;
-                               else
-                                       pf_inpcb_unlink_state_key(
-                                           pd->m->m_pkthdr.pf.inp);
+                       pd->m->m_pkthdr.pf.inp = NULL;
+
+                       st = inp->inp_pf_st;
+                       if (st != NULL) {
+                               if (pf_state_isvalid(st))
+                                       goto match;
+
+                               pf_inpcb_unlink_state(inp);
                        }
                }
        }
 
-       if (sk == NULL) {
-               if ((sk = RBT_FIND(pf_state_tree, &pf_statetbl,
-                   (struct pf_state_key *)key)) == NULL)
-                       return (PF_DROP);
-               if (pd->dir == PF_OUT && pkt_sk &&
-                   pf_compare_state_keys(pkt_sk, sk, pd->kif, pd->dir) == 0)
-                       pf_state_key_link_reverse(sk, pkt_sk);
-               else if (pd->dir == PF_OUT && pd->m->m_pkthdr.pf.inp &&
-                   !pd->m->m_pkthdr.pf.inp->inp_pf_sk && !sk->sk_inp)
-                       pf_state_key_link_inpcb(sk, pd->m->m_pkthdr.pf.inp);
-       }
-
-       /* remove firewall data from outbound packet */
-       if (pd->dir == PF_OUT)
-               pf_pkt_addr_changed(pd->m);
+       st = pf_find_state_lookup(pd, key);
+       if (st == NULL || ISSET(st->state_flags, PFSTATE_INP_UNLINKED))
+               goto drop;
 
-       /* list is sorted, if-bound states before floating ones */
-       TAILQ_FOREACH(si, &sk->sk_states, si_entry) {
-               struct pf_state *sist = si->si_st;
-               if (sist->timeout != PFTM_PURGE &&
-                   (sist->kif == pfi_all || sist->kif == pd->kif) &&
-                   ((sist->key[PF_SK_WIRE]->af == sist->key[PF_SK_STACK]->af &&
-                     sk == (pd->dir == PF_IN ? sist->key[PF_SK_WIRE] :
-                   sist->key[PF_SK_STACK])) ||
-                   (sist->key[PF_SK_WIRE]->af != sist->key[PF_SK_STACK]->af
-                   && pd->dir == PF_IN && (sk == sist->key[PF_SK_STACK] ||
-                   sk == sist->key[PF_SK_WIRE])))) {
-                       st = sist;
-                       break;
-               }
+       if (pd->dir == PF_OUT) {
+               if (strev != NULL)
+                       pf_state_link_reverse(st, strev);
+               else if (inp != NULL)
+                       pf_state_link_inpcb(st, inp);
        }
 
-       if (st == NULL)
-               return (PF_DROP);
-       if (ISSET(st->state_flags, PFSTATE_INP_UNLINKED))
-               return (PF_DROP);
-
+match:
        if (st->rule.ptr->pktrate.limit && pd->dir == st->direction) {
                pf_add_threshold(&st->rule.ptr->pktrate);
                if (pf_check_threshold(&st->rule.ptr->pktrate))
-                       return (PF_DROP);
+                       goto drop;
        }
 
        *stp = st;
+       rv = PF_MATCH;
+
+drop:
+       if (strev != NULL)
+               pf_state_unref(strev);
+       else if (inp != NULL)
+               in_pcbunref(inp);
 
-       return (PF_MATCH);
+       return (rv);
 }
 
 struct pf_state *
@@ -1763,6 +1777,9 @@ pf_remove_state(struct pf_state *st)
 
        st->timeout = PFTM_UNLINKED;
 
+       pf_state_unlink_reverse(st);
+       pf_state_unlink_inpcb(st);
+
        /* handle load balancing related tasks */
        pf_postprocess_addr(st);
 
@@ -1792,38 +1809,32 @@ pf_remove_state(struct pf_state *st)
 }
 
 void
-pf_remove_divert_state(struct pf_state_key *sk)
+pf_remove_divert_state(struct pf_state *st)
 {
-       struct pf_state_item    *si;
-
        PF_ASSERT_UNLOCKED();
 
        PF_LOCK();
        PF_STATE_ENTER_WRITE();
-       TAILQ_FOREACH(si, &sk->sk_states, si_entry) {
-               struct pf_state *sist = si->si_st;
-               if (sk == sist->key[PF_SK_STACK] && sist->rule.ptr &&
-                   (sist->rule.ptr->divert.type == PF_DIVERT_TO ||
-                    sist->rule.ptr->divert.type == PF_DIVERT_REPLY)) {
-                       if (sist->key[PF_SK_STACK]->proto == IPPROTO_TCP &&
-                           sist->key[PF_SK_WIRE] != sist->key[PF_SK_STACK]) {
-                               /*
-                                * If the local address is translated, keep
-                                * the state for "tcp.closed" seconds to
-                                * prevent its source port from being reused.
-                                */
-                               if (sist->src.state < TCPS_FIN_WAIT_2 ||
-                                   sist->dst.state < TCPS_FIN_WAIT_2) {
-                                       pf_set_protostate(sist, PF_PEER_BOTH,
-                                           TCPS_TIME_WAIT);
-                                       sist->timeout = PFTM_TCP_CLOSED;
-                                       sist->expire = getuptime();
-                               }
-                               sist->state_flags |= PFSTATE_INP_UNLINKED;
-                       } else
-                               pf_remove_state(sist);
-                       break;
-               }
+       if (st->rule.ptr &&
+           (st->rule.ptr->divert.type == PF_DIVERT_TO ||
+            st->rule.ptr->divert.type == PF_DIVERT_REPLY)) {
+               if (st->key[PF_SK_STACK]->proto == IPPROTO_TCP &&
+                   st->key[PF_SK_WIRE] != st->key[PF_SK_STACK]) {
+                       /*
+                        * If the local address is translated, keep
+                        * the state for "tcp.closed" seconds to
+                        * prevent its source port from being reused.
+                        */
+                       if (st->src.state < TCPS_FIN_WAIT_2 ||
+                           st->dst.state < TCPS_FIN_WAIT_2) {
+                               pf_set_protostate(st, PF_PEER_BOTH,
+                                   TCPS_TIME_WAIT);
+                               st->timeout = PFTM_TCP_CLOSED;
+                               st->expire = getuptime();
+                       }
+                       st->state_flags |= PFSTATE_INP_UNLINKED;
+               } else
+                       pf_remove_state(st);
        }
        PF_STATE_EXIT_WRITE();
        PF_UNLOCK();
@@ -7836,17 +7847,22 @@ done:
 
        if (action == PF_PASS && qid)
                pd.m->m_pkthdr.pf.qid = qid;
-       if (pd.dir == PF_IN && st && st->key[PF_SK_STACK])
-               pf_mbuf_link_state_key(pd.m, st->key[PF_SK_STACK]);
-       if (pd.dir == PF_OUT &&
-           pd.m->m_pkthdr.pf.inp && !pd.m->m_pkthdr.pf.inp->inp_pf_sk &&
-           st && st->key[PF_SK_STACK] && !st->key[PF_SK_STACK]->sk_inp)
-               pf_state_key_link_inpcb(st->key[PF_SK_STACK],
-                   pd.m->m_pkthdr.pf.inp);
-
-       if (st != NULL && !ISSET(pd.m->m_pkthdr.csum_flags, M_FLOWID)) {
-               pd.m->m_pkthdr.ph_flowid = st->key[PF_SK_WIRE]->hash;
-               SET(pd.m->m_pkthdr.csum_flags, M_FLOWID);
+       if (st != NULL) {
+               struct mbuf *m = pd.m;
+               struct inpcb *inp = m->m_pkthdr.pf.inp;
+
+               if (pd.dir == PF_IN) {
+                       KASSERT(inp == NULL);
+                       pf_mbuf_link_state(m, st);
+               } else {
+                       if (inp != NULL && inp->inp_pf_st == NULL)
+                               pf_state_link_inpcb(st, inp);
+               }
+
+               if (!ISSET(m->m_pkthdr.csum_flags, M_FLOWID)) {
+                       m->m_pkthdr.ph_flowid = st->key[PF_SK_WIRE]->hash;
+                       SET(m->m_pkthdr.csum_flags, M_FLOWID);
+               }
        }
 
        /*
@@ -8004,14 +8020,14 @@ done:
 int
 pf_ouraddr(struct mbuf *m)
 {
-       struct pf_state_key     *sk;
+       struct pf_state         *st;
 
        if (m->m_pkthdr.pf.flags & PF_TAG_DIVERTED)
                return (1);
 
-       sk = m->m_pkthdr.pf.statekey;
-       if (sk != NULL) {
-               if (sk->sk_inp != NULL)
+       st = m->m_pkthdr.pf.st;
+       if (st != NULL) {
+               if (st->inp != NULL)
                        return (1);
        }
 
@@ -8025,7 +8041,7 @@ pf_ouraddr(struct mbuf *m)
 void
 pf_pkt_addr_changed(struct mbuf *m)
 {
-       pf_mbuf_unlink_state_key(m);
+       pf_mbuf_unlink_state(m);
        pf_mbuf_unlink_inpcb(m);
 }
 
@@ -8033,71 +8049,56 @@ struct inpcb *
 pf_inp_lookup(struct mbuf *m)
 {
        struct inpcb *inp = NULL;
-       struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
+       struct pf_state *st;
 
-       if (!pf_state_key_isvalid(sk))
-               pf_mbuf_unlink_state_key(m);
-       else
-               inp = m->m_pkthdr.pf.statekey->sk_inp;
+       st = m->m_pkthdr.pf.st;
+       if (st == NULL)
+               return (NULL);
+       if (!pf_state_isvalid(st)) {
+               pf_mbuf_unlink_state(m);
+               return (NULL);
+       }
+
+       inp = st->inp;
+       if (inp == NULL)
+               return (NULL);
 
-       if (inp && inp->inp_pf_sk)
-               KASSERT(m->m_pkthdr.pf.statekey == inp->inp_pf_sk);
+       KASSERT(inp->inp_pf_st == NULL || inp->inp_pf_st == st);
 
-       in_pcbref(inp);
-       return (inp);
+       return (in_pcbref(inp));
 }
 
+/*
+ * This is called from the IP stack after it's found an inpcb for
+ * an mbuf so it can link the pf_state to that pcb.
+ */
 void
 pf_inp_link(struct mbuf *m, struct inpcb *inp)
 {
-       struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
+       struct pf_state *st;
 
-       if (!pf_state_key_isvalid(sk)) {
-               pf_mbuf_unlink_state_key(m);
+       st = m->m_pkthdr.pf.st;
+       if (st == NULL)
                return;
-       }
 
        /*
         * we don't need to grab PF-lock here. At worst case we link inp to
         * state, which might be just being marked as deleted by another
         * thread.
         */
-       if (inp && !sk->sk_inp && !inp->inp_pf_sk)
-               pf_state_key_link_inpcb(sk, inp);
+       if (pf_state_isvalid(st)) {
+               if (st->inp == NULL && inp->inp_pf_st == NULL)
+                       pf_state_link_inpcb(st, inp);
+       }
 
        /* The statekey has finished finding the inp, it is no longer needed. */
-       pf_mbuf_unlink_state_key(m);
+       pf_mbuf_unlink_state(m);
 }
 
 void
 pf_inp_unlink(struct inpcb *inp)
 {
-       pf_inpcb_unlink_state_key(inp);
-}
-
-void
-pf_state_key_link_reverse(struct pf_state_key *sk, struct pf_state_key *skrev)
-{
-       struct pf_state_key *old_reverse;
-
-       old_reverse = atomic_cas_ptr(&sk->sk_reverse, NULL, skrev);
-       if (old_reverse != NULL)
-               KASSERT(old_reverse == skrev);
-       else {
-               pf_state_key_ref(skrev);
-
-               /*
-                * NOTE: if sk == skrev, then KASSERT() below holds true, we
-                * still want to grab a reference in such case, because
-                * pf_state_key_unlink_reverse() does not check whether keys
-                * are identical or not.
-                */
-               old_reverse = atomic_cas_ptr(&skrev->sk_reverse, NULL, sk);
-               if (old_reverse != NULL)
-                       KASSERT(old_reverse == sk);
-
-               pf_state_key_ref(sk);
-       }
+       pf_inpcb_unlink_state(inp);
 }
 
 #if NPFLOG > 0
@@ -8132,10 +8133,6 @@ pf_state_key_unref(struct pf_state_key *
        if (PF_REF_RELE(sk->sk_refcnt)) {
                /* state key must be removed from tree */
                KASSERT(!pf_state_key_isvalid(sk));
-               /* state key must be unlinked from reverse key */
-               KASSERT(sk->sk_reverse == NULL);
-               /* state key must be unlinked from socket */
-               KASSERT(sk->sk_inp == NULL);
                pool_put(&pf_state_key_pl, sk);
        }
 }
@@ -8146,21 +8143,28 @@ pf_state_key_isvalid(struct pf_state_key
        return ((sk != NULL) && (sk->sk_removed == 0));
 }
 
+int
+pf_state_isvalid(struct pf_state *st)
+{
+       return (st->timeout < PFTM_MAX);
+}
+
 void
-pf_mbuf_link_state_key(struct mbuf *m, struct pf_state_key *sk)
+pf_mbuf_link_state(struct mbuf *m, struct pf_state *st)
 {
-       KASSERT(m->m_pkthdr.pf.statekey == NULL);
-       m->m_pkthdr.pf.statekey = pf_state_key_ref(sk);
+       KASSERT(m->m_pkthdr.pf.st == NULL);
+       m->m_pkthdr.pf.st = pf_state_ref(st);
 }
 
 void
-pf_mbuf_unlink_state_key(struct mbuf *m)
+pf_mbuf_unlink_state(struct mbuf *m)
 {
-       struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
+       struct pf_state *st;
 
-       if (sk != NULL) {
-               m->m_pkthdr.pf.statekey = NULL;
-               pf_state_key_unref(sk);
+       st = m->m_pkthdr.pf.st;
+       if (st != NULL) {
+               m->m_pkthdr.pf.st = NULL;
+               pf_state_unref(st);
        }
 }
 
@@ -8174,64 +8178,107 @@ pf_mbuf_link_inpcb(struct mbuf *m, struc
 void
 pf_mbuf_unlink_inpcb(struct mbuf *m)
 {
-       struct inpcb *inp = m->m_pkthdr.pf.inp;
+       struct inpcb *inp;
 
+       inp = m->m_pkthdr.pf.inp;
        if (inp != NULL) {
                m->m_pkthdr.pf.inp = NULL;
                in_pcbunref(inp);
        }
 }
 
+/* assumes caller has an exclusive lock around inp */
 void
-pf_state_key_link_inpcb(struct pf_state_key *sk, struct inpcb *inp)
+pf_state_link_inpcb(struct pf_state *st, struct inpcb *inp)
 {
-       KASSERT(sk->sk_inp == NULL);
-       sk->sk_inp = in_pcbref(inp);
-       KASSERT(inp->inp_pf_sk == NULL);
-       inp->inp_pf_sk = pf_state_key_ref(sk);
+       KASSERT(inp->inp_pf_st == NULL);
+       inp->inp_pf_st = pf_state_ref(st);
+
+       mtx_enter(&st->mtx);
+       KASSERT(st->inp == NULL);
+       st->inp = in_pcbref(inp);
+       mtx_leave(&st->mtx);
 }
 
+/* assumes caller has an exclusive lock around inp */
 void
-pf_inpcb_unlink_state_key(struct inpcb *inp)
+pf_inpcb_unlink_state(struct inpcb *inp)
 {
-       struct pf_state_key *sk = inp->inp_pf_sk;
+       struct pf_state *st;
 
-       if (sk != NULL) {
-               KASSERT(sk->sk_inp == inp);
-               sk->sk_inp = NULL;
-               inp->inp_pf_sk = NULL;
-               pf_state_key_unref(sk);
+       st = inp->inp_pf_st;
+       if (st != NULL) {
+               inp->inp_pf_st = NULL;
+
+               mtx_enter(&st->mtx);
+               KASSERT(st->inp == inp);
+               st->inp = NULL;
+               mtx_leave(&st->mtx);
                in_pcbunref(inp);
+
+               pf_state_unref(st);
        }
 }
 
 void
-pf_state_key_unlink_inpcb(struct pf_state_key *sk)
+pf_state_unlink_inpcb(struct pf_state *st)
 {
-       struct inpcb *inp = sk->sk_inp;
+       struct inpcb *inp;
+
+       mtx_enter(&st->mtx);
+       inp = st->inp;
+       if (inp != NULL)
+               st->inp = NULL;
+       mtx_leave(&st->mtx);
 
+       /* XXX wtf lock? */
        if (inp != NULL) {
-               KASSERT(inp->inp_pf_sk == sk);
-               sk->sk_inp = NULL;
-               inp->inp_pf_sk = NULL;
-               pf_state_key_unref(sk);
+               KASSERT(inp->inp_pf_st == st);
+               inp->inp_pf_st = NULL;
+
+               pf_state_unref(st);
                in_pcbunref(inp);
        }
 }
 
 void
-pf_state_key_unlink_reverse(struct pf_state_key *sk)
+pf_state_link_reverse(struct pf_state *st, struct pf_state *strev)
 {
-       struct pf_state_key *skrev = sk->sk_reverse;
+       mtx_enter(&st->mtx);
+       if (st->reverse == NULL)
+               st->reverse = pf_state_ref(strev);
+       mtx_leave(&st->mtx);
 
-       /* Note that sk and skrev may be equal, then we unref twice. */
-       if (skrev != NULL) {
-               KASSERT(skrev->sk_reverse == sk);
-               sk->sk_reverse = NULL;
-               skrev->sk_reverse = NULL;
-               pf_state_key_unref(skrev);
-               pf_state_key_unref(sk);
-       }
+       mtx_enter(&strev->mtx);
+       if (strev->reverse == NULL)
+               strev->reverse = pf_state_ref(st);
+       mtx_leave(&strev->mtx);
+}
+
+void
+pf_state_unlink_reverse(struct pf_state *st)
+{
+       struct pf_state *strev;
+
+       mtx_enter(&st->mtx);
+       strev = st->reverse;
+       if (strev != NULL)
+               st->reverse = NULL; /* take over strev reference */
+       mtx_leave(&st->mtx);
+
+       if (strev == NULL)
+               return;
+
+       mtx_enter(&strev->mtx);
+       if (strev->reverse == st)
+               strev->reverse = NULL;
+       else
+               st = NULL;
+       mtx_leave(&strev->mtx);
+
+       pf_state_unref(strev); /* drop the reference we just inherited */
+       if (st != NULL)
+               pf_state_unref(st); /* drop the reference strev had */
 }
 
 struct pf_state *
@@ -8257,6 +8304,11 @@ pf_state_unref(struct pf_state *st)
 
                pf_state_key_unref(st->key[PF_SK_WIRE]);
                pf_state_key_unref(st->key[PF_SK_STACK]);
+
+               /* state must be unlinked from reverse */
+               KASSERT(st->reverse == NULL);
+               /* state must be unlinked from socket */
+               KASSERT(st->inp == NULL);
 
                pool_put(&pf_state_pl, st);
        }
Index: net/pfvar.h
===================================================================
RCS file: /cvs/src/sys/net/pfvar.h,v
retrieving revision 1.533
diff -u -p -r1.533 pfvar.h
--- net/pfvar.h 6 Jul 2023 04:55:05 -0000       1.533
+++ net/pfvar.h 17 Aug 2023 01:31:04 -0000
@@ -1606,7 +1606,7 @@ extern void                        
pf_calc_skip_steps(struct
 extern void                     pf_purge_expired_src_nodes(void);
 extern void                     pf_purge_expired_rules(void);
 extern void                     pf_remove_state(struct pf_state *);
-extern void                     pf_remove_divert_state(struct pf_state_key *);
+extern void                     pf_remove_divert_state(struct pf_state *);
 extern void                     pf_free_state(struct pf_state *);
 int                             pf_insert_src_node(struct pf_src_node **,
                                    struct pf_rule *, enum pf_sn_types,
@@ -1860,9 +1860,8 @@ int                        pf_map_addr(sa_family_t, 
struct p
                            struct pf_pool *, enum pf_sn_types);
 int                     pf_postprocess_addr(struct pf_state *);
 
-void                    pf_mbuf_link_state_key(struct mbuf *,
-                           struct pf_state_key *);
-void                    pf_mbuf_unlink_state_key(struct mbuf *);
+void                    pf_mbuf_link_state(struct mbuf *, struct pf_state *);
+void                    pf_mbuf_unlink_state(struct mbuf *);
 void                    pf_mbuf_link_inpcb(struct mbuf *, struct inpcb *);
 void                    pf_mbuf_unlink_inpcb(struct mbuf *);
 
Index: net/pfvar_priv.h
===================================================================
RCS file: /cvs/src/sys/net/pfvar_priv.h,v
retrieving revision 1.34
diff -u -p -r1.34 pfvar_priv.h
--- net/pfvar_priv.h    6 Jul 2023 04:55:05 -0000       1.34
+++ net/pfvar_priv.h    17 Aug 2023 01:31:04 -0000
@@ -69,8 +69,6 @@ struct pf_state_key {
 
        RB_ENTRY(pf_state_key)   sk_entry;
        struct pf_statelisthead  sk_states;
-       struct pf_state_key     *sk_reverse;
-       struct inpcb            *sk_inp;
        pf_refcnt_t              sk_refcnt;
        u_int8_t                 sk_removed;
 };
@@ -115,6 +113,8 @@ struct pf_state {
        struct pf_sn_head        src_nodes;     /* [I] */
        struct pf_state_key     *key[2];        /* [I] stack and wire */
        struct pfi_kif          *kif;           /* [I] */
+       struct pf_state         *reverse;       /* [M] */
+       struct inpcb            *inp;           /* [M] */
        struct mutex             mtx;
        pf_refcnt_t              refcnt;
        u_int64_t                packets[2];
Index: netinet/in_pcb.c
===================================================================
RCS file: /cvs/src/sys/netinet/in_pcb.c,v
retrieving revision 1.277
diff -u -p -r1.277 in_pcb.c
--- netinet/in_pcb.c    24 Jun 2023 20:54:46 -0000      1.277
+++ netinet/in_pcb.c    17 Aug 2023 01:31:04 -0000
@@ -538,8 +538,8 @@ void
 in_pcbdisconnect(struct inpcb *inp)
 {
 #if NPF > 0
-       if (inp->inp_pf_sk) {
-               pf_remove_divert_state(inp->inp_pf_sk);
+       if (inp->inp_pf_st) {
+               pf_remove_divert_state(inp->inp_pf_st);
                /* pf_remove_divert_state() may have detached the state */
                pf_inp_unlink(inp);
        }
@@ -588,8 +588,8 @@ in_pcbdetach(struct inpcb *inp)
 #endif
                ip_freemoptions(inp->inp_moptions);
 #if NPF > 0
-       if (inp->inp_pf_sk) {
-               pf_remove_divert_state(inp->inp_pf_sk);
+       if (inp->inp_pf_st) {
+               pf_remove_divert_state(inp->inp_pf_st);
                /* pf_remove_divert_state() may have detached the state */
                pf_inp_unlink(inp);
        }
Index: netinet/in_pcb.h
===================================================================
RCS file: /cvs/src/sys/netinet/in_pcb.h,v
retrieving revision 1.136
diff -u -p -r1.136 in_pcb.h
--- netinet/in_pcb.h    24 Jun 2023 20:54:46 -0000      1.136
+++ netinet/in_pcb.h    17 Aug 2023 01:31:04 -0000
@@ -84,7 +84,7 @@
  *     p       inpcb_mtx               pcb mutex
  */
 
-struct pf_state_key;
+struct pf_state;
 
 union inpaddru {
        struct in6_addr iau_addr6;
@@ -155,7 +155,7 @@ struct inpcb {
 #define inp_csumoffset inp_cksum6
 #endif
        struct  icmp6_filter *inp_icmp6filt;
-       struct  pf_state_key *inp_pf_sk;
+       struct  pf_state *inp_pf_st;
        struct  mbuf *(*inp_upcall)(void *, struct mbuf *,
                    struct ip *, struct ip6_hdr *, void *, int);
        void    *inp_upcall_arg;
Index: sys/mbuf.h
===================================================================
RCS file: /cvs/src/sys/sys/mbuf.h,v
retrieving revision 1.261
diff -u -p -r1.261 mbuf.h
--- sys/mbuf.h  16 Jul 2023 03:01:31 -0000      1.261
+++ sys/mbuf.h  17 Aug 2023 01:31:04 -0000
@@ -92,11 +92,11 @@ struct m_hdr {
 };
 
 /* pf stuff */
-struct pf_state_key;
+struct pf_state;
 struct inpcb;
 
 struct pkthdr_pf {
-       struct pf_state_key *statekey;  /* pf stackside statekey */
+       struct pf_state *st;            /* pf state */
        struct inpcb    *inp;           /* connected pcb for outgoing packet */
        u_int32_t        qid;           /* queue id */
        u_int16_t        tag;           /* tag id */
@@ -327,7 +327,7 @@ u_int mextfree_register(void (*)(caddr_t
        (to)->m_pkthdr = (from)->m_pkthdr;                              \
        (from)->m_flags &= ~M_PKTHDR;                                   \
        SLIST_INIT(&(from)->m_pkthdr.ph_tags);                          \
-       (from)->m_pkthdr.pf.statekey = NULL;                            \
+       (from)->m_pkthdr.pf.st = NULL;                                  \
 } while (/* CONSTCOND */ 0)
 
 /*

Reply via email to