After adding checks to ensure TCP is in ESTABLISHED state when a
sock is added we need to also ensure that user does not transition
through tcp_disconnect() and back into ESTABLISHED state without
sockmap removing the sock.

To do this add unhash hook and remove sock from map there.

Reported-by: Eric Dumazet <eduma...@google.com>
Fixes: 81110384441a ("bpf: sockmap, add hash map support")
Signed-off-by: John Fastabend <john.fastab...@gmail.com>
---
 kernel/bpf/sockmap.c |   67 +++++++++++++++++++++++++++++++++++++-------------
 1 file changed, 50 insertions(+), 17 deletions(-)

diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
index 2e848cd..91c7b47 100644
--- a/kernel/bpf/sockmap.c
+++ b/kernel/bpf/sockmap.c
@@ -130,6 +130,7 @@ struct smap_psock {
 
        struct proto *sk_proto;
        void (*save_close)(struct sock *sk, long timeout);
+       void (*save_unhash)(struct sock *sk);
        void (*save_data_ready)(struct sock *sk);
        void (*save_write_space)(struct sock *sk);
 };
@@ -141,6 +142,7 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr 
*msg, size_t len,
 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
                            int offset, size_t size, int flags);
 static void bpf_tcp_close(struct sock *sk, long timeout);
+static void bpf_tcp_unhash(struct sock *sk);
 
 static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
 {
@@ -182,6 +184,7 @@ static void build_protos(struct proto 
prot[SOCKMAP_NUM_CONFIGS],
 {
        prot[SOCKMAP_BASE]                      = *base;
        prot[SOCKMAP_BASE].close                = bpf_tcp_close;
+       prot[SOCKMAP_BASE].unhash               = bpf_tcp_unhash;
        prot[SOCKMAP_BASE].recvmsg              = bpf_tcp_recvmsg;
        prot[SOCKMAP_BASE].stream_memory_read   = bpf_tcp_stream_read;
 
@@ -215,6 +218,7 @@ static int bpf_tcp_init(struct sock *sk)
        }
 
        psock->save_close = sk->sk_prot->close;
+       psock->save_unhash = sk->sk_prot->unhash;
        psock->sk_proto = sk->sk_prot;
 
        /* Build IPv6 sockmap whenever the address of tcpv6_prot changes */
@@ -302,28 +306,12 @@ struct smap_psock_map_entry *psock_map_pop(struct sock 
*sk,
        return e;
 }
 
-static void bpf_tcp_close(struct sock *sk, long timeout)
+static void bpf_tcp_remove(struct sock *sk, struct smap_psock *psock)
 {
-       void (*close_fun)(struct sock *sk, long timeout);
        struct smap_psock_map_entry *e;
        struct sk_msg_buff *md, *mtmp;
-       struct smap_psock *psock;
        struct sock *osk;
 
-       rcu_read_lock();
-       psock = smap_psock_sk(sk);
-       if (unlikely(!psock)) {
-               rcu_read_unlock();
-               return sk->sk_prot->close(sk, timeout);
-       }
-
-       /* The psock may be destroyed anytime after exiting the RCU critial
-        * section so by the time we use close_fun the psock may no longer
-        * be valid. However, bpf_tcp_close is called with the sock lock
-        * held so the close hook and sk are still valid.
-        */
-       close_fun = psock->save_close;
-
        if (psock->cork) {
                free_start_sg(psock->sock, psock->cork);
                kfree(psock->cork);
@@ -378,6 +366,51 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
                }
                e = psock_map_pop(sk, psock);
        }
+}
+
+static void bpf_tcp_unhash(struct sock *sk)
+{
+       void (*unhash_fun)(struct sock *sk);
+       struct smap_psock *psock;
+
+       rcu_read_lock();
+       psock = smap_psock_sk(sk);
+       if (unlikely(!psock)) {
+               rcu_read_unlock();
+               return sk->sk_prot->unhash(sk);
+       }
+
+       /* The psock may be destroyed anytime after exiting the RCU critial
+        * section so by the time we use close_fun the psock may no longer
+        * be valid. However, bpf_tcp_close is called with the sock lock
+        * held so the close hook and sk are still valid.
+        */
+       unhash_fun = psock->save_unhash;
+       bpf_tcp_remove(sk, psock);
+       rcu_read_unlock();
+       unhash_fun(sk);
+
+}
+
+static void bpf_tcp_close(struct sock *sk, long timeout)
+{
+       void (*close_fun)(struct sock *sk, long timeout);
+       struct smap_psock *psock;
+
+       rcu_read_lock();
+       psock = smap_psock_sk(sk);
+       if (unlikely(!psock)) {
+               rcu_read_unlock();
+               return sk->sk_prot->close(sk, timeout);
+       }
+
+       /* The psock may be destroyed anytime after exiting the RCU critial
+        * section so by the time we use close_fun the psock may no longer
+        * be valid. However, bpf_tcp_close is called with the sock lock
+        * held so the close hook and sk are still valid.
+        */
+       close_fun = psock->save_close;
+       bpf_tcp_remove(sk, psock);
        rcu_read_unlock();
        close_fun(sk, timeout);
 }

Reply via email to