diff --git a/pkg/encapsulation/cilium.go b/pkg/encapsulation/cilium.go index 9e34292..bfbb327 100644 --- a/pkg/encapsulation/cilium.go +++ b/pkg/encapsulation/cilium.go @@ -96,8 +96,8 @@ func (f *cilium) Init(_ int) error { } // Rules is a no-op. -func (f *cilium) Rules(_ []*net.IPNet) []iptables.Rule { - return nil +func (f *cilium) Rules(_ []*net.IPNet) iptables.RuleSet { + return iptables.RuleSet{} } // Set is a no-op. diff --git a/pkg/encapsulation/encapsulation.go b/pkg/encapsulation/encapsulation.go index 21e698a..77fab57 100644 --- a/pkg/encapsulation/encapsulation.go +++ b/pkg/encapsulation/encapsulation.go @@ -49,7 +49,7 @@ type Encapsulator interface { Gw(net.IP, net.IP, *net.IPNet) net.IP Index() int Init(int) error - Rules([]*net.IPNet) []iptables.Rule + Rules([]*net.IPNet) iptables.RuleSet Set(*net.IPNet) error Strategy() Strategy } diff --git a/pkg/encapsulation/flannel.go b/pkg/encapsulation/flannel.go index e08af61..9375b8f 100644 --- a/pkg/encapsulation/flannel.go +++ b/pkg/encapsulation/flannel.go @@ -95,8 +95,8 @@ func (f *flannel) Init(_ int) error { } // Rules is a no-op. -func (f *flannel) Rules(_ []*net.IPNet) []iptables.Rule { - return nil +func (f *flannel) Rules(_ []*net.IPNet) iptables.RuleSet { + return iptables.RuleSet{} } // Set is a no-op. diff --git a/pkg/encapsulation/ipip.go b/pkg/encapsulation/ipip.go index d92b39f..d6d2632 100644 --- a/pkg/encapsulation/ipip.go +++ b/pkg/encapsulation/ipip.go @@ -65,20 +65,20 @@ func (i *ipip) Init(base int) error { // Rules returns a set of iptables rules that are necessary // when traffic between nodes must be encapsulated. -func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule { - var rules []iptables.Rule +func (i *ipip) Rules(nodes []*net.IPNet) iptables.RuleSet { + rules := iptables.RuleSet{} proto := ipipProtocolName() - rules = append(rules, iptables.NewIPv4Chain("filter", "KILO-IPIP")) - rules = append(rules, iptables.NewIPv6Chain("filter", "KILO-IPIP")) - rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) - rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) + rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Chain("filter", "KILO-IPIP")) + rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Chain("filter", "KILO-IPIP")) + rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) + rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) for _, n := range nodes { // Accept encapsulated traffic from peers. - rules = append(rules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT")) } // Drop all other IPIP traffic. - rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) - rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) + rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) + rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) return rules } diff --git a/pkg/encapsulation/noop.go b/pkg/encapsulation/noop.go index d5b3906..ad9818d 100644 --- a/pkg/encapsulation/noop.go +++ b/pkg/encapsulation/noop.go @@ -44,8 +44,8 @@ func (n Noop) Init(_ int) error { } // Rules will also do nothing. -func (n Noop) Rules(_ []*net.IPNet) []iptables.Rule { - return nil +func (n Noop) Rules(_ []*net.IPNet) iptables.RuleSet { + return iptables.RuleSet{} } // Set will also do nothing. diff --git a/pkg/iptables/fake.go b/pkg/iptables/fake.go index 24c97dd..2a51169 100644 --- a/pkg/iptables/fake.go +++ b/pkg/iptables/fake.go @@ -46,6 +46,20 @@ type fakeClient struct { var _ Client = &fakeClient{} +func (f *fakeClient) Insert(table, chain string, pos int, spec ...string) error { + atomic.AddUint64(&f.calls, 1) + exists, err := f.Exists(table, chain, spec...) + if err != nil { + return err + } + if exists { + return nil + } + // FIXME obey pos! + f.storage = append([]Rule{&rule{table: table, chain: chain, spec: spec}}, f.storage...) + return nil +} + func (f *fakeClient) AppendUnique(table, chain string, spec ...string) error { atomic.AddUint64(&f.calls, 1) exists, err := f.Exists(table, chain, spec...) diff --git a/pkg/iptables/iptables.go b/pkg/iptables/iptables.go index 4cad47c..faeaedc 100644 --- a/pkg/iptables/iptables.go +++ b/pkg/iptables/iptables.go @@ -64,6 +64,7 @@ func GetProtocol(ip net.IP) Protocol { // Client represents any type that can administer iptables rules. type Client interface { AppendUnique(table string, chain string, rule ...string) error + Insert(table string, chain string, pos int, rule ...string) error Delete(table string, chain string, rule ...string) error Exists(table string, chain string, rule ...string) (bool, error) List(table string, chain string) ([]string, error) @@ -75,7 +76,8 @@ type Client interface { // Rule is an interface for interacting with iptables objects. type Rule interface { - Add(Client) error + Append(Client) error + Prepend(Client) error Delete(Client) error Exists(Client) (bool, error) String() string @@ -106,7 +108,14 @@ func NewIPv6Rule(table, chain string, spec ...string) Rule { return &rule{table, chain, spec, ProtocolIPv6} } -func (r *rule) Add(client Client) error { +func (r *rule) Prepend(client Client) error { + if err := client.Insert(r.table, r.chain, 1, r.spec...); err != nil { + return fmt.Errorf("failed to add iptables rule: %v", err) + } + return nil +} + +func (r *rule) Append(client Client) error { if err := client.AppendUnique(r.table, r.chain, r.spec...); err != nil { return fmt.Errorf("failed to add iptables rule: %v", err) } @@ -162,7 +171,11 @@ func NewIPv6Chain(table, name string) Rule { return &chain{table, name, ProtocolIPv6} } -func (c *chain) Add(client Client) error { +func (c *chain) Prepend(client Client) error { + return c.Append(client) +} + +func (c *chain) Append(client Client) error { // Note: `ClearChain` creates a chain if it does not exist. if err := client.ClearChain(c.table, c.chain); err != nil { return fmt.Errorf("failed to add iptables chain: %v", err) @@ -224,8 +237,9 @@ type Controller struct { registerer prometheus.Registerer sync.Mutex - rules []Rule - subscribed bool + appendRules []Rule + prependRules []Rule + subscribed bool } // ControllerOption modifies the controller's configuration. @@ -333,14 +347,14 @@ func (c *Controller) reconcile() error { c.Lock() defer c.Unlock() var rc ruleCache - for i, r := range c.rules { + for i, r := range c.appendRules { ok, err := rc.exists(c.client(r.Proto()), r) if err != nil { return fmt.Errorf("failed to check if rule exists: %v", err) } if !ok { - level.Info(c.logger).Log("msg", fmt.Sprintf("applying %d iptables rules", len(c.rules)-i)) - if err := c.resetFromIndex(i, c.rules); err != nil { + level.Info(c.logger).Log("msg", fmt.Sprintf("applying %d iptables rules", len(c.appendRules)-i)) + if err := c.resetFromIndex(i, c.appendRules); err != nil { return fmt.Errorf("failed to add rule: %v", err) } break @@ -358,7 +372,7 @@ func (c *Controller) resetFromIndex(i int, rules []Rule) error { if err := rules[j].Delete(c.client(rules[j].Proto())); err != nil { return fmt.Errorf("failed to delete rule: %v", err) } - if err := rules[j].Add(c.client(rules[j].Proto())); err != nil { + if err := rules[j].Append(c.client(rules[j].Proto())); err != nil { return fmt.Errorf("failed to add rule: %v", err) } } @@ -383,34 +397,87 @@ func (c *Controller) deleteFromIndex(i int, rules *[]Rule) error { // Set idempotently overwrites any iptables rules previously defined // for the controller with the given set of rules. -func (c *Controller) Set(rules []Rule) error { +func (c *Controller) Set(rules RuleSet) error { c.Lock() defer c.Unlock() + if err := c.setAppendRules(rules.AppendRules); err != nil { + return err + } + return c.setPrependRules(rules.PrependRules) +} + +func (c *Controller) setAppendRules(appendRules []Rule) error { var i int - for ; i < len(rules); i++ { - if i < len(c.rules) { - if rules[i].String() != c.rules[i].String() { - if err := c.deleteFromIndex(i, &c.rules); err != nil { + for ; i < len(appendRules); i++ { + if i < len(c.appendRules) { + if appendRules[i].String() != c.appendRules[i].String() { + if err := c.deleteFromIndex(i, &c.appendRules); err != nil { return err } } } - if i >= len(c.rules) { - if err := rules[i].Add(c.client(rules[i].Proto())); err != nil { + if i >= len(c.appendRules) { + if err := appendRules[i].Append(c.client(appendRules[i].Proto())); err != nil { return fmt.Errorf("failed to add rule: %v", err) } - c.rules = append(c.rules, rules[i]) + c.appendRules = append(c.appendRules, appendRules[i]) } - } - return c.deleteFromIndex(i, &c.rules) + err := c.deleteFromIndex(i, &c.appendRules) + if err != nil { + return fmt.Errorf("failed to delete rule: %v", err) + } + return nil +} + +func (c *Controller) setPrependRules(prependRules []Rule) error { + for _, prependRule := range prependRules { + if !containsRule(c.prependRules, prependRule) { + if err := prependRule.Prepend(c.client(prependRule.Proto())); err != nil { + return fmt.Errorf("failed to add rule: %v", err) + } + c.prependRules = append(c.prependRules, prependRule) + } + } + for _, existingRule := range c.prependRules { + if !containsRule(prependRules, existingRule) { + if err := existingRule.Delete(c.client(existingRule.Proto())); err != nil { + return fmt.Errorf("failed to delete rule: %v", err) + } + c.prependRules = removeRule(c.prependRules, existingRule) + } + } + return nil +} + +func removeRule(rules []Rule, toRemove Rule) []Rule { + ret := make([]Rule, 0, len(rules)) + for _, rule := range rules { + if rule.String() != toRemove.String() { + ret = append(ret, rule) + } + } + return ret +} + +func containsRule(haystack []Rule, needle Rule) bool { + for _, element := range haystack { + if element.String() == needle.String() { + return true + } + } + return false } // CleanUp will clean up any rules created by the controller. func (c *Controller) CleanUp() error { c.Lock() defer c.Unlock() - return c.deleteFromIndex(0, &c.rules) + err := c.deleteFromIndex(0, &c.prependRules) + if err != nil { + return err + } + return c.deleteFromIndex(0, &c.appendRules) } func (c *Controller) client(p Protocol) Client { @@ -430,3 +497,8 @@ func nonBlockingSend(errors chan<- error, err error) { default: } } + +type RuleSet struct { + AppendRules []Rule // Rules to append to the chain - order matters. + PrependRules []Rule // Rules to prepend to the chain - order does not matter. +} diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index f1091dc..d44a5f7 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -525,7 +525,9 @@ func (m *Mesh) applyTopology() { } } - ipRules = append(m.enc.Rules(cidrs), ipRules...) + encIpRules := m.enc.Rules(cidrs) + ipRules.AppendRules = append(encIpRules.AppendRules, ipRules.AppendRules...) + ipRules.PrependRules = append(encIpRules.PrependRules, ipRules.PrependRules...) // If we are handling local routes, ensure the local // tunnel has an IP address. diff --git a/pkg/mesh/routes.go b/pkg/mesh/routes.go index 2c0cf24..2567a72 100644 --- a/pkg/mesh/routes.go +++ b/pkg/mesh/routes.go @@ -311,12 +311,12 @@ func encapsulateRoute(route *netlink.Route, encapsulate encapsulation.Strategy, } // Rules returns the iptables rules required by the local node. -func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule { - var rules []iptables.Rule - rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT")) - rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT")) +func (t *Topology) Rules(cni, iptablesForwardRule bool) iptables.RuleSet { + rules := iptables.RuleSet{} + rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Chain("nat", "KILO-NAT")) + rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Chain("nat", "KILO-NAT")) if cni { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT")) // Some linux distros or docker will set forward DROP in the filter table. // To still be able to have pod to pod communication we need to ALLOW packets from and to pod CIDRs within a location. // Leader nodes will forward packets from all nodes within a location because they act as a gateway for them. @@ -326,37 +326,37 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule { if s.location == t.location { // Make sure packets to and from pod cidrs are not dropped in the forward chain. for _, c := range s.cidrs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT")) } // Make sure packets to and from allowed location IPs are not dropped in the forward chain. for _, c := range s.allowedLocationIPs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT")) } // Make sure packets to and from private IPs are not dropped in the forward chain. for _, c := range s.privateIPs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT")) } } } } else if iptablesForwardRule { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT")) } } for _, s := range t.segments { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN")) for _, aip := range s.allowedIPs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN")) + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN")) } // Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd, // Otherwise packets to these destinations will reach the destination, but never find their way back. // We only want to NAT in locations of the corresponding allowed location IPs. if t.location == s.location { for _, alip := range s.allowedLocationIPs { - rules = append(rules, + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"), ) } @@ -364,14 +364,14 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule { } for _, p := range t.peers { for _, aip := range p.AllowedIPs { - rules = append(rules, + rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"), iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"), ) } } - rules = append(rules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) - rules = append(rules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) + rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) + rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) return rules }