From: Cong Wang <[email protected]>

Now both AF_UNIX and UDP support sockmap and redirection,
we can safely update the sock type checks for them accordingly.

Cc: John Fastabend <[email protected]>
Cc: Daniel Borkmann <[email protected]>
Cc: Jakub Sitnicki <[email protected]>
Cc: Lorenz Bauer <[email protected]>
Signed-off-by: Cong Wang <[email protected]>
---
 net/core/skmsg.c    |  3 ++-
 net/core/sock_map.c | 15 ++++++++++++---
 2 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index 8e3edbdf4c7c..a502137f7bc2 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -667,7 +667,8 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
 
        write_lock_bh(&sk->sk_callback_lock);
 
-       if (inet_csk_has_ulp(sk)) {
+       if ((sk->sk_family == AF_INET || sk->sk_family == AF_INET6) &&
+           inet_csk_has_ulp(sk)) {
                psock = ERR_PTR(-EINVAL);
                goto out;
        }
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index 255067e5c73a..7e56a3ec7a57 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -544,14 +544,22 @@ static bool sk_is_udp(const struct sock *sk)
               sk->sk_protocol == IPPROTO_UDP;
 }
 
+static bool sk_is_unix(const struct sock *sk)
+{
+       return sk->sk_type == SOCK_DGRAM && sk->sk_family == AF_UNIX;
+}
+
 static bool sock_map_redirect_allowed(const struct sock *sk)
 {
-       return sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN;
+       if (sk_is_tcp(sk))
+               return sk->sk_state != TCP_LISTEN;
+       else
+               return sk->sk_state == TCP_ESTABLISHED;
 }
 
 static bool sock_map_sk_is_suitable(const struct sock *sk)
 {
-       return sk_is_tcp(sk) || sk_is_udp(sk);
+       return !!sk->sk_prot->update_proto;
 }
 
 static bool sock_map_sk_state_allowed(const struct sock *sk)
@@ -560,7 +568,8 @@ static bool sock_map_sk_state_allowed(const struct sock *sk)
                return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
        else if (sk_is_udp(sk))
                return sk_hashed(sk);
-
+       else if (sk_is_unix(sk))
+               return sk->sk_state == TCP_ESTABLISHED;
        return false;
 }
 
-- 
2.25.1

Reply via email to