From: Florian Westphal <f...@strlen.de> subflow sockets already have lifetime managed by RCU, so we can switch to atomic_inc_not_zero and skip/pretend we did not find such socket in the mptcp subflow list.
This is required to get rid of synchronize_rcu() from mptcp_close(). Signed-off-by: Florian Westphal <f...@strlen.de> --- net/mptcp/protocol.c | 104 +++++++++++++++++++++++++++---------------- 1 file changed, 66 insertions(+), 38 deletions(-) diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index c00e837a1766..0db4099d9c13 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -24,14 +24,35 @@ static inline bool before64(__u64 seq1, __u64 seq2) #define after64(seq2, seq1) before64(seq1, seq2) +static bool mptcp_subflow_hold(struct subflow_context *subflow) +{ + struct sock *sk = mptcp_subflow_tcp_socket(subflow)->sk; + + return refcount_inc_not_zero(&sk->sk_refcnt); +} + +static struct sock *mptcp_subflow_get_ref(const struct mptcp_sock *msk) +{ + struct subflow_context *subflow; + + rcu_read_lock(); + mptcp_for_each_subflow(msk, subflow) { + if (mptcp_subflow_hold(subflow)) { + rcu_read_unlock(); + return mptcp_subflow_tcp_socket(subflow)->sk; + } + } + + rcu_read_unlock(); + return NULL; +} + static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) { struct mptcp_sock *msk = mptcp_sk(sk); int mss_now, size_goal, poffset, ret; struct mptcp_ext *mpext = NULL; - struct subflow_context *subflow; struct page *page = NULL; - struct hlist_node *node; struct sk_buff *skb; struct sock *ssk; size_t psize; @@ -42,20 +63,17 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) return sock_sendmsg(msk->subflow, msg); } - rcu_read_lock(); - node = rcu_dereference(hlist_first_rcu(&msk->conn_list)); - subflow = hlist_entry(node, struct subflow_context, node); - ssk = mptcp_subflow_tcp_socket(subflow)->sk; - sock_hold(ssk); - rcu_read_unlock(); + ssk = mptcp_subflow_get_ref(msk); + if (!ssk) + return -ENOTCONN; if (!msg_data_left(msg)) { pr_debug("empty send"); - ret = sock_sendmsg(mptcp_subflow_tcp_socket(subflow), msg); + ret = sock_sendmsg(ssk->sk_socket, msg); goto put_out; } - pr_debug("conn_list->subflow=%p", subflow); + pr_debug("conn_list->subflow=%p", ssk); if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) { ret = -ENOTSUPP; @@ -293,7 +311,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, struct mptcp_sock *msk = mptcp_sk(sk); struct subflow_context *subflow; struct mptcp_read_arg arg; - struct hlist_node *node; read_descriptor_t desc; struct tcp_sock *tp; struct sock *ssk; @@ -306,13 +323,11 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, return sock_recvmsg(msk->subflow, msg, flags); } - rcu_read_lock(); - node = rcu_dereference(hlist_first_rcu(&msk->conn_list)); - subflow = hlist_entry(node, struct subflow_context, node); - ssk = mptcp_subflow_tcp_socket(subflow)->sk; - sock_hold(ssk); - rcu_read_unlock(); + ssk = mptcp_subflow_get_ref(msk); + if (!ssk) + return -ENOTCONN; + subflow = subflow_ctx(ssk); tp = tcp_sk(ssk); lock_sock(sk); @@ -778,8 +793,6 @@ static int mptcp_getname(struct socket *sock, struct sockaddr *uaddr, int peer) { struct mptcp_sock *msk = mptcp_sk(sock->sk); - struct subflow_context *subflow; - struct hlist_node *node; struct sock *ssk; int ret; @@ -794,14 +807,11 @@ static int mptcp_getname(struct socket *sock, struct sockaddr *uaddr, * is connected and there are multiple subflows is not defined. * For now just use the first subflow on the list. */ - rcu_read_lock(); - node = rcu_dereference(hlist_first_rcu(&msk->conn_list)); - subflow = hlist_entry(node, struct subflow_context, node); - ssk = mptcp_subflow_tcp_socket(subflow)->sk; - sock_hold(ssk); - rcu_read_unlock(); + ssk = mptcp_subflow_get_ref(msk); + if (!ssk) + return -ENOTCONN; - ret = inet_getname(mptcp_subflow_tcp_socket(subflow), uaddr, peer); + ret = inet_getname(ssk->sk_socket, uaddr, peer); sock_put(ssk); return ret; } @@ -837,26 +847,44 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, static __poll_t mptcp_poll(struct file *file, struct socket *sock, struct poll_table_struct *wait) { - const struct mptcp_sock *msk; struct subflow_context *subflow; + const struct mptcp_sock *msk; struct sock *sk = sock->sk; - struct hlist_node *node; - struct sock *ssk; - __poll_t ret; + __poll_t ret = 0; + unsigned int i; msk = mptcp_sk(sk); if (msk->subflow) return tcp_poll(file, msk->subflow, wait); - rcu_read_lock(); - node = rcu_dereference(hlist_first_rcu(&msk->conn_list)); - subflow = hlist_entry(node, struct subflow_context, node); - ssk = mptcp_subflow_tcp_socket(subflow)->sk; - sock_hold(ssk); - rcu_read_unlock(); + i = 0; + for (;;) { + struct subflow_context *tmp = NULL; + int j = 0; + + rcu_read_lock(); + mptcp_for_each_subflow(msk, subflow) { + if (j < i) { + j++; + continue; + } + + if (!mptcp_subflow_hold(subflow)) + continue; + + tmp = subflow; + i++; + break; + } + rcu_read_unlock(); + + if (!tmp) + break; + + ret |= tcp_poll(file, mptcp_subflow_tcp_socket(tmp), wait); + sock_put(mptcp_subflow_tcp_socket(tmp)->sk); + } - ret = tcp_poll(file, ssk->sk_socket, wait); - sock_put(ssk); return ret; } -- 2.22.0