Most of the arguments passed to the various diag_fill functions
are only a couple of pointer lookups away from in_skb. Instead
of passing all the arguments down the call chain, pass around
in_skb and have the fill functions get what they need from that.

This patch does not remove the "bool net_admin" parameter. This
parameter can also be derived from in_skb, but removing it would
require calculating it once per socket instead of once per dump,
and because netlink_ns_capable is more than just a couple of
pointer lookups, that might affect performance. If performance is
not a concern (which is likely, since dumping a socket already
involves things like copying data to userspace), it's easy to
remove this one as well.

Tested: Passes all existing Android sock_diag tests.
Signed-off-by: Lorenzo Colitti <lore...@google.com>
---
 include/linux/inet_diag.h |  6 ++--
 net/ipv4/inet_diag.c      | 75 +++++++++++++++++------------------------------
 net/ipv4/udp_diag.c       | 13 +++-----
 net/sctp/sctp_diag.c      | 36 ++++++++---------------
 4 files changed, 46 insertions(+), 84 deletions(-)

diff --git a/include/linux/inet_diag.h b/include/linux/inet_diag.h
index 65da430..da0777c 100644
--- a/include/linux/inet_diag.h
+++ b/include/linux/inet_diag.h
@@ -35,8 +35,7 @@ struct inet_diag_handler {
 struct inet_connection_sock;
 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
                      struct sk_buff *skb, const struct inet_diag_req_v2 *req,
-                     struct user_namespace *user_ns,
-                     u32 pid, u32 seq, u16 nlmsg_flags,
+                     const struct sk_buff *in_skb, u16 nlmsg_flags,
                      const struct nlmsghdr *unlh, bool net_admin);
 void inet_diag_dump_icsk(struct inet_hashinfo *h, struct sk_buff *skb,
                         struct netlink_callback *cb,
@@ -56,7 +55,8 @@ void inet_diag_msg_common_fill(struct inet_diag_msg *r, 
struct sock *sk);
 
 int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
                             struct inet_diag_msg *r, int ext,
-                            struct user_namespace *user_ns, bool net_admin);
+                            const struct sk_buff *in_skb,
+                            bool net_admin);
 
 extern int  inet_diag_register(const struct inet_diag_handler *handler);
 extern void inet_diag_unregister(const struct inet_diag_handler *handler);
diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c
index 1683bf5..de5bfa8 100644
--- a/net/ipv4/inet_diag.c
+++ b/net/ipv4/inet_diag.c
@@ -110,7 +110,7 @@ static size_t inet_sk_attr_size(void)
 
 int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
                             struct inet_diag_msg *r, int ext,
-                            struct user_namespace *user_ns,
+                            const struct sk_buff *in_skb,
                             bool net_admin)
 {
        const struct inet_sock *inet = inet_sk(sk);
@@ -141,7 +141,8 @@ int inet_diag_msg_attrs_fill(struct sock *sk, struct 
sk_buff *skb,
        if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark))
                goto errout;
 
-       r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
+       r->idiag_uid = from_kuid_munged(sk_user_ns(NETLINK_CB(in_skb).sk),
+                                       sock_i_uid(sk));
        r->idiag_inode = sock_i_ino(sk);
 
        return 0;
@@ -152,10 +153,8 @@ EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
 
 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
                      struct sk_buff *skb, const struct inet_diag_req_v2 *req,
-                     struct user_namespace *user_ns,
-                     u32 portid, u32 seq, u16 nlmsg_flags,
-                     const struct nlmsghdr *unlh,
-                     bool net_admin)
+                     const struct sk_buff *in_skb, u16 nlmsg_flags,
+                     const struct nlmsghdr *unlh, bool net_admin)
 {
        const struct tcp_congestion_ops *ca_ops;
        const struct inet_diag_handler *handler;
@@ -168,8 +167,8 @@ int inet_sk_diag_fill(struct sock *sk, struct 
inet_connection_sock *icsk,
        handler = inet_diag_table[req->sdiag_protocol];
        BUG_ON(!handler);
 
-       nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
-                       nlmsg_flags);
+       nlh = nlmsg_put(skb, NETLINK_CB(in_skb).portid, unlh->nlmsg_seq,
+                       unlh->nlmsg_type, sizeof(*r), nlmsg_flags);
        if (!nlh)
                return -EMSGSIZE;
 
@@ -181,7 +180,7 @@ int inet_sk_diag_fill(struct sock *sk, struct 
inet_connection_sock *icsk,
        r->idiag_timer = 0;
        r->idiag_retrans = 0;
 
-       if (inet_diag_msg_attrs_fill(sk, skb, r, ext, user_ns, net_admin))
+       if (inet_diag_msg_attrs_fill(sk, skb, r, ext, in_skb, net_admin))
                goto errout;
 
        if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
@@ -275,30 +274,20 @@ errout:
 }
 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
 
-static int inet_csk_diag_fill(struct sock *sk,
-                             struct sk_buff *skb,
-                             const struct inet_diag_req_v2 *req,
-                             struct user_namespace *user_ns,
-                             u32 portid, u32 seq, u16 nlmsg_flags,
-                             const struct nlmsghdr *unlh,
-                             bool net_admin)
-{
-       return inet_sk_diag_fill(sk, inet_csk(sk), skb, req, user_ns,
-                                portid, seq, nlmsg_flags, unlh, net_admin);
-}
-
 static int inet_twsk_diag_fill(struct sock *sk,
                               struct sk_buff *skb,
-                              u32 portid, u32 seq, u16 nlmsg_flags,
+                              const struct sk_buff *in_skb,
+                              u16 nlmsg_flags,
                               const struct nlmsghdr *unlh)
 {
        struct inet_timewait_sock *tw = inet_twsk(sk);
+       u32 portid = NETLINK_CB(in_skb).portid;
        struct inet_diag_msg *r;
        struct nlmsghdr *nlh;
        long tmo;
 
-       nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
-                       nlmsg_flags);
+       nlh = nlmsg_put(skb, portid, unlh->nlmsg_seq, unlh->nlmsg_type,
+                       sizeof(*r), nlmsg_flags);
        if (!nlh)
                return -EMSGSIZE;
 
@@ -325,16 +314,17 @@ static int inet_twsk_diag_fill(struct sock *sk,
 }
 
 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
-                             u32 portid, u32 seq, u16 nlmsg_flags,
+                             const struct sk_buff *in_skb, u16 nlmsg_flags,
                              const struct nlmsghdr *unlh, bool net_admin)
 {
        struct request_sock *reqsk = inet_reqsk(sk);
+       u32 portid = NETLINK_CB(in_skb).portid;
        struct inet_diag_msg *r;
        struct nlmsghdr *nlh;
        long tmo;
 
-       nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
-                       nlmsg_flags);
+       nlh = nlmsg_put(skb, portid, unlh->nlmsg_seq, unlh->nlmsg_type,
+                       sizeof(*r), nlmsg_flags);
        if (!nlh)
                return -EMSGSIZE;
 
@@ -364,20 +354,18 @@ static int inet_req_diag_fill(struct sock *sk, struct 
sk_buff *skb,
 
 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
                        const struct inet_diag_req_v2 *r,
-                       struct user_namespace *user_ns,
-                       u32 portid, u32 seq, u16 nlmsg_flags,
+                       const struct sk_buff *in_skb, u16 nlmsg_flags,
                        const struct nlmsghdr *unlh, bool net_admin)
 {
        if (sk->sk_state == TCP_TIME_WAIT)
-               return inet_twsk_diag_fill(sk, skb, portid, seq,
-                                          nlmsg_flags, unlh);
+               return inet_twsk_diag_fill(sk, skb, in_skb, nlmsg_flags, unlh);
 
        if (sk->sk_state == TCP_NEW_SYN_RECV)
-               return inet_req_diag_fill(sk, skb, portid, seq,
-                                         nlmsg_flags, unlh, net_admin);
+               return inet_req_diag_fill(sk, skb, in_skb, nlmsg_flags, unlh,
+                                         net_admin);
 
-       return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
-                                 nlmsg_flags, unlh, net_admin);
+       return inet_sk_diag_fill(sk, inet_csk(sk), skb, r, in_skb, nlmsg_flags,
+                                unlh, net_admin);
 }
 
 struct sock *inet_diag_find_one_icsk(struct net *net,
@@ -444,10 +432,7 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
                goto out;
        }
 
-       err = sk_diag_fill(sk, rep, req,
-                          sk_user_ns(NETLINK_CB(in_skb).sk),
-                          NETLINK_CB(in_skb).portid,
-                          nlh->nlmsg_seq, 0, nlh,
+       err = sk_diag_fill(sk, rep, req, in_skb, 0, nlh,
                           netlink_net_capable(in_skb, CAP_NET_ADMIN));
        if (err < 0) {
                WARN_ON(err == -EMSGSIZE);
@@ -815,11 +800,8 @@ static int inet_csk_diag_dump(struct sock *sk,
        if (!inet_diag_bc_sk(bc, sk))
                return 0;
 
-       return inet_csk_diag_fill(sk, skb, r,
-                                 sk_user_ns(NETLINK_CB(cb->skb).sk),
-                                 NETLINK_CB(cb->skb).portid,
-                                 cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh,
-                                 net_admin);
+       return inet_sk_diag_fill(sk, inet_csk(sk), skb, r, cb->skb,
+                                NLM_F_MULTI, cb->nlh, net_admin);
 }
 
 static void twsk_build_assert(void)
@@ -961,10 +943,7 @@ skip_listen_ht:
                        if (!inet_diag_bc_sk(bc, sk))
                                goto next_normal;
 
-                       res = sk_diag_fill(sk, skb, r,
-                                          sk_user_ns(NETLINK_CB(cb->skb).sk),
-                                          NETLINK_CB(cb->skb).portid,
-                                          cb->nlh->nlmsg_seq, NLM_F_MULTI,
+                       res = sk_diag_fill(sk, skb, r, cb->skb, NLM_F_MULTI,
                                           cb->nlh, net_admin);
                        if (res < 0) {
                                spin_unlock_bh(lock);
diff --git a/net/ipv4/udp_diag.c b/net/ipv4/udp_diag.c
index 9a89c10..917abac 100644
--- a/net/ipv4/udp_diag.c
+++ b/net/ipv4/udp_diag.c
@@ -25,10 +25,8 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb,
        if (!inet_diag_bc_sk(bc, sk))
                return 0;
 
-       return inet_sk_diag_fill(sk, NULL, skb, req,
-                       sk_user_ns(NETLINK_CB(cb->skb).sk),
-                       NETLINK_CB(cb->skb).portid,
-                       cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh, net_admin);
+       return inet_sk_diag_fill(sk, NULL, skb, req, cb->skb, NLM_F_MULTI,
+                                cb->nlh, net_admin);
 }
 
 static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb,
@@ -73,11 +71,8 @@ static int udp_dump_one(struct udp_table *tbl, struct 
sk_buff *in_skb,
        if (!rep)
                goto out;
 
-       err = inet_sk_diag_fill(sk, NULL, rep, req,
-                          sk_user_ns(NETLINK_CB(in_skb).sk),
-                          NETLINK_CB(in_skb).portid,
-                          nlh->nlmsg_seq, 0, nlh,
-                          netlink_net_capable(in_skb, CAP_NET_ADMIN));
+       err = inet_sk_diag_fill(sk, NULL, rep, req, in_skb, 0, nlh,
+                               netlink_net_capable(in_skb, CAP_NET_ADMIN));
        if (err < 0) {
                WARN_ON(err == -EMSGSIZE);
                kfree_skb(rep);
diff --git a/net/sctp/sctp_diag.c b/net/sctp/sctp_diag.c
index 807158e3..ab2009e 100644
--- a/net/sctp/sctp_diag.c
+++ b/net/sctp/sctp_diag.c
@@ -104,12 +104,12 @@ static int inet_diag_msg_sctpaddrs_fill(struct sk_buff 
*skb,
 static int inet_sctp_diag_fill(struct sock *sk, struct sctp_association *asoc,
                               struct sk_buff *skb,
                               const struct inet_diag_req_v2 *req,
-                              struct user_namespace *user_ns,
-                              int portid, u32 seq, u16 nlmsg_flags,
-                              const struct nlmsghdr *unlh,
+                              const struct sk_buff *in_skb,
+                              u16 nlmsg_flags, const struct nlmsghdr *unlh,
                               bool net_admin)
 {
        struct sctp_endpoint *ep = sctp_sk(sk)->ep;
+       u32 portid = NETLINK_CB(in_skb).portid;
        struct list_head *addr_list;
        struct inet_diag_msg *r;
        struct nlmsghdr  *nlh;
@@ -117,8 +117,8 @@ static int inet_sctp_diag_fill(struct sock *sk, struct 
sctp_association *asoc,
        struct sctp_infox infox;
        void *info = NULL;
 
-       nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
-                       nlmsg_flags);
+       nlh = nlmsg_put(skb, portid, unlh->nlmsg_seq, unlh->nlmsg_type,
+                       sizeof(*r), nlmsg_flags);
        if (!nlh)
                return -EMSGSIZE;
 
@@ -134,7 +134,7 @@ static int inet_sctp_diag_fill(struct sock *sk, struct 
sctp_association *asoc,
                r->idiag_retrans = 0;
        }
 
-       if (inet_diag_msg_attrs_fill(sk, skb, r, ext, user_ns, net_admin))
+       if (inet_diag_msg_attrs_fill(sk, skb, r, ext, in_skb, net_admin))
                goto errout;
 
        if (ext & (1 << (INET_DIAG_SKMEMINFO - 1))) {
@@ -256,10 +256,7 @@ static int sctp_tsp_dump_one(struct sctp_transport *tsp, 
void *p)
                sk = assoc->base.sk;
                lock_sock(sk);
        }
-       err = inet_sctp_diag_fill(sk, assoc, rep, req,
-                                 sk_user_ns(NETLINK_CB(in_skb).sk),
-                                 NETLINK_CB(in_skb).portid,
-                                 nlh->nlmsg_seq, 0, nlh,
+       err = inet_sctp_diag_fill(sk, assoc, rep, req, in_skb, 0, nlh,
                                  commp->net_admin);
        release_sock(sk);
        if (err < 0) {
@@ -310,10 +307,7 @@ static int sctp_tsp_dump(struct sctp_transport *tsp, void 
*p)
                        goto next;
 
                if (!cb->args[3] &&
-                   inet_sctp_diag_fill(sk, NULL, skb, r,
-                                       sk_user_ns(NETLINK_CB(cb->skb).sk),
-                                       NETLINK_CB(cb->skb).portid,
-                                       cb->nlh->nlmsg_seq,
+                   inet_sctp_diag_fill(sk, NULL, skb, r, cb->skb,
                                        NLM_F_MULTI, cb->nlh,
                                        commp->net_admin) < 0) {
                        cb->args[3] = 1;
@@ -322,11 +316,8 @@ static int sctp_tsp_dump(struct sctp_transport *tsp, void 
*p)
                }
                cb->args[3] = 1;
 
-               if (inet_sctp_diag_fill(sk, assoc, skb, r,
-                                       sk_user_ns(NETLINK_CB(cb->skb).sk),
-                                       NETLINK_CB(cb->skb).portid,
-                                       cb->nlh->nlmsg_seq, 0, cb->nlh,
-                                       commp->net_admin) < 0) {
+               if (inet_sctp_diag_fill(sk, assoc, skb, r, cb->skb,
+                                       0, cb->nlh, commp->net_admin) < 0) {
                        err = 2;
                        goto release;
                }
@@ -377,11 +368,8 @@ static int sctp_ep_dump(struct sctp_endpoint *ep, void *p)
            r->id.idiag_dport)
                goto next;
 
-       if (inet_sctp_diag_fill(sk, NULL, skb, r,
-                               sk_user_ns(NETLINK_CB(cb->skb).sk),
-                               NETLINK_CB(cb->skb).portid,
-                               cb->nlh->nlmsg_seq, NLM_F_MULTI,
-                               cb->nlh, commp->net_admin) < 0) {
+       if (inet_sctp_diag_fill(sk, NULL, skb, r, cb->skb,
+                               NLM_F_MULTI, cb->nlh, commp->net_admin) < 0) {
                err = 2;
                goto out;
        }
-- 
2.8.0.rc3.226.g39d4020

Reply via email to