Split iptables rules into append and prepend rules
This commit is contained in:
		| @@ -46,6 +46,20 @@ type fakeClient struct { | ||||
|  | ||||
| var _ Client = &fakeClient{} | ||||
|  | ||||
| func (f *fakeClient) Insert(table, chain string, pos int, spec ...string) error { | ||||
| 	atomic.AddUint64(&f.calls, 1) | ||||
| 	exists, err := f.Exists(table, chain, spec...) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if exists { | ||||
| 		return nil | ||||
| 	} | ||||
| 	// FIXME obey pos! | ||||
| 	f.storage = append([]Rule{&rule{table: table, chain: chain, spec: spec}}, f.storage...) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (f *fakeClient) AppendUnique(table, chain string, spec ...string) error { | ||||
| 	atomic.AddUint64(&f.calls, 1) | ||||
| 	exists, err := f.Exists(table, chain, spec...) | ||||
|   | ||||
| @@ -64,6 +64,7 @@ func GetProtocol(ip net.IP) Protocol { | ||||
| // Client represents any type that can administer iptables rules. | ||||
| type Client interface { | ||||
| 	AppendUnique(table string, chain string, rule ...string) error | ||||
| 	Insert(table string, chain string, pos int, rule ...string) error | ||||
| 	Delete(table string, chain string, rule ...string) error | ||||
| 	Exists(table string, chain string, rule ...string) (bool, error) | ||||
| 	List(table string, chain string) ([]string, error) | ||||
| @@ -75,7 +76,8 @@ type Client interface { | ||||
|  | ||||
| // Rule is an interface for interacting with iptables objects. | ||||
| type Rule interface { | ||||
| 	Add(Client) error | ||||
| 	Append(Client) error | ||||
| 	Prepend(Client) error | ||||
| 	Delete(Client) error | ||||
| 	Exists(Client) (bool, error) | ||||
| 	String() string | ||||
| @@ -106,7 +108,14 @@ func NewIPv6Rule(table, chain string, spec ...string) Rule { | ||||
| 	return &rule{table, chain, spec, ProtocolIPv6} | ||||
| } | ||||
|  | ||||
| func (r *rule) Add(client Client) error { | ||||
| func (r *rule) Prepend(client Client) error { | ||||
| 	if err := client.Insert(r.table, r.chain, 1, r.spec...); err != nil { | ||||
| 		return fmt.Errorf("failed to add iptables rule: %v", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *rule) Append(client Client) error { | ||||
| 	if err := client.AppendUnique(r.table, r.chain, r.spec...); err != nil { | ||||
| 		return fmt.Errorf("failed to add iptables rule: %v", err) | ||||
| 	} | ||||
| @@ -162,7 +171,11 @@ func NewIPv6Chain(table, name string) Rule { | ||||
| 	return &chain{table, name, ProtocolIPv6} | ||||
| } | ||||
|  | ||||
| func (c *chain) Add(client Client) error { | ||||
| func (c *chain) Prepend(client Client) error { | ||||
| 	return c.Append(client) | ||||
| } | ||||
|  | ||||
| func (c *chain) Append(client Client) error { | ||||
| 	// Note: `ClearChain` creates a chain if it does not exist. | ||||
| 	if err := client.ClearChain(c.table, c.chain); err != nil { | ||||
| 		return fmt.Errorf("failed to add iptables chain: %v", err) | ||||
| @@ -224,8 +237,9 @@ type Controller struct { | ||||
| 	registerer   prometheus.Registerer | ||||
|  | ||||
| 	sync.Mutex | ||||
| 	rules      []Rule | ||||
| 	subscribed bool | ||||
| 	appendRules  []Rule | ||||
| 	prependRules []Rule | ||||
| 	subscribed   bool | ||||
| } | ||||
|  | ||||
| // ControllerOption modifies the controller's configuration. | ||||
| @@ -333,14 +347,14 @@ func (c *Controller) reconcile() error { | ||||
| 	c.Lock() | ||||
| 	defer c.Unlock() | ||||
| 	var rc ruleCache | ||||
| 	for i, r := range c.rules { | ||||
| 	for i, r := range c.appendRules { | ||||
| 		ok, err := rc.exists(c.client(r.Proto()), r) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("failed to check if rule exists: %v", err) | ||||
| 		} | ||||
| 		if !ok { | ||||
| 			level.Info(c.logger).Log("msg", fmt.Sprintf("applying %d iptables rules", len(c.rules)-i)) | ||||
| 			if err := c.resetFromIndex(i, c.rules); err != nil { | ||||
| 			level.Info(c.logger).Log("msg", fmt.Sprintf("applying %d iptables rules", len(c.appendRules)-i)) | ||||
| 			if err := c.resetFromIndex(i, c.appendRules); err != nil { | ||||
| 				return fmt.Errorf("failed to add rule: %v", err) | ||||
| 			} | ||||
| 			break | ||||
| @@ -358,7 +372,7 @@ func (c *Controller) resetFromIndex(i int, rules []Rule) error { | ||||
| 		if err := rules[j].Delete(c.client(rules[j].Proto())); err != nil { | ||||
| 			return fmt.Errorf("failed to delete rule: %v", err) | ||||
| 		} | ||||
| 		if err := rules[j].Add(c.client(rules[j].Proto())); err != nil { | ||||
| 		if err := rules[j].Append(c.client(rules[j].Proto())); err != nil { | ||||
| 			return fmt.Errorf("failed to add rule: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| @@ -383,34 +397,87 @@ func (c *Controller) deleteFromIndex(i int, rules *[]Rule) error { | ||||
|  | ||||
| // Set idempotently overwrites any iptables rules previously defined | ||||
| // for the controller with the given set of rules. | ||||
| func (c *Controller) Set(rules []Rule) error { | ||||
| func (c *Controller) Set(rules RuleSet) error { | ||||
| 	c.Lock() | ||||
| 	defer c.Unlock() | ||||
| 	if err := c.setAppendRules(rules.AppendRules); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return c.setPrependRules(rules.PrependRules) | ||||
| } | ||||
|  | ||||
| func (c *Controller) setAppendRules(appendRules []Rule) error { | ||||
| 	var i int | ||||
| 	for ; i < len(rules); i++ { | ||||
| 		if i < len(c.rules) { | ||||
| 			if rules[i].String() != c.rules[i].String() { | ||||
| 				if err := c.deleteFromIndex(i, &c.rules); err != nil { | ||||
| 	for ; i < len(appendRules); i++ { | ||||
| 		if i < len(c.appendRules) { | ||||
| 			if appendRules[i].String() != c.appendRules[i].String() { | ||||
| 				if err := c.deleteFromIndex(i, &c.appendRules); err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		if i >= len(c.rules) { | ||||
| 			if err := rules[i].Add(c.client(rules[i].Proto())); err != nil { | ||||
| 		if i >= len(c.appendRules) { | ||||
| 			if err := appendRules[i].Append(c.client(appendRules[i].Proto())); err != nil { | ||||
| 				return fmt.Errorf("failed to add rule: %v", err) | ||||
| 			} | ||||
| 			c.rules = append(c.rules, rules[i]) | ||||
| 			c.appendRules = append(c.appendRules, appendRules[i]) | ||||
| 		} | ||||
|  | ||||
| 	} | ||||
| 	return c.deleteFromIndex(i, &c.rules) | ||||
| 	err := c.deleteFromIndex(i, &c.appendRules) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to delete rule: %v", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *Controller) setPrependRules(prependRules []Rule) error { | ||||
| 	for _, prependRule := range prependRules { | ||||
| 		if !containsRule(c.prependRules, prependRule) { | ||||
| 			if err := prependRule.Prepend(c.client(prependRule.Proto())); err != nil { | ||||
| 				return fmt.Errorf("failed to add rule: %v", err) | ||||
| 			} | ||||
| 			c.prependRules = append(c.prependRules, prependRule) | ||||
| 		} | ||||
| 	} | ||||
| 	for _, existingRule := range c.prependRules { | ||||
| 		if !containsRule(prependRules, existingRule) { | ||||
| 			if err := existingRule.Delete(c.client(existingRule.Proto())); err != nil { | ||||
| 				return fmt.Errorf("failed to delete rule: %v", err) | ||||
| 			} | ||||
| 			c.prependRules = removeRule(c.prependRules, existingRule) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func removeRule(rules []Rule, toRemove Rule) []Rule { | ||||
| 	ret := make([]Rule, 0, len(rules)) | ||||
| 	for _, rule := range rules { | ||||
| 		if rule.String() != toRemove.String() { | ||||
| 			ret = append(ret, rule) | ||||
| 		} | ||||
| 	} | ||||
| 	return ret | ||||
| } | ||||
|  | ||||
| func containsRule(haystack []Rule, needle Rule) bool { | ||||
| 	for _, element := range haystack { | ||||
| 		if element.String() == needle.String() { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // CleanUp will clean up any rules created by the controller. | ||||
| func (c *Controller) CleanUp() error { | ||||
| 	c.Lock() | ||||
| 	defer c.Unlock() | ||||
| 	return c.deleteFromIndex(0, &c.rules) | ||||
| 	err := c.deleteFromIndex(0, &c.prependRules) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return c.deleteFromIndex(0, &c.appendRules) | ||||
| } | ||||
|  | ||||
| func (c *Controller) client(p Protocol) Client { | ||||
| @@ -430,3 +497,8 @@ func nonBlockingSend(errors chan<- error, err error) { | ||||
| 	default: | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type RuleSet struct { | ||||
| 	AppendRules  []Rule // Rules to append to the chain - order matters. | ||||
| 	PrependRules []Rule // Rules to prepend to the chain - order does not matter. | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user