From 46cdd6c60cbffd725f6da3f364e46c00b3bdb21f Mon Sep 17 00:00:00 2001 From: Alex Stockinger Date: Tue, 26 Jul 2022 13:45:57 +0000 Subject: [PATCH] Make usage of `RuleSet` prettier --- pkg/encapsulation/ipip.go | 14 ++++++------ pkg/iptables/iptables.go | 29 ++++++++++++++++++------ pkg/iptables/iptables_test.go | 20 ++++++++--------- pkg/iptables/rulecache_test.go | 2 +- pkg/mesh/mesh.go | 3 +-- pkg/mesh/routes.go | 40 +++++++++++++++------------------- 6 files changed, 59 insertions(+), 49 deletions(-) diff --git a/pkg/encapsulation/ipip.go b/pkg/encapsulation/ipip.go index d6d2632..5452783 100644 --- a/pkg/encapsulation/ipip.go +++ b/pkg/encapsulation/ipip.go @@ -68,17 +68,17 @@ func (i *ipip) Init(base int) error { func (i *ipip) Rules(nodes []*net.IPNet) iptables.RuleSet { rules := iptables.RuleSet{} proto := ipipProtocolName() - rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Chain("filter", "KILO-IPIP")) - rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Chain("filter", "KILO-IPIP")) - rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) - rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) + rules.AddToAppend(iptables.NewIPv4Chain("filter", "KILO-IPIP")) + rules.AddToAppend(iptables.NewIPv6Chain("filter", "KILO-IPIP")) + rules.AddToAppend(iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) + rules.AddToAppend(iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) for _, n := range nodes { // Accept encapsulated traffic from peers. - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT")) } // Drop all other IPIP traffic. - rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) - rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) + rules.AddToAppend(iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) + rules.AddToAppend(iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) return rules } diff --git a/pkg/iptables/iptables.go b/pkg/iptables/iptables.go index e624fef..3ec6543 100644 --- a/pkg/iptables/iptables.go +++ b/pkg/iptables/iptables.go @@ -46,6 +46,11 @@ func ipv6Disabled() (bool, error) { // Protocol represents an IP protocol. type Protocol byte +type RuleSet struct { + appendRules []Rule // Rules to append to the chain - order matters. + prependRules []Rule // Rules to prepend to the chain - order does not matter. +} + const ( // ProtocolIPv4 represents the IPv4 protocol. ProtocolIPv4 Protocol = iota @@ -53,6 +58,21 @@ const ( ProtocolIPv6 ) +func (rs *RuleSet) AddToAppend(rule Rule) { + rs.appendRules = append(rs.appendRules, rule) +} + +func (rs *RuleSet) AddToPrepend(rule Rule) { + rs.prependRules = append(rs.prependRules, rule) +} + +func (rs *RuleSet) AppendRuleSet(other RuleSet) RuleSet { + return RuleSet{ + appendRules: append(rs.appendRules, other.appendRules...), + prependRules: append(rs.prependRules, other.prependRules...), + } +} + // GetProtocol will return a protocol from the length of an IP address. func GetProtocol(ip net.IP) Protocol { if len(ip) == net.IPv4len || ip.To4() != nil { @@ -423,10 +443,10 @@ func (c *Controller) deleteFromIndex(i int, rules *[]Rule) error { func (c *Controller) Set(rules RuleSet) error { c.Lock() defer c.Unlock() - if err := c.setAppendRules(rules.AppendRules); err != nil { + if err := c.setAppendRules(rules.appendRules); err != nil { return err } - return c.setPrependRules(rules.PrependRules) + return c.setPrependRules(rules.prependRules) } func (c *Controller) setAppendRules(appendRules []Rule) error { @@ -520,8 +540,3 @@ func nonBlockingSend(errors chan<- error, err error) { default: } } - -type RuleSet struct { - AppendRules []Rule // Rules to append to the chain - order matters. - PrependRules []Rule // Rules to prepend to the chain - order does not matter. -} diff --git a/pkg/iptables/iptables_test.go b/pkg/iptables/iptables_test.go index 4e9deaa..447eb82 100644 --- a/pkg/iptables/iptables_test.go +++ b/pkg/iptables/iptables_test.go @@ -43,7 +43,7 @@ func TestSet(t *testing.T) { { name: "single", sets: []RuleSet{ - {AppendRules: []Rule{appendRules[0]}}, + {appendRules: []Rule{appendRules[0]}}, }, appendOut: []Rule{appendRules[0]}, storageOut: []Rule{appendRules[0]}, @@ -51,7 +51,7 @@ func TestSet(t *testing.T) { { name: "two rules", sets: []RuleSet{ - {AppendRules: []Rule{appendRules[0], appendRules[1]}}, + {appendRules: []Rule{appendRules[0], appendRules[1]}}, }, appendOut: []Rule{appendRules[0], appendRules[1]}, storageOut: []Rule{appendRules[0], appendRules[1]}, @@ -59,8 +59,8 @@ func TestSet(t *testing.T) { { name: "multiple", sets: []RuleSet{ - {AppendRules: []Rule{appendRules[0], appendRules[1]}}, - {AppendRules: []Rule{appendRules[1]}}, + {appendRules: []Rule{appendRules[0], appendRules[1]}}, + {appendRules: []Rule{appendRules[1]}}, }, appendOut: []Rule{appendRules[1]}, storageOut: []Rule{appendRules[1]}, @@ -68,7 +68,7 @@ func TestSet(t *testing.T) { { name: "re-add", sets: []RuleSet{ - {AppendRules: []Rule{appendRules[0], appendRules[1]}}, + {appendRules: []Rule{appendRules[0], appendRules[1]}}, }, appendOut: []Rule{appendRules[0], appendRules[1]}, storageOut: []Rule{appendRules[0], appendRules[1]}, @@ -84,7 +84,7 @@ func TestSet(t *testing.T) { { name: "order", sets: []RuleSet{ - {AppendRules: []Rule{appendRules[0], appendRules[1]}}, + {appendRules: []Rule{appendRules[0], appendRules[1]}}, }, appendOut: []Rule{appendRules[0], appendRules[1]}, storageOut: []Rule{appendRules[0], appendRules[1]}, @@ -98,8 +98,8 @@ func TestSet(t *testing.T) { name: "append and prepend", sets: []RuleSet{ { - PrependRules: []Rule{prependRules[0], prependRules[1]}, - AppendRules: []Rule{appendRules[0], appendRules[1]}, + prependRules: []Rule{prependRules[0], prependRules[1]}, + appendRules: []Rule{appendRules[0], appendRules[1]}, }, }, appendOut: []Rule{appendRules[0], appendRules[1]}, @@ -184,12 +184,12 @@ func TestCleanUp(t *testing.T) { if err != nil { t.Fatalf("test case %q: got unexpected error instantiating controller: %v", tc.name, err) } - ruleSet := RuleSet{AppendRules: tc.appendRules, PrependRules: tc.prependRules} + ruleSet := RuleSet{appendRules: tc.appendRules, prependRules: tc.prependRules} if err := controller.Set(ruleSet); err != nil { t.Fatalf("test case %q: Set should not fail: %v", tc.name, err) } if len(client.storage) != len(tc.appendRules)+len(tc.prependRules) { - t.Errorf("test case %q: expected %d rules in storage, got %d rules", tc.name, len(ruleSet.AppendRules)+len(ruleSet.PrependRules), len(client.storage)) + t.Errorf("test case %q: expected %d rules in storage, got %d rules", tc.name, len(ruleSet.appendRules)+len(ruleSet.prependRules), len(client.storage)) } if err := controller.CleanUp(); err != nil { t.Errorf("test case %q: got unexpected error: %v", tc.name, err) diff --git a/pkg/iptables/rulecache_test.go b/pkg/iptables/rulecache_test.go index 7ee8e1e..3561f27 100644 --- a/pkg/iptables/rulecache_test.go +++ b/pkg/iptables/rulecache_test.go @@ -101,7 +101,7 @@ func TestRuleCache(t *testing.T) { client := &fakeClient{} controller.v4 = client controller.v6 = client - ruleSet := RuleSet{AppendRules: tc.rules} + ruleSet := RuleSet{appendRules: tc.rules} if err := controller.Set(ruleSet); err != nil { t.Fatalf("test case %q: Set should not fail: %v", tc.name, err) } diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index d44a5f7..250d0c7 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -526,8 +526,7 @@ func (m *Mesh) applyTopology() { } encIpRules := m.enc.Rules(cidrs) - ipRules.AppendRules = append(encIpRules.AppendRules, ipRules.AppendRules...) - ipRules.PrependRules = append(encIpRules.PrependRules, ipRules.PrependRules...) + ipRules = encIpRules.AppendRuleSet(ipRules) // If we are handling local routes, ensure the local // tunnel has an IP address. diff --git a/pkg/mesh/routes.go b/pkg/mesh/routes.go index 2567a72..b90d510 100644 --- a/pkg/mesh/routes.go +++ b/pkg/mesh/routes.go @@ -313,10 +313,10 @@ func encapsulateRoute(route *netlink.Route, encapsulate encapsulation.Strategy, // Rules returns the iptables rules required by the local node. func (t *Topology) Rules(cni, iptablesForwardRule bool) iptables.RuleSet { rules := iptables.RuleSet{} - rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Chain("nat", "KILO-NAT")) - rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Chain("nat", "KILO-NAT")) + rules.AddToAppend(iptables.NewIPv4Chain("nat", "KILO-NAT")) + rules.AddToAppend(iptables.NewIPv6Chain("nat", "KILO-NAT")) if cni { - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT")) // Some linux distros or docker will set forward DROP in the filter table. // To still be able to have pod to pod communication we need to ALLOW packets from and to pod CIDRs within a location. // Leader nodes will forward packets from all nodes within a location because they act as a gateway for them. @@ -326,52 +326,48 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) iptables.RuleSet { if s.location == t.location { // Make sure packets to and from pod cidrs are not dropped in the forward chain. for _, c := range s.cidrs { - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT")) - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT")) } // Make sure packets to and from allowed location IPs are not dropped in the forward chain. for _, c := range s.allowedLocationIPs { - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT")) - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT")) } // Make sure packets to and from private IPs are not dropped in the forward chain. for _, c := range s.privateIPs { - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT")) - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT")) } } } } else if iptablesForwardRule { - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT")) - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT")) } } for _, s := range t.segments { - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN")) for _, aip := range s.allowedIPs { - rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN")) } // Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd, // Otherwise packets to these destinations will reach the destination, but never find their way back. // We only want to NAT in locations of the corresponding allowed location IPs. if t.location == s.location { for _, alip := range s.allowedLocationIPs { - rules.PrependRules = append(rules.PrependRules, - iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"), - ) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT")) } } } for _, p := range t.peers { for _, aip := range p.AllowedIPs { - rules.PrependRules = append(rules.PrependRules, - iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"), - iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"), - ) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN")) } } - rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) - rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) + rules.AddToAppend(iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) + rules.AddToAppend(iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) return rules }