This is more or less the same thing for PF_KEY that we now do in PF_ROUTE.
Use one PCB LIST on the keycb and embedd the rawcb in that PF_KEY cb.
Diff also has a few variable renames in it to make this code less alien
regarding the rest of our kernel. Mainly use so instead of socket and
pfkeyv2_socket is also replaced with better variable names.

This needs the previous diff I just sent out for PF_ROUTE.
After that I can make pfkey use the same SRPL_LIST as PF_ROUTE (from an
other diff) to unlock them more.
-- 
:wq Claudio

Index: net/pfkeyv2.c
===================================================================
RCS file: /cvs/src/sys/net/pfkeyv2.c,v
retrieving revision 1.160
diff -u -p -r1.160 pfkeyv2.c
--- net/pfkeyv2.c       29 May 2017 20:31:12 -0000      1.160
+++ net/pfkeyv2.c       30 May 2017 08:44:05 -0000
@@ -131,14 +131,16 @@ extern struct radix_node_head **spd_tabl
 struct sockaddr pfkey_addr = { 2, PF_KEY, };
 struct domain pfkeydomain;
 
-struct pfkeyv2_socket {
-       LIST_ENTRY(pfkeyv2_socket)      kcb_list;
-       struct socket *socket;
+struct keycb {
+       struct rawcb                    rcb;
+       LIST_ENTRY(keycb)       kcb_list;
        int flags;
        uint32_t pid;
        uint32_t registration;    /* Increase size if SATYPE_MAX > 31 */
        uint rdomain;
 };
+#define sotokeycb(so) ((struct keycb *)(so)->so_pcb)
+
 
 struct dump_state {
        struct sadb_msg *sadb_msg;
@@ -146,8 +148,7 @@ struct dump_state {
 };
 
 /* Static globals */
-static LIST_HEAD(, pfkeyv2_socket) pfkeyv2_sockets =
-    LIST_HEAD_INITIALIZER(pfkeyv2_sockets);
+static LIST_HEAD(, keycb) pfkeyv2_sockets = LIST_HEAD_INITIALIZER(keycb);
 static uint32_t pfkeyv2_seq = 1;
 static int nregistered = 0;
 static int npromisc = 0;
@@ -160,7 +161,7 @@ int pfkeyv2_usrreq(struct socket *, int,
     struct mbuf *, struct proc *);
 int pfkeyv2_output(struct mbuf *, struct socket *, struct sockaddr *,
     struct mbuf *);
-int pfkey_sendup(struct socket *socket, struct mbuf *packet, int more);
+int pfkey_sendup(struct keycb *, struct mbuf *, int);
 
 /*
  * Wrapper around m_devget(); copy data from contiguous buffer to mbuf
@@ -212,71 +213,62 @@ pfkey_init(void)
 int
 pfkeyv2_attach(struct socket *so, int proto)
 {
-       struct pfkeyv2_socket *pfkeyv2_socket;
+       struct rawcb *rp;
+       struct keycb *pk;
        int error;
 
        if ((so->so_state & SS_PRIV) == 0)
                return EACCES;
 
-       if (!(so->so_pcb = malloc(sizeof(struct rawcb),
-           M_PCB, M_DONTWAIT | M_ZERO)))
-               return (ENOMEM);
-
-       error = raw_attach(so, so->so_proto->pr_protocol);
-       if (error)
-               goto ret;
-
-       ((struct rawcb *)so->so_pcb)->rcb_faddr = &pfkey_addr;
-
-       if (!(pfkeyv2_socket = malloc(sizeof(struct pfkeyv2_socket),
-           M_PFKEY, M_NOWAIT | M_ZERO)))
-               return (ENOMEM);
+       pk = malloc(sizeof(struct keycb), M_PCB, M_WAITOK | M_ZERO);
+       rp = &pk->rcb;
+       so->so_pcb = rp;
+
+       error = raw_attach(so, proto);
+       if (error) {
+               free(pk, M_PCB, sizeof(struct keycb));
+               return (error);
+       }
 
-       LIST_INSERT_HEAD(&pfkeyv2_sockets, pfkeyv2_socket, kcb_list);
-       pfkeyv2_socket->socket = so;
-       pfkeyv2_socket->pid = curproc->p_p->ps_pid;
+       rp->rcb_faddr = &pfkey_addr;
+       pk->pid = curproc->p_p->ps_pid;
 
        /*
         * XXX we should get this from the socket instead but
         * XXX rawcb doesn't store the rdomain like inpcb does.
         */
-       pfkeyv2_socket->rdomain = rtable_l2(curproc->p_p->ps_rtableid);
+       pk->rdomain = rtable_l2(curproc->p_p->ps_rtableid);
+
+       LIST_INSERT_HEAD(&pfkeyv2_sockets, pk, kcb_list);
 
        so->so_options |= SO_USELOOPBACK;
        soisconnected(so);
 
        return (0);
-ret:
-       free(so->so_pcb, M_PCB, sizeof(struct rawcb));
-       return (error);
 }
 
 /*
  * Close a PF_KEYv2 socket.
  */
 int
-pfkeyv2_detach(struct socket *socket, struct proc *p)
+pfkeyv2_detach(struct socket *so, struct proc *p)
 {
-       struct pfkeyv2_socket *pp;
+       struct keycb *pp;
        int error;
 
-       LIST_FOREACH(pp, &pfkeyv2_sockets, kcb_list)
-               if (pp->socket == socket)
-                       break;
-
-       if (pp) {
-               LIST_REMOVE(pp, kcb_list);
+       pp = sotokeycb(so);
+       if (pp == NULL)
+               return ENOTCONN;
 
-               if (pp->flags & PFKEYV2_SOCKETFLAGS_REGISTERED)
-                       nregistered--;
+       LIST_REMOVE(pp, kcb_list);
 
-               if (pp->flags & PFKEYV2_SOCKETFLAGS_PROMISC)
-                       npromisc--;
+       if (pp->flags & PFKEYV2_SOCKETFLAGS_REGISTERED)
+               nregistered--;
 
-               free(pp, M_PFKEY, 0);
-       }
+       if (pp->flags & PFKEYV2_SOCKETFLAGS_PROMISC)
+               npromisc--;
 
-       error = raw_usrreq(socket, PRU_DETACH, NULL, NULL, NULL, p);
+       error = raw_usrreq(so, PRU_DETACH, NULL, NULL, NULL, p);
        return (error);
 }
 
@@ -293,7 +285,7 @@ pfkeyv2_usrreq(struct socket *so, int re
 }
 
 int
-pfkeyv2_output(struct mbuf *mbuf, struct socket *socket,
+pfkeyv2_output(struct mbuf *mbuf, struct socket *so,
     struct sockaddr *dstaddr, struct mbuf *control)
 {
        void *message;
@@ -319,7 +311,7 @@ pfkeyv2_output(struct mbuf *mbuf, struct
 
        m_copydata(mbuf, 0, mbuf->m_pkthdr.len, message);
 
-       error = pfkeyv2_send(socket, message, mbuf->m_pkthdr.len);
+       error = pfkeyv2_send(so, message, mbuf->m_pkthdr.len);
 
 ret:
        m_freem(mbuf);
@@ -327,8 +319,9 @@ ret:
 }
 
 int
-pfkey_sendup(struct socket *socket, struct mbuf *packet, int more)
+pfkey_sendup(struct keycb *kp, struct mbuf *packet, int more)
 {
+       struct socket *so = kp->rcb.rcb_socket;
        struct mbuf *packet2;
 
        NET_ASSERT_LOCKED();
@@ -339,12 +332,12 @@ pfkey_sendup(struct socket *socket, stru
        } else
                packet2 = packet;
 
-       if (!sbappendaddr(&socket->so_rcv, &pfkey_addr, packet2, NULL)) {
+       if (!sbappendaddr(&so->so_rcv, &pfkey_addr, packet2, NULL)) {
                m_freem(packet2);
                return (ENOBUFS);
        }
 
-       sorwakeup(socket);
+       sorwakeup(so);
        return (0);
 }
 
@@ -354,13 +347,13 @@ pfkey_sendup(struct socket *socket, stru
  * third argument.
  */
 int
-pfkeyv2_sendmessage(void **headers, int mode, struct socket *socket,
+pfkeyv2_sendmessage(void **headers, int mode, struct socket *so,
     u_int8_t satype, int count, u_int rdomain)
 {
        int i, j, rval;
        void *p, *buffer = NULL;
        struct mbuf *packet;
-       struct pfkeyv2_socket *s;
+       struct keycb *s;
        struct sadb_msg *smsg;
 
        /* Find out how much space we'll need... */
@@ -401,7 +394,7 @@ pfkeyv2_sendmessage(void **headers, int 
                 * Send message to the specified socket, plus all
                 * promiscuous listeners.
                 */
-               pfkey_sendup(socket, packet, 0);
+               pfkey_sendup(sotokeycb(so), packet, 0);
 
                /*
                 * Promiscuous messages contain the original message
@@ -426,9 +419,9 @@ pfkeyv2_sendmessage(void **headers, int 
                 */
                LIST_FOREACH(s, &pfkeyv2_sockets, kcb_list) {
                        if ((s->flags & PFKEYV2_SOCKETFLAGS_PROMISC) &&
-                           (s->socket != socket) &&
+                           (s->rcb.rcb_socket != so) &&
                            (s->rdomain == rdomain))
-                               pfkey_sendup(s->socket, packet, 1);
+                               pfkey_sendup(s, packet, 1);
                }
                m_freem(packet);
                break;
@@ -442,11 +435,11 @@ pfkeyv2_sendmessage(void **headers, int 
                        if ((s->flags & PFKEYV2_SOCKETFLAGS_REGISTERED) &&
                            (s->rdomain == rdomain)) {
                                if (!satype)    /* Just send to everyone 
registered */
-                                       pfkey_sendup(s->socket, packet, 1);
+                                       pfkey_sendup(s, packet, 1);
                                else {
                                        /* Check for specified satype */
                                        if ((1 << satype) & s->registration)
-                                               pfkey_sendup(s->socket, packet, 
1);
+                                               pfkey_sendup(s, packet, 1);
                                }
                        }
                }
@@ -472,7 +465,7 @@ pfkeyv2_sendmessage(void **headers, int 
                        if ((s->flags & PFKEYV2_SOCKETFLAGS_PROMISC) &&
                            !(s->flags & PFKEYV2_SOCKETFLAGS_REGISTERED) &&
                            (s->rdomain == rdomain))
-                               pfkey_sendup(s->socket, packet, 1);
+                               pfkey_sendup(s, packet, 1);
                }
                m_freem(packet);
                break;
@@ -481,7 +474,7 @@ pfkeyv2_sendmessage(void **headers, int 
                /* Send message to all sockets */
                LIST_FOREACH(s, &pfkeyv2_sockets, kcb_list) {
                        if (s->rdomain == rdomain)
-                               pfkey_sendup(s->socket, packet, 1);
+                               pfkey_sendup(s, packet, 1);
                }
                m_freem(packet);
                break;
@@ -940,7 +933,7 @@ pfkeyv2_get_proto_alg(u_int8_t satype, u
  * Handle all messages from userland to kernel.
  */
 int
-pfkeyv2_send(struct socket *socket, void *message, int len)
+pfkeyv2_send(struct socket *so, void *message, int len)
 {
        int i, j, s, rval = 0, mode = PFKEYV2_SENDMESSAGE_BROADCAST;
        int delflag = 0;
@@ -950,7 +943,7 @@ pfkeyv2_send(struct socket *socket, void
        struct radix_node_head *rnh;
        struct radix_node *rn = NULL;
 
-       struct pfkeyv2_socket *pfkeyv2_socket, *so = NULL;
+       struct keycb *pk, *bpk = NULL;
 
        void *freeme = NULL, *bckptr = NULL;
        void *headers[SADB_EXT_MAX + 1];
@@ -972,16 +965,14 @@ pfkeyv2_send(struct socket *socket, void
        /* Verify that we received this over a legitimate pfkeyv2 socket */
        bzero(headers, sizeof(headers));
 
-       LIST_FOREACH(pfkeyv2_socket, &pfkeyv2_sockets, kcb_list)
-               if (pfkeyv2_socket->socket == socket)
-                       break;
+       pk = sotokeycb(so);
 
-       if (!pfkeyv2_socket) {
+       if (!pk) {
                rval = EINVAL;
                goto ret;
        }
 
-       rdomain = pfkeyv2_socket->rdomain;
+       rdomain = pk->rdomain;
 
        /* If we have any promiscuous listeners, send them a copy of the 
message */
        if (npromisc) {
@@ -1010,10 +1001,10 @@ pfkeyv2_send(struct socket *socket, void
                        goto ret;
 
                /* Send to all promiscuous listeners */
-               LIST_FOREACH(so, &pfkeyv2_sockets, kcb_list) {
-                       if ((so->flags & PFKEYV2_SOCKETFLAGS_PROMISC) &&
-                           (so->rdomain == rdomain))
-                               pfkey_sendup(so->socket, packet, 1);
+               LIST_FOREACH(bpk, &pfkeyv2_sockets, kcb_list) {
+                       if ((bpk->flags & PFKEYV2_SOCKETFLAGS_PROMISC) &&
+                           (bpk->rdomain == rdomain))
+                               pfkey_sendup(bpk, packet, 1);
                }
 
                m_freem(packet);
@@ -1393,8 +1384,8 @@ pfkeyv2_send(struct socket *socket, void
                break;
 
        case SADB_REGISTER:
-               if (!(pfkeyv2_socket->flags & PFKEYV2_SOCKETFLAGS_REGISTERED)) {
-                       pfkeyv2_socket->flags |= PFKEYV2_SOCKETFLAGS_REGISTERED;
+               if (!(pk->flags & PFKEYV2_SOCKETFLAGS_REGISTERED)) {
+                       pk->flags |= PFKEYV2_SOCKETFLAGS_REGISTERED;
                        nregistered++;
                }
 
@@ -1424,7 +1415,7 @@ pfkeyv2_send(struct socket *socket, void
                }
 
                /* Keep track what this socket has registered for */
-               pfkeyv2_socket->registration |= (1 << ((struct sadb_msg 
*)message)->sadb_msg_satype);
+               pk->registration |= (1 << ((struct sadb_msg 
*)message)->sadb_msg_satype);
 
                ssup = (struct sadb_supported *) freeme;
                ssup->sadb_supported_len = i / sizeof(uint64_t);
@@ -1497,7 +1488,7 @@ pfkeyv2_send(struct socket *socket, void
        {
                struct dump_state dump_state;
                dump_state.sadb_msg = (struct sadb_msg *) headers[0];
-               dump_state.socket = socket;
+               dump_state.socket = so;
 
                rval = tdb_walk(rdomain, pfkeyv2_dump_walker, &dump_state);
                if (!rval)
@@ -1769,12 +1760,12 @@ pfkeyv2_send(struct socket *socket, void
                        if ((rval = pfdatatopacket(message, len, &packet)) != 0)
                                goto ret;
 
-                       LIST_FOREACH(so, &pfkeyv2_sockets, kcb_list)
-                               if ((so != pfkeyv2_socket) &&
-                                   (so->rdomain == rdomain) &&
+                       LIST_FOREACH(bpk, &pfkeyv2_sockets, kcb_list)
+                               if ((bpk != pk) &&
+                                   (bpk->rdomain == rdomain) &&
                                    (!smsg->sadb_msg_seq ||
-                                   (smsg->sadb_msg_seq == 
pfkeyv2_socket->pid)))
-                                       pfkey_sendup(so->socket, packet, 1);
+                                   (smsg->sadb_msg_seq == pk->pid)))
+                                       pfkey_sendup(bpk, packet, 1);
 
                        m_freem(packet);
                } else {
@@ -1783,17 +1774,17 @@ pfkeyv2_send(struct socket *socket, void
                                goto ret;
                        }
 
-                       i = (pfkeyv2_socket->flags &
+                       i = (pk->flags &
                            PFKEYV2_SOCKETFLAGS_PROMISC) ? 1 : 0;
                        j = smsg->sadb_msg_satype ? 1 : 0;
 
                        if (i ^ j) {
                                if (j) {
-                                       pfkeyv2_socket->flags |=
+                                       pk->flags |=
                                            PFKEYV2_SOCKETFLAGS_PROMISC;
                                        npromisc++;
                                } else {
-                                       pfkeyv2_socket->flags &=
+                                       pk->flags &=
                                            ~PFKEYV2_SOCKETFLAGS_PROMISC;
                                        npromisc--;
                                }
@@ -1832,7 +1823,7 @@ ret:
                        goto realret;
        }
 
-       rval = pfkeyv2_sendmessage(headers, mode, socket, 0, 0, rdomain);
+       rval = pfkeyv2_sendmessage(headers, mode, so, 0, 0, rdomain);
 
 realret:
        NET_UNLOCK(s);

Reply via email to