From: Paolo Abeni <pab...@redhat.com>

The locking schema implied by sendmsg(), recvmsg(), etc.
requires acquiring the msk's socket lock before manipulating
the msk internal status.

Additionally, we can't acquire the msk->subflow socket lock while holding
the msk lock, due to mptcp_finish_connect().

Many socket operations do not enforce the required locking, e.g. we have
several patterns alike:

        if (msk->subflow)
                // do something with msk->subflow

or:

        if (!msk->subflow)
                // allocate msk->subflow

all without any lock acquired.

They can race with each other and with mptcp_finish_connect() causing
UAF, null ptr dereference and/or memory leaks.

This patch ensures that all mptcp socket operations access and manipulate
msk->subflow under the msk socket lock. To avoid breaking the locking
assumption introduced by mptcp_finish_connect(), while avoiding UAF
issues, we acquire a reference to the msk->subflow, where needed.

Signed-off-by: Paolo Abeni <pab...@redhat.com>
Signed-off-by: Peter Krystad <peter.krys...@linux.intel.com>
---
 net/mptcp/protocol.c | 82 +++++++++++++++++++++++++++++++++-----------
 net/mptcp/subflow.c  |  3 --
 2 files changed, 62 insertions(+), 23 deletions(-)

diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index 32d9963c492d..8512cf5e0e0f 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -178,6 +178,7 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr 
*msg, size_t len)
        struct sock *ssk;
        long timeo;
 
+       pr_debug("msk=%p", msk);
        lock_sock(sk);
        ssock = __mptcp_fallback_get_ref(msk);
        if (ssock) {
@@ -846,38 +847,72 @@ static struct proto mptcp_prot = {
        .no_autobind    = 1,
 };
 
+static struct socket *mptcp_socket_create_get(struct mptcp_sock *msk)
+{
+       struct mptcp_subflow_context *subflow;
+       struct sock *sk = (struct sock *)msk;
+       struct socket *ssock;
+       int err;
+
+       lock_sock(sk);
+       ssock = __mptcp_fallback_get_ref(msk);
+       if (ssock)
+               goto release;
+
+       err = mptcp_subflow_create_socket(sk, &ssock);
+       if (err) {
+               ssock = ERR_PTR(err);
+               goto release;
+       }
+
+       msk->subflow = ssock;
+       subflow = mptcp_subflow_ctx(msk->subflow->sk);
+       subflow->request_mptcp = 1; /* @@ if MPTCP enabled */
+       subflow->request_cksum = 0; /* checksum not supported */
+       subflow->request_version = 0; /* only v0 supported */
+
+       sock_hold(ssock->sk);
+
+release:
+       release_sock(sk);
+       return ssock;
+}
+
 static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int 
addr_len)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
+       struct socket *ssock;
        int err = -ENOTSUPP;
 
        if (uaddr->sa_family != AF_INET) // @@ allow only IPv4 for now
                return err;
 
-       if (!msk->subflow) {
-               err = mptcp_subflow_create_socket(sock->sk, &msk->subflow);
-               if (err)
-                       return err;
-       }
-       return inet_bind(msk->subflow, uaddr, addr_len);
+       ssock = mptcp_socket_create_get(msk);
+       if (IS_ERR(ssock))
+               return PTR_ERR(ssock);
+
+       err = inet_bind(ssock, uaddr, addr_len);
+       sock_put(ssock->sk);
+       return err;
 }
 
 static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
                                int addr_len, int flags)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
+       struct socket *ssock;
        int err = -ENOTSUPP;
 
        if (uaddr->sa_family != AF_INET) // @@ allow only IPv4 for now
                return err;
 
-       if (!msk->subflow) {
-               err = mptcp_subflow_create_socket(sock->sk, &msk->subflow);
-               if (err)
-                       return err;
-       }
+       ssock = mptcp_socket_create_get(msk);
+       if (IS_ERR(ssock))
+               return PTR_ERR(ssock);
 
-       return inet_stream_connect(msk->subflow, uaddr, addr_len, flags);
+       err = inet_stream_connect(ssock, uaddr, addr_len, flags);
+       sock_put(ssock->sk);
+       return err;
 }
 
 static int mptcp_getname(struct socket *sock, struct sockaddr *uaddr,
@@ -929,29 +964,36 @@ static int mptcp_getname(struct socket *sock, struct 
sockaddr *uaddr,
 static int mptcp_listen(struct socket *sock, int backlog)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
+       struct socket *ssock;
        int err;
 
        pr_debug("msk=%p", msk);
 
-       if (!msk->subflow) {
-               err = mptcp_subflow_create_socket(sock->sk, &msk->subflow);
-               if (err)
-                       return err;
-       }
-       return inet_listen(msk->subflow, backlog);
+       ssock = mptcp_socket_create_get(msk);
+       if (IS_ERR(ssock))
+               return PTR_ERR(ssock);
+
+       err = inet_listen(ssock, backlog);
+       sock_put(ssock->sk);
+       return err;
 }
 
 static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
                               int flags, bool kern)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
+       struct socket *ssock;
+       int err;
 
        pr_debug("msk=%p", msk);
 
-       if (!msk->subflow)
+       ssock = mptcp_fallback_get_ref(msk);
+       if (!ssock)
                return -EINVAL;
 
-       return inet_accept(sock, newsock, flags, kern);
+       err = inet_accept(sock, newsock, flags, kern);
+       sock_put(ssock->sk);
+       return err;
 }
 
 static __poll_t mptcp_poll(struct file *file, struct socket *sock,
diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c
index 1c3330ab2f30..04f232ff1df0 100644
--- a/net/mptcp/subflow.c
+++ b/net/mptcp/subflow.c
@@ -293,9 +293,6 @@ int mptcp_subflow_create_socket(struct sock *sk, struct 
socket **new_sock)
        *new_sock = sf;
        sock_hold(sk);
        subflow->conn = sk;
-       subflow->request_mptcp = 1; // @@ if MPTCP enabled
-       subflow->request_cksum = 1; // @@ if checksum enabled
-       subflow->request_version = 0;
 
        return 0;
 }
-- 
2.23.0

Reply via email to