From: Florian Westphal <f...@strlen.de>

subflow sockets already have lifetime managed by RCU, so we can
switch to atomic_inc_not_zero and skip/pretend we did not find
such socket in the mptcp subflow list.

This is required to get rid of synchronize_rcu() from mptcp_close().

Signed-off-by: Florian Westphal <f...@strlen.de>
---
 net/mptcp/protocol.c | 104 +++++++++++++++++++++++++++----------------
 1 file changed, 66 insertions(+), 38 deletions(-)

diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index c00e837a1766..0db4099d9c13 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -24,14 +24,35 @@ static inline bool before64(__u64 seq1, __u64 seq2)
 
 #define after64(seq2, seq1)    before64(seq1, seq2)
 
+static bool mptcp_subflow_hold(struct subflow_context *subflow)
+{
+       struct sock *sk = mptcp_subflow_tcp_socket(subflow)->sk;
+
+       return refcount_inc_not_zero(&sk->sk_refcnt);
+}
+
+static struct sock *mptcp_subflow_get_ref(const struct mptcp_sock *msk)
+{
+       struct subflow_context *subflow;
+
+       rcu_read_lock();
+       mptcp_for_each_subflow(msk, subflow) {
+               if (mptcp_subflow_hold(subflow)) {
+                       rcu_read_unlock();
+                       return mptcp_subflow_tcp_socket(subflow)->sk;
+               }
+       }
+
+       rcu_read_unlock();
+       return NULL;
+}
+
 static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
        int mss_now, size_goal, poffset, ret;
        struct mptcp_ext *mpext = NULL;
-       struct subflow_context *subflow;
        struct page *page = NULL;
-       struct hlist_node *node;
        struct sk_buff *skb;
        struct sock *ssk;
        size_t psize;
@@ -42,20 +63,17 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr 
*msg, size_t len)
                return sock_sendmsg(msk->subflow, msg);
        }
 
-       rcu_read_lock();
-       node = rcu_dereference(hlist_first_rcu(&msk->conn_list));
-       subflow = hlist_entry(node, struct subflow_context, node);
-       ssk = mptcp_subflow_tcp_socket(subflow)->sk;
-       sock_hold(ssk);
-       rcu_read_unlock();
+       ssk = mptcp_subflow_get_ref(msk);
+       if (!ssk)
+               return -ENOTCONN;
 
        if (!msg_data_left(msg)) {
                pr_debug("empty send");
-               ret = sock_sendmsg(mptcp_subflow_tcp_socket(subflow), msg);
+               ret = sock_sendmsg(ssk->sk_socket, msg);
                goto put_out;
        }
 
-       pr_debug("conn_list->subflow=%p", subflow);
+       pr_debug("conn_list->subflow=%p", ssk);
 
        if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) {
                ret = -ENOTSUPP;
@@ -293,7 +311,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr 
*msg, size_t len,
        struct mptcp_sock *msk = mptcp_sk(sk);
        struct subflow_context *subflow;
        struct mptcp_read_arg arg;
-       struct hlist_node *node;
        read_descriptor_t desc;
        struct tcp_sock *tp;
        struct sock *ssk;
@@ -306,13 +323,11 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr 
*msg, size_t len,
                return sock_recvmsg(msk->subflow, msg, flags);
        }
 
-       rcu_read_lock();
-       node = rcu_dereference(hlist_first_rcu(&msk->conn_list));
-       subflow = hlist_entry(node, struct subflow_context, node);
-       ssk = mptcp_subflow_tcp_socket(subflow)->sk;
-       sock_hold(ssk);
-       rcu_read_unlock();
+       ssk = mptcp_subflow_get_ref(msk);
+       if (!ssk)
+               return -ENOTCONN;
 
+       subflow = subflow_ctx(ssk);
        tp = tcp_sk(ssk);
 
        lock_sock(sk);
@@ -778,8 +793,6 @@ static int mptcp_getname(struct socket *sock, struct 
sockaddr *uaddr,
                         int peer)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
-       struct subflow_context *subflow;
-       struct hlist_node *node;
        struct sock *ssk;
        int ret;
 
@@ -794,14 +807,11 @@ static int mptcp_getname(struct socket *sock, struct 
sockaddr *uaddr,
         * is connected and there are multiple subflows is not defined.
         * For now just use the first subflow on the list.
         */
-       rcu_read_lock();
-       node = rcu_dereference(hlist_first_rcu(&msk->conn_list));
-       subflow = hlist_entry(node, struct subflow_context, node);
-       ssk = mptcp_subflow_tcp_socket(subflow)->sk;
-       sock_hold(ssk);
-       rcu_read_unlock();
+       ssk = mptcp_subflow_get_ref(msk);
+       if (!ssk)
+               return -ENOTCONN;
 
-       ret = inet_getname(mptcp_subflow_tcp_socket(subflow), uaddr, peer);
+       ret = inet_getname(ssk->sk_socket, uaddr, peer);
        sock_put(ssk);
        return ret;
 }
@@ -837,26 +847,44 @@ static int mptcp_stream_accept(struct socket *sock, 
struct socket *newsock,
 static __poll_t mptcp_poll(struct file *file, struct socket *sock,
                           struct poll_table_struct *wait)
 {
-       const struct mptcp_sock *msk;
        struct subflow_context *subflow;
+       const struct mptcp_sock *msk;
        struct sock *sk = sock->sk;
-       struct hlist_node *node;
-       struct sock *ssk;
-       __poll_t ret;
+       __poll_t ret = 0;
+       unsigned int i;
 
        msk = mptcp_sk(sk);
        if (msk->subflow)
                return tcp_poll(file, msk->subflow, wait);
 
-       rcu_read_lock();
-       node = rcu_dereference(hlist_first_rcu(&msk->conn_list));
-       subflow = hlist_entry(node, struct subflow_context, node);
-       ssk = mptcp_subflow_tcp_socket(subflow)->sk;
-       sock_hold(ssk);
-       rcu_read_unlock();
+       i = 0;
+       for (;;) {
+               struct subflow_context *tmp = NULL;
+               int j = 0;
+
+               rcu_read_lock();
+               mptcp_for_each_subflow(msk, subflow) {
+                       if (j < i) {
+                               j++;
+                               continue;
+                       }
+
+                       if (!mptcp_subflow_hold(subflow))
+                               continue;
+
+                       tmp = subflow;
+                       i++;
+                       break;
+               }
+               rcu_read_unlock();
+
+               if (!tmp)
+                       break;
+
+               ret |= tcp_poll(file, mptcp_subflow_tcp_socket(tmp), wait);
+               sock_put(mptcp_subflow_tcp_socket(tmp)->sk);
+       }
 
-       ret = tcp_poll(file, ssk->sk_socket, wait);
-       sock_put(ssk);
        return ret;
 }
 
-- 
2.22.0

Reply via email to