Split iptables rules into append and prepend rules

This commit is contained in:
Alex Stockinger 2022-07-25 11:46:20 +02:00
parent 1921c6a212
commit c28a1a24d8
9 changed files with 144 additions and 56 deletions

View File

@ -96,8 +96,8 @@ func (f *cilium) Init(_ int) error {
}
// Rules is a no-op.
func (f *cilium) Rules(_ []*net.IPNet) []iptables.Rule {
return nil
func (f *cilium) Rules(_ []*net.IPNet) iptables.RuleSet {
return iptables.RuleSet{}
}
// Set is a no-op.

View File

@ -49,7 +49,7 @@ type Encapsulator interface {
Gw(net.IP, net.IP, *net.IPNet) net.IP
Index() int
Init(int) error
Rules([]*net.IPNet) []iptables.Rule
Rules([]*net.IPNet) iptables.RuleSet
Set(*net.IPNet) error
Strategy() Strategy
}

View File

@ -95,8 +95,8 @@ func (f *flannel) Init(_ int) error {
}
// Rules is a no-op.
func (f *flannel) Rules(_ []*net.IPNet) []iptables.Rule {
return nil
func (f *flannel) Rules(_ []*net.IPNet) iptables.RuleSet {
return iptables.RuleSet{}
}
// Set is a no-op.

View File

@ -65,20 +65,20 @@ func (i *ipip) Init(base int) error {
// Rules returns a set of iptables rules that are necessary
// when traffic between nodes must be encapsulated.
func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule {
var rules []iptables.Rule
func (i *ipip) Rules(nodes []*net.IPNet) iptables.RuleSet {
rules := iptables.RuleSet{}
proto := ipipProtocolName()
rules = append(rules, iptables.NewIPv4Chain("filter", "KILO-IPIP"))
rules = append(rules, iptables.NewIPv6Chain("filter", "KILO-IPIP"))
rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Chain("filter", "KILO-IPIP"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Chain("filter", "KILO-IPIP"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
for _, n := range nodes {
// Accept encapsulated traffic from peers.
rules = append(rules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
}
// Drop all other IPIP traffic.
rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
return rules
}

View File

@ -44,8 +44,8 @@ func (n Noop) Init(_ int) error {
}
// Rules will also do nothing.
func (n Noop) Rules(_ []*net.IPNet) []iptables.Rule {
return nil
func (n Noop) Rules(_ []*net.IPNet) iptables.RuleSet {
return iptables.RuleSet{}
}
// Set will also do nothing.

View File

@ -46,6 +46,20 @@ type fakeClient struct {
var _ Client = &fakeClient{}
func (f *fakeClient) Insert(table, chain string, pos int, spec ...string) error {
atomic.AddUint64(&f.calls, 1)
exists, err := f.Exists(table, chain, spec...)
if err != nil {
return err
}
if exists {
return nil
}
// FIXME obey pos!
f.storage = append([]Rule{&rule{table: table, chain: chain, spec: spec}}, f.storage...)
return nil
}
func (f *fakeClient) AppendUnique(table, chain string, spec ...string) error {
atomic.AddUint64(&f.calls, 1)
exists, err := f.Exists(table, chain, spec...)

View File

@ -64,6 +64,7 @@ func GetProtocol(ip net.IP) Protocol {
// Client represents any type that can administer iptables rules.
type Client interface {
AppendUnique(table string, chain string, rule ...string) error
Insert(table string, chain string, pos int, rule ...string) error
Delete(table string, chain string, rule ...string) error
Exists(table string, chain string, rule ...string) (bool, error)
List(table string, chain string) ([]string, error)
@ -75,7 +76,8 @@ type Client interface {
// Rule is an interface for interacting with iptables objects.
type Rule interface {
Add(Client) error
Append(Client) error
Prepend(Client) error
Delete(Client) error
Exists(Client) (bool, error)
String() string
@ -106,7 +108,14 @@ func NewIPv6Rule(table, chain string, spec ...string) Rule {
return &rule{table, chain, spec, ProtocolIPv6}
}
func (r *rule) Add(client Client) error {
func (r *rule) Prepend(client Client) error {
if err := client.Insert(r.table, r.chain, 1, r.spec...); err != nil {
return fmt.Errorf("failed to add iptables rule: %v", err)
}
return nil
}
func (r *rule) Append(client Client) error {
if err := client.AppendUnique(r.table, r.chain, r.spec...); err != nil {
return fmt.Errorf("failed to add iptables rule: %v", err)
}
@ -162,7 +171,11 @@ func NewIPv6Chain(table, name string) Rule {
return &chain{table, name, ProtocolIPv6}
}
func (c *chain) Add(client Client) error {
func (c *chain) Prepend(client Client) error {
return c.Append(client)
}
func (c *chain) Append(client Client) error {
// Note: `ClearChain` creates a chain if it does not exist.
if err := client.ClearChain(c.table, c.chain); err != nil {
return fmt.Errorf("failed to add iptables chain: %v", err)
@ -224,8 +237,9 @@ type Controller struct {
registerer prometheus.Registerer
sync.Mutex
rules []Rule
subscribed bool
appendRules []Rule
prependRules []Rule
subscribed bool
}
// ControllerOption modifies the controller's configuration.
@ -333,14 +347,14 @@ func (c *Controller) reconcile() error {
c.Lock()
defer c.Unlock()
var rc ruleCache
for i, r := range c.rules {
for i, r := range c.appendRules {
ok, err := rc.exists(c.client(r.Proto()), r)
if err != nil {
return fmt.Errorf("failed to check if rule exists: %v", err)
}
if !ok {
level.Info(c.logger).Log("msg", fmt.Sprintf("applying %d iptables rules", len(c.rules)-i))
if err := c.resetFromIndex(i, c.rules); err != nil {
level.Info(c.logger).Log("msg", fmt.Sprintf("applying %d iptables rules", len(c.appendRules)-i))
if err := c.resetFromIndex(i, c.appendRules); err != nil {
return fmt.Errorf("failed to add rule: %v", err)
}
break
@ -358,7 +372,7 @@ func (c *Controller) resetFromIndex(i int, rules []Rule) error {
if err := rules[j].Delete(c.client(rules[j].Proto())); err != nil {
return fmt.Errorf("failed to delete rule: %v", err)
}
if err := rules[j].Add(c.client(rules[j].Proto())); err != nil {
if err := rules[j].Append(c.client(rules[j].Proto())); err != nil {
return fmt.Errorf("failed to add rule: %v", err)
}
}
@ -383,34 +397,87 @@ func (c *Controller) deleteFromIndex(i int, rules *[]Rule) error {
// Set idempotently overwrites any iptables rules previously defined
// for the controller with the given set of rules.
func (c *Controller) Set(rules []Rule) error {
func (c *Controller) Set(rules RuleSet) error {
c.Lock()
defer c.Unlock()
if err := c.setAppendRules(rules.AppendRules); err != nil {
return err
}
return c.setPrependRules(rules.PrependRules)
}
func (c *Controller) setAppendRules(appendRules []Rule) error {
var i int
for ; i < len(rules); i++ {
if i < len(c.rules) {
if rules[i].String() != c.rules[i].String() {
if err := c.deleteFromIndex(i, &c.rules); err != nil {
for ; i < len(appendRules); i++ {
if i < len(c.appendRules) {
if appendRules[i].String() != c.appendRules[i].String() {
if err := c.deleteFromIndex(i, &c.appendRules); err != nil {
return err
}
}
}
if i >= len(c.rules) {
if err := rules[i].Add(c.client(rules[i].Proto())); err != nil {
if i >= len(c.appendRules) {
if err := appendRules[i].Append(c.client(appendRules[i].Proto())); err != nil {
return fmt.Errorf("failed to add rule: %v", err)
}
c.rules = append(c.rules, rules[i])
c.appendRules = append(c.appendRules, appendRules[i])
}
}
return c.deleteFromIndex(i, &c.rules)
err := c.deleteFromIndex(i, &c.appendRules)
if err != nil {
return fmt.Errorf("failed to delete rule: %v", err)
}
return nil
}
func (c *Controller) setPrependRules(prependRules []Rule) error {
for _, prependRule := range prependRules {
if !containsRule(c.prependRules, prependRule) {
if err := prependRule.Prepend(c.client(prependRule.Proto())); err != nil {
return fmt.Errorf("failed to add rule: %v", err)
}
c.prependRules = append(c.prependRules, prependRule)
}
}
for _, existingRule := range c.prependRules {
if !containsRule(prependRules, existingRule) {
if err := existingRule.Delete(c.client(existingRule.Proto())); err != nil {
return fmt.Errorf("failed to delete rule: %v", err)
}
c.prependRules = removeRule(c.prependRules, existingRule)
}
}
return nil
}
func removeRule(rules []Rule, toRemove Rule) []Rule {
ret := make([]Rule, 0, len(rules))
for _, rule := range rules {
if rule.String() != toRemove.String() {
ret = append(ret, rule)
}
}
return ret
}
func containsRule(haystack []Rule, needle Rule) bool {
for _, element := range haystack {
if element.String() == needle.String() {
return true
}
}
return false
}
// CleanUp will clean up any rules created by the controller.
func (c *Controller) CleanUp() error {
c.Lock()
defer c.Unlock()
return c.deleteFromIndex(0, &c.rules)
err := c.deleteFromIndex(0, &c.prependRules)
if err != nil {
return err
}
return c.deleteFromIndex(0, &c.appendRules)
}
func (c *Controller) client(p Protocol) Client {
@ -430,3 +497,8 @@ func nonBlockingSend(errors chan<- error, err error) {
default:
}
}
type RuleSet struct {
AppendRules []Rule // Rules to append to the chain - order matters.
PrependRules []Rule // Rules to prepend to the chain - order does not matter.
}

View File

@ -525,7 +525,9 @@ func (m *Mesh) applyTopology() {
}
}
ipRules = append(m.enc.Rules(cidrs), ipRules...)
encIpRules := m.enc.Rules(cidrs)
ipRules.AppendRules = append(encIpRules.AppendRules, ipRules.AppendRules...)
ipRules.PrependRules = append(encIpRules.PrependRules, ipRules.PrependRules...)
// If we are handling local routes, ensure the local
// tunnel has an IP address.

View File

@ -311,12 +311,12 @@ func encapsulateRoute(route *netlink.Route, encapsulate encapsulation.Strategy,
}
// Rules returns the iptables rules required by the local node.
func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
var rules []iptables.Rule
rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
func (t *Topology) Rules(cni, iptablesForwardRule bool) iptables.RuleSet {
rules := iptables.RuleSet{}
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
if cni {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
// Some linux distros or docker will set forward DROP in the filter table.
// To still be able to have pod to pod communication we need to ALLOW packets from and to pod CIDRs within a location.
// Leader nodes will forward packets from all nodes within a location because they act as a gateway for them.
@ -326,37 +326,37 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
if s.location == t.location {
// Make sure packets to and from pod cidrs are not dropped in the forward chain.
for _, c := range s.cidrs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
}
// Make sure packets to and from allowed location IPs are not dropped in the forward chain.
for _, c := range s.allowedLocationIPs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
}
// Make sure packets to and from private IPs are not dropped in the forward chain.
for _, c := range s.privateIPs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
}
}
}
} else if iptablesForwardRule {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
}
}
for _, s := range t.segments {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
for _, aip := range s.allowedIPs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
}
// Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd,
// Otherwise packets to these destinations will reach the destination, but never find their way back.
// We only want to NAT in locations of the corresponding allowed location IPs.
if t.location == s.location {
for _, alip := range s.allowedLocationIPs {
rules = append(rules,
rules.PrependRules = append(rules.PrependRules,
iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
)
}
@ -364,14 +364,14 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
}
for _, p := range t.peers {
for _, aip := range p.AllowedIPs {
rules = append(rules,
rules.PrependRules = append(rules.PrependRules,
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
)
}
}
rules = append(rules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
rules = append(rules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
return rules
}