diff --git a/pkg/iptables/fake.go b/pkg/iptables/fake.go index 6f8a86a..98b28fd 100644 --- a/pkg/iptables/fake.go +++ b/pkg/iptables/fake.go @@ -16,7 +16,6 @@ package iptables import ( "fmt" - "strings" "github.com/coreos/go-iptables/iptables" ) @@ -38,55 +37,91 @@ func (s statusError) ExitStatus() int { return int(s) } -type fakeClient map[string]Rule +type fakeClient struct { + storage []Rule +} -var _ iptablesClient = fakeClient(nil) +var _ iptablesClient = &fakeClient{} -func (f fakeClient) AppendUnique(table, chain string, spec ...string) error { - r := &rule{table, chain, spec, nil} - f[r.String()] = r +func (f *fakeClient) AppendUnique(table, chain string, spec ...string) error { + exists, err := f.Exists(table, chain, spec...) + if err != nil { + return err + } + if exists { + return nil + } + f.storage = append(f.storage, &rule{table, chain, spec, nil}) return nil } -func (f fakeClient) Delete(table, chain string, spec ...string) error { +func (f *fakeClient) Delete(table, chain string, spec ...string) error { r := &rule{table, chain, spec, nil} - delete(f, r.String()) - return nil -} - -func (f fakeClient) Exists(table, chain string, spec ...string) (bool, error) { - r := &rule{table, chain, spec, nil} - _, ok := f[r.String()] - return ok, nil -} - -func (f fakeClient) ClearChain(table, name string) error { - c := &chain{table, name, nil} - for k := range f { - if strings.HasPrefix(k, c.String()) { - delete(f, k) + for i := range f.storage { + if f.storage[i].String() == r.String() { + copy(f.storage[i:], f.storage[i+1:]) + f.storage[len(f.storage)-1] = nil + f.storage = f.storage[:len(f.storage)-1] + break } } - f[c.String()] = c return nil } -func (f fakeClient) DeleteChain(table, name string) error { - c := &chain{table, name, nil} - for k := range f { - if strings.HasPrefix(k, c.String()) { +func (f *fakeClient) Exists(table, chain string, spec ...string) (bool, error) { + r := &rule{table, chain, spec, nil} + for i := range f.storage { + if f.storage[i].String() == r.String() { + return true, nil + } + } + return false, nil +} + +func (f *fakeClient) ClearChain(table, name string) error { + for i := range f.storage { + r, ok := f.storage[i].(*rule) + if !ok { + continue + } + if table == r.table && name == r.chain { + if err := f.Delete(table, name, r.spec...); err != nil { + return nil + } + } + } + return f.DeleteChain(table, name) +} + +func (f *fakeClient) DeleteChain(table, name string) error { + for i := range f.storage { + r, ok := f.storage[i].(*rule) + if !ok { + continue + } + if table == r.table && name == r.chain { return fmt.Errorf("cannot delete chain %s; rules exist", name) } } - delete(f, c.String()) + c := &chain{table, name, nil} + for i := range f.storage { + if f.storage[i].String() == c.String() { + copy(f.storage[i:], f.storage[i+1:]) + f.storage[len(f.storage)-1] = nil + f.storage = f.storage[:len(f.storage)-1] + break + } + } return nil } -func (f fakeClient) NewChain(table, name string) error { +func (f *fakeClient) NewChain(table, name string) error { c := &chain{table, name, nil} - if _, ok := f[c.String()]; ok { - return statusError(1) + for i := range f.storage { + if f.storage[i].String() == c.String() { + return statusError(1) + } } - f[c.String()] = c + f.storage = append(f.storage, c) return nil } diff --git a/pkg/iptables/iptables.go b/pkg/iptables/iptables.go index 035db8b..4c34790 100644 --- a/pkg/iptables/iptables.go +++ b/pkg/iptables/iptables.go @@ -126,10 +126,11 @@ type Rule interface { // Controller is able to reconcile a given set of iptables rules. type Controller struct { - client iptablesClient - errors chan error - rules map[string]Rule - mu sync.Mutex + client iptablesClient + errors chan error + + sync.Mutex + rules []Rule subscribed bool } @@ -148,21 +149,20 @@ func New(ipLength int) (*Controller, error) { return &Controller{ client: client, errors: make(chan error), - rules: make(map[string]Rule), }, nil } // Run watches for changes to iptables rules and reconciles // the rules against the desired state. func (c *Controller) Run(stop <-chan struct{}) (<-chan error, error) { - c.mu.Lock() + c.Lock() if c.subscribed { - c.mu.Unlock() + c.Unlock() return c.errors, nil } // Ensure a given instance only subscribes once. c.subscribed = true - c.mu.Unlock() + c.Unlock() go func() { defer close(c.errors) for { @@ -171,76 +171,100 @@ func (c *Controller) Run(stop <-chan struct{}) (<-chan error, error) { case <-stop: return } - c.mu.Lock() - for _, r := range c.rules { - ok, err := r.Exists() - if err != nil { - nonBlockingSend(c.errors, fmt.Errorf("failed to check if rule exists: %v", err)) - } - if !ok { - if err := r.Add(); err != nil { - nonBlockingSend(c.errors, fmt.Errorf("failed to add rule: %v", err)) - } - } + if err := c.reconcile(); err != nil { + nonBlockingSend(c.errors, fmt.Errorf("failed to reconcile rules: %v", err)) } - c.mu.Unlock() } }() return c.errors, nil } -// Set idempotently overwrites any iptables rules previously defined -// for the controller with the given set of rules. -func (c *Controller) Set(rules []Rule) error { - r := make(map[string]struct{}) - for i := range rules { - if rules[i] == nil { - continue +// reconcile makes sure that every rule is still in the backend. +// It does not ensure that the order in the backend is correct. +// If any rule is missing, that rule and all following rules are +// re-added. +func (c *Controller) reconcile() error { + c.Lock() + defer c.Unlock() + for i, r := range c.rules { + ok, err := r.Exists() + if err != nil { + return fmt.Errorf("failed to check if rule exists: %v", err) } - switch v := rules[i].(type) { - case *rule: - v.client = c.client - case *chain: - v.client = c.client - } - r[rules[i].String()] = struct{}{} - } - c.mu.Lock() - defer c.mu.Unlock() - for k, rule := range c.rules { - if _, ok := r[k]; !ok { - if err := rule.Delete(); err != nil { - return fmt.Errorf("failed to delete rule: %v", err) - } - delete(c.rules, k) - } - } - // Iterate over the slice rather than the map - // to ensure the rules are added in order. - for _, rule := range rules { - if _, ok := c.rules[rule.String()]; !ok { - if err := rule.Add(); err != nil { + if !ok { + if err := resetFromIndex(i, c.rules); err != nil { return fmt.Errorf("failed to add rule: %v", err) } - c.rules[rule.String()] = rule + break } } return nil } -// CleanUp will clean up any rules created by the controller. -func (c *Controller) CleanUp() error { - c.mu.Lock() - defer c.mu.Unlock() - for k, rule := range c.rules { - if err := rule.Delete(); err != nil { +// resetFromIndex re-adds all rules starting from the given index. +func resetFromIndex(i int, rules []Rule) error { + if i >= len(rules) { + return nil + } + for j := range rules[i:] { + if err := rules[j].Delete(); err != nil { return fmt.Errorf("failed to delete rule: %v", err) } - delete(c.rules, k) + if err := rules[j].Add(); err != nil { + return fmt.Errorf("failed to add rule: %v", err) + } } return nil } +// deleteFromIndex deletes all rules starting from the given index. +func deleteFromIndex(i int, rules *[]Rule) error { + if i >= len(*rules) { + return nil + } + for j := range (*rules)[i:] { + if err := (*rules)[j].Delete(); err != nil { + return fmt.Errorf("failed to delete rule: %v", err) + } + (*rules)[j] = nil + } + *rules = (*rules)[:i] + return nil +} + +// Set idempotently overwrites any iptables rules previously defined +// for the controller with the given set of rules. +func (c *Controller) Set(rules []Rule) error { + c.Lock() + defer c.Unlock() + var i int + for ; i < len(rules); i++ { + if i < len(c.rules) { + if rules[i].String() != c.rules[i].String() { + if err := deleteFromIndex(i, &c.rules); err != nil { + return err + } + } + } + if i >= len(c.rules) { + setRuleClient(rules[i], c.client) + if err := rules[i].Add(); err != nil { + return fmt.Errorf("failed to add rule: %v", err) + } + c.rules = append(c.rules, rules[i]) + } + + } + return deleteFromIndex(i, &c.rules) +} + +// CleanUp will clean up any rules created by the controller. +func (c *Controller) CleanUp() error { + c.Lock() + defer c.Unlock() + return deleteFromIndex(0, &c.rules) +} + // IPIPRules returns a set of iptables rules that are necessary // when traffic between nodes must be encapsulated with IPIP. func IPIPRules(nodes []*net.IPNet) []Rule { @@ -303,3 +327,13 @@ func nonBlockingSend(errors chan<- error, err error) { default: } } + +// setRuleClient is a helper to set the iptables client on different kinds of rules. +func setRuleClient(r Rule, c iptablesClient) { + switch v := r.(type) { + case *rule: + v.client = c + case *chain: + v.client = c + } +} diff --git a/pkg/iptables/iptables_test.go b/pkg/iptables/iptables_test.go index f30151d..ef5eeb2 100644 --- a/pkg/iptables/iptables_test.go +++ b/pkg/iptables/iptables_test.go @@ -23,41 +23,101 @@ var rules = []Rule{ &rule{"filter", "FORWARD", []string{"-d", "10.4.0.0/16", "-j", "ACCEPT"}, nil}, } -func newController() *Controller { - return &Controller{ - rules: make(map[string]Rule), - } -} - func TestSet(t *testing.T) { for _, tc := range []struct { - name string - rules []Rule + name string + sets [][]Rule + out []Rule + actions []func(iptablesClient) error }{ { - name: "empty", - rules: nil, + name: "empty", }, { - name: "single", - rules: []Rule{rules[0]}, + name: "single", + sets: [][]Rule{ + {rules[0]}, + }, + out: []Rule{rules[0]}, }, { - name: "multiple", - rules: []Rule{rules[0], rules[1]}, + name: "two rules", + sets: [][]Rule{ + {rules[0], rules[1]}, + }, + out: []Rule{rules[0], rules[1]}, + }, + { + name: "multiple", + sets: [][]Rule{ + {rules[0], rules[1]}, + {rules[1]}, + }, + out: []Rule{rules[1]}, + }, + { + name: "re-add", + sets: [][]Rule{ + {rules[0], rules[1]}, + }, + out: []Rule{rules[0], rules[1]}, + actions: []func(c iptablesClient) error{ + func(c iptablesClient) error { + setRuleClient(rules[0], c) + return rules[0].Delete() + }, + func(c iptablesClient) error { + setRuleClient(rules[1], c) + return rules[1].Delete() + }, + }, + }, + { + name: "order", + sets: [][]Rule{ + {rules[0], rules[1]}, + }, + out: []Rule{rules[0], rules[1]}, + actions: []func(c iptablesClient) error{ + func(c iptablesClient) error { + setRuleClient(rules[0], c) + return rules[0].Delete() + }, + }, }, } { - backend := make(map[string]Rule) - controller := newController() - controller.client = fakeClient(backend) - if err := controller.Set(tc.rules); err != nil { - t.Fatalf("test case %q: got unexpected error: %v", tc.name, err) + controller := &Controller{} + client := &fakeClient{} + controller.client = client + for i := range tc.sets { + if err := controller.Set(tc.sets[i]); err != nil { + t.Fatalf("test case %q: got unexpected error seting rule set %d: %v", tc.name, i, err) + } } - for _, r := range tc.rules { - r1 := backend[r.String()] - r2 := controller.rules[r.String()] - if r.String() != r1.String() || r.String() != r2.String() { - t.Errorf("test case %q: expected all rules to be equal: expected %v, got %v and %v", tc.name, r, r1, r2) + for i, f := range tc.actions { + if err := f(controller.client); err != nil { + t.Fatalf("test case %q action %d: got unexpected error %v", tc.name, i, err) + } + } + if err := controller.reconcile(); err != nil { + t.Fatalf("test case %q: got unexpected error %v", tc.name, err) + } + if len(tc.out) != len(client.storage) { + t.Errorf("test case %q: expected %d rules in storage, got %d", tc.name, len(tc.out), len(client.storage)) + } else { + for i := range tc.out { + if tc.out[i].String() != client.storage[i].String() { + t.Errorf("test case %q: expected rule %d in storage to be equal: expected %v, got %v", tc.name, i, tc.out[i], client.storage[i]) + } + } + } + if len(tc.out) != len(controller.rules) { + t.Errorf("test case %q: expected %d rules in controller, got %d", tc.name, len(tc.out), len(controller.rules)) + } else { + for i := range tc.out { + if tc.out[i].String() != controller.rules[i].String() { + t.Errorf("test case %q: expected rule %d in controller to be equal: expected %v, got %v", tc.name, i, tc.out[i], controller.rules[i]) + } } } } @@ -81,21 +141,20 @@ func TestCleanUp(t *testing.T) { rules: []Rule{rules[0], rules[1]}, }, } { - backend := make(map[string]Rule) - controller := newController() - controller.client = fakeClient(backend) + controller := &Controller{} + client := &fakeClient{} + controller.client = client if err := controller.Set(tc.rules); err != nil { t.Fatalf("test case %q: Set should not fail: %v", tc.name, err) } + if len(client.storage) != len(tc.rules) { + t.Errorf("test case %q: expected %d rules in storage, got %d rules", tc.name, len(tc.rules), len(client.storage)) + } if err := controller.CleanUp(); err != nil { t.Errorf("test case %q: got unexpected error: %v", tc.name, err) } - for _, r := range tc.rules { - r1 := backend[r.String()] - r2 := controller.rules[r.String()] - if r1 != nil || r2 != nil { - t.Errorf("test case %q: expected all rules to be nil: expected got %v and %v", tc.name, r1, r2) - } + if len(client.storage) != 0 { + t.Errorf("test case %q: expected storage to be empty, got %d rules", tc.name, len(client.storage)) } } }