From: Blair Steven <blair.ste...@alliedtelesis.co.nz>

This adds support for masquerading into a smaller subset of ports -
defined by the PSID values from RFC-7597 Section 5.1. This is part of
the support for MAP-E, which allows multiple devices to share an IPv4
address by splitting the L4 port / id into ranges by both masquerading
and encapsulating IPv4 packets inside an IPv6 carrier.

Co-developed-by: Anthony Lineham <anthony.line...@alliedtelesis.co.nz>
Co-developed-by: Scott Parlane <scott.parl...@alliedtelesis.co.nz>
Signed-off-by: Blair Steven <blair.ste...@alliedtelesis.co.nz>
Signed-off-by: Anthony Lineham <anthony.line...@alliedtelesis.co.nz>
Signed-off-by: Scott Parlane <scott.parl...@alliedtelesis.co.nz>
Signed-off-by: Felix Jia <felix....@alliedtelesis.co.nz>
---
 include/net/netfilter/nf_nat_l4proto.h        |  6 +--
 .../netfilter/nf_conntrack_tuple_common.h     |  5 ++
 include/uapi/linux/netfilter/nf_nat.h         |  4 +-
 net/ipv4/netfilter/nf_nat_proto_icmp.c        | 51 ++++++++++++++++++-
 net/netfilter/nf_nat_core.c                   | 25 ++++++---
 net/netfilter/nf_nat_proto_common.c           | 51 +++++++++++++++++--
 net/netfilter/nf_nat_proto_unknown.c          |  3 +-
 7 files changed, 126 insertions(+), 19 deletions(-)

diff --git a/include/net/netfilter/nf_nat_l4proto.h 
b/include/net/netfilter/nf_nat_l4proto.h
index b4d6b29bca62..d3fb8f138d0a 100644
--- a/include/net/netfilter/nf_nat_l4proto.h
+++ b/include/net/netfilter/nf_nat_l4proto.h
@@ -24,8 +24,7 @@ struct nf_nat_l4proto {
        /* Is the manipable part of the tuple between min and max incl? */
        bool (*in_range)(const struct nf_conntrack_tuple *tuple,
                         enum nf_nat_manip_type maniptype,
-                        const union nf_conntrack_man_proto *min,
-                        const union nf_conntrack_man_proto *max);
+                        const struct nf_nat_range2 *range);
 
        /* Alter the per-proto part of the tuple (depending on
         * maniptype), to give a unique tuple in the given range if
@@ -67,8 +66,7 @@ extern const struct nf_nat_l4proto nf_nat_l4proto_udplite;
 
 bool nf_nat_l4proto_in_range(const struct nf_conntrack_tuple *tuple,
                             enum nf_nat_manip_type maniptype,
-                            const union nf_conntrack_man_proto *min,
-                            const union nf_conntrack_man_proto *max);
+                            const struct nf_nat_range2 *range);
 
 void nf_nat_l4proto_unique_tuple(const struct nf_nat_l3proto *l3proto,
                                 struct nf_conntrack_tuple *tuple,
diff --git a/include/uapi/linux/netfilter/nf_conntrack_tuple_common.h 
b/include/uapi/linux/netfilter/nf_conntrack_tuple_common.h
index 64390fac6f7e..36d16d47c2b0 100644
--- a/include/uapi/linux/netfilter/nf_conntrack_tuple_common.h
+++ b/include/uapi/linux/netfilter/nf_conntrack_tuple_common.h
@@ -39,6 +39,11 @@ union nf_conntrack_man_proto {
        struct {
                __be16 key;     /* GRE key is 32bit, PPtP only uses 16bit */
        } gre;
+       struct {
+               unsigned char psid_length;
+               unsigned char offset;
+               __be16 psid;
+       } psid;
 };
 
 #define CTINFO2DIR(ctinfo) ((ctinfo) >= IP_CT_IS_REPLY ? IP_CT_DIR_REPLY : 
IP_CT_DIR_ORIGINAL)
diff --git a/include/uapi/linux/netfilter/nf_nat.h 
b/include/uapi/linux/netfilter/nf_nat.h
index 4a95c0db14d4..d145d3eca25f 100644
--- a/include/uapi/linux/netfilter/nf_nat.h
+++ b/include/uapi/linux/netfilter/nf_nat.h
@@ -11,6 +11,7 @@
 #define NF_NAT_RANGE_PERSISTENT                        (1 << 3)
 #define NF_NAT_RANGE_PROTO_RANDOM_FULLY                (1 << 4)
 #define NF_NAT_RANGE_PROTO_OFFSET              (1 << 5)
+#define NF_NAT_RANGE_PSID                      (1 << 6)
 
 #define NF_NAT_RANGE_PROTO_RANDOM_ALL          \
        (NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
@@ -18,7 +19,8 @@
 #define NF_NAT_RANGE_MASK                                      \
        (NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED |  \
         NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PERSISTENT |  \
-        NF_NAT_RANGE_PROTO_RANDOM_FULLY | NF_NAT_RANGE_PROTO_OFFSET)
+        NF_NAT_RANGE_PROTO_RANDOM_FULLY | NF_NAT_RANGE_PROTO_OFFSET |  \
+        NF_NAT_RANGE_PSID)
 
 struct nf_nat_ipv4_range {
        unsigned int                    flags;
diff --git a/net/ipv4/netfilter/nf_nat_proto_icmp.c 
b/net/ipv4/netfilter/nf_nat_proto_icmp.c
index 6d7cf1d79baf..39efac4930b6 100644
--- a/net/ipv4/netfilter/nf_nat_proto_icmp.c
+++ b/net/ipv4/netfilter/nf_nat_proto_icmp.c
@@ -20,9 +20,23 @@
 static bool
 icmp_in_range(const struct nf_conntrack_tuple *tuple,
              enum nf_nat_manip_type maniptype,
-             const union nf_conntrack_man_proto *min,
-             const union nf_conntrack_man_proto *max)
+             const struct nf_nat_range2 *range)
 {
+       const union nf_conntrack_man_proto *min = &range->min_proto;
+       const union nf_conntrack_man_proto *max = &range->max_proto;
+
+       /* In this case we are in PSID mode and the rules are all different */
+       if (range->flags & NF_NAT_RANGE_PSID) {
+               u16 offset_mask = ((1 << min->psid.offset) - 1) <<
+                               (16 - min->psid.offset);
+               u16 psid_mask = ((1 << min->psid.psid_length) - 1) <<
+                               (16 - min->psid.offset - min->psid.psid_length);
+               u16 id = ntohs(tuple->src.u.icmp.id);
+
+               return ((id & offset_mask) != 0) &&
+                       ((id & psid_mask) == min->psid.psid);
+       }
+
        return ntohs(tuple->src.u.icmp.id) >= ntohs(min->icmp.id) &&
               ntohs(tuple->src.u.icmp.id) <= ntohs(max->icmp.id);
 }
@@ -38,6 +52,39 @@ icmp_unique_tuple(const struct nf_nat_l3proto *l3proto,
        unsigned int range_size;
        unsigned int i;
 
+       if (range->flags & NF_NAT_RANGE_PSID) {
+               /* m = number of bits in each valid range */
+               u16 off;
+               int m = 16 - range->min_proto.psid.psid_length -
+                               range->min_proto.psid.offset;
+
+               range_size = (1 << (16 - range->min_proto.psid.psid_length)) -
+                               (1 << m);
+               off = ntohs(tuple->src.u.icmp.id);
+               for (i = 0; ; ++off) {
+                       /* Find the non-PSID parts of the Index.
+                        * To do this we look for an unused ID that is
+                        * comprised of [t_chunk|PSID|b_chunk]. The size of
+                        * these pieces is defined by the psid_length and
+                        * offset.
+                        */
+                       int b_chunk = (off % range_size) & ((1 << (m)) - 1);
+                       int t_chunk = (((off % range_size) >> m) &
+                                       ((1 << range->min_proto.psid.offset) -
+                                                       1)) <<
+                               (m + range->min_proto.psid.psid_length);
+                       /* Skip the all-zeroes reserved chunk */
+                       t_chunk += (1 << (16 - range->min_proto.psid.offset));
+
+                       tuple->src.u.icmp.id = htons(t_chunk |
+                                       (range->min_proto.psid.psid << m)
+                                       | b_chunk);
+
+                       if (++i == range_size || !nf_nat_used_tuple(tuple, ct))
+                               return;
+               }
+       }
+
        range_size = ntohs(range->max_proto.icmp.id) -
                     ntohs(range->min_proto.icmp.id) + 1;
        /* If no range specified... */
diff --git a/net/netfilter/nf_nat_core.c b/net/netfilter/nf_nat_core.c
index e2b196054dfc..18e39af3838d 100644
--- a/net/netfilter/nf_nat_core.c
+++ b/net/netfilter/nf_nat_core.c
@@ -187,9 +187,15 @@ static int in_range(const struct nf_nat_l3proto *l3proto,
            !l3proto->in_range(tuple, range))
                return 0;
 
+       /* If we are using PSID mode all protocols need to be checked
+        * to see that they fit inside the range.
+        */
+       if ((range->flags & NF_NAT_RANGE_PSID) &&
+           !l4proto->in_range(tuple, NF_NAT_MANIP_SRC, range))
+               return 0;
+
        if (!(range->flags & NF_NAT_RANGE_PROTO_SPECIFIED) ||
-           l4proto->in_range(tuple, NF_NAT_MANIP_SRC,
-                             &range->min_proto, &range->max_proto))
+           l4proto->in_range(tuple, NF_NAT_MANIP_SRC, range))
                return 1;
 
        return 0;
@@ -369,11 +375,18 @@ get_unique_tuple(struct nf_conntrack_tuple *tuple,
 
        /* Only bother mapping if it's not already in range and unique */
        if (!(range->flags & NF_NAT_RANGE_PROTO_RANDOM_ALL)) {
-               if (range->flags & NF_NAT_RANGE_PROTO_SPECIFIED) {
+               /* Now that the PSID mode is present we always need to check
+                * to see if the source ports are in range.
+                */
+               if (range->flags & NF_NAT_RANGE_PROTO_SPECIFIED ||
+                   (range->flags & NF_NAT_RANGE_PSID &&
+                    !in_range(l3proto, l4proto, orig_tuple, range))) {
+                       /* The in_range prototype has been changed to take a
+                        * whole range rather than min and max protocol
+                        * information.
+                        */
                        if (!(range->flags & NF_NAT_RANGE_PROTO_OFFSET) &&
-                           l4proto->in_range(tuple, maniptype,
-                                 &range->min_proto,
-                                 &range->max_proto) &&
+                           l4proto->in_range(tuple, maniptype, range) &&
                            (range->min_proto.all == range->max_proto.all ||
                             !nf_nat_used_tuple(tuple, ct)))
                                goto out;
diff --git a/net/netfilter/nf_nat_proto_common.c 
b/net/netfilter/nf_nat_proto_common.c
index 5d849d835561..4ca3b2715e7c 100644
--- a/net/netfilter/nf_nat_proto_common.c
+++ b/net/netfilter/nf_nat_proto_common.c
@@ -19,16 +19,30 @@
 
 bool nf_nat_l4proto_in_range(const struct nf_conntrack_tuple *tuple,
                             enum nf_nat_manip_type maniptype,
-                            const union nf_conntrack_man_proto *min,
-                            const union nf_conntrack_man_proto *max)
+                            const struct nf_nat_range2 *range)
 {
        __be16 port;
+       const union nf_conntrack_man_proto *min = &range->min_proto;
+       const union nf_conntrack_man_proto *max = &range->max_proto;
 
        if (maniptype == NF_NAT_MANIP_SRC)
                port = tuple->src.u.all;
        else
                port = tuple->dst.u.all;
 
+       /* In this case we are in PSID mode and the rules are all different */
+       if (range->flags & NF_NAT_RANGE_PSID) {
+               /* m = number of bits in each valid range */
+               int m = 16 - min->psid.psid_length - min->psid.offset;
+               u16 offset_mask = ((1 << min->psid.offset) - 1) <<
+                               (16 - min->psid.offset);
+               u16 psid_mask = ((1 << min->psid.psid_length) - 1) << m;
+
+               return ((ntohs(port) & offset_mask) != 0) &&
+                               (((ntohs(port) & psid_mask) >> m) ==
+                                  min->psid.psid);
+       }
+
        return ntohs(port) >= ntohs(min->all) &&
               ntohs(port) <= ntohs(max->all);
 }
@@ -46,9 +60,38 @@ void nf_nat_l4proto_unique_tuple(const struct nf_nat_l3proto 
*l3proto,
        u_int16_t off;
 
        if (maniptype == NF_NAT_MANIP_SRC)
-               portptr = &tuple->src.u.all;
+               portptr = &tuple->src.u.tcp.port;
        else
-               portptr = &tuple->dst.u.all;
+               portptr = &tuple->dst.u.tcp.port;
+
+       if (range->flags & NF_NAT_RANGE_PSID) {
+               /* Find the non-PSID parts of the port.
+                * To do this we look for an unused port that is
+                * comprised of [t_chunk|PSID|b_chunk]. The size of
+                * these pieces is defined by the psid_length and
+                * offset.
+                */
+               int m = 16 - range->min_proto.psid.psid_length -
+                               range->min_proto.psid.offset;
+               range_size = (1 << (16 - range->min_proto.psid.psid_length)) -
+                               (1 << m);
+               off = ntohs(*portptr);
+               for (i = 0; ; ++off) {
+                       int b_chunk = (off % range_size) & ((1 << (m)) - 1);
+                       int t_chunk = (((off % range_size) >> m) &
+                                       ((1 << range->min_proto.psid.offset) -
+                                                       1)) <<
+                               (m + range->min_proto.psid.psid_length);
+                       /* Skip the all-zeroes reserved chunk */
+                       t_chunk += (1 << (16 - range->min_proto.psid.offset));
+
+                       *portptr = htons(t_chunk |
+                                       (range->min_proto.psid.psid << m)
+                                       | b_chunk);
+                       if (++i == range_size || !nf_nat_used_tuple(tuple, ct))
+                               return;
+               }
+       }
 
        /* If no range specified... */
        if (!(range->flags & NF_NAT_RANGE_PROTO_SPECIFIED)) {
diff --git a/net/netfilter/nf_nat_proto_unknown.c 
b/net/netfilter/nf_nat_proto_unknown.c
index c5db3e251232..82140ffc706d 100644
--- a/net/netfilter/nf_nat_proto_unknown.c
+++ b/net/netfilter/nf_nat_proto_unknown.c
@@ -19,8 +19,7 @@
 
 static bool unknown_in_range(const struct nf_conntrack_tuple *tuple,
                             enum nf_nat_manip_type manip_type,
-                            const union nf_conntrack_man_proto *min,
-                            const union nf_conntrack_man_proto *max)
+                            const struct nf_nat_range2 *range)
 {
        return true;
 }
-- 
2.19.1

Reply via email to