Make usage of RuleSet prettier

This commit is contained in:
Alex Stockinger 2022-07-26 13:45:57 +00:00
parent 378dafffe8
commit 46cdd6c60c
6 changed files with 59 additions and 49 deletions

View File

@ -68,17 +68,17 @@ func (i *ipip) Init(base int) error {
func (i *ipip) Rules(nodes []*net.IPNet) iptables.RuleSet {
rules := iptables.RuleSet{}
proto := ipipProtocolName()
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Chain("filter", "KILO-IPIP"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Chain("filter", "KILO-IPIP"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
rules.AddToAppend(iptables.NewIPv4Chain("filter", "KILO-IPIP"))
rules.AddToAppend(iptables.NewIPv6Chain("filter", "KILO-IPIP"))
rules.AddToAppend(iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
rules.AddToAppend(iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
for _, n := range nodes {
// Accept encapsulated traffic from peers.
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
}
// Drop all other IPIP traffic.
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
rules.AddToAppend(iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
rules.AddToAppend(iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
return rules
}

View File

@ -46,6 +46,11 @@ func ipv6Disabled() (bool, error) {
// Protocol represents an IP protocol.
type Protocol byte
type RuleSet struct {
appendRules []Rule // Rules to append to the chain - order matters.
prependRules []Rule // Rules to prepend to the chain - order does not matter.
}
const (
// ProtocolIPv4 represents the IPv4 protocol.
ProtocolIPv4 Protocol = iota
@ -53,6 +58,21 @@ const (
ProtocolIPv6
)
func (rs *RuleSet) AddToAppend(rule Rule) {
rs.appendRules = append(rs.appendRules, rule)
}
func (rs *RuleSet) AddToPrepend(rule Rule) {
rs.prependRules = append(rs.prependRules, rule)
}
func (rs *RuleSet) AppendRuleSet(other RuleSet) RuleSet {
return RuleSet{
appendRules: append(rs.appendRules, other.appendRules...),
prependRules: append(rs.prependRules, other.prependRules...),
}
}
// GetProtocol will return a protocol from the length of an IP address.
func GetProtocol(ip net.IP) Protocol {
if len(ip) == net.IPv4len || ip.To4() != nil {
@ -423,10 +443,10 @@ func (c *Controller) deleteFromIndex(i int, rules *[]Rule) error {
func (c *Controller) Set(rules RuleSet) error {
c.Lock()
defer c.Unlock()
if err := c.setAppendRules(rules.AppendRules); err != nil {
if err := c.setAppendRules(rules.appendRules); err != nil {
return err
}
return c.setPrependRules(rules.PrependRules)
return c.setPrependRules(rules.prependRules)
}
func (c *Controller) setAppendRules(appendRules []Rule) error {
@ -520,8 +540,3 @@ func nonBlockingSend(errors chan<- error, err error) {
default:
}
}
type RuleSet struct {
AppendRules []Rule // Rules to append to the chain - order matters.
PrependRules []Rule // Rules to prepend to the chain - order does not matter.
}

View File

@ -43,7 +43,7 @@ func TestSet(t *testing.T) {
{
name: "single",
sets: []RuleSet{
{AppendRules: []Rule{appendRules[0]}},
{appendRules: []Rule{appendRules[0]}},
},
appendOut: []Rule{appendRules[0]},
storageOut: []Rule{appendRules[0]},
@ -51,7 +51,7 @@ func TestSet(t *testing.T) {
{
name: "two rules",
sets: []RuleSet{
{AppendRules: []Rule{appendRules[0], appendRules[1]}},
{appendRules: []Rule{appendRules[0], appendRules[1]}},
},
appendOut: []Rule{appendRules[0], appendRules[1]},
storageOut: []Rule{appendRules[0], appendRules[1]},
@ -59,8 +59,8 @@ func TestSet(t *testing.T) {
{
name: "multiple",
sets: []RuleSet{
{AppendRules: []Rule{appendRules[0], appendRules[1]}},
{AppendRules: []Rule{appendRules[1]}},
{appendRules: []Rule{appendRules[0], appendRules[1]}},
{appendRules: []Rule{appendRules[1]}},
},
appendOut: []Rule{appendRules[1]},
storageOut: []Rule{appendRules[1]},
@ -68,7 +68,7 @@ func TestSet(t *testing.T) {
{
name: "re-add",
sets: []RuleSet{
{AppendRules: []Rule{appendRules[0], appendRules[1]}},
{appendRules: []Rule{appendRules[0], appendRules[1]}},
},
appendOut: []Rule{appendRules[0], appendRules[1]},
storageOut: []Rule{appendRules[0], appendRules[1]},
@ -84,7 +84,7 @@ func TestSet(t *testing.T) {
{
name: "order",
sets: []RuleSet{
{AppendRules: []Rule{appendRules[0], appendRules[1]}},
{appendRules: []Rule{appendRules[0], appendRules[1]}},
},
appendOut: []Rule{appendRules[0], appendRules[1]},
storageOut: []Rule{appendRules[0], appendRules[1]},
@ -98,8 +98,8 @@ func TestSet(t *testing.T) {
name: "append and prepend",
sets: []RuleSet{
{
PrependRules: []Rule{prependRules[0], prependRules[1]},
AppendRules: []Rule{appendRules[0], appendRules[1]},
prependRules: []Rule{prependRules[0], prependRules[1]},
appendRules: []Rule{appendRules[0], appendRules[1]},
},
},
appendOut: []Rule{appendRules[0], appendRules[1]},
@ -184,12 +184,12 @@ func TestCleanUp(t *testing.T) {
if err != nil {
t.Fatalf("test case %q: got unexpected error instantiating controller: %v", tc.name, err)
}
ruleSet := RuleSet{AppendRules: tc.appendRules, PrependRules: tc.prependRules}
ruleSet := RuleSet{appendRules: tc.appendRules, prependRules: tc.prependRules}
if err := controller.Set(ruleSet); err != nil {
t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
}
if len(client.storage) != len(tc.appendRules)+len(tc.prependRules) {
t.Errorf("test case %q: expected %d rules in storage, got %d rules", tc.name, len(ruleSet.AppendRules)+len(ruleSet.PrependRules), len(client.storage))
t.Errorf("test case %q: expected %d rules in storage, got %d rules", tc.name, len(ruleSet.appendRules)+len(ruleSet.prependRules), len(client.storage))
}
if err := controller.CleanUp(); err != nil {
t.Errorf("test case %q: got unexpected error: %v", tc.name, err)

View File

@ -101,7 +101,7 @@ func TestRuleCache(t *testing.T) {
client := &fakeClient{}
controller.v4 = client
controller.v6 = client
ruleSet := RuleSet{AppendRules: tc.rules}
ruleSet := RuleSet{appendRules: tc.rules}
if err := controller.Set(ruleSet); err != nil {
t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
}

View File

@ -526,8 +526,7 @@ func (m *Mesh) applyTopology() {
}
encIpRules := m.enc.Rules(cidrs)
ipRules.AppendRules = append(encIpRules.AppendRules, ipRules.AppendRules...)
ipRules.PrependRules = append(encIpRules.PrependRules, ipRules.PrependRules...)
ipRules = encIpRules.AppendRuleSet(ipRules)
// If we are handling local routes, ensure the local
// tunnel has an IP address.

View File

@ -313,10 +313,10 @@ func encapsulateRoute(route *netlink.Route, encapsulate encapsulation.Strategy,
// Rules returns the iptables rules required by the local node.
func (t *Topology) Rules(cni, iptablesForwardRule bool) iptables.RuleSet {
rules := iptables.RuleSet{}
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
rules.AddToAppend(iptables.NewIPv4Chain("nat", "KILO-NAT"))
rules.AddToAppend(iptables.NewIPv6Chain("nat", "KILO-NAT"))
if cni {
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
// Some linux distros or docker will set forward DROP in the filter table.
// To still be able to have pod to pod communication we need to ALLOW packets from and to pod CIDRs within a location.
// Leader nodes will forward packets from all nodes within a location because they act as a gateway for them.
@ -326,52 +326,48 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) iptables.RuleSet {
if s.location == t.location {
// Make sure packets to and from pod cidrs are not dropped in the forward chain.
for _, c := range s.cidrs {
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
}
// Make sure packets to and from allowed location IPs are not dropped in the forward chain.
for _, c := range s.allowedLocationIPs {
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
}
// Make sure packets to and from private IPs are not dropped in the forward chain.
for _, c := range s.privateIPs {
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
}
}
}
} else if iptablesForwardRule {
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
}
}
for _, s := range t.segments {
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
for _, aip := range s.allowedIPs {
rules.PrependRules = append(rules.PrependRules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
}
// Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd,
// Otherwise packets to these destinations will reach the destination, but never find their way back.
// We only want to NAT in locations of the corresponding allowed location IPs.
if t.location == s.location {
for _, alip := range s.allowedLocationIPs {
rules.PrependRules = append(rules.PrependRules,
iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
)
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"))
}
}
}
for _, p := range t.peers {
for _, aip := range p.AllowedIPs {
rules.PrependRules = append(rules.PrependRules,
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
)
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"))
rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"))
}
}
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
rules.AppendRules = append(rules.AppendRules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
rules.AddToAppend(iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
rules.AddToAppend(iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
return rules
}