From: Dewi Morgan <morg...@vyatta.att-mail.com>

For bound udp sockets in a vrf, also check the sdif to get the index
for ingress devices enslaved to an l3mdev. Verify the multicast address
against the enslaved rather than the l3mdev device.

Signed-off-by: Dewi Morgan <morg...@vyatta.att-mail.com>
Signed-off-by: Mike Manning <mmann...@vyatta.att-mail.com>
---
 net/ipv6/ip6_input.c | 27 ++++++++++++++++++++++++---
 net/ipv6/udp.c       |  8 +++++---
 2 files changed, 29 insertions(+), 6 deletions(-)

diff --git a/net/ipv6/ip6_input.c b/net/ipv6/ip6_input.c
index 108f5f88ec98..fc60f297d95b 100644
--- a/net/ipv6/ip6_input.c
+++ b/net/ipv6/ip6_input.c
@@ -325,9 +325,12 @@ static int ip6_input_finish(struct net *net, struct sock 
*sk, struct sk_buff *sk
 {
        const struct inet6_protocol *ipprot;
        struct inet6_dev *idev;
+       struct net_device *dev;
        unsigned int nhoff;
+       int sdif = inet6_sdif(skb);
        int nexthdr;
        bool raw;
+       bool deliver;
        bool have_final = false;
 
        /*
@@ -371,9 +374,27 @@ static int ip6_input_finish(struct net *net, struct sock 
*sk, struct sk_buff *sk
                        skb_postpull_rcsum(skb, skb_network_header(skb),
                                           skb_network_header_len(skb));
                        hdr = ipv6_hdr(skb);
-                       if (ipv6_addr_is_multicast(&hdr->daddr) &&
-                           !ipv6_chk_mcast_addr(skb->dev, &hdr->daddr,
-                           &hdr->saddr) &&
+
+                       /* skb->dev passed may be master dev for vrfs. */
+                       if (sdif) {
+                               rcu_read_lock();
+                               dev = dev_get_by_index_rcu(dev_net(skb->dev),
+                                                          sdif);
+                               if (!dev) {
+                                       rcu_read_unlock();
+                                       kfree_skb(skb);
+                                       return -ENODEV;
+                               }
+                       } else {
+                               dev = skb->dev;
+                       }
+
+                       deliver = ipv6_chk_mcast_addr(dev, &hdr->daddr,
+                                                     &hdr->saddr);
+                       if (sdif)
+                               rcu_read_unlock();
+
+                       if (ipv6_addr_is_multicast(&hdr->daddr) && !deliver &&
                            !ipv6_is_mld(skb, nexthdr, 
skb_network_header_len(skb)))
                                goto discard;
                }
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index e22b7dd78c9b..35f71b7a1070 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -637,7 +637,7 @@ static int udpv6_queue_rcv_skb(struct sock *sk, struct 
sk_buff *skb)
 static bool __udp_v6_is_mcast_sock(struct net *net, struct sock *sk,
                                   __be16 loc_port, const struct in6_addr 
*loc_addr,
                                   __be16 rmt_port, const struct in6_addr 
*rmt_addr,
-                                  int dif, unsigned short hnum)
+                                  int dif, int sdif, unsigned short hnum)
 {
        struct inet_sock *inet = inet_sk(sk);
 
@@ -649,7 +649,7 @@ static bool __udp_v6_is_mcast_sock(struct net *net, struct 
sock *sk,
            (inet->inet_dport && inet->inet_dport != rmt_port) ||
            (!ipv6_addr_any(&sk->sk_v6_daddr) &&
                    !ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr)) ||
-           (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif) ||
+           !inet_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif) ||
            (!ipv6_addr_any(&sk->sk_v6_rcv_saddr) &&
                    !ipv6_addr_equal(&sk->sk_v6_rcv_saddr, loc_addr)))
                return false;
@@ -683,6 +683,7 @@ static int __udp6_lib_mcast_deliver(struct net *net, struct 
sk_buff *skb,
        unsigned int offset = offsetof(typeof(*sk), sk_node);
        unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10);
        int dif = inet6_iif(skb);
+       int sdif = inet6_sdif(skb);
        struct hlist_node *node;
        struct sk_buff *nskb;
 
@@ -697,7 +698,8 @@ static int __udp6_lib_mcast_deliver(struct net *net, struct 
sk_buff *skb,
 
        sk_for_each_entry_offset_rcu(sk, node, &hslot->head, offset) {
                if (!__udp_v6_is_mcast_sock(net, sk, uh->dest, daddr,
-                                           uh->source, saddr, dif, hnum))
+                                           uh->source, saddr, dif, sdif,
+                                           hnum))
                        continue;
                /* If zero checksum and no_check is not on for
                 * the socket then skip it.
-- 
2.11.0

Reply via email to