Merge the two very similar functions sock_map_update_elem and
sock_hash_update_elem into one.

Signed-off-by: Lorenz Bauer <l...@cloudflare.com>
---
 net/core/sock_map.c | 53 ++++++++-------------------------------------
 1 file changed, 9 insertions(+), 44 deletions(-)

diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index abe4bac40db9..f464a0ebc871 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -559,10 +559,12 @@ static bool sock_map_sk_state_allowed(const struct sock 
*sk)
        return false;
 }
 
-static int sock_map_update_elem(struct bpf_map *map, void *key,
-                               void *value, u64 flags)
+static int sock_hash_update_common(struct bpf_map *map, void *key,
+                                  struct sock *sk, u64 flags);
+
+int sock_map_update_elem(struct bpf_map *map, void *key,
+                        void *value, u64 flags)
 {
-       u32 idx = *(u32 *)key;
        struct socket *sock;
        struct sock *sk;
        int ret;
@@ -591,8 +593,10 @@ static int sock_map_update_elem(struct bpf_map *map, void 
*key,
        sock_map_sk_acquire(sk);
        if (!sock_map_sk_state_allowed(sk))
                ret = -EOPNOTSUPP;
+       else if (map->map_type == BPF_MAP_TYPE_SOCKMAP)
+               ret = sock_map_update_common(map, *(u32 *)key, sk, flags);
        else
-               ret = sock_map_update_common(map, idx, sk, flags);
+               ret = sock_hash_update_common(map, key, sk, flags);
        sock_map_sk_release(sk);
 out:
        fput(sock->file);
@@ -909,45 +913,6 @@ static int sock_hash_update_common(struct bpf_map *map, 
void *key,
        return ret;
 }
 
-static int sock_hash_update_elem(struct bpf_map *map, void *key,
-                                void *value, u64 flags)
-{
-       struct socket *sock;
-       struct sock *sk;
-       int ret;
-       u64 ufd;
-
-       if (map->value_size == sizeof(u64))
-               ufd = *(u64 *)value;
-       else
-               ufd = *(u32 *)value;
-       if (ufd > S32_MAX)
-               return -EINVAL;
-
-       sock = sockfd_lookup(ufd, &ret);
-       if (!sock)
-               return ret;
-       sk = sock->sk;
-       if (!sk) {
-               ret = -EINVAL;
-               goto out;
-       }
-       if (!sock_map_sk_is_suitable(sk)) {
-               ret = -EOPNOTSUPP;
-               goto out;
-       }
-
-       sock_map_sk_acquire(sk);
-       if (!sock_map_sk_state_allowed(sk))
-               ret = -EOPNOTSUPP;
-       else
-               ret = sock_hash_update_common(map, key, sk, flags);
-       sock_map_sk_release(sk);
-out:
-       fput(sock->file);
-       return ret;
-}
-
 static int sock_hash_get_next_key(struct bpf_map *map, void *key,
                                  void *key_next)
 {
@@ -1216,7 +1181,7 @@ const struct bpf_map_ops sock_hash_ops = {
        .map_alloc              = sock_hash_alloc,
        .map_free               = sock_hash_free,
        .map_get_next_key       = sock_hash_get_next_key,
-       .map_update_elem        = sock_hash_update_elem,
+       .map_update_elem        = sock_map_update_elem,
        .map_delete_elem        = sock_hash_delete_elem,
        .map_lookup_elem        = sock_hash_lookup,
        .map_lookup_elem_sys_only = sock_hash_lookup_sys,
-- 
2.25.1

Reply via email to