pkg/iptables: enable simultaneous ipv4 and ipv6

This commit enables simultaneously managing IPv4 and IPv6 iptables
rules. This makes it possible to have peers with IPv6 allowed IPs in an
otherwise IPv4 stack and vice versa.

Signed-off-by: Lucas Servén Marín <lserven@gmail.com>
This commit is contained in:
Lucas Servén Marín 2020-03-12 15:48:01 +01:00
parent 8e8eb1a213
commit b668c1ec3e
No known key found for this signature in database
GPG Key ID: 586FEAF680DA74AD
7 changed files with 108 additions and 43 deletions

View File

@ -9,7 +9,7 @@ FROM $FROM
ARG GOARCH
LABEL maintainer="squat <lserven@gmail.com>"
RUN echo -e "https://dl-3.alpinelinux.org/alpine/edge/main\nhttps://dl-3.alpinelinux.org/alpine/edge/community" > /etc/apk/repositories && \
apk add --no-cache ipset iptables wireguard-tools
apk add --no-cache ipset iptables ip6tables wireguard-tools
COPY --from=cni bridge host-local loopback portmap /opt/cni/bin/
COPY bin/$GOARCH/kg /opt/bin/
ENTRYPOINT ["/opt/bin/kg"]

View File

@ -67,14 +67,17 @@ func (i *ipip) Init(base int) error {
// when traffic between nodes must be encapsulated.
func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule {
var rules []iptables.Rule
rules = append(rules, iptables.NewChain("filter", "KILO-IPIP"))
rules = append(rules, iptables.NewRule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-p", "4", "-j", "KILO-IPIP"))
rules = append(rules, iptables.NewIPv4Chain("filter", "KILO-IPIP"))
rules = append(rules, iptables.NewIPv6Chain("filter", "KILO-IPIP"))
rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-p", "4", "-j", "KILO-IPIP"))
rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-p", "4", "-j", "KILO-IPIP"))
for _, n := range nodes {
// Accept encapsulated traffic from peers.
rules = append(rules, iptables.NewRule("filter", "KILO-IPIP", "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-s", n.IP.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(n.IP)), "filter", "KILO-IPIP", "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-s", n.IP.String(), "-j", "ACCEPT"))
}
// Drop all other IPIP traffic.
rules = append(rules, iptables.NewRule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-p", "4", "-j", "DROP"))
rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-p", "4", "-j", "DROP"))
rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-p", "4", "-j", "DROP"))
return rules
}

View File

@ -51,12 +51,12 @@ func (f *fakeClient) AppendUnique(table, chain string, spec ...string) error {
if exists {
return nil
}
f.storage = append(f.storage, &rule{table, chain, spec})
f.storage = append(f.storage, &rule{table: table, chain: chain, spec: spec})
return nil
}
func (f *fakeClient) Delete(table, chain string, spec ...string) error {
r := &rule{table, chain, spec}
r := &rule{table: table, chain: chain, spec: spec}
for i := range f.storage {
if f.storage[i].String() == r.String() {
copy(f.storage[i:], f.storage[i+1:])
@ -69,7 +69,7 @@ func (f *fakeClient) Delete(table, chain string, spec ...string) error {
}
func (f *fakeClient) Exists(table, chain string, spec ...string) (bool, error) {
r := &rule{table, chain, spec}
r := &rule{table: table, chain: chain, spec: spec}
for i := range f.storage {
if f.storage[i].String() == r.String() {
return true, nil
@ -103,7 +103,7 @@ func (f *fakeClient) DeleteChain(table, name string) error {
return fmt.Errorf("cannot delete chain %s; rules exist", name)
}
}
c := &chain{table, name}
c := &chain{table: table, chain: name}
for i := range f.storage {
if f.storage[i].String() == c.String() {
copy(f.storage[i:], f.storage[i+1:])
@ -116,7 +116,7 @@ func (f *fakeClient) DeleteChain(table, name string) error {
}
func (f *fakeClient) NewChain(table, name string) error {
c := &chain{table, name}
c := &chain{table: table, chain: name}
for i := range f.storage {
if f.storage[i].String() == c.String() {
return statusError(1)

View File

@ -24,6 +24,24 @@ import (
"github.com/coreos/go-iptables/iptables"
)
// Protocol represents an IP protocol.
type Protocol byte
const (
// ProtocolIPv4 represents the IPv4 protocol.
ProtocolIPv4 Protocol = iota
// ProtocolIPv6 represents the IPv6 protocol.
ProtocolIPv6
)
// GetProtocol will return a protocol from the length of an IP address.
func GetProtocol(length int) Protocol {
if length == net.IPv6len {
return ProtocolIPv6
}
return ProtocolIPv4
}
// Client represents any type that can administer iptables rules.
type Client interface {
AppendUnique(table string, chain string, rule ...string) error
@ -40,6 +58,7 @@ type Rule interface {
Delete(Client) error
Exists(Client) (bool, error)
String() string
Proto() Protocol
}
// rule represents an iptables rule.
@ -47,11 +66,23 @@ type rule struct {
table string
chain string
spec []string
proto Protocol
}
// NewRule creates a new iptables rule in the given table and chain.
func NewRule(table, chain string, spec ...string) Rule {
return &rule{table, chain, spec}
// NewRule creates a new iptables or ip6tables rule in the given table and chain
// depending on the given protocol.
func NewRule(proto Protocol, table, chain string, spec ...string) Rule {
return &rule{table, chain, spec, proto}
}
// NewIPv4Rule creates a new iptables rule in the given table and chain.
func NewIPv4Rule(table, chain string, spec ...string) Rule {
return &rule{table, chain, spec, ProtocolIPv4}
}
// NewIPv6Rule creates a new ip6tables rule in the given table and chain.
func NewIPv6Rule(table, chain string, spec ...string) Rule {
return &rule{table, chain, spec, ProtocolIPv6}
}
func (r *rule) Add(client Client) error {
@ -79,15 +110,25 @@ func (r *rule) String() string {
return fmt.Sprintf("%s_%s_%s", r.table, r.chain, strings.Join(r.spec, "_"))
}
func (r *rule) Proto() Protocol {
return r.proto
}
// chain represents an iptables chain.
type chain struct {
table string
chain string
proto Protocol
}
// NewChain creates a new iptables chain in the given table.
func NewChain(table, name string) Rule {
return &chain{table, name}
// NewIPv4Chain creates a new iptables chain in the given table.
func NewIPv4Chain(table, name string) Rule {
return &chain{table, name, ProtocolIPv4}
}
// NewIPv6Chain creates a new ip6tables chain in the given table.
func NewIPv6Chain(table, name string) Rule {
return &chain{table, name, ProtocolIPv6}
}
func (c *chain) Add(client Client) error {
@ -133,9 +174,14 @@ func (c *chain) String() string {
return fmt.Sprintf("%s_%s", c.table, c.chain)
}
func (c *chain) Proto() Protocol {
return c.proto
}
// Controller is able to reconcile a given set of iptables rules.
type Controller struct {
client Client
v4 Client
v6 Client
errors chan error
sync.Mutex
@ -146,17 +192,18 @@ type Controller struct {
// New generates a new iptables rules controller.
// It expects an IP address length to determine
// whether to operate in IPv4 or IPv6 mode.
func New(ipLength int) (*Controller, error) {
p := iptables.ProtocolIPv4
if ipLength == net.IPv6len {
p = iptables.ProtocolIPv6
}
client, err := iptables.NewWithProtocol(p)
func New() (*Controller, error) {
v4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("failed to create iptables client: %v", err)
return nil, fmt.Errorf("failed to create iptables IPv4 client: %v", err)
}
v6, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
return nil, fmt.Errorf("failed to create iptables IPv6 client: %v", err)
}
return &Controller{
client: client,
v4: v4,
v6: v6,
errors: make(chan error),
}, nil
}
@ -196,7 +243,7 @@ func (c *Controller) reconcile() error {
c.Lock()
defer c.Unlock()
for i, r := range c.rules {
ok, err := r.Exists(c.client)
ok, err := r.Exists(c.client(r.Proto()))
if err != nil {
return fmt.Errorf("failed to check if rule exists: %v", err)
}
@ -216,10 +263,10 @@ func (c *Controller) resetFromIndex(i int, rules []Rule) error {
return nil
}
for j := i; j < len(rules); j++ {
if err := rules[j].Delete(c.client); err != nil {
if err := rules[j].Delete(c.client(rules[j].Proto())); err != nil {
return fmt.Errorf("failed to delete rule: %v", err)
}
if err := rules[j].Add(c.client); err != nil {
if err := rules[j].Add(c.client(rules[j].Proto())); err != nil {
return fmt.Errorf("failed to add rule: %v", err)
}
}
@ -232,7 +279,7 @@ func (c *Controller) deleteFromIndex(i int, rules *[]Rule) error {
return nil
}
for j := i; j < len(*rules); j++ {
if err := (*rules)[j].Delete(c.client); err != nil {
if err := (*rules)[j].Delete(c.client((*rules)[j].Proto())); err != nil {
return fmt.Errorf("failed to delete rule: %v", err)
}
(*rules)[j] = nil
@ -256,7 +303,7 @@ func (c *Controller) Set(rules []Rule) error {
}
}
if i >= len(c.rules) {
if err := rules[i].Add(c.client); err != nil {
if err := rules[i].Add(c.client(rules[i].Proto())); err != nil {
return fmt.Errorf("failed to add rule: %v", err)
}
c.rules = append(c.rules, rules[i])
@ -273,6 +320,17 @@ func (c *Controller) CleanUp() error {
return c.deleteFromIndex(0, &c.rules)
}
func (c *Controller) client(p Protocol) Client {
switch p {
case ProtocolIPv4:
return c.v4
case ProtocolIPv6:
return c.v6
default:
panic("unknown protocol")
}
}
func nonBlockingSend(errors chan<- error, err error) {
select {
case errors <- err:

View File

@ -19,8 +19,8 @@ import (
)
var rules = []Rule{
&rule{"filter", "FORWARD", []string{"-s", "10.4.0.0/16", "-j", "ACCEPT"}},
&rule{"filter", "FORWARD", []string{"-d", "10.4.0.0/16", "-j", "ACCEPT"}},
NewIPv4Rule("filter", "FORWARD", "-s", "10.4.0.0/16", "-j", "ACCEPT"),
NewIPv4Rule("filter", "FORWARD", "-d", "10.4.0.0/16", "-j", "ACCEPT"),
}
func TestSet(t *testing.T) {
@ -85,14 +85,15 @@ func TestSet(t *testing.T) {
} {
controller := &Controller{}
client := &fakeClient{}
controller.client = client
controller.v4 = client
controller.v6 = client
for i := range tc.sets {
if err := controller.Set(tc.sets[i]); err != nil {
t.Fatalf("test case %q: got unexpected error seting rule set %d: %v", tc.name, i, err)
}
}
for i, f := range tc.actions {
if err := f(controller.client); err != nil {
if err := f(controller.v4); err != nil {
t.Fatalf("test case %q action %d: got unexpected error %v", tc.name, i, err)
}
}
@ -140,7 +141,8 @@ func TestCleanUp(t *testing.T) {
} {
controller := &Controller{}
client := &fakeClient{}
controller.client = client
controller.v4 = client
controller.v6 = client
if err := controller.Set(tc.rules); err != nil {
t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
}

View File

@ -252,7 +252,7 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit
}
level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the private IP address", privateIP.String()))
level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the public IP address", publicIP.String()))
ipTables, err := iptables.New(len(subnet.IP))
ipTables, err := iptables.New()
if err != nil {
return nil, fmt.Errorf("failed to IP tables controller: %v", err)
}

View File

@ -436,25 +436,27 @@ func (t *Topology) PeerConf(name string) *wireguard.Conf {
// Rules returns the iptables rules required by the local node.
func (t *Topology) Rules(cni bool) []iptables.Rule {
var rules []iptables.Rule
rules = append(rules, iptables.NewChain("nat", "KILO-NAT"))
rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
if cni {
rules = append(rules, iptables.NewRule("nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-s", t.subnet.String(), "-j", "KILO-NAT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-s", t.subnet.String(), "-j", "KILO-NAT"))
}
for _, s := range t.segments {
rules = append(rules, iptables.NewRule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-d", s.wireGuardIP.String(), "-j", "RETURN"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(s.wireGuardIP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-d", s.wireGuardIP.String(), "-j", "RETURN"))
for _, aip := range s.allowedIPs {
rules = append(rules, iptables.NewRule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-d", aip.String(), "-j", "RETURN"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-d", aip.String(), "-j", "RETURN"))
}
}
for _, p := range t.peers {
for _, aip := range p.AllowedIPs {
rules = append(rules,
iptables.NewRule("nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-s", aip.String(), "-j", "KILO-NAT"),
iptables.NewRule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-d", aip.String(), "-j", "RETURN"),
iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-s", aip.String(), "-j", "KILO-NAT"),
iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-d", aip.String(), "-j", "RETURN"),
)
}
}
rules = append(rules, iptables.NewRule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
rules = append(rules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
rules = append(rules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
return rules
}