so that they can be later used for recvmmsg refactor

Signed-off-by: Sabrina Dubroca <s...@queasysnail.net>
Signed-off-by: Paolo Abeni <pab...@redhat.com>
---
 include/net/sock.h | 18 ++++++++++
 net/socket.c       | 97 +++++++++++++++++++++++++++++-------------------------
 2 files changed, 70 insertions(+), 45 deletions(-)

diff --git a/include/net/sock.h b/include/net/sock.h
index 442cbb1..c92dc19 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1528,6 +1528,24 @@ int __sock_cmsg_send(struct sock *sk, struct msghdr 
*msg, struct cmsghdr *cmsg,
 int sock_cmsg_send(struct sock *sk, struct msghdr *msg,
                   struct sockcm_cookie *sockc);
 
+static inline bool sock_recvmmsg_timeout(struct timespec *timeout,
+                                        struct timespec64 end_time)
+{
+       struct timespec64 timeout64;
+
+       if (!timeout)
+               return false;
+
+       ktime_get_ts64(&timeout64);
+       *timeout = timespec64_to_timespec(timespec64_sub(end_time, timeout64));
+       if (timeout->tv_sec < 0) {
+               timeout->tv_sec = timeout->tv_nsec = 0;
+               return true;
+       }
+
+       return timeout->tv_nsec == 0 && timeout->tv_sec == 0;
+}
+
 /*
  * Functions to fill in entries in struct proto_ops when a protocol
  * does not implement a particular function.
diff --git a/net/socket.c b/net/socket.c
index e2584c5..9b5f360 100644
--- a/net/socket.c
+++ b/net/socket.c
@@ -1903,6 +1903,21 @@ static int copy_msghdr_from_user(struct msghdr *kmsg,
                            UIO_FASTIOV, iov, &kmsg->msg_iter);
 }
 
+static int copy_msghdr_from_user_gen(struct msghdr *msg_sys, unsigned int 
flags,
+                                    struct compat_msghdr __user *msg_compat,
+                                    struct user_msghdr __user *msg,
+                                    struct sockaddr __user **uaddr,
+                                    struct iovec **iov,
+                                    struct sockaddr_storage *addr)
+{
+       msg_sys->msg_name = addr;
+
+       if (MSG_CMSG_COMPAT & flags)
+               return get_compat_msghdr(msg_sys, msg_compat, uaddr, iov);
+       else
+               return copy_msghdr_from_user(msg_sys, msg, uaddr, iov);
+}
+
 static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
                         struct msghdr *msg_sys, unsigned int flags,
                         struct used_address *used_address,
@@ -1919,12 +1934,8 @@ static int ___sys_sendmsg(struct socket *sock, struct 
user_msghdr __user *msg,
        int ctl_len;
        ssize_t err;
 
-       msg_sys->msg_name = &address;
-
-       if (MSG_CMSG_COMPAT & flags)
-               err = get_compat_msghdr(msg_sys, msg_compat, NULL, &iov);
-       else
-               err = copy_msghdr_from_user(msg_sys, msg, NULL, &iov);
+       err = copy_msghdr_from_user_gen(msg_sys, flags, msg_compat, msg, NULL,
+                                       &iov, &address);
        if (err < 0)
                return err;
 
@@ -2101,6 +2112,34 @@ int __sys_sendmmsg(int fd, struct mmsghdr __user *mmsg, 
unsigned int vlen,
        return __sys_sendmmsg(fd, mmsg, vlen, flags);
 }
 
+static int copy_msghdr_to_user_gen(struct msghdr *msg_sys, int flags,
+                                  struct compat_msghdr __user *msg_compat,
+                                  struct user_msghdr __user *msg,
+                                  struct sockaddr __user *uaddr,
+                                  struct sockaddr_storage *addr,
+                                  unsigned long cmsgptr)
+{
+       int __user *uaddr_len = COMPAT_NAMELEN(msg);
+       int err;
+
+       if (uaddr) {
+               err = move_addr_to_user(addr, msg_sys->msg_namelen, uaddr,
+                                       uaddr_len);
+               if (err < 0)
+                       return err;
+       }
+       err = __put_user((msg_sys->msg_flags & ~MSG_CMSG_COMPAT),
+                        COMPAT_FLAGS(msg));
+       if (err)
+               return err;
+       if (MSG_CMSG_COMPAT & flags)
+               return __put_user((unsigned long)msg_sys->msg_control -
+                                 cmsgptr, &msg_compat->msg_controllen);
+       else
+               return __put_user((unsigned long)msg_sys->msg_control - cmsgptr,
+                                 &msg->msg_controllen);
+}
+
 static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
                         struct msghdr *msg_sys, unsigned int flags, int nosec)
 {
@@ -2117,14 +2156,9 @@ static int ___sys_recvmsg(struct socket *sock, struct 
user_msghdr __user *msg,
 
        /* user mode address pointers */
        struct sockaddr __user *uaddr;
-       int __user *uaddr_len = COMPAT_NAMELEN(msg);
 
-       msg_sys->msg_name = &addr;
-
-       if (MSG_CMSG_COMPAT & flags)
-               err = get_compat_msghdr(msg_sys, msg_compat, &uaddr, &iov);
-       else
-               err = copy_msghdr_from_user(msg_sys, msg, &uaddr, &iov);
+       err = copy_msghdr_from_user_gen(msg_sys, flags, msg_compat, msg, &uaddr,
+                                       &iov, &addr);
        if (err < 0)
                return err;
 
@@ -2140,24 +2174,8 @@ static int ___sys_recvmsg(struct socket *sock, struct 
user_msghdr __user *msg,
        if (err < 0)
                goto out_freeiov;
        len = err;
-
-       if (uaddr != NULL) {
-               err = move_addr_to_user(&addr,
-                                       msg_sys->msg_namelen, uaddr,
-                                       uaddr_len);
-               if (err < 0)
-                       goto out_freeiov;
-       }
-       err = __put_user((msg_sys->msg_flags & ~MSG_CMSG_COMPAT),
-                        COMPAT_FLAGS(msg));
-       if (err)
-               goto out_freeiov;
-       if (MSG_CMSG_COMPAT & flags)
-               err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr,
-                                &msg_compat->msg_controllen);
-       else
-               err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr,
-                                &msg->msg_controllen);
+       err = copy_msghdr_to_user_gen(msg_sys, flags, msg_compat, msg, uaddr,
+                                     &addr, cmsg_ptr);
        if (err)
                goto out_freeiov;
        err = len;
@@ -2209,7 +2227,6 @@ int __sys_recvmmsg(int fd, struct mmsghdr __user *mmsg, 
unsigned int vlen,
        struct compat_mmsghdr __user *compat_entry;
        struct msghdr msg_sys;
        struct timespec64 end_time;
-       struct timespec64 timeout64;
 
        if (timeout &&
            poll_select_set_timeout(&end_time, timeout->tv_sec,
@@ -2260,19 +2277,9 @@ int __sys_recvmmsg(int fd, struct mmsghdr __user *mmsg, 
unsigned int vlen,
                if (flags & MSG_WAITFORONE)
                        flags |= MSG_DONTWAIT;
 
-               if (timeout) {
-                       ktime_get_ts64(&timeout64);
-                       *timeout = timespec64_to_timespec(
-                                       timespec64_sub(end_time, timeout64));
-                       if (timeout->tv_sec < 0) {
-                               timeout->tv_sec = timeout->tv_nsec = 0;
-                               break;
-                       }
-
-                       /* Timeout, return less than vlen datagrams */
-                       if (timeout->tv_nsec == 0 && timeout->tv_sec == 0)
-                               break;
-               }
+               /* Timeout, return less than vlen datagrams */
+               if (sock_recvmmsg_timeout(timeout, end_time))
+                       break;
 
                /* Out of band data, return right away */
                if (msg_sys.msg_flags & MSG_OOB)
-- 
1.8.3.1

Reply via email to