diff --git a/pkg/iptables/iptables_test.go b/pkg/iptables/iptables_test.go index ca661d7..8d1d8d8 100644 --- a/pkg/iptables/iptables_test.go +++ b/pkg/iptables/iptables_test.go @@ -16,74 +16,96 @@ package iptables import ( "testing" - - "github.com/prometheus/client_golang/prometheus" ) -var rules = []Rule{ +var appendRules = []Rule{ NewIPv4Rule("filter", "FORWARD", "-s", "10.4.0.0/16", "-j", "ACCEPT"), NewIPv4Rule("filter", "FORWARD", "-d", "10.4.0.0/16", "-j", "ACCEPT"), } +var prependRules = []Rule{ + NewIPv4Rule("filter", "FORWARD", "-s", "10.5.0.0/16", "-j", "DROP"), + NewIPv4Rule("filter", "FORWARD", "-s", "10.6.0.0/16", "-j", "DROP"), +} + func TestSet(t *testing.T) { for _, tc := range []struct { - name string - sets [][]Rule - out []Rule - actions []func(Client) error + name string + sets []RuleSet + appendOut []Rule + prependOut []Rule + storageOut []Rule + actions []func(Client) error }{ { name: "empty", }, { name: "single", - sets: [][]Rule{ - {rules[0]}, + sets: []RuleSet{ + {AppendRules: []Rule{appendRules[0]}}, }, - out: []Rule{rules[0]}, + appendOut: []Rule{appendRules[0]}, + storageOut: []Rule{appendRules[0]}, }, { name: "two rules", - sets: [][]Rule{ - {rules[0], rules[1]}, + sets: []RuleSet{ + {AppendRules: []Rule{appendRules[0], appendRules[1]}}, }, - out: []Rule{rules[0], rules[1]}, + appendOut: []Rule{appendRules[0], appendRules[1]}, + storageOut: []Rule{appendRules[0], appendRules[1]}, }, { name: "multiple", - sets: [][]Rule{ - {rules[0], rules[1]}, - {rules[1]}, + sets: []RuleSet{ + {AppendRules: []Rule{appendRules[0], appendRules[1]}}, + {AppendRules: []Rule{appendRules[1]}}, }, - out: []Rule{rules[1]}, + appendOut: []Rule{appendRules[1]}, + storageOut: []Rule{appendRules[1]}, }, { name: "re-add", - sets: [][]Rule{ - {rules[0], rules[1]}, + sets: []RuleSet{ + {AppendRules: []Rule{appendRules[0], appendRules[1]}}, }, - out: []Rule{rules[0], rules[1]}, + appendOut: []Rule{appendRules[0], appendRules[1]}, + storageOut: []Rule{appendRules[0], appendRules[1]}, actions: []func(c Client) error{ func(c Client) error { - return rules[0].Delete(c) + return appendRules[0].Delete(c) }, func(c Client) error { - return rules[1].Delete(c) + return appendRules[1].Delete(c) }, }, }, { name: "order", - sets: [][]Rule{ - {rules[0], rules[1]}, + sets: []RuleSet{ + {AppendRules: []Rule{appendRules[0], appendRules[1]}}, }, - out: []Rule{rules[0], rules[1]}, + appendOut: []Rule{appendRules[0], appendRules[1]}, + storageOut: []Rule{appendRules[0], appendRules[1]}, actions: []func(c Client) error{ func(c Client) error { - return rules[0].Delete(c) + return appendRules[0].Delete(c) }, }, }, + { + name: "append and prepend", + sets: []RuleSet{ + { + PrependRules: []Rule{prependRules[0], prependRules[1]}, + AppendRules: []Rule{appendRules[0], appendRules[1]}, + }, + }, + appendOut: []Rule{appendRules[0], appendRules[1]}, + prependOut: []Rule{prependRules[0], prependRules[1]}, + storageOut: []Rule{prependRules[1], prependRules[0], appendRules[0], appendRules[1]}, + }, } { client := &fakeClient{} controller, err := New(WithClients(client, client)) @@ -91,8 +113,7 @@ func TestSet(t *testing.T) { t.Fatalf("test case %q: got unexpected error instantiating controller: %v", tc.name, err) } for i := range tc.sets { - ruleSet := RuleSet{AppendRules: tc.sets[i]} - if err := controller.Set(ruleSet); err != nil { + if err := controller.Set(tc.sets[i]); err != nil { t.Fatalf("test case %q: got unexpected error setting rule set %d: %v", tc.name, i, err) } } @@ -104,21 +125,30 @@ func TestSet(t *testing.T) { if err := controller.reconcile(); err != nil { t.Fatalf("test case %q: got unexpected error %v", tc.name, err) } - if len(tc.out) != len(client.storage) { - t.Errorf("test case %q: expected %d rules in storage, got %d", tc.name, len(tc.out), len(client.storage)) + if len(tc.storageOut) != len(client.storage) { + t.Errorf("test case %q: expected %d rules in storage, got %d", tc.name, len(tc.appendOut), len(client.storage)) } else { - for i := range tc.out { - if tc.out[i].String() != client.storage[i].String() { - t.Errorf("test case %q: expected rule %d in storage to be equal: expected %v, got %v", tc.name, i, tc.out[i], client.storage[i]) + for i := range tc.storageOut { + if tc.storageOut[i].String() != client.storage[i].String() { + t.Errorf("test case %q: expected rule %d in storage to be equal: expected %v, got %v", tc.name, i, tc.storageOut[i], client.storage[i]) } } } - if len(tc.out) != len(controller.appendRules) { - t.Errorf("test case %q: expected %d rules in controller, got %d", tc.name, len(tc.out), len(controller.appendRules)) + if len(tc.appendOut) != len(controller.appendRules) { + t.Errorf("test case %q: expected %d appendRules in controller, got %d", tc.name, len(tc.appendOut), len(controller.appendRules)) } else { - for i := range tc.out { - if tc.out[i].String() != controller.appendRules[i].String() { - t.Errorf("test case %q: expected rule %d in controller to be equal: expected %v, got %v", tc.name, i, tc.out[i], controller.appendRules[i]) + for i := range tc.appendOut { + if tc.appendOut[i].String() != controller.appendRules[i].String() { + t.Errorf("test case %q: expected appendRule %d in controller to be equal: expected %v, got %v", tc.name, i, tc.appendOut[i], controller.appendRules[i]) + } + } + } + if len(tc.prependOut) != len(controller.prependRules) { + t.Errorf("test case %q: expected %d prependRules in controller, got %d", tc.name, len(tc.prependOut), len(controller.prependRules)) + } else { + for i := range tc.prependOut { + if tc.prependOut[i].String() != controller.prependRules[i].String() { + t.Errorf("test case %q: expected prependRule %d in controller to be equal: expected %v, got %v", tc.name, i, tc.prependOut[i], controller.prependRules[i]) } } } @@ -127,20 +157,26 @@ func TestSet(t *testing.T) { func TestCleanUp(t *testing.T) { for _, tc := range []struct { - name string - rules []Rule + name string + appendRules []Rule + prependRules []Rule }{ { - name: "empty", - rules: nil, + name: "empty", + appendRules: nil, }, { - name: "single", - rules: []Rule{rules[0]}, + name: "single append", + appendRules: []Rule{appendRules[0]}, }, { - name: "multiple", - rules: []Rule{rules[0], rules[1]}, + name: "multiple append", + appendRules: []Rule{appendRules[0], appendRules[1]}, + }, + { + name: "multiple append and prepend", + appendRules: []Rule{appendRules[0], appendRules[1]}, + prependRules: []Rule{prependRules[0], prependRules[1]}, }, } { client := &fakeClient{} @@ -148,12 +184,12 @@ func TestCleanUp(t *testing.T) { if err != nil { t.Fatalf("test case %q: got unexpected error instantiating controller: %v", tc.name, err) } - ruleSet := RuleSet{AppendRules: tc.rules} + ruleSet := RuleSet{AppendRules: tc.appendRules, PrependRules: tc.prependRules} if err := controller.Set(ruleSet); err != nil { t.Fatalf("test case %q: Set should not fail: %v", tc.name, err) } - if len(client.storage) != len(tc.rules) { - t.Errorf("test case %q: expected %d rules in storage, got %d rules", tc.name, len(ruleSet.AppendRules), len(client.storage)) + if len(client.storage) != len(tc.appendRules)+len(tc.prependRules) { + t.Errorf("test case %q: expected %d rules in storage, got %d rules", tc.name, len(ruleSet.AppendRules)+len(ruleSet.PrependRules), len(client.storage)) } if err := controller.CleanUp(); err != nil { t.Errorf("test case %q: got unexpected error: %v", tc.name, err) diff --git a/pkg/iptables/rulecache_test.go b/pkg/iptables/rulecache_test.go index d9ed4c5..7ee8e1e 100644 --- a/pkg/iptables/rulecache_test.go +++ b/pkg/iptables/rulecache_test.go @@ -29,21 +29,21 @@ func TestRuleCache(t *testing.T) { { name: "empty", rules: nil, - check: []Rule{rules[0]}, + check: []Rule{appendRules[0]}, out: []bool{false}, calls: 1, }, { name: "single negative", - rules: []Rule{rules[1]}, - check: []Rule{rules[0]}, + rules: []Rule{appendRules[1]}, + check: []Rule{appendRules[0]}, out: []bool{false}, calls: 1, }, { name: "single positive", - rules: []Rule{rules[1]}, - check: []Rule{rules[1]}, + rules: []Rule{appendRules[1]}, + check: []Rule{appendRules[1]}, out: []bool{true}, calls: 1, }, @@ -56,29 +56,29 @@ func TestRuleCache(t *testing.T) { }, { name: "rule on chain means chain exists", - rules: []Rule{rules[0]}, - check: []Rule{rules[0], &chain{"filter", "FORWARD", ProtocolIPv4}}, + rules: []Rule{appendRules[0]}, + check: []Rule{appendRules[0], &chain{"filter", "FORWARD", ProtocolIPv4}}, out: []bool{true, true}, calls: 1, }, { name: "rule on chain does not mean table is fully populated", - rules: []Rule{rules[0], &chain{"filter", "INPUT", ProtocolIPv4}}, - check: []Rule{rules[0], &chain{"filter", "OUTPUT", ProtocolIPv4}, &chain{"filter", "INPUT", ProtocolIPv4}}, + rules: []Rule{appendRules[0], &chain{"filter", "INPUT", ProtocolIPv4}}, + check: []Rule{appendRules[0], &chain{"filter", "OUTPUT", ProtocolIPv4}, &chain{"filter", "INPUT", ProtocolIPv4}}, out: []bool{true, false, true}, calls: 2, }, { name: "multiple rules on chain", - rules: []Rule{rules[0], rules[1]}, - check: []Rule{rules[0], rules[1], &chain{"filter", "FORWARD", ProtocolIPv4}}, + rules: []Rule{appendRules[0], appendRules[1]}, + check: []Rule{appendRules[0], appendRules[1], &chain{"filter", "FORWARD", ProtocolIPv4}}, out: []bool{true, true, true}, calls: 1, }, { name: "checking rule on chain does not mean chain exists", rules: nil, - check: []Rule{rules[0], &chain{"filter", "FORWARD", ProtocolIPv4}}, + check: []Rule{appendRules[0], &chain{"filter", "FORWARD", ProtocolIPv4}}, out: []bool{false, false}, calls: 2, },