skbs are extracted from the receive queue in burts, and a single sk_rmem_alloc/forward allocated memory update is performed for each burst. MSG_PEEK and MSG_ERRQUEUE are not supported to keep the implementation as simple as possible.
Signed-off-by: Sabrina Dubroca <s...@queasysnail.net> Signed-off-by: Paolo Abeni <pab...@redhat.com> --- include/net/udp.h | 7 +++ net/ipv4/udp.c | 121 ++++++++++++++++++++++++++++++++++++++++++++++++++++ net/ipv4/udp_impl.h | 3 ++ net/ipv4/udplite.c | 1 + net/ipv6/udp.c | 16 +++++++ net/ipv6/udp_impl.h | 3 ++ net/ipv6/udplite.c | 1 + 7 files changed, 152 insertions(+) diff --git a/include/net/udp.h b/include/net/udp.h index 1661791..2bd63c9 100644 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -308,6 +308,13 @@ struct sock *__udp6_lib_lookup(struct net *net, struct sock *udp6_lib_lookup_skb(struct sk_buff *skb, __be16 sport, __be16 dport); +int __udp_recvmmsg(struct sock *sk, struct mmsghdr __user *ummsg, + unsigned int *vlen, unsigned int flags, + struct timespec *timeout, const struct timespec64 *end_time, + int (*udp_process_msg)(struct sock *sk, struct sk_buff *skb, + struct msghdr *msg, + unsigned int flags)); + /* * SNMP statistics for UDP and UDP-Lite */ diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index d99429d..44f1326 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -1467,6 +1467,126 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, return err; } +static void udp_skb_bulk_destructor(struct sock *sk, int totalsize) +{ + udp_rmem_release(sk, totalsize, 1); +} + +int __udp_recvmmsg(struct sock *sk, struct mmsghdr __user *mmsg, + unsigned int *nr, unsigned int flags, + struct timespec *timeout, const struct timespec64 *end_time, + int (*process_msg)(struct sock *sk, struct sk_buff *skb, + struct msghdr *msg, + unsigned int flags)) +{ + long timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + int datagrams = 0, err = 0, ret = 0, vlen = *nr; + struct sk_buff *skb, *next, *last; + + if (flags & (MSG_PEEK | MSG_ERRQUEUE)) + return -EOPNOTSUPP; + +again: + for (;;) { + skb = __skb_try_recv_datagram_batch(sk, flags, vlen - datagrams, + udp_skb_bulk_destructor, + &err); + if (skb) + break; + + if ((err != -EAGAIN) || !timeo || (flags & MSG_DONTWAIT)) + goto out; + + /* no packets, and we are supposed to wait for the next */ + if (timeout) { + long expires; + + if (sock_recvmmsg_timeout(timeout, *end_time)) + goto out; + expires = timeout->tv_sec * HZ + + (timeout->tv_nsec >> 20); + if (expires + 1 < timeo) + timeo = expires + 1; + } + + /* the queue was empty when tried to dequeue */ + last = (struct sk_buff *)&sk->sk_receive_queue; + if (__skb_wait_for_more_packets(sk, &err, &timeo, last)) + goto out; + } + + for (; skb; skb = next) { + struct recvmmsg_ctx ctx; + int len; + + next = skb->next; + err = recvmmsg_ctx_from_user(sk, mmsg, flags, datagrams, &ctx); + if (err < 0) { + kfree_skb(skb); + goto free_ctx; + } + + /* process skb's until we find a valid one */ + for (;;) { + len = process_msg(sk, skb, &ctx.msg_sys, flags); + if (len >= 0) + break; + + /* only non csum errors are propagated to the caller */ + if (len != -EINVAL) { + err = len; + goto free_ctx; + } + + if (!next) + goto free_ctx; + skb = next; + next = skb->next; + } + + err = recvmmsg_ctx_to_user(&mmsg, len, flags, &ctx); + if (err < 0) + goto free_ctx; + + /* now we're sure the skb is fully processed, we can count it */ + datagrams++; + +free_ctx: + recvmmsg_ctx_free(&ctx); + if (err < 0) + ret = err; + } + + /* only handle waitforone after processing a full batch. */ + if (datagrams && (flags & MSG_WAITFORONE)) + flags |= MSG_DONTWAIT; + + if (!ret && (datagrams < vlen)) { + cond_resched(); + goto again; + } + +out: + *nr = datagrams; + return ret < 0 ? ret : -EAGAIN; +} +EXPORT_SYMBOL_GPL(__udp_recvmmsg); + +static int udp_process_msg(struct sock *sk, struct sk_buff *skb, + struct msghdr *msg, unsigned int flags) +{ + return udp_process_skb(sk, skb, msg, msg_data_left(msg), flags, + &msg->msg_namelen, 0, 0, skb->peeked); +} + +int udp_recvmmsg(struct sock *sk, struct mmsghdr __user *ummsg, + unsigned int *nr, unsigned int flags, struct timespec *timeout, + const struct timespec64 *end_time) +{ + return __udp_recvmmsg(sk, ummsg, nr, flags, timeout, end_time, + udp_process_msg); +} + int __udp_disconnect(struct sock *sk, int flags) { struct inet_sock *inet = inet_sk(sk); @@ -2329,6 +2449,7 @@ struct proto udp_prot = { .getsockopt = udp_getsockopt, .sendmsg = udp_sendmsg, .recvmsg = udp_recvmsg, + .recvmmsg = udp_recvmmsg, .sendpage = udp_sendpage, .release_cb = ip4_datagram_release_cb, .hash = udp_lib_hash, diff --git a/net/ipv4/udp_impl.h b/net/ipv4/udp_impl.h index 7e0fe4b..f11d608 100644 --- a/net/ipv4/udp_impl.h +++ b/net/ipv4/udp_impl.h @@ -23,6 +23,9 @@ int compat_udp_getsockopt(struct sock *sk, int level, int optname, #endif int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, int flags, int *addr_len); +int udp_recvmmsg(struct sock *sk, struct mmsghdr __user *ummsg, + unsigned int *nr, unsigned int flags, struct timespec *timeout, + const struct timespec64 *end_time); int udp_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags); int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); diff --git a/net/ipv4/udplite.c b/net/ipv4/udplite.c index 59f10fe..a0e7fe9 100644 --- a/net/ipv4/udplite.c +++ b/net/ipv4/udplite.c @@ -49,6 +49,7 @@ struct proto udplite_prot = { .getsockopt = udp_getsockopt, .sendmsg = udp_sendmsg, .recvmsg = udp_recvmsg, + .recvmmsg = udp_recvmmsg, .sendpage = udp_sendpage, .hash = udp_lib_hash, .unhash = udp_lib_unhash, diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index 3218c64..2c034be 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -479,6 +479,21 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, } +static int udp6_process_msg(struct sock *sk, struct sk_buff *skb, + struct msghdr *msg, unsigned int flags) +{ + return udp6_process_skb(sk, skb, msg, msg_data_left(msg), flags, + &msg->msg_namelen, 0, 0, skb->peeked); +} + +int udpv6_recvmmsg(struct sock *sk, struct mmsghdr __user *ummsg, + unsigned int *nr, unsigned int flags, + struct timespec *timeout, const struct timespec64 *end_time) +{ + return __udp_recvmmsg(sk, ummsg, nr, flags, timeout, end_time, + udp6_process_msg); +} + void __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt, u8 type, u8 code, int offset, __be32 info, struct udp_table *udptable) @@ -1443,6 +1458,7 @@ struct proto udpv6_prot = { .getsockopt = udpv6_getsockopt, .sendmsg = udpv6_sendmsg, .recvmsg = udpv6_recvmsg, + .recvmmsg = udpv6_recvmmsg, .release_cb = ip6_datagram_release_cb, .hash = udp_lib_hash, .unhash = udp_lib_unhash, diff --git a/net/ipv6/udp_impl.h b/net/ipv6/udp_impl.h index f6eb1ab..fe566db 100644 --- a/net/ipv6/udp_impl.h +++ b/net/ipv6/udp_impl.h @@ -26,6 +26,9 @@ int compat_udpv6_getsockopt(struct sock *sk, int level, int optname, int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len); int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, int flags, int *addr_len); +int udpv6_recvmmsg(struct sock *sk, struct mmsghdr __user *ummsg, + unsigned int *nr, unsigned int flags, + struct timespec *timeout, const struct timespec64 *end_time); int udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); void udpv6_destroy_sock(struct sock *sk); diff --git a/net/ipv6/udplite.c b/net/ipv6/udplite.c index 2784cc3..23d80ac 100644 --- a/net/ipv6/udplite.c +++ b/net/ipv6/udplite.c @@ -45,6 +45,7 @@ struct proto udplitev6_prot = { .getsockopt = udpv6_getsockopt, .sendmsg = udpv6_sendmsg, .recvmsg = udpv6_recvmsg, + .recvmmsg = udpv6_recvmmsg, .hash = udp_lib_hash, .unhash = udp_lib_unhash, .get_port = udp_v6_get_port, -- 1.8.3.1