pkg/iptables: enable simultaneous ipv4 and ipv6

This commit enables simultaneously managing IPv4 and IPv6 iptables
rules. This makes it possible to have peers with IPv6 allowed IPs in an
otherwise IPv4 stack and vice versa.

Signed-off-by: Lucas Servén Marín <lserven@gmail.com>
This commit is contained in:
Lucas Servén Marín 2020-03-12 15:48:01 +01:00
parent 8e8eb1a213
commit b668c1ec3e
No known key found for this signature in database
GPG Key ID: 586FEAF680DA74AD
7 changed files with 108 additions and 43 deletions

View File

@ -9,7 +9,7 @@ FROM $FROM
ARG GOARCH ARG GOARCH
LABEL maintainer="squat <lserven@gmail.com>" LABEL maintainer="squat <lserven@gmail.com>"
RUN echo -e "https://dl-3.alpinelinux.org/alpine/edge/main\nhttps://dl-3.alpinelinux.org/alpine/edge/community" > /etc/apk/repositories && \ RUN echo -e "https://dl-3.alpinelinux.org/alpine/edge/main\nhttps://dl-3.alpinelinux.org/alpine/edge/community" > /etc/apk/repositories && \
apk add --no-cache ipset iptables wireguard-tools apk add --no-cache ipset iptables ip6tables wireguard-tools
COPY --from=cni bridge host-local loopback portmap /opt/cni/bin/ COPY --from=cni bridge host-local loopback portmap /opt/cni/bin/
COPY bin/$GOARCH/kg /opt/bin/ COPY bin/$GOARCH/kg /opt/bin/
ENTRYPOINT ["/opt/bin/kg"] ENTRYPOINT ["/opt/bin/kg"]

View File

@ -67,14 +67,17 @@ func (i *ipip) Init(base int) error {
// when traffic between nodes must be encapsulated. // when traffic between nodes must be encapsulated.
func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule { func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule {
var rules []iptables.Rule var rules []iptables.Rule
rules = append(rules, iptables.NewChain("filter", "KILO-IPIP")) rules = append(rules, iptables.NewIPv4Chain("filter", "KILO-IPIP"))
rules = append(rules, iptables.NewRule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-p", "4", "-j", "KILO-IPIP")) rules = append(rules, iptables.NewIPv6Chain("filter", "KILO-IPIP"))
rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-p", "4", "-j", "KILO-IPIP"))
rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-p", "4", "-j", "KILO-IPIP"))
for _, n := range nodes { for _, n := range nodes {
// Accept encapsulated traffic from peers. // Accept encapsulated traffic from peers.
rules = append(rules, iptables.NewRule("filter", "KILO-IPIP", "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-s", n.IP.String(), "-j", "ACCEPT")) rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(n.IP)), "filter", "KILO-IPIP", "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-s", n.IP.String(), "-j", "ACCEPT"))
} }
// Drop all other IPIP traffic. // Drop all other IPIP traffic.
rules = append(rules, iptables.NewRule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-p", "4", "-j", "DROP")) rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-p", "4", "-j", "DROP"))
rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-p", "4", "-j", "DROP"))
return rules return rules
} }

View File

@ -51,12 +51,12 @@ func (f *fakeClient) AppendUnique(table, chain string, spec ...string) error {
if exists { if exists {
return nil return nil
} }
f.storage = append(f.storage, &rule{table, chain, spec}) f.storage = append(f.storage, &rule{table: table, chain: chain, spec: spec})
return nil return nil
} }
func (f *fakeClient) Delete(table, chain string, spec ...string) error { func (f *fakeClient) Delete(table, chain string, spec ...string) error {
r := &rule{table, chain, spec} r := &rule{table: table, chain: chain, spec: spec}
for i := range f.storage { for i := range f.storage {
if f.storage[i].String() == r.String() { if f.storage[i].String() == r.String() {
copy(f.storage[i:], f.storage[i+1:]) copy(f.storage[i:], f.storage[i+1:])
@ -69,7 +69,7 @@ func (f *fakeClient) Delete(table, chain string, spec ...string) error {
} }
func (f *fakeClient) Exists(table, chain string, spec ...string) (bool, error) { func (f *fakeClient) Exists(table, chain string, spec ...string) (bool, error) {
r := &rule{table, chain, spec} r := &rule{table: table, chain: chain, spec: spec}
for i := range f.storage { for i := range f.storage {
if f.storage[i].String() == r.String() { if f.storage[i].String() == r.String() {
return true, nil return true, nil
@ -103,7 +103,7 @@ func (f *fakeClient) DeleteChain(table, name string) error {
return fmt.Errorf("cannot delete chain %s; rules exist", name) return fmt.Errorf("cannot delete chain %s; rules exist", name)
} }
} }
c := &chain{table, name} c := &chain{table: table, chain: name}
for i := range f.storage { for i := range f.storage {
if f.storage[i].String() == c.String() { if f.storage[i].String() == c.String() {
copy(f.storage[i:], f.storage[i+1:]) copy(f.storage[i:], f.storage[i+1:])
@ -116,7 +116,7 @@ func (f *fakeClient) DeleteChain(table, name string) error {
} }
func (f *fakeClient) NewChain(table, name string) error { func (f *fakeClient) NewChain(table, name string) error {
c := &chain{table, name} c := &chain{table: table, chain: name}
for i := range f.storage { for i := range f.storage {
if f.storage[i].String() == c.String() { if f.storage[i].String() == c.String() {
return statusError(1) return statusError(1)

View File

@ -24,6 +24,24 @@ import (
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
) )
// Protocol represents an IP protocol.
type Protocol byte
const (
// ProtocolIPv4 represents the IPv4 protocol.
ProtocolIPv4 Protocol = iota
// ProtocolIPv6 represents the IPv6 protocol.
ProtocolIPv6
)
// GetProtocol will return a protocol from the length of an IP address.
func GetProtocol(length int) Protocol {
if length == net.IPv6len {
return ProtocolIPv6
}
return ProtocolIPv4
}
// Client represents any type that can administer iptables rules. // Client represents any type that can administer iptables rules.
type Client interface { type Client interface {
AppendUnique(table string, chain string, rule ...string) error AppendUnique(table string, chain string, rule ...string) error
@ -40,6 +58,7 @@ type Rule interface {
Delete(Client) error Delete(Client) error
Exists(Client) (bool, error) Exists(Client) (bool, error)
String() string String() string
Proto() Protocol
} }
// rule represents an iptables rule. // rule represents an iptables rule.
@ -47,11 +66,23 @@ type rule struct {
table string table string
chain string chain string
spec []string spec []string
proto Protocol
} }
// NewRule creates a new iptables rule in the given table and chain. // NewRule creates a new iptables or ip6tables rule in the given table and chain
func NewRule(table, chain string, spec ...string) Rule { // depending on the given protocol.
return &rule{table, chain, spec} func NewRule(proto Protocol, table, chain string, spec ...string) Rule {
return &rule{table, chain, spec, proto}
}
// NewIPv4Rule creates a new iptables rule in the given table and chain.
func NewIPv4Rule(table, chain string, spec ...string) Rule {
return &rule{table, chain, spec, ProtocolIPv4}
}
// NewIPv6Rule creates a new ip6tables rule in the given table and chain.
func NewIPv6Rule(table, chain string, spec ...string) Rule {
return &rule{table, chain, spec, ProtocolIPv6}
} }
func (r *rule) Add(client Client) error { func (r *rule) Add(client Client) error {
@ -79,15 +110,25 @@ func (r *rule) String() string {
return fmt.Sprintf("%s_%s_%s", r.table, r.chain, strings.Join(r.spec, "_")) return fmt.Sprintf("%s_%s_%s", r.table, r.chain, strings.Join(r.spec, "_"))
} }
func (r *rule) Proto() Protocol {
return r.proto
}
// chain represents an iptables chain. // chain represents an iptables chain.
type chain struct { type chain struct {
table string table string
chain string chain string
proto Protocol
} }
// NewChain creates a new iptables chain in the given table. // NewIPv4Chain creates a new iptables chain in the given table.
func NewChain(table, name string) Rule { func NewIPv4Chain(table, name string) Rule {
return &chain{table, name} return &chain{table, name, ProtocolIPv4}
}
// NewIPv6Chain creates a new ip6tables chain in the given table.
func NewIPv6Chain(table, name string) Rule {
return &chain{table, name, ProtocolIPv6}
} }
func (c *chain) Add(client Client) error { func (c *chain) Add(client Client) error {
@ -133,9 +174,14 @@ func (c *chain) String() string {
return fmt.Sprintf("%s_%s", c.table, c.chain) return fmt.Sprintf("%s_%s", c.table, c.chain)
} }
func (c *chain) Proto() Protocol {
return c.proto
}
// Controller is able to reconcile a given set of iptables rules. // Controller is able to reconcile a given set of iptables rules.
type Controller struct { type Controller struct {
client Client v4 Client
v6 Client
errors chan error errors chan error
sync.Mutex sync.Mutex
@ -146,17 +192,18 @@ type Controller struct {
// New generates a new iptables rules controller. // New generates a new iptables rules controller.
// It expects an IP address length to determine // It expects an IP address length to determine
// whether to operate in IPv4 or IPv6 mode. // whether to operate in IPv4 or IPv6 mode.
func New(ipLength int) (*Controller, error) { func New() (*Controller, error) {
p := iptables.ProtocolIPv4 v4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if ipLength == net.IPv6len {
p = iptables.ProtocolIPv6
}
client, err := iptables.NewWithProtocol(p)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create iptables client: %v", err) return nil, fmt.Errorf("failed to create iptables IPv4 client: %v", err)
}
v6, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
return nil, fmt.Errorf("failed to create iptables IPv6 client: %v", err)
} }
return &Controller{ return &Controller{
client: client, v4: v4,
v6: v6,
errors: make(chan error), errors: make(chan error),
}, nil }, nil
} }
@ -196,7 +243,7 @@ func (c *Controller) reconcile() error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
for i, r := range c.rules { for i, r := range c.rules {
ok, err := r.Exists(c.client) ok, err := r.Exists(c.client(r.Proto()))
if err != nil { if err != nil {
return fmt.Errorf("failed to check if rule exists: %v", err) return fmt.Errorf("failed to check if rule exists: %v", err)
} }
@ -216,10 +263,10 @@ func (c *Controller) resetFromIndex(i int, rules []Rule) error {
return nil return nil
} }
for j := i; j < len(rules); j++ { for j := i; j < len(rules); j++ {
if err := rules[j].Delete(c.client); err != nil { if err := rules[j].Delete(c.client(rules[j].Proto())); err != nil {
return fmt.Errorf("failed to delete rule: %v", err) return fmt.Errorf("failed to delete rule: %v", err)
} }
if err := rules[j].Add(c.client); err != nil { if err := rules[j].Add(c.client(rules[j].Proto())); err != nil {
return fmt.Errorf("failed to add rule: %v", err) return fmt.Errorf("failed to add rule: %v", err)
} }
} }
@ -232,7 +279,7 @@ func (c *Controller) deleteFromIndex(i int, rules *[]Rule) error {
return nil return nil
} }
for j := i; j < len(*rules); j++ { for j := i; j < len(*rules); j++ {
if err := (*rules)[j].Delete(c.client); err != nil { if err := (*rules)[j].Delete(c.client((*rules)[j].Proto())); err != nil {
return fmt.Errorf("failed to delete rule: %v", err) return fmt.Errorf("failed to delete rule: %v", err)
} }
(*rules)[j] = nil (*rules)[j] = nil
@ -256,7 +303,7 @@ func (c *Controller) Set(rules []Rule) error {
} }
} }
if i >= len(c.rules) { if i >= len(c.rules) {
if err := rules[i].Add(c.client); err != nil { if err := rules[i].Add(c.client(rules[i].Proto())); err != nil {
return fmt.Errorf("failed to add rule: %v", err) return fmt.Errorf("failed to add rule: %v", err)
} }
c.rules = append(c.rules, rules[i]) c.rules = append(c.rules, rules[i])
@ -273,6 +320,17 @@ func (c *Controller) CleanUp() error {
return c.deleteFromIndex(0, &c.rules) return c.deleteFromIndex(0, &c.rules)
} }
func (c *Controller) client(p Protocol) Client {
switch p {
case ProtocolIPv4:
return c.v4
case ProtocolIPv6:
return c.v6
default:
panic("unknown protocol")
}
}
func nonBlockingSend(errors chan<- error, err error) { func nonBlockingSend(errors chan<- error, err error) {
select { select {
case errors <- err: case errors <- err:

View File

@ -19,8 +19,8 @@ import (
) )
var rules = []Rule{ var rules = []Rule{
&rule{"filter", "FORWARD", []string{"-s", "10.4.0.0/16", "-j", "ACCEPT"}}, NewIPv4Rule("filter", "FORWARD", "-s", "10.4.0.0/16", "-j", "ACCEPT"),
&rule{"filter", "FORWARD", []string{"-d", "10.4.0.0/16", "-j", "ACCEPT"}}, NewIPv4Rule("filter", "FORWARD", "-d", "10.4.0.0/16", "-j", "ACCEPT"),
} }
func TestSet(t *testing.T) { func TestSet(t *testing.T) {
@ -85,14 +85,15 @@ func TestSet(t *testing.T) {
} { } {
controller := &Controller{} controller := &Controller{}
client := &fakeClient{} client := &fakeClient{}
controller.client = client controller.v4 = client
controller.v6 = client
for i := range tc.sets { for i := range tc.sets {
if err := controller.Set(tc.sets[i]); err != nil { if err := controller.Set(tc.sets[i]); err != nil {
t.Fatalf("test case %q: got unexpected error seting rule set %d: %v", tc.name, i, err) t.Fatalf("test case %q: got unexpected error seting rule set %d: %v", tc.name, i, err)
} }
} }
for i, f := range tc.actions { for i, f := range tc.actions {
if err := f(controller.client); err != nil { if err := f(controller.v4); err != nil {
t.Fatalf("test case %q action %d: got unexpected error %v", tc.name, i, err) t.Fatalf("test case %q action %d: got unexpected error %v", tc.name, i, err)
} }
} }
@ -140,7 +141,8 @@ func TestCleanUp(t *testing.T) {
} { } {
controller := &Controller{} controller := &Controller{}
client := &fakeClient{} client := &fakeClient{}
controller.client = client controller.v4 = client
controller.v6 = client
if err := controller.Set(tc.rules); err != nil { if err := controller.Set(tc.rules); err != nil {
t.Fatalf("test case %q: Set should not fail: %v", tc.name, err) t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
} }

View File

@ -252,7 +252,7 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit
} }
level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the private IP address", privateIP.String())) level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the private IP address", privateIP.String()))
level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the public IP address", publicIP.String())) level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the public IP address", publicIP.String()))
ipTables, err := iptables.New(len(subnet.IP)) ipTables, err := iptables.New()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to IP tables controller: %v", err) return nil, fmt.Errorf("failed to IP tables controller: %v", err)
} }

View File

@ -436,25 +436,27 @@ func (t *Topology) PeerConf(name string) *wireguard.Conf {
// Rules returns the iptables rules required by the local node. // Rules returns the iptables rules required by the local node.
func (t *Topology) Rules(cni bool) []iptables.Rule { func (t *Topology) Rules(cni bool) []iptables.Rule {
var rules []iptables.Rule var rules []iptables.Rule
rules = append(rules, iptables.NewChain("nat", "KILO-NAT")) rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
if cni { if cni {
rules = append(rules, iptables.NewRule("nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-s", t.subnet.String(), "-j", "KILO-NAT")) rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-s", t.subnet.String(), "-j", "KILO-NAT"))
} }
for _, s := range t.segments { for _, s := range t.segments {
rules = append(rules, iptables.NewRule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-d", s.wireGuardIP.String(), "-j", "RETURN")) rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(s.wireGuardIP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-d", s.wireGuardIP.String(), "-j", "RETURN"))
for _, aip := range s.allowedIPs { for _, aip := range s.allowedIPs {
rules = append(rules, iptables.NewRule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-d", aip.String(), "-j", "RETURN")) rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-d", aip.String(), "-j", "RETURN"))
} }
} }
for _, p := range t.peers { for _, p := range t.peers {
for _, aip := range p.AllowedIPs { for _, aip := range p.AllowedIPs {
rules = append(rules, rules = append(rules,
iptables.NewRule("nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-s", aip.String(), "-j", "KILO-NAT"), iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-s", aip.String(), "-j", "KILO-NAT"),
iptables.NewRule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-d", aip.String(), "-j", "RETURN"), iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-d", aip.String(), "-j", "RETURN"),
) )
} }
} }
rules = append(rules, iptables.NewRule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) rules = append(rules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
rules = append(rules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
return rules return rules
} }