This commit is contained in:
Lucas Serven
2019-01-18 02:50:10 +01:00
commit e989f0a25f
1789 changed files with 680059 additions and 0 deletions

59
pkg/iproute/ipip.go Normal file
View File

@@ -0,0 +1,59 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package iproute
import (
"bytes"
"fmt"
"os/exec"
"github.com/vishvananda/netlink"
)
const (
ipipHeaderSize = 20
tunnelName = "tunl0"
)
// NewIPIP creates an IPIP interface using the base interface
// to derive the tunnel's MTU.
func NewIPIP(baseIndex int) (int, error) {
link, err := netlink.LinkByName(tunnelName)
if err != nil {
// If we failed to find the tunnel, then it probably simply does not exist.
cmd := exec.Command("ip", "tunnel", "add", tunnelName, "mode", "ipip")
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return 0, fmt.Errorf("failed to create IPIP tunnel: %s", stderr.String())
}
link, err = netlink.LinkByName(tunnelName)
if err != nil {
return 0, fmt.Errorf("failed to get tunnel device: %v", err)
}
}
base, err := netlink.LinkByIndex(baseIndex)
if err != nil {
return 0, fmt.Errorf("failed to get base device: %v", err)
}
mtu := base.Attrs().MTU - ipipHeaderSize
if err = netlink.LinkSetMTU(link, mtu); err != nil {
return 0, fmt.Errorf("failed to set tunnel MTU: %v", err)
}
return link.Attrs().Index, nil
}

70
pkg/iproute/iproute.go Normal file
View File

@@ -0,0 +1,70 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package iproute
import (
"fmt"
"net"
"github.com/vishvananda/netlink"
)
// RemoveInterface removes an interface.
func RemoveInterface(index int) error {
link, err := netlink.LinkByIndex(index)
if err != nil {
return fmt.Errorf("failed to get link: %s", err)
}
return netlink.LinkDel(link)
}
// Set sets the interface up or down.
func Set(index int, up bool) error {
link, err := netlink.LinkByIndex(index)
if err != nil {
return fmt.Errorf("failed to get link: %s", err)
}
if up {
return netlink.LinkSetUp(link)
}
return netlink.LinkSetDown(link)
}
// SetAddress sets the IP address of an interface.
func SetAddress(index int, cidr *net.IPNet) error {
link, err := netlink.LinkByIndex(index)
if err != nil {
return fmt.Errorf("failed to get link: %s", err)
}
addrs, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return err
}
l := len(addrs)
for _, addr := range addrs {
if addr.IP.Equal(cidr.IP) && addr.Mask.String() == cidr.Mask.String() {
continue
}
if err := netlink.AddrDel(link, &addr); err != nil {
return fmt.Errorf("failed to delete address: %s", err)
}
l--
}
// The only address left is the desired address, so quit.
if l == 1 {
return nil
}
return netlink.AddrReplace(link, &netlink.Addr{IPNet: cidr})
}

199
pkg/ipset/ipset.go Normal file
View File

@@ -0,0 +1,199 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ipset
import (
"bytes"
"fmt"
"net"
"os/exec"
"sync"
"time"
)
// Set represents an ipset.
// Set can safely be used concurrently.
type Set struct {
errors chan error
hosts map[string]struct{}
mu sync.Mutex
name string
subscribed bool
// Make these functions fields to allow
// for testing.
add func(string) error
del func(string) error
}
func setExists(name string) (bool, error) {
cmd := exec.Command("ipset", "list", "-n")
var stderr, stdout bytes.Buffer
cmd.Stderr = &stderr
cmd.Stdout = &stdout
if err := cmd.Run(); err != nil {
return false, fmt.Errorf("failed to check for set %s: %s", name, stderr.String())
}
return bytes.Contains(stdout.Bytes(), []byte(name)), nil
}
func hostInSet(set, name string) (bool, error) {
cmd := exec.Command("ipset", "list", set)
var stderr, stdout bytes.Buffer
cmd.Stderr = &stderr
cmd.Stdout = &stdout
if err := cmd.Run(); err != nil {
return false, fmt.Errorf("failed to check for host %s: %s", name, stderr.String())
}
return bytes.Contains(stdout.Bytes(), []byte(name)), nil
}
// New generates a new ipset.
func New(name string) *Set {
return &Set{
errors: make(chan error),
hosts: make(map[string]struct{}),
name: name,
add: func(ip string) error {
ok, err := hostInSet(name, ip)
if err != nil {
return err
}
if !ok {
cmd := exec.Command("ipset", "add", name, ip)
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to add host %s to set %s: %s", ip, name, stderr.String())
}
}
return nil
},
del: func(ip string) error {
ok, err := hostInSet(name, ip)
if err != nil {
return err
}
if ok {
cmd := exec.Command("ipset", "del", name, ip)
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to remove host %s from set %s: %s", ip, name, stderr.String())
}
}
return nil
},
}
}
// Run watches for changes to the ipset and reconciles
// the ipset against the desired state.
func (s *Set) Run(stop <-chan struct{}) (<-chan error, error) {
s.mu.Lock()
if s.subscribed {
s.mu.Unlock()
return s.errors, nil
}
// Ensure a given instance only subscribes once.
s.subscribed = true
s.mu.Unlock()
go func() {
defer close(s.errors)
for {
select {
case <-time.After(2 * time.Second):
case <-stop:
return
}
ok, err := setExists(s.name)
if err != nil {
nonBlockingSend(s.errors, err)
}
// The set does not exist so wait and try again later.
if !ok {
continue
}
s.mu.Lock()
for h := range s.hosts {
if err := s.add(h); err != nil {
nonBlockingSend(s.errors, err)
}
}
s.mu.Unlock()
}
}()
return s.errors, nil
}
// CleanUp will clean up any hosts added to the set.
func (s *Set) CleanUp() error {
s.mu.Lock()
defer s.mu.Unlock()
for h := range s.hosts {
if err := s.del(h); err != nil {
return err
}
delete(s.hosts, h)
}
return nil
}
// Set idempotently overwrites any hosts previously defined
// for the ipset with the given hosts.
func (s *Set) Set(hosts []net.IP) error {
h := make(map[string]struct{})
for _, host := range hosts {
if host == nil {
continue
}
h[host.String()] = struct{}{}
}
exists, err := setExists(s.name)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
for k := range s.hosts {
if _, ok := h[k]; !ok {
if exists {
if err := s.del(k); err != nil {
return err
}
}
delete(s.hosts, k)
}
}
for k := range h {
if _, ok := s.hosts[k]; !ok {
if exists {
if err := s.add(k); err != nil {
return err
}
}
s.hosts[k] = struct{}{}
}
}
return nil
}
func nonBlockingSend(errors chan<- error, err error) {
select {
case errors <- err:
default:
}
}

92
pkg/iptables/fake.go Normal file
View File

@@ -0,0 +1,92 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package iptables
import (
"fmt"
"strings"
"github.com/coreos/go-iptables/iptables"
)
type statusExiter interface {
ExitStatus() int
}
var _ statusExiter = (*iptables.Error)(nil)
var _ statusExiter = statusError(0)
type statusError int
func (s statusError) Error() string {
return fmt.Sprintf("%d", s)
}
func (s statusError) ExitStatus() int {
return int(s)
}
type fakeClient map[string]Rule
var _ iptablesClient = fakeClient(nil)
func (f fakeClient) AppendUnique(table, chain string, spec ...string) error {
r := &rule{table, chain, spec, nil}
f[r.String()] = r
return nil
}
func (f fakeClient) Delete(table, chain string, spec ...string) error {
r := &rule{table, chain, spec, nil}
delete(f, r.String())
return nil
}
func (f fakeClient) Exists(table, chain string, spec ...string) (bool, error) {
r := &rule{table, chain, spec, nil}
_, ok := f[r.String()]
return ok, nil
}
func (f fakeClient) ClearChain(table, name string) error {
c := &chain{table, name, nil}
for k := range f {
if strings.HasPrefix(k, c.String()) {
delete(f, k)
}
}
f[c.String()] = c
return nil
}
func (f fakeClient) DeleteChain(table, name string) error {
c := &chain{table, name, nil}
for k := range f {
if strings.HasPrefix(k, c.String()) {
return fmt.Errorf("cannot delete chain %s; rules exist", name)
}
}
delete(f, c.String())
return nil
}
func (f fakeClient) NewChain(table, name string) error {
c := &chain{table, name, nil}
if _, ok := f[c.String()]; ok {
return statusError(1)
}
f[c.String()] = c
return nil
}

289
pkg/iptables/iptables.go Normal file
View File

@@ -0,0 +1,289 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package iptables
import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/coreos/go-iptables/iptables"
)
type iptablesClient interface {
AppendUnique(string, string, ...string) error
Delete(string, string, ...string) error
Exists(string, string, ...string) (bool, error)
ClearChain(string, string) error
DeleteChain(string, string) error
NewChain(string, string) error
}
// rule represents an iptables rule.
type rule struct {
table string
chain string
spec []string
client iptablesClient
}
func (r *rule) Add() error {
if err := r.client.AppendUnique(r.table, r.chain, r.spec...); err != nil {
return fmt.Errorf("failed to add iptables rule: %v", err)
}
return nil
}
func (r *rule) Delete() error {
// Ignore the returned error as an error likely means
// that the rule doesn't exist, which is fine.
r.client.Delete(r.table, r.chain, r.spec...)
return nil
}
func (r *rule) Exists() (bool, error) {
return r.client.Exists(r.table, r.chain, r.spec...)
}
func (r *rule) String() string {
if r == nil {
return ""
}
return fmt.Sprintf("%s_%s_%s", r.table, r.chain, strings.Join(r.spec, "_"))
}
// chain represents an iptables chain.
type chain struct {
table string
chain string
client iptablesClient
}
func (c *chain) Add() error {
if err := c.client.ClearChain(c.table, c.chain); err != nil {
return fmt.Errorf("failed to add iptables chain: %v", err)
}
return nil
}
func (c *chain) Delete() error {
// The chain must be empty before it can be deleted.
if err := c.client.ClearChain(c.table, c.chain); err != nil {
return fmt.Errorf("failed to clear iptables chain: %v", err)
}
// Ignore the returned error as an error likely means
// that the chain doesn't exist, which is fine.
c.client.DeleteChain(c.table, c.chain)
return nil
}
func (c *chain) Exists() (bool, error) {
// The code for "chain already exists".
existsErr := 1
err := c.client.NewChain(c.table, c.chain)
se, ok := err.(statusExiter)
switch {
case err == nil:
// If there was no error adding a new chain, then it did not exist.
// Delete it and return false.
c.client.DeleteChain(c.table, c.chain)
return false, nil
case ok && se.ExitStatus() == existsErr:
return true, nil
default:
return false, err
}
}
func (c *chain) String() string {
if c == nil {
return ""
}
return fmt.Sprintf("%s_%s", c.table, c.chain)
}
// Rule is an interface for interacting with iptables objects.
type Rule interface {
Add() error
Delete() error
Exists() (bool, error)
String() string
}
// Controller is able to reconcile a given set of iptables rules.
type Controller struct {
client iptablesClient
errors chan error
rules map[string]Rule
mu sync.Mutex
subscribed bool
}
// New generates a new iptables rules controller.
// It expects an IP address length to determine
// whether to operate in IPv4 or IPv6 mode.
func New(ipLength int) (*Controller, error) {
p := iptables.ProtocolIPv4
if ipLength == net.IPv6len {
p = iptables.ProtocolIPv6
}
client, err := iptables.NewWithProtocol(p)
if err != nil {
return nil, fmt.Errorf("failed to create iptables client: %v", err)
}
return &Controller{
client: client,
errors: make(chan error),
rules: make(map[string]Rule),
}, nil
}
// Run watches for changes to iptables rules and reconciles
// the rules against the desired state.
func (c *Controller) Run(stop <-chan struct{}) (<-chan error, error) {
c.mu.Lock()
if c.subscribed {
c.mu.Unlock()
return c.errors, nil
}
// Ensure a given instance only subscribes once.
c.subscribed = true
c.mu.Unlock()
go func() {
defer close(c.errors)
for {
select {
case <-time.After(5 * time.Second):
case <-stop:
return
}
c.mu.Lock()
for _, r := range c.rules {
ok, err := r.Exists()
if err != nil {
nonBlockingSend(c.errors, fmt.Errorf("failed to check if rule exists: %v", err))
}
if !ok {
if err := r.Add(); err != nil {
nonBlockingSend(c.errors, fmt.Errorf("failed to add rule: %v", err))
}
}
}
c.mu.Unlock()
}
}()
return c.errors, nil
}
// Set idempotently overwrites any iptables rules previously defined
// for the controller with the given set of rules.
func (c *Controller) Set(rules []Rule) error {
r := make(map[string]struct{})
for i := range rules {
if rules[i] == nil {
continue
}
switch v := rules[i].(type) {
case *rule:
v.client = c.client
case *chain:
v.client = c.client
}
r[rules[i].String()] = struct{}{}
}
c.mu.Lock()
defer c.mu.Unlock()
for k, rule := range c.rules {
if _, ok := r[k]; !ok {
if err := rule.Delete(); err != nil {
return fmt.Errorf("failed to delete rule: %v", err)
}
delete(c.rules, k)
}
}
// Iterate over the slice rather than the map
// to ensure the rules are added in order.
for _, rule := range rules {
if _, ok := c.rules[rule.String()]; !ok {
if err := rule.Add(); err != nil {
return fmt.Errorf("failed to add rule: %v", err)
}
c.rules[rule.String()] = rule
}
}
return nil
}
// CleanUp will clean up any rules created by the controller.
func (c *Controller) CleanUp() error {
c.mu.Lock()
defer c.mu.Unlock()
for k, rule := range c.rules {
if err := rule.Delete(); err != nil {
return fmt.Errorf("failed to delete rule: %v", err)
}
delete(c.rules, k)
}
return nil
}
// EncapsulateRules returns a set of iptables rules that are necessary
// when traffic between nodes must be encapsulated.
func EncapsulateRules(nodes []*net.IPNet) []Rule {
var rules []Rule
for _, n := range nodes {
// Accept encapsulated traffic from peers.
rules = append(rules, &rule{"filter", "INPUT", []string{"-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-s", n.IP.String(), "-p", "4", "-j", "ACCEPT"}, nil})
}
return rules
}
// ForwardRules returns a set of iptables rules that are necessary
// when traffic must be forwarded for the overlay.
func ForwardRules(subnet *net.IPNet) []Rule {
s := subnet.String()
return []Rule{
// Forward traffic to and from the overlay.
&rule{"filter", "FORWARD", []string{"-s", s, "-j", "ACCEPT"}, nil},
&rule{"filter", "FORWARD", []string{"-d", s, "-j", "ACCEPT"}, nil},
}
}
// MasqueradeRules returns a set of iptables rules that are necessary
// when traffic must be masqueraded for Kilo.
func MasqueradeRules(subnet, localPodSubnet *net.IPNet, remotePodSubnet []*net.IPNet) []Rule {
var rules []Rule
rules = append(rules, &chain{"mangle", "KILO-MARK", nil})
rules = append(rules, &rule{"mangle", "PREROUTING", []string{"-m", "comment", "--comment", "Kilo: jump to mark chain", "-i", "kilo+", "-j", "KILO-MARK"}, nil})
rules = append(rules, &rule{"mangle", "KILO-MARK", []string{"-m", "comment", "--comment", "Kilo: do not mark packets destined for the local Pod subnet", "-d", localPodSubnet.String(), "-j", "RETURN"}, nil})
if subnet != nil {
rules = append(rules, &rule{"mangle", "KILO-MARK", []string{"-m", "comment", "--comment", "Kilo: do not mark packets destined for the local private subnet", "-d", subnet.String(), "-j", "RETURN"}, nil})
}
rules = append(rules, &rule{"mangle", "KILO-MARK", []string{"-m", "comment", "--comment", "Kilo: remaining packets should be marked for NAT", "-j", "MARK", "--set-xmark", "0x1107/0x1107"}, nil})
rules = append(rules, &rule{"nat", "POSTROUTING", []string{"-m", "comment", "--comment", "Kilo: NAT packets from Kilo interface", "-m", "mark", "--mark", "0x1107/0x1107", "-j", "MASQUERADE"}, nil})
for _, r := range remotePodSubnet {
rules = append(rules, &rule{"nat", "POSTROUTING", []string{"-m", "comment", "--comment", "Kilo: NAT packets from local pod subnet to remote pod subnets", "-s", localPodSubnet.String(), "-d", r.String(), "-j", "MASQUERADE"}, nil})
}
return rules
}
func nonBlockingSend(errors chan<- error, err error) {
select {
case errors <- err:
default:
}
}

View File

@@ -0,0 +1,101 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package iptables
import (
"testing"
)
var rules = []Rule{
&rule{"filter", "FORWARD", []string{"-s", "10.4.0.0/16", "-j", "ACCEPT"}, nil},
&rule{"filter", "FORWARD", []string{"-d", "10.4.0.0/16", "-j", "ACCEPT"}, nil},
}
func newController() *Controller {
return &Controller{
rules: make(map[string]Rule),
}
}
func TestSet(t *testing.T) {
for _, tc := range []struct {
name string
rules []Rule
}{
{
name: "empty",
rules: nil,
},
{
name: "single",
rules: []Rule{rules[0]},
},
{
name: "multiple",
rules: []Rule{rules[0], rules[1]},
},
} {
backend := make(map[string]Rule)
controller := newController()
controller.client = fakeClient(backend)
if err := controller.Set(tc.rules); err != nil {
t.Fatalf("test case %q: got unexpected error: %v", tc.name, err)
}
for _, r := range tc.rules {
r1 := backend[r.String()]
r2 := controller.rules[r.String()]
if r.String() != r1.String() || r.String() != r2.String() {
t.Errorf("test case %q: expected all rules to be equal: expected %v, got %v and %v", tc.name, r, r1, r2)
}
}
}
}
func TestCleanUp(t *testing.T) {
for _, tc := range []struct {
name string
rules []Rule
}{
{
name: "empty",
rules: nil,
},
{
name: "single",
rules: []Rule{rules[0]},
},
{
name: "multiple",
rules: []Rule{rules[0], rules[1]},
},
} {
backend := make(map[string]Rule)
controller := newController()
controller.client = fakeClient(backend)
if err := controller.Set(tc.rules); err != nil {
t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
}
if err := controller.CleanUp(); err != nil {
t.Errorf("test case %q: got unexpected error: %v", tc.name, err)
}
for _, r := range tc.rules {
r1 := backend[r.String()]
r2 := controller.rules[r.String()]
if r1 != nil || r2 != nil {
t.Errorf("test case %q: expected all rules to be nil: expected got %v and %v", tc.name, r1, r2)
}
}
}
}

229
pkg/k8s/backend.go Normal file
View File

@@ -0,0 +1,229 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package k8s
import (
"encoding/json"
"errors"
"fmt"
"net"
"path"
"strings"
"time"
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/strategicpatch"
v1informers "k8s.io/client-go/informers/core/v1"
"k8s.io/client-go/kubernetes"
v1listers "k8s.io/client-go/listers/core/v1"
"k8s.io/client-go/tools/cache"
"github.com/squat/kilo/pkg/mesh"
)
const (
// Backend is the name of this mesh backend.
Backend = "kubernetes"
externalIPAnnotationKey = "kilo.squat.ai/external-ip"
forceExternalIPAnnotationKey = "kilo.squat.ai/force-external-ip"
internalIPAnnotationKey = "kilo.squat.ai/internal-ip"
keyAnnotationKey = "kilo.squat.ai/key"
leaderAnnotationKey = "kilo.squat.ai/leader"
locationAnnotationKey = "kilo.squat.ai/location"
regionLabelKey = "failure-domain.beta.kubernetes.io/region"
jsonPatchSlash = "~1"
jsonRemovePatch = `{"op": "remove", "path": "%s"}`
)
type backend struct {
client kubernetes.Interface
events chan *mesh.Event
informer cache.SharedIndexInformer
lister v1listers.NodeLister
}
// New creates a new instance of a mesh.Backend.
func New(client kubernetes.Interface) mesh.Backend {
informer := v1informers.NewNodeInformer(client, 5*time.Minute, nil)
b := &backend{
client: client,
events: make(chan *mesh.Event),
informer: informer,
lister: v1listers.NewNodeLister(informer.GetIndexer()),
}
return b
}
// CleanUp removes configuration applied to the backend.
func (b *backend) CleanUp(name string) error {
patch := []byte("[" + strings.Join([]string{
fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(externalIPAnnotationKey, "/", jsonPatchSlash, 1))),
fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(internalIPAnnotationKey, "/", jsonPatchSlash, 1))),
fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(keyAnnotationKey, "/", jsonPatchSlash, 1))),
}, ",") + "]")
if _, err := b.client.CoreV1().Nodes().Patch(name, types.JSONPatchType, patch); err != nil {
return fmt.Errorf("failed to patch node: %v", err)
}
return nil
}
// Get gets a single Node by name.
func (b *backend) Get(name string) (*mesh.Node, error) {
n, err := b.lister.Get(name)
if err != nil {
return nil, err
}
return translateNode(n), nil
}
// Init initializes the backend; for this backend that means
// syncing the informer cache.
func (b *backend) Init(stop <-chan struct{}) error {
go b.informer.Run(stop)
if ok := cache.WaitForCacheSync(stop, func() bool {
return b.informer.HasSynced()
}); !ok {
return errors.New("failed to start sync node cache")
}
b.informer.AddEventHandler(
cache.ResourceEventHandlerFuncs{
AddFunc: func(obj interface{}) {
n, ok := obj.(*v1.Node)
if !ok {
// Failed to decode Node; ignoring...
return
}
b.events <- &mesh.Event{Type: mesh.AddEvent, Node: translateNode(n)}
},
UpdateFunc: func(_, obj interface{}) {
n, ok := obj.(*v1.Node)
if !ok {
// Failed to decode Node; ignoring...
return
}
b.events <- &mesh.Event{Type: mesh.UpdateEvent, Node: translateNode(n)}
},
DeleteFunc: func(obj interface{}) {
n, ok := obj.(*v1.Node)
if !ok {
// Failed to decode Node; ignoring...
return
}
b.events <- &mesh.Event{Type: mesh.DeleteEvent, Node: translateNode(n)}
},
},
)
return nil
}
// List gets all the Nodes in the cluster.
func (b *backend) List() ([]*mesh.Node, error) {
ns, err := b.lister.List(labels.Everything())
if err != nil {
return nil, err
}
nodes := make([]*mesh.Node, len(ns))
for i := range ns {
nodes[i] = translateNode(ns[i])
}
return nodes, nil
}
// Set sets the fields of a node.
func (b *backend) Set(name string, node *mesh.Node) error {
old, err := b.lister.Get(name)
if err != nil {
return fmt.Errorf("failed to find node: %v", err)
}
n := old.DeepCopy()
n.ObjectMeta.Annotations[externalIPAnnotationKey] = node.ExternalIP.String()
n.ObjectMeta.Annotations[internalIPAnnotationKey] = node.InternalIP.String()
n.ObjectMeta.Annotations[keyAnnotationKey] = string(node.Key)
oldData, err := json.Marshal(old)
if err != nil {
return err
}
newData, err := json.Marshal(n)
if err != nil {
return err
}
patch, err := strategicpatch.CreateTwoWayMergePatch(oldData, newData, v1.Node{})
if err != nil {
return fmt.Errorf("failed to create patch for node %q: %v", n.Name, err)
}
if _, err = b.client.CoreV1().Nodes().Patch(name, types.StrategicMergePatchType, patch); err != nil {
return fmt.Errorf("failed to patch node: %v", err)
}
return nil
}
// Watch returns a chan of node events.
func (b *backend) Watch() <-chan *mesh.Event {
return b.events
}
// translateNode translates a Kubernetes Node to a mesh.Node.
func translateNode(node *v1.Node) *mesh.Node {
if node == nil {
return nil
}
_, subnet, err := net.ParseCIDR(node.Spec.PodCIDR)
// The subnet should only ever fail to parse if the pod CIDR has not been set,
// so in this case set the subnet to nil and let the node be updated.
if err != nil {
subnet = nil
}
_, leader := node.ObjectMeta.Annotations[leaderAnnotationKey]
// Allow the region to be overridden by an explicit location.
location, ok := node.ObjectMeta.Annotations[locationAnnotationKey]
if !ok {
location = node.ObjectMeta.Labels[regionLabelKey]
}
// Allow the external IP to be overridden.
externalIP, ok := node.ObjectMeta.Annotations[forceExternalIPAnnotationKey]
if !ok {
externalIP = node.ObjectMeta.Annotations[externalIPAnnotationKey]
}
return &mesh.Node{
// ExternalIP and InternalIP should only ever fail to parse if the
// remote node's mesh has not yet set its IP address;
// in this case the IP will be nil and
// the mesh can wait for the node to be updated.
ExternalIP: normalizeIP(externalIP),
InternalIP: normalizeIP(node.ObjectMeta.Annotations[internalIPAnnotationKey]),
Key: []byte(node.ObjectMeta.Annotations[keyAnnotationKey]),
Leader: leader,
Location: location,
Name: node.Name,
Subnet: subnet,
}
}
func normalizeIP(ip string) *net.IPNet {
i, ipNet, _ := net.ParseCIDR(ip)
if ipNet == nil {
return ipNet
}
if ip4 := i.To4(); ip4 != nil {
ipNet.IP = ip4
return ipNet
}
ipNet.IP = i.To16()
return ipNet
}

145
pkg/k8s/backend_test.go Normal file
View File

@@ -0,0 +1,145 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package k8s
import (
"net"
"testing"
"github.com/kylelemons/godebug/pretty"
"k8s.io/api/core/v1"
"github.com/squat/kilo/pkg/mesh"
)
func TestTranslateNode(t *testing.T) {
for _, tc := range []struct {
name string
annotations map[string]string
labels map[string]string
out *mesh.Node
subnet string
}{
{
name: "empty",
annotations: nil,
out: &mesh.Node{},
},
{
name: "invalid ip",
annotations: map[string]string{
externalIPAnnotationKey: "10.0.0.1",
internalIPAnnotationKey: "10.0.0.1",
},
out: &mesh.Node{},
},
{
name: "valid ip",
annotations: map[string]string{
externalIPAnnotationKey: "10.0.0.1/24",
internalIPAnnotationKey: "10.0.0.2/32",
},
out: &mesh.Node{
ExternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
},
},
{
name: "invalid subnet",
annotations: map[string]string{},
out: &mesh.Node{},
subnet: "foo",
},
{
name: "normalize subnet",
annotations: map[string]string{},
out: &mesh.Node{
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(24, 32)},
},
subnet: "10.2.0.1/24",
},
{
name: "valid subnet",
annotations: map[string]string{},
out: &mesh.Node{
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
},
subnet: "10.2.1.0/24",
},
{
name: "region",
labels: map[string]string{
regionLabelKey: "a",
},
out: &mesh.Node{
Location: "a",
},
},
{
name: "region override",
annotations: map[string]string{
locationAnnotationKey: "b",
},
labels: map[string]string{
regionLabelKey: "a",
},
out: &mesh.Node{
Location: "b",
},
},
{
name: "external IP override",
annotations: map[string]string{
externalIPAnnotationKey: "10.0.0.1/24",
forceExternalIPAnnotationKey: "10.0.0.2/24",
},
out: &mesh.Node{
ExternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
},
},
{
name: "complete",
annotations: map[string]string{
externalIPAnnotationKey: "10.0.0.1/24",
forceExternalIPAnnotationKey: "10.0.0.2/24",
internalIPAnnotationKey: "10.0.0.2/32",
keyAnnotationKey: "foo",
leaderAnnotationKey: "",
locationAnnotationKey: "b",
},
labels: map[string]string{
regionLabelKey: "a",
},
out: &mesh.Node{
ExternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
Key: []byte("foo"),
Leader: true,
Location: "b",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
},
subnet: "10.2.1.0/24",
},
} {
n := &v1.Node{}
n.ObjectMeta.Annotations = tc.annotations
n.ObjectMeta.Labels = tc.labels
n.Spec.PodCIDR = tc.subnet
node := translateNode(n)
if diff := pretty.Compare(node, tc.out); diff != "" {
t.Errorf("test case %q: got diff: %v", tc.name, diff)
}
}
}

101
pkg/mesh/graph.go Normal file
View File

@@ -0,0 +1,101 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mesh
import (
"fmt"
"net"
"github.com/awalterschulze/gographviz"
)
// Dot generates a Graphviz graph of the Topology in DOT fomat.
func (t *Topology) Dot() (string, error) {
g := gographviz.NewGraph()
g.Name = "kilo"
if err := g.AddAttr("kilo", string(gographviz.Label), graphEscape(t.subnet.String())); err != nil {
return "", fmt.Errorf("failed to add label to graph")
}
if err := g.AddAttr("kilo", string(gographviz.LabelLOC), "t"); err != nil {
return "", fmt.Errorf("failed to add label location to graph")
}
if err := g.AddAttr("kilo", string(gographviz.Overlap), "false"); err != nil {
return "", fmt.Errorf("failed to disable graph overlap")
}
if err := g.SetDir(true); err != nil {
return "", fmt.Errorf("failed to set direction")
}
leaders := make([]string, len(t.Segments))
nodeAttrs := map[string]string{
string(gographviz.Shape): "ellipse",
}
for i, s := range t.Segments {
if err := g.AddSubGraph("kilo", subGraphName(s.Location), nil); err != nil {
return "", fmt.Errorf("failed to add subgraph")
}
if err := g.AddAttr(subGraphName(s.Location), string(gographviz.Label), graphEscape(s.Location)); err != nil {
return "", fmt.Errorf("failed to add label to subgraph")
}
if err := g.AddAttr(subGraphName(s.Location), string(gographviz.Style), `"dashed,rounded"`); err != nil {
return "", fmt.Errorf("failed to add style to subgraph")
}
for j := range s.cidrs {
if err := g.AddNode(subGraphName(s.Location), graphEscape(s.hostnames[j]), nodeAttrs); err != nil {
return "", fmt.Errorf("failed to add node to subgraph")
}
var wg net.IP
if j == s.leader {
wg = s.wireGuardIP
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Rank), "1"); err != nil {
return "", fmt.Errorf("failed to add rank to node")
}
}
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.Location, s.hostnames[j], s.cidrs[j], s.privateIPs[j], wg)); err != nil {
return "", fmt.Errorf("failed to add label to node")
}
}
meshSubGraph(g, g.Relations.SortedChildren(subGraphName(s.Location)), s.leader)
leaders[i] = graphEscape(s.hostnames[s.leader])
}
meshSubGraph(g, leaders, 0)
return g.String(), nil
}
func meshSubGraph(g *gographviz.Graph, nodes []string, leader int) {
for i := range nodes {
if i == leader {
continue
}
a := make(gographviz.Attrs)
a[gographviz.Dir] = "both"
g.Edges.Add(&gographviz.Edge{Src: nodes[leader], Dst: nodes[i], Dir: true, Attrs: a})
}
}
func graphEscape(s string) string {
return fmt.Sprintf("\"%s\"", s)
}
func subGraphName(name string) string {
return graphEscape(fmt.Sprintf("cluster_%s", name))
}
func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP) string {
var wg string
if wgIP != nil {
wg = wgIP.String()
}
return graphEscape(fmt.Sprintf("%s\n%s\n%s\n%s\n%s", location, name, cidr.String(), priv.String(), wg))
}

348
pkg/mesh/ip.go Normal file
View File

@@ -0,0 +1,348 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mesh
import (
"errors"
"fmt"
"net"
"sort"
"github.com/vishvananda/netlink"
)
// getIP returns a private and public IP address for the local node.
// It selects the private IP address in the following order:
// - private IP to which hostname resolves
// - private IP assigned to interface of default route
// - private IP assigned to local interface
// - public IP to which hostname resolves
// - public IP assigned to interface of default route
// - public IP assigned to local interface
// It selects the public IP address in the following order:
// - public IP to which hostname resolves
// - public IP assigned to interface of default route
// - public IP assigned to local interface
// - private IP to which hostname resolves
// - private IP assigned to interface of default route
// - private IP assigned to local interface
// - if no IP was found, return nil and an error.
func getIP(hostname string) (*net.IPNet, *net.IPNet, error) {
var hostPriv, hostPub []*net.IPNet
{
// Check IPs to which hostname resolves first.
ips, err := ipsForHostname(hostname)
if err != nil {
return nil, nil, err
}
for _, ip := range ips {
ok, mask, err := assignedToInterface(ip)
if err != nil {
return nil, nil, fmt.Errorf("failed to search locally assigned addresses: %v", err)
}
if !ok {
continue
}
ip.Mask = mask
if isPublic(ip) {
hostPub = append(hostPub, ip)
continue
}
hostPriv = append(hostPriv, ip)
}
sortIPs(hostPriv)
sortIPs(hostPub)
}
var defaultPriv, defaultPub []*net.IPNet
{
// Check IPs on interface for default route next.
iface, err := defaultInterface()
if err != nil {
return nil, nil, err
}
ips, err := ipsForInterface(iface)
if err != nil {
return nil, nil, err
}
for _, ip := range ips {
if isLocal(ip.IP) {
continue
}
if isPublic(ip) {
defaultPub = append(defaultPub, ip)
continue
}
defaultPriv = append(defaultPriv, ip)
}
sortIPs(defaultPriv)
sortIPs(defaultPub)
}
var interfacePriv, interfacePub []*net.IPNet
{
// Finally look for IPs on all interfaces.
ips, err := ipsForAllInterfaces()
if err != nil {
return nil, nil, err
}
for _, ip := range ips {
if isLocal(ip.IP) {
continue
}
if isPublic(ip) {
interfacePub = append(interfacePub, ip)
continue
}
interfacePriv = append(interfacePriv, ip)
}
sortIPs(interfacePriv)
sortIPs(interfacePub)
}
var priv, pub []*net.IPNet
priv = append(priv, hostPriv...)
priv = append(priv, defaultPriv...)
priv = append(priv, interfacePriv...)
pub = append(pub, hostPub...)
pub = append(pub, defaultPub...)
pub = append(pub, interfacePub...)
if len(priv) == 0 && len(pub) == 0 {
return nil, nil, errors.New("no valid IP was found")
}
if len(priv) == 0 {
priv = pub
}
if len(pub) == 0 {
pub = priv
}
return priv[0], pub[0], nil
}
// sortIPs sorts IPs so the result is stable.
// It will first sort IPs by type, to prefer selecting
// IPs of the same type, and then by value.
func sortIPs(ips []*net.IPNet) {
sort.Slice(ips, func(i, j int) bool {
i4, j4 := ips[i].IP.To4(), ips[j].IP.To4()
if i4 != nil && j4 == nil {
return true
}
if j4 != nil && i4 == nil {
return false
}
return ips[i].String() < ips[j].String()
})
}
func assignedToInterface(ip *net.IPNet) (bool, net.IPMask, error) {
links, err := netlink.LinkList()
if err != nil {
return false, nil, fmt.Errorf("failed to list interfaces: %v", err)
}
// Sort the links for stability.
sort.Slice(links, func(i, j int) bool {
return links[i].Attrs().Name < links[j].Attrs().Name
})
for _, link := range links {
addrs, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return false, nil, fmt.Errorf("failed to list addresses for %s: %v", link.Attrs().Name, err)
}
// Sort the IPs for stability.
sort.Slice(addrs, func(i, j int) bool {
return addrs[i].String() < addrs[j].String()
})
for i := range addrs {
if ip.IP.Equal(addrs[i].IP) {
return true, addrs[i].Mask, nil
}
}
}
return false, nil, nil
}
func isLocal(ip net.IP) bool {
return ip.IsLoopback() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast()
}
func isPublic(ip *net.IPNet) bool {
// Check RFC 1918 addresses.
if ip4 := ip.IP.To4(); ip4 != nil {
switch true {
// Check for 10.0.0.0/8.
case ip4[0] == 10:
return false
// Check for 172.16.0.0/12.
case ip4[0] == 172 && ip4[1]&0xf0 == 0x01:
return false
// Check for 192.168.0.0/16.
case ip4[0] == 192 && ip4[1] == 168:
return false
default:
return true
}
}
// Check RFC 4193 addresses.
if len(ip.IP) == net.IPv6len {
switch true {
// Check for fd00::/8.
case ip.IP[0] == 0xfd && ip.IP[1] == 0x00:
return false
default:
return true
}
}
return false
}
// ipsForHostname returns a slice of IPs to which the
// given hostname resolves.
func ipsForHostname(hostname string) ([]*net.IPNet, error) {
if ip := net.ParseIP(hostname); ip != nil {
return []*net.IPNet{oneAddressCIDR(ip)}, nil
}
ips, err := net.LookupIP(hostname)
if err != nil {
return nil, fmt.Errorf("failed to lookip IPs of hostname: %v", err)
}
nets := make([]*net.IPNet, len(ips))
for i := range ips {
nets[i] = oneAddressCIDR(ips[i])
}
return nets, nil
}
// ipsForAllInterfaces returns a slice of IPs assigned to all the
// interfaces on the host.
func ipsForAllInterfaces() ([]*net.IPNet, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("failed to list interfaces: %v", err)
}
var nets []*net.IPNet
for _, iface := range ifaces {
ips, err := ipsForInterface(&iface)
if err != nil {
return nil, fmt.Errorf("failed to list addresses for %s: %v", iface.Name, err)
}
nets = append(nets, ips...)
}
return nets, nil
}
// ipsForInterface returns a slice of IPs assigned to the given interface.
func ipsForInterface(iface *net.Interface) ([]*net.IPNet, error) {
link, err := netlink.LinkByIndex(iface.Index)
if err != nil {
return nil, fmt.Errorf("failed to get link: %s", err)
}
addrs, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return nil, fmt.Errorf("failed to list addresses for %s: %v", iface.Name, err)
}
var ips []*net.IPNet
for _, a := range addrs {
if a.IPNet != nil {
ips = append(ips, a.IPNet)
}
}
return ips, nil
}
// interfacesForIP returns a slice of interfaces withthe given IP.
func interfacesForIP(ip *net.IPNet) ([]net.Interface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("failed to list interfaces: %v", err)
}
var interfaces []net.Interface
for _, iface := range ifaces {
ips, err := ipsForInterface(&iface)
if err != nil {
return nil, fmt.Errorf("failed to list addresses for %s: %v", iface.Name, err)
}
for i := range ips {
if ip.IP.Equal(ips[i].IP) {
interfaces = append(interfaces, iface)
break
}
}
}
if len(interfaces) == 0 {
return nil, fmt.Errorf("no interface has %s assigned", ip.String())
}
return interfaces, nil
}
// defaultInterface returns the interface for the default route of the host.
func defaultInterface() (*net.Interface, error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return nil, err
}
for _, route := range routes {
if route.Dst == nil || route.Dst.String() == "0.0.0.0/0" || route.Dst.String() == "::/0" {
if route.LinkIndex <= 0 {
return nil, errors.New("failed to determine interface of route")
}
return net.InterfaceByIndex(route.LinkIndex)
}
}
return nil, errors.New("failed to find default route")
}
type allocator struct {
bits int
cidr *net.IPNet
current net.IP
}
func newAllocator(cidr net.IPNet) *allocator {
_, bits := cidr.Mask.Size()
current := make(net.IP, len(cidr.IP))
copy(current, cidr.IP)
if ip4 := current.To4(); ip4 != nil {
current = ip4
}
return &allocator{
bits: bits,
cidr: &cidr,
current: current,
}
}
func (a *allocator) next() *net.IPNet {
if a.current == nil {
return nil
}
for i := len(a.current) - 1; i >= 0; i-- {
a.current[i]++
// if we haven't overflowed, then we can exit.
if a.current[i] != 0 {
break
}
}
if !a.cidr.Contains(a.current) {
a.current = nil
}
ip := make(net.IP, len(a.current))
copy(ip, a.current)
return &net.IPNet{IP: ip, Mask: net.CIDRMask(a.bits, a.bits)}
}

75
pkg/mesh/ip_test.go Normal file
View File

@@ -0,0 +1,75 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mesh
import (
"net"
"testing"
)
func TestSortIPs(t *testing.T) {
ip1 := oneAddressCIDR(net.ParseIP("10.0.0.1"))
ip2 := oneAddressCIDR(net.ParseIP("10.0.0.2"))
ip3 := oneAddressCIDR(net.ParseIP("192.168.0.1"))
ip4 := oneAddressCIDR(net.ParseIP("2001::7"))
ip5 := oneAddressCIDR(net.ParseIP("fd68:da49:09da:b27f::"))
for _, tc := range []struct {
name string
ips []*net.IPNet
out []*net.IPNet
}{
{
name: "single",
ips: []*net.IPNet{ip1},
out: []*net.IPNet{ip1},
},
{
name: "IPv4s",
ips: []*net.IPNet{ip2, ip3, ip1},
out: []*net.IPNet{ip1, ip2, ip3},
},
{
name: "IPv4 and IPv6",
ips: []*net.IPNet{ip4, ip1},
out: []*net.IPNet{ip1, ip4},
},
{
name: "IPv6s",
ips: []*net.IPNet{ip5, ip4},
out: []*net.IPNet{ip4, ip5},
},
{
name: "all",
ips: []*net.IPNet{ip3, ip4, ip2, ip5, ip1},
out: []*net.IPNet{ip1, ip2, ip3, ip4, ip5},
},
} {
sortIPs(tc.ips)
equal := true
if len(tc.ips) != len(tc.out) {
equal = false
} else {
for i := range tc.ips {
if !ipNetsEqual(tc.ips[i], tc.out[i]) {
equal = false
break
}
}
}
if !equal {
t.Errorf("test case %q: expected %s, got %s", tc.name, tc.out, tc.ips)
}
}
}

581
pkg/mesh/mesh.go Normal file
View File

@@ -0,0 +1,581 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mesh
import (
"fmt"
"io/ioutil"
"net"
"os"
"sync"
"time"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/prometheus/client_golang/prometheus"
"github.com/vishvananda/netlink"
"github.com/squat/kilo/pkg/iproute"
"github.com/squat/kilo/pkg/ipset"
"github.com/squat/kilo/pkg/iptables"
"github.com/squat/kilo/pkg/route"
"github.com/squat/kilo/pkg/wireguard"
)
const resyncPeriod = 30 * time.Second
const (
// KiloPath is the directory where Kilo stores its configuration.
KiloPath = "/var/lib/kilo"
// PrivateKeyPath is the filepath where the WireGuard private key is stored.
PrivateKeyPath = KiloPath + "/key"
// ConfPath is the filepath where the WireGuard configuration is stored.
ConfPath = KiloPath + "/conf"
)
// Granularity represents the abstraction level at which the network
// should be meshed.
type Granularity string
// Encapsulate identifies what packets within a location should
// be encapsulated.
type Encapsulate string
const (
// DataCenterGranularity indicates that the network should create
// a mesh between data-centers but not between nodes within a
// single data-center.
DataCenterGranularity Granularity = "data-center"
// NodeGranularity indicates that the network should create
// a mesh between every node.
NodeGranularity Granularity = "node"
// NeverEncapsulate indicates that no packets within a location
// should be encapsulated.
NeverEncapsulate Encapsulate = "never"
// CrossSubnetEncapsulate indicates that only packets that
// traverse subnets within a location should be encapsulated.
CrossSubnetEncapsulate Encapsulate = "crosssubnet"
// AlwaysEncapsulate indicates that all packets within a location
// should be encapsulated.
AlwaysEncapsulate Encapsulate = "always"
)
// Node represents a node in the network.
type Node struct {
ExternalIP *net.IPNet
Key []byte
InternalIP *net.IPNet
// Leader is a suggestion to Kilo that
// the node wants to lead its segment.
Leader bool
Location string
Name string
Subnet *net.IPNet
}
// Ready indicates whether or not the node is ready.
func (n *Node) Ready() bool {
return n != nil && n.ExternalIP != nil && n.Key != nil && n.InternalIP != nil && n.Subnet != nil
}
// EventType describes what kind of an action an event represents.
type EventType string
const (
// AddEvent represents an action where an item was added.
AddEvent EventType = "add"
// DeleteEvent represents an action where an item was removed.
DeleteEvent EventType = "delete"
// UpdateEvent represents an action where an item was updated.
UpdateEvent EventType = "update"
)
// Event represents an update event concerning a node in the cluster.
type Event struct {
Type EventType
Node *Node
}
// Backend can get nodes by name, init itself,
// list the nodes that should be meshed,
// set Kilo properties for a node,
// clean up any changes applied to the backend,
// and watch for changes to nodes.
type Backend interface {
CleanUp(string) error
Get(string) (*Node, error)
Init(<-chan struct{}) error
List() ([]*Node, error)
Set(string, *Node) error
Watch() <-chan *Event
}
// Mesh is able to create Kilo network meshes.
type Mesh struct {
Backend
encapsulate Encapsulate
externalIP *net.IPNet
granularity Granularity
hostname string
internalIP *net.IPNet
ipset *ipset.Set
ipTables *iptables.Controller
kiloIface int
key []byte
local bool
port int
priv []byte
privIface int
pub []byte
pubIface int
stop chan struct{}
subnet *net.IPNet
table *route.Table
tunlIface int
// nodes is a mutable field in the struct
// and needs to be guarded.
nodes map[string]*Node
mu sync.Mutex
errorCounter *prometheus.CounterVec
nodesGuage prometheus.Gauge
logger log.Logger
}
// New returns a new Mesh instance.
func New(backend Backend, encapsulate Encapsulate, granularity Granularity, hostname string, port int, subnet *net.IPNet, local bool, logger log.Logger) (*Mesh, error) {
if err := os.MkdirAll(KiloPath, 0700); err != nil {
return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
}
private, err := ioutil.ReadFile(PrivateKeyPath)
if err != nil {
level.Warn(logger).Log("msg", "no private key found on disk; generating one now")
if private, err = wireguard.GenKey(); err != nil {
return nil, err
}
}
public, err := wireguard.PubKey(private)
if err != nil {
return nil, err
}
if err := ioutil.WriteFile(PrivateKeyPath, private, 0600); err != nil {
return nil, fmt.Errorf("failed to write private key to disk: %v", err)
}
privateIP, publicIP, err := getIP(hostname)
if err != nil {
return nil, fmt.Errorf("failed to find public IP: %v", err)
}
ifaces, err := interfacesForIP(privateIP)
if err != nil {
return nil, fmt.Errorf("failed to find interface for private IP: %v", err)
}
privIface := ifaces[0].Index
ifaces, err = interfacesForIP(publicIP)
if err != nil {
return nil, fmt.Errorf("failed to find interface for public IP: %v", err)
}
pubIface := ifaces[0].Index
kiloIface, err := wireguard.New("kilo")
if err != nil {
return nil, fmt.Errorf("failed to create WireGuard interface: %v", err)
}
var tunlIface int
if encapsulate != NeverEncapsulate {
if tunlIface, err = iproute.NewIPIP(privIface); err != nil {
return nil, fmt.Errorf("failed to create tunnel interface: %v", err)
}
if err := iproute.Set(tunlIface, true); err != nil {
return nil, fmt.Errorf("failed to set tunnel interface up: %v", err)
}
}
level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the private IP address", privateIP.String()))
level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the public IP address", publicIP.String()))
ipTables, err := iptables.New(len(subnet.IP))
if err != nil {
return nil, fmt.Errorf("failed to IP tables controller: %v", err)
}
return &Mesh{
Backend: backend,
encapsulate: encapsulate,
externalIP: publicIP,
granularity: granularity,
hostname: hostname,
internalIP: privateIP,
// This is a patch until Calico supports
// other hosts adding IPIP iptables rules.
ipset: ipset.New("cali40all-hosts-net"),
ipTables: ipTables,
kiloIface: kiloIface,
nodes: make(map[string]*Node),
port: port,
priv: private,
privIface: privIface,
pub: public,
pubIface: pubIface,
local: local,
stop: make(chan struct{}),
subnet: subnet,
table: route.NewTable(),
tunlIface: tunlIface,
errorCounter: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "kilo_errors_total",
Help: "Number of errors that occurred while administering the mesh.",
}, []string{"event"}),
nodesGuage: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "kilo_nodes",
Help: "Number of in the mesh.",
}),
logger: logger,
}, nil
}
// Run starts the mesh.
func (m *Mesh) Run() error {
if err := m.Init(m.stop); err != nil {
return fmt.Errorf("failed to initialize backend: %v", err)
}
ipsetErrors, err := m.ipset.Run(m.stop)
if err != nil {
return fmt.Errorf("failed to watch for ipset updates: %v", err)
}
ipTablesErrors, err := m.ipTables.Run(m.stop)
if err != nil {
return fmt.Errorf("failed to watch for IP tables updates: %v", err)
}
routeErrors, err := m.table.Run(m.stop)
if err != nil {
return fmt.Errorf("failed to watch for route table updates: %v", err)
}
go func() {
for {
var err error
select {
case err = <-ipsetErrors:
case err = <-ipTablesErrors:
case err = <-routeErrors:
case <-m.stop:
return
}
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("run").Inc()
}
}
}()
defer m.cleanUp()
t := time.NewTimer(resyncPeriod)
w := m.Watch()
for {
var e *Event
select {
case e = <-w:
m.sync(e)
case <-t.C:
m.applyTopology()
t.Reset(resyncPeriod)
case <-m.stop:
return nil
}
}
}
func (m *Mesh) sync(e *Event) {
logger := log.With(m.logger, "event", e.Type)
level.Debug(logger).Log("msg", "syncing", "event", e.Type)
if isSelf(m.hostname, e.Node) {
level.Debug(logger).Log("msg", "processing local node", "node", e.Node)
m.handleLocal(e.Node)
return
}
var diff bool
m.mu.Lock()
if !e.Node.Ready() {
level.Debug(logger).Log("msg", "received incomplete node", "node", e.Node)
// An existing node is no longer valid
// so remove it from the mesh.
if _, ok := m.nodes[e.Node.Name]; ok {
level.Info(logger).Log("msg", "node is no longer in the mesh", "node", e.Node)
delete(m.nodes, e.Node.Name)
diff = true
}
} else {
switch e.Type {
case AddEvent:
fallthrough
case UpdateEvent:
if !nodesAreEqual(m.nodes[e.Node.Name], e.Node) {
m.nodes[e.Node.Name] = e.Node
diff = true
}
case DeleteEvent:
delete(m.nodes, e.Node.Name)
diff = true
}
}
m.mu.Unlock()
if diff {
level.Info(logger).Log("node", e.Node)
m.applyTopology()
}
}
func (m *Mesh) handleLocal(n *Node) {
// Allow the external IP to be overridden.
if n.ExternalIP == nil {
n.ExternalIP = m.externalIP
}
// Compare the given node to the calculated local node.
// Take leader, location, and subnet from the argument, as these
// are not determined by kilo.
local := &Node{ExternalIP: n.ExternalIP, Key: m.pub, InternalIP: m.internalIP, Leader: n.Leader, Location: n.Location, Name: m.hostname, Subnet: n.Subnet}
if !nodesAreEqual(n, local) {
level.Debug(m.logger).Log("msg", "local node differs from backend")
if err := m.Set(m.hostname, local); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", local)
m.errorCounter.WithLabelValues("local").Inc()
return
}
level.Debug(m.logger).Log("msg", "successfully reconciled local node against backend")
}
m.mu.Lock()
n = m.nodes[m.hostname]
if n == nil {
n = &Node{}
}
m.mu.Unlock()
if !nodesAreEqual(n, local) {
m.mu.Lock()
m.nodes[local.Name] = local
m.mu.Unlock()
m.applyTopology()
}
}
func (m *Mesh) applyTopology() {
m.mu.Lock()
defer m.mu.Unlock()
// Ensure all unready nodes are removed.
var ready float64
for n := range m.nodes {
if !m.nodes[n].Ready() {
delete(m.nodes, n)
continue
}
ready++
}
m.nodesGuage.Set(ready)
// We cannot do anything with the topology until the local node is available.
if m.nodes[m.hostname] == nil {
return
}
t, err := NewTopology(m.nodes, m.granularity, m.hostname, m.port, m.priv, m.subnet)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
conf, err := t.Conf()
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
}
if err := ioutil.WriteFile(ConfPath, conf, 0600); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
var private *net.IPNet
// If we are not encapsulating packets to the local private network,
// then pass the private IP to add an exception to the NAT rule.
if m.encapsulate != AlwaysEncapsulate {
private = t.privateIP
}
rules := iptables.MasqueradeRules(private, m.nodes[m.hostname].Subnet, t.RemoteSubnets())
rules = append(rules, iptables.ForwardRules(m.subnet)...)
if err := m.ipTables.Set(rules); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
if m.encapsulate != NeverEncapsulate {
var peers []net.IP
for _, s := range t.Segments {
if s.Location == m.nodes[m.hostname].Location {
peers = s.privateIPs
break
}
}
if err := m.ipset.Set(peers); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
if m.local {
if err := iproute.SetAddress(m.tunlIface, oneAddressCIDR(newAllocator(*m.nodes[m.hostname].Subnet).next().IP)); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
}
}
if t.leader {
if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
link, err := linkByIndex(m.kiloIface)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
oldConf, err := wireguard.ShowConf(link.Attrs().Name)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
// Setting the WireGuard configuration interrupts existing connections
// so only set the configuration if it has changed.
equal, err := wireguard.CompareConf(conf, oldConf)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
// Don't return here, simply overwrite the old configuration.
equal = false
}
if !equal {
if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
}
if err := iproute.Set(m.kiloIface, true); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
} else {
level.Debug(m.logger).Log("msg", "local node is not the leader")
if err := iproute.Set(m.kiloIface, false); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
}
// We need to add routes last since they may depend
// on the WireGuard interface.
routes := t.Routes(m.kiloIface, m.privIface, m.tunlIface, m.local, m.encapsulate)
if err := m.table.Set(routes); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
}
}
// RegisterMetrics registers Prometheus metrics on the given Prometheus
// registerer.
func (m *Mesh) RegisterMetrics(r prometheus.Registerer) {
r.MustRegister(
m.errorCounter,
m.nodesGuage,
)
}
// Stop stops the mesh.
func (m *Mesh) Stop() {
close(m.stop)
}
func (m *Mesh) cleanUp() {
if err := m.ipTables.CleanUp(); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up IP tables: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := m.table.CleanUp(); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up routes: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := os.Remove(PrivateKeyPath); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to delete private key: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := os.Remove(ConfPath); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to delete configuration file: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := iproute.RemoveInterface(m.kiloIface); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove wireguard interface: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := m.CleanUp(m.hostname); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up backend: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := m.ipset.CleanUp(); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up ipset: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
}
func isSelf(hostname string, node *Node) bool {
return node != nil && node.Name == hostname
}
func nodesAreEqual(a, b *Node) bool {
if !(a != nil) == (b != nil) {
return false
}
if a == b {
return true
}
return ipNetsEqual(a.ExternalIP, b.ExternalIP) && string(a.Key) == string(b.Key) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet)
}
func ipNetsEqual(a, b *net.IPNet) bool {
if a == nil && b == nil {
return true
}
if (a != nil) != (b != nil) {
return false
}
if a.Mask.String() != b.Mask.String() {
return false
}
return a.IP.Equal(b.IP)
}
func subnetsEqual(a, b *net.IPNet) bool {
if a.Mask.String() != b.Mask.String() {
return false
}
if !a.Contains(b.IP) {
return false
}
if !b.Contains(a.IP) {
return false
}
return true
}
func linkByIndex(index int) (netlink.Link, error) {
link, err := netlink.LinkByIndex(index)
if err != nil {
return nil, fmt.Errorf("failed to get interface: %v", err)
}
return link, nil
}

146
pkg/mesh/mesh_test.go Normal file
View File

@@ -0,0 +1,146 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mesh
import (
"net"
"testing"
)
func TestNewAllocator(t *testing.T) {
_, c1, err := net.ParseCIDR("10.1.0.0/16")
if err != nil {
t.Fatalf("failed to parse CIDR: %v", err)
}
a1 := newAllocator(*c1)
_, c2, err := net.ParseCIDR("10.1.0.0/32")
if err != nil {
t.Fatalf("failed to parse CIDR: %v", err)
}
a2 := newAllocator(*c2)
_, c3, err := net.ParseCIDR("10.1.0.0/31")
if err != nil {
t.Fatalf("failed to parse CIDR: %v", err)
}
a3 := newAllocator(*c3)
for _, tc := range []struct {
name string
a *allocator
next string
}{
{
name: "10.1.0.0/16 first",
a: a1,
next: "10.1.0.1/32",
},
{
name: "10.1.0.0/16 second",
a: a1,
next: "10.1.0.2/32",
},
{
name: "10.1.0.0/32",
a: a2,
next: "<nil>",
},
{
name: "10.1.0.0/31 first",
a: a3,
next: "10.1.0.1/32",
},
{
name: "10.1.0.0/31 second",
a: a3,
next: "<nil>",
},
} {
next := tc.a.next()
if next.String() != tc.next {
t.Errorf("test case %q: expected %s, got %s", tc.name, tc.next, next.String())
}
}
}
func TestReady(t *testing.T) {
internalIP := oneAddressCIDR(net.ParseIP("1.1.1.1"))
externalIP := oneAddressCIDR(net.ParseIP("2.2.2.2"))
for _, tc := range []struct {
name string
node *Node
ready bool
}{
{
name: "nil",
node: nil,
ready: false,
},
{
name: "empty fields",
node: &Node{},
ready: false,
},
{
name: "empty external IP",
node: &Node{
InternalIP: internalIP,
Key: []byte{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
},
{
name: "empty internal IP",
node: &Node{
ExternalIP: externalIP,
Key: []byte{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
},
{
name: "empty key",
node: &Node{
ExternalIP: externalIP,
InternalIP: internalIP,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
},
{
name: "empty subnet",
node: &Node{
ExternalIP: externalIP,
InternalIP: internalIP,
Key: []byte{},
},
ready: false,
},
{
name: "valid",
node: &Node{
ExternalIP: externalIP,
InternalIP: internalIP,
Key: []byte{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: true,
},
} {
ready := tc.node.Ready()
if ready != tc.ready {
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.ready, ready)
}
}
}

334
pkg/mesh/topology.go Normal file
View File

@@ -0,0 +1,334 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mesh
import (
"bytes"
"errors"
"fmt"
"net"
"sort"
"strings"
"text/template"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
var (
confTemplate = template.Must(template.New("").Parse(`[Interface]
PrivateKey = {{.Key}}
ListenPort = {{.Port}}
{{range .Segments -}}
{{if ne .Location $.Location}}
[Peer]
PublicKey = {{.Key}}
Endpoint = {{.Endpoint}}:{{$.Port}}
AllowedIPs = {{.AllowedIPs}}
{{end}}
{{- end -}}
`))
)
// Topology represents the logical structure of the overlay network.
type Topology struct {
// Some fields need to be exported so that the template can read them.
Key string
Port int
// Location is the logical location of the local host.
Location string
Segments []*segment
// hostname is the hostname of the local host.
hostname string
// leader represents whether or not the local host
// is the segment leader.
leader bool
// subnet is the entire subnet from which IPs
// for the WireGuard interfaces will be allocated.
subnet *net.IPNet
// privateIP is the private IP address of the local node.
privateIP *net.IPNet
// wireGuardCIDR is the allocated CIDR of the WireGuard
// interface of the local node. If the local node is not
// the leader, then it is nil.
wireGuardCIDR *net.IPNet
}
type segment struct {
// Some fields need to be exported so that the template can read them.
AllowedIPs string
Endpoint string
Key string
// Location is the logical location of this segment.
Location string
// cidrs is a slice of subnets of all peers in the segment.
cidrs []*net.IPNet
// hostnames is a slice of the hostnames of the peers in the segment.
hostnames []string
// leader is the index of the leader of the segment.
leader int
// privateIPs is a slice of private IPs of all peers in the segment.
privateIPs []net.IP
// wireGuardIP is the allocated IP address of the WireGuard
// interface on the leader of the segment.
wireGuardIP net.IP
}
// NewTopology creates a new Topology struct from a given set of nodes.
func NewTopology(nodes map[string]*Node, granularity Granularity, hostname string, port int, key []byte, subnet *net.IPNet) (*Topology, error) {
topoMap := make(map[string][]*Node)
for _, node := range nodes {
var location string
switch granularity {
case DataCenterGranularity:
location = node.Location
case NodeGranularity:
location = node.Name
}
topoMap[location] = append(topoMap[location], node)
}
var localLocation string
switch granularity {
case DataCenterGranularity:
localLocation = nodes[hostname].Location
case NodeGranularity:
localLocation = hostname
}
t := Topology{Key: strings.TrimSpace(string(key)), Port: port, hostname: hostname, Location: localLocation, subnet: subnet, privateIP: nodes[hostname].InternalIP}
for location := range topoMap {
// Sort the location so the result is stable.
sort.Slice(topoMap[location], func(i, j int) bool {
return topoMap[location][i].Name < topoMap[location][j].Name
})
leader := findLeader(topoMap[location])
if location == localLocation && topoMap[location][leader].Name == hostname {
t.leader = true
}
var allowedIPs []string
var cidrs []*net.IPNet
var hostnames []string
var privateIPs []net.IP
for _, node := range topoMap[location] {
// Allowed IPs should include:
// - the node's allocated subnet
// - the node's WireGuard IP
// - the node's internal IP
allowedIPs = append(allowedIPs, node.Subnet.String(), oneAddressCIDR(node.InternalIP.IP).String())
cidrs = append(cidrs, node.Subnet)
hostnames = append(hostnames, node.Name)
privateIPs = append(privateIPs, node.InternalIP.IP)
}
t.Segments = append(t.Segments, &segment{
AllowedIPs: strings.Join(allowedIPs, ", "),
Endpoint: topoMap[location][leader].ExternalIP.IP.String(),
Key: strings.TrimSpace(string(topoMap[location][leader].Key)),
Location: location,
cidrs: cidrs,
hostnames: hostnames,
leader: leader,
privateIPs: privateIPs,
})
}
// Sort the Topology so the result is stable.
sort.Slice(t.Segments, func(i, j int) bool {
return t.Segments[i].Location < t.Segments[j].Location
})
// Allocate IPs to the segment leaders in a stable, coordination-free manner.
a := newAllocator(*subnet)
for _, segment := range t.Segments {
ipNet := a.next()
if ipNet == nil {
return nil, errors.New("failed to allocate an IP address; ran out of IP addresses")
}
segment.wireGuardIP = ipNet.IP
segment.AllowedIPs = fmt.Sprintf("%s, %s", segment.AllowedIPs, ipNet.String())
if t.leader && segment.Location == t.Location {
t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: t.subnet.Mask}
}
}
return &t, nil
}
// RemoteSubnets identifies the subnets of the hosts in segments different than the host's.
func (t *Topology) RemoteSubnets() []*net.IPNet {
var remote []*net.IPNet
for _, s := range t.Segments {
if s == nil || s.Location == t.Location {
continue
}
remote = append(remote, s.cidrs...)
}
return remote
}
// Routes generates a slice of routes for a given Topology.
func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, encapsulate Encapsulate) []*netlink.Route {
var routes []*netlink.Route
if !t.leader {
// Find the leader for this segment.
var leader net.IP
for _, segment := range t.Segments {
if segment.Location == t.Location {
leader = segment.privateIPs[segment.leader]
break
}
}
for _, segment := range t.Segments {
// First, add a route to the WireGuard IP of the segment.
routes = append(routes, encapsulateRoute(&netlink.Route{
Dst: oneAddressCIDR(segment.wireGuardIP),
Flags: int(netlink.FLAG_ONLINK),
Gw: leader,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
}, encapsulate, t.privateIP, tunlIface))
// Add routes for the current segment if local is true.
if segment.Location == t.Location {
if local {
for i := range segment.cidrs {
// Don't add routes for the local node.
if segment.privateIPs[i].Equal(t.privateIP.IP) {
continue
}
routes = append(routes, encapsulateRoute(&netlink.Route{
Dst: segment.cidrs[i],
Flags: int(netlink.FLAG_ONLINK),
Gw: segment.privateIPs[i],
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
}, encapsulate, t.privateIP, tunlIface))
}
}
continue
}
for i := range segment.cidrs {
// Add routes to the Pod CIDRs of nodes in other segments.
routes = append(routes, encapsulateRoute(&netlink.Route{
Dst: segment.cidrs[i],
Flags: int(netlink.FLAG_ONLINK),
Gw: leader,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
}, encapsulate, t.privateIP, tunlIface))
// Add routes to the private IPs of nodes in other segments.
// Number of CIDRs and private IPs always match so
// we can reuse the loop.
routes = append(routes, encapsulateRoute(&netlink.Route{
Dst: oneAddressCIDR(segment.privateIPs[i]),
Flags: int(netlink.FLAG_ONLINK),
Gw: leader,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
}, encapsulate, t.privateIP, tunlIface))
}
}
return routes
}
for _, segment := range t.Segments {
// Add routes for the current segment if local is true.
if segment.Location == t.Location {
if local {
for i := range segment.cidrs {
// Don't add routes for the local node.
if segment.privateIPs[i].Equal(t.privateIP.IP) {
continue
}
routes = append(routes, encapsulateRoute(&netlink.Route{
Dst: segment.cidrs[i],
Flags: int(netlink.FLAG_ONLINK),
Gw: segment.privateIPs[i],
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
}, encapsulate, t.privateIP, tunlIface))
}
}
continue
}
for i := range segment.cidrs {
// Add routes to the Pod CIDRs of nodes in other segments.
routes = append(routes, &netlink.Route{
Dst: segment.cidrs[i],
Flags: int(netlink.FLAG_ONLINK),
Gw: segment.wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
})
// Add routes to the private IPs of nodes in other segments.
// Number of CIDRs and private IPs always match so
// we can reuse the loop.
routes = append(routes, &netlink.Route{
Dst: oneAddressCIDR(segment.privateIPs[i]),
Flags: int(netlink.FLAG_ONLINK),
Gw: segment.wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
})
}
}
return routes
}
func encapsulateRoute(route *netlink.Route, encapsulate Encapsulate, subnet *net.IPNet, tunlIface int) *netlink.Route {
if encapsulate == AlwaysEncapsulate || (encapsulate == CrossSubnetEncapsulate && !subnet.Contains(route.Gw)) {
route.LinkIndex = tunlIface
}
return route
}
// Conf generates a WireGuard configuration file for a given Topology.
func (t *Topology) Conf() ([]byte, error) {
conf := new(bytes.Buffer)
if err := confTemplate.Execute(conf, t); err != nil {
return nil, err
}
return conf.Bytes(), nil
}
// oneAddressCIDR takes an IP address and returns a CIDR
// that contains only that address.
func oneAddressCIDR(ip net.IP) *net.IPNet {
return &net.IPNet{IP: ip, Mask: net.CIDRMask(len(ip)*8, len(ip)*8)}
}
// findLeader selects a leader for the nodes in a segment;
// it will select the first node that says it should lead
// or the first node in the segment if none have volunteered,
// always preferring those with a public external IP address,
func findLeader(nodes []*Node) int {
var leaders, public []int
for i := range nodes {
if nodes[i].Leader {
if isPublic(nodes[i].ExternalIP) {
return i
}
leaders = append(leaders, i)
}
if isPublic(nodes[i].ExternalIP) {
public = append(public, i)
}
}
if len(leaders) != 0 {
return leaders[0]
}
if len(public) != 0 {
return public[0]
}
return 0
}

982
pkg/mesh/topology_test.go Normal file
View File

@@ -0,0 +1,982 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mesh
import (
"net"
"strings"
"testing"
"github.com/kylelemons/godebug/pretty"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
func allowedIPs(ips ...string) string {
return strings.Join(ips, ", ")
}
func setup(t *testing.T) (map[string]*Node, []byte, int, *net.IPNet) {
key := []byte("private")
port := 51820
_, kiloNet, err := net.ParseCIDR("10.4.0.0/16")
if err != nil {
t.Fatalf("failed to parse Kilo subnet CIDR: %v", err)
}
ip, e1, err := net.ParseCIDR("10.1.0.1/16")
if err != nil {
t.Fatalf("failed to parse external IP CIDR: %v", err)
}
e1.IP = ip
ip, e2, err := net.ParseCIDR("10.1.0.2/16")
if err != nil {
t.Fatalf("failed to parse external IP CIDR: %v", err)
}
e2.IP = ip
ip, e3, err := net.ParseCIDR("10.1.0.3/16")
if err != nil {
t.Fatalf("failed to parse external IP CIDR: %v", err)
}
e3.IP = ip
ip, i1, err := net.ParseCIDR("192.168.0.1/24")
if err != nil {
t.Fatalf("failed to parse internal IP CIDR: %v", err)
}
i1.IP = ip
ip, i2, err := net.ParseCIDR("192.168.0.2/24")
if err != nil {
t.Fatalf("failed to parse internal IP CIDR: %v", err)
}
i2.IP = ip
nodes := map[string]*Node{
"a": {
Name: "a",
ExternalIP: e1,
InternalIP: i1,
Location: "1",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
Key: []byte("key1"),
},
"b": {
Name: "b",
ExternalIP: e2,
InternalIP: i1,
Location: "2",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)},
Key: []byte("key2"),
},
"c": {
Name: "c",
ExternalIP: e3,
InternalIP: i2,
// Same location a node b.
Location: "2",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.3.0"), Mask: net.CIDRMask(24, 32)},
Key: []byte("key3"),
},
}
return nodes, key, port, kiloNet
}
func TestNewTopology(t *testing.T) {
nodes, key, port, kiloNet := setup(t)
w1 := net.ParseIP("10.4.0.1").To4()
w2 := net.ParseIP("10.4.0.2").To4()
w3 := net.ParseIP("10.4.0.3").To4()
for _, tc := range []struct {
name string
granularity Granularity
hostname string
result *Topology
}{
{
name: "datacenter from a",
granularity: DataCenterGranularity,
hostname: nodes["a"].Name,
result: &Topology{
hostname: nodes["a"].Name,
leader: true,
Location: nodes["a"].Location,
subnet: kiloNet,
privateIP: nodes["a"].InternalIP,
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
Segments: []*segment{
{
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
Endpoint: nodes["a"].ExternalIP.IP.String(),
Key: string(nodes["a"].Key),
Location: nodes["a"].Location,
cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
wireGuardIP: w1,
},
{
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.2/32"),
Endpoint: nodes["b"].ExternalIP.IP.String(),
Key: string(nodes["b"].Key),
Location: nodes["b"].Location,
cidrs: []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
hostnames: []string{"b", "c"},
privateIPs: []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
wireGuardIP: w2,
},
},
},
},
{
name: "datacenter from b",
granularity: DataCenterGranularity,
hostname: nodes["b"].Name,
result: &Topology{
hostname: nodes["b"].Name,
leader: true,
Location: nodes["b"].Location,
subnet: kiloNet,
privateIP: nodes["b"].InternalIP,
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
Segments: []*segment{
{
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
Endpoint: nodes["a"].ExternalIP.IP.String(),
Key: string(nodes["a"].Key),
Location: nodes["a"].Location,
cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
wireGuardIP: w1,
},
{
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.2/32"),
Endpoint: nodes["b"].ExternalIP.IP.String(),
Key: string(nodes["b"].Key),
Location: nodes["b"].Location,
cidrs: []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
hostnames: []string{"b", "c"},
privateIPs: []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
wireGuardIP: w2,
},
},
},
},
{
name: "datacenter from c",
granularity: DataCenterGranularity,
hostname: nodes["c"].Name,
result: &Topology{
hostname: nodes["c"].Name,
leader: false,
Location: nodes["b"].Location,
subnet: kiloNet,
privateIP: nodes["c"].InternalIP,
wireGuardCIDR: nil,
Segments: []*segment{
{
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
Endpoint: nodes["a"].ExternalIP.IP.String(),
Key: string(nodes["a"].Key),
Location: nodes["a"].Location,
cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
wireGuardIP: w1,
},
{
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.2/32"),
Endpoint: nodes["b"].ExternalIP.IP.String(),
Key: string(nodes["b"].Key),
Location: nodes["b"].Location,
cidrs: []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
hostnames: []string{"b", "c"},
privateIPs: []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
wireGuardIP: w2,
},
},
},
},
{
name: "node from a",
granularity: NodeGranularity,
hostname: nodes["a"].Name,
result: &Topology{
hostname: nodes["a"].Name,
leader: true,
Location: nodes["a"].Name,
subnet: kiloNet,
privateIP: nodes["a"].InternalIP,
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
Segments: []*segment{
{
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
Endpoint: nodes["a"].ExternalIP.IP.String(),
Key: string(nodes["a"].Key),
Location: nodes["a"].Name,
cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
wireGuardIP: w1,
},
{
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", "10.4.0.2/32"),
Endpoint: nodes["b"].ExternalIP.IP.String(),
Key: string(nodes["b"].Key),
Location: nodes["b"].Name,
cidrs: []*net.IPNet{nodes["b"].Subnet},
hostnames: []string{"b"},
privateIPs: []net.IP{nodes["b"].InternalIP.IP},
wireGuardIP: w2,
},
{
AllowedIPs: allowedIPs(nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.3/32"),
Endpoint: nodes["c"].ExternalIP.IP.String(),
Key: string(nodes["c"].Key),
Location: nodes["c"].Name,
cidrs: []*net.IPNet{nodes["c"].Subnet},
hostnames: []string{"c"},
privateIPs: []net.IP{nodes["c"].InternalIP.IP},
wireGuardIP: w3,
},
},
},
},
{
name: "node from b",
granularity: NodeGranularity,
hostname: nodes["b"].Name,
result: &Topology{
hostname: nodes["b"].Name,
leader: true,
Location: nodes["b"].Name,
subnet: kiloNet,
privateIP: nodes["b"].InternalIP,
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
Segments: []*segment{
{
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
Endpoint: nodes["a"].ExternalIP.IP.String(),
Key: string(nodes["a"].Key),
Location: nodes["a"].Name,
cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
wireGuardIP: w1,
},
{
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", "10.4.0.2/32"),
Endpoint: nodes["b"].ExternalIP.IP.String(),
Key: string(nodes["b"].Key),
Location: nodes["b"].Name,
cidrs: []*net.IPNet{nodes["b"].Subnet},
hostnames: []string{"b"},
privateIPs: []net.IP{nodes["b"].InternalIP.IP},
wireGuardIP: w2,
},
{
AllowedIPs: allowedIPs(nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.3/32"),
Endpoint: nodes["c"].ExternalIP.IP.String(),
Key: string(nodes["c"].Key),
Location: nodes["c"].Name,
cidrs: []*net.IPNet{nodes["c"].Subnet},
hostnames: []string{"c"},
privateIPs: []net.IP{nodes["c"].InternalIP.IP},
wireGuardIP: w3,
},
},
},
},
{
name: "node from c",
granularity: NodeGranularity,
hostname: nodes["c"].Name,
result: &Topology{
hostname: nodes["c"].Name,
leader: true,
Location: nodes["c"].Name,
subnet: kiloNet,
privateIP: nodes["c"].InternalIP,
wireGuardCIDR: &net.IPNet{IP: w3, Mask: net.CIDRMask(16, 32)},
Segments: []*segment{
{
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
Endpoint: nodes["a"].ExternalIP.IP.String(),
Key: string(nodes["a"].Key),
Location: nodes["a"].Name,
cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
wireGuardIP: w1,
},
{
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", "10.4.0.2/32"),
Endpoint: nodes["b"].ExternalIP.IP.String(),
Key: string(nodes["b"].Key),
Location: nodes["b"].Name,
cidrs: []*net.IPNet{nodes["b"].Subnet},
hostnames: []string{"b"},
privateIPs: []net.IP{nodes["b"].InternalIP.IP},
wireGuardIP: w2,
},
{
AllowedIPs: allowedIPs(nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.3/32"),
Endpoint: nodes["c"].ExternalIP.IP.String(),
Key: string(nodes["c"].Key),
Location: nodes["c"].Name,
cidrs: []*net.IPNet{nodes["c"].Subnet},
hostnames: []string{"c"},
privateIPs: []net.IP{nodes["c"].InternalIP.IP},
wireGuardIP: w3,
},
},
},
},
} {
tc.result.Key = string(key)
tc.result.Port = port
topo, err := NewTopology(nodes, tc.granularity, tc.hostname, port, key, kiloNet)
if err != nil {
t.Errorf("test case %q: failed to generate Topology: %v", tc.name, err)
}
if diff := pretty.Compare(topo, tc.result); diff != "" {
t.Errorf("test case %q: got diff: %v", tc.name, diff)
}
}
}
func mustTopo(t *testing.T, nodes map[string]*Node, granularity Granularity, hostname string, port int, key []byte, subnet *net.IPNet) *Topology {
topo, err := NewTopology(nodes, granularity, hostname, port, key, subnet)
if err != nil {
t.Errorf("failed to generate Topology: %v", err)
}
return topo
}
func TestRoutes(t *testing.T) {
nodes, key, port, kiloNet := setup(t)
kiloIface := 0
privIface := 1
pubIface := 2
mustTopoForGranularityAndHost := func(granularity Granularity, hostname string) *Topology {
return mustTopo(t, nodes, granularity, hostname, port, key, kiloNet)
}
for _, tc := range []struct {
name string
local bool
topology *Topology
result []*netlink.Route
}{
{
name: "datacenter from a",
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name),
result: []*netlink.Route{
{
Dst: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].cidrs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].cidrs[1],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
{
name: "datacenter from b",
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name),
result: []*netlink.Route{
{
Dst: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name).Segments[0].cidrs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
{
name: "datacenter from c",
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name),
result: []*netlink.Route{
{
Dst: oneAddressCIDR(mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name).Segments[0].wireGuardIP),
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name).Segments[0].cidrs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name).Segments[1].wireGuardIP),
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
{
name: "node from a",
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name),
result: []*netlink.Route{
{
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[1].cidrs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[2].cidrs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[2].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[2].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
{
name: "node from b",
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name),
result: []*netlink.Route{
{
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[0].cidrs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[2].cidrs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[2].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[2].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
{
name: "node from c",
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name),
result: []*netlink.Route{
{
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[0].cidrs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[1].cidrs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
{
name: "datacenter from a local",
local: true,
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name),
result: []*netlink.Route{
{
Dst: nodes["b"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["c"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
{
name: "datacenter from b local",
local: true,
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name),
result: []*netlink.Route{
{
Dst: nodes["a"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["c"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["c"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
{
name: "datacenter from c local",
local: true,
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name),
result: []*netlink.Route{
{
Dst: oneAddressCIDR(mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name).Segments[0].wireGuardIP),
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["a"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name).Segments[1].wireGuardIP),
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
{
name: "node from a local",
local: true,
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name),
result: []*netlink.Route{
{
Dst: nodes["b"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["c"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[2].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[2].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
{
name: "node from b local",
local: true,
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name),
result: []*netlink.Route{
{
Dst: nodes["a"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["c"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[2].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[2].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
{
name: "node from c local",
local: true,
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name),
result: []*netlink.Route{
{
Dst: nodes["a"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[0].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].Subnet,
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
},
},
} {
routes := tc.topology.Routes(kiloIface, privIface, pubIface, tc.local, NeverEncapsulate)
if diff := pretty.Compare(routes, tc.result); diff != "" {
t.Errorf("test case %q: got diff: %v", tc.name, diff)
}
}
}
func TestConf(t *testing.T) {
nodes, key, port, kiloNet := setup(t)
for _, tc := range []struct {
name string
topology *Topology
result string
}{
{
name: "datacenter from a",
topology: mustTopo(t, nodes, DataCenterGranularity, nodes["a"].Name, port, key, kiloNet),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key2
Endpoint = 10.1.0.2:51820
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`,
},
{
name: "datacenter from b",
topology: mustTopo(t, nodes, DataCenterGranularity, nodes["b"].Name, port, key, kiloNet),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
`,
},
{
name: "datacenter from c",
topology: mustTopo(t, nodes, DataCenterGranularity, nodes["c"].Name, port, key, kiloNet),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
`,
},
{
name: "node from a",
topology: mustTopo(t, nodes, NodeGranularity, nodes["a"].Name, port, key, kiloNet),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key2
Endpoint = 10.1.0.2:51820
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32
[Peer]
PublicKey = key3
Endpoint = 10.1.0.3:51820
AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32
`,
},
{
name: "node from b",
topology: mustTopo(t, nodes, NodeGranularity, nodes["b"].Name, port, key, kiloNet),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer]
PublicKey = key3
Endpoint = 10.1.0.3:51820
AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32
`,
},
{
name: "node from c",
topology: mustTopo(t, nodes, NodeGranularity, nodes["c"].Name, port, key, kiloNet),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer]
PublicKey = key2
Endpoint = 10.1.0.2:51820
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32
`,
},
} {
conf, err := tc.topology.Conf()
if err != nil {
t.Errorf("test case %q: failed to generate conf: %v", tc.name, err)
}
if string(conf) != tc.result {
t.Errorf("test case %q: expected %s got %s", tc.name, tc.result, string(conf))
}
}
}
func TestFindLeader(t *testing.T) {
ip, e1, err := net.ParseCIDR("10.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse external IP CIDR: %v", err)
}
e1.IP = ip
ip, e2, err := net.ParseCIDR("8.8.8.8/32")
if err != nil {
t.Fatalf("failed to parse external IP CIDR: %v", err)
}
e2.IP = ip
nodes := []*Node{
{
Name: "a",
ExternalIP: e1,
},
{
Name: "b",
ExternalIP: e2,
},
{
Name: "c",
ExternalIP: e2,
},
{
Name: "d",
ExternalIP: e1,
Leader: true,
},
{
Name: "2",
ExternalIP: e2,
Leader: true,
},
}
for _, tc := range []struct {
name string
nodes []*Node
out int
}{
{
name: "nil",
nodes: nil,
out: 0,
},
{
name: "one",
nodes: []*Node{nodes[0]},
out: 0,
},
{
name: "non-leaders",
nodes: []*Node{nodes[0], nodes[1], nodes[2]},
out: 1,
},
{
name: "leaders",
nodes: []*Node{nodes[3], nodes[4]},
out: 1,
},
{
name: "public",
nodes: []*Node{nodes[1], nodes[2], nodes[4]},
out: 2,
},
{
name: "private",
nodes: []*Node{nodes[0], nodes[3]},
out: 1,
},
{
name: "all",
nodes: nodes,
out: 4,
},
} {
l := findLeader(tc.nodes)
if l != tc.out {
t.Errorf("test case %q: expected %d got %d", tc.name, tc.out, l)
}
}
}

173
pkg/route/route.go Normal file
View File

@@ -0,0 +1,173 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package route
import (
"errors"
"fmt"
"sync"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
// Table represents a routing table.
// Table can safely be used concurrently.
type Table struct {
errors chan error
mu sync.Mutex
routes map[string]*netlink.Route
subscribed bool
// Make these functions fields to allow
// for testing.
add func(*netlink.Route) error
del func(*netlink.Route) error
}
// NewTable generates a new table.
func NewTable() *Table {
return &Table{
errors: make(chan error),
routes: make(map[string]*netlink.Route),
add: netlink.RouteReplace,
del: func(r *netlink.Route) error {
name := routeToString(r)
if name == "" {
return errors.New("attempting to delete invalid route")
}
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to list routes before deletion: %v", err)
}
for _, route := range routes {
if routeToString(&route) == name {
return netlink.RouteDel(r)
}
}
return nil
},
}
}
// Run watches for changes to routes in the table and reconciles
// the table against the desired state.
func (t *Table) Run(stop <-chan struct{}) (<-chan error, error) {
t.mu.Lock()
if t.subscribed {
t.mu.Unlock()
return t.errors, nil
}
// Ensure a given instance only subscribes once.
t.subscribed = true
t.mu.Unlock()
events := make(chan netlink.RouteUpdate)
if err := netlink.RouteSubscribe(events, stop); err != nil {
return t.errors, fmt.Errorf("failed to subscribe to route events: %v", err)
}
go func() {
defer close(t.errors)
for {
var e netlink.RouteUpdate
select {
case e = <-events:
case <-stop:
return
}
switch e.Type {
// Watch for deleted routes to reconcile this table's routes.
case unix.RTM_DELROUTE:
t.mu.Lock()
for _, r := range t.routes {
// If any deleted route's destination matches a destination
// in the table, reset the corresponding route just in case.
if r.Dst.IP.Equal(e.Route.Dst.IP) && r.Dst.Mask.String() == e.Route.Dst.Mask.String() {
if err := t.add(r); err != nil {
nonBlockingSend(t.errors, fmt.Errorf("failed add route: %v", err))
}
}
}
t.mu.Unlock()
}
}
}()
return t.errors, nil
}
// CleanUp will clean up any routes created by the instance.
func (t *Table) CleanUp() error {
t.mu.Lock()
defer t.mu.Unlock()
for k, route := range t.routes {
if err := t.del(route); err != nil {
return fmt.Errorf("failed to delete route: %v", err)
}
delete(t.routes, k)
}
return nil
}
// Set idempotently overwrites any routes previously defined
// for the table with the given set of routes.
func (t *Table) Set(routes []*netlink.Route) error {
r := make(map[string]*netlink.Route)
for _, route := range routes {
if route == nil {
continue
}
r[routeToString(route)] = route
}
t.mu.Lock()
defer t.mu.Unlock()
for k := range t.routes {
if _, ok := r[k]; !ok {
if err := t.del(t.routes[k]); err != nil {
return fmt.Errorf("failed to delete route: %v", err)
}
delete(t.routes, k)
}
}
for k := range r {
if _, ok := t.routes[k]; !ok {
if err := t.add(r[k]); err != nil {
return fmt.Errorf("failed to add route %q: %v", routeToString(r[k]), err)
}
t.routes[k] = r[k]
}
}
return nil
}
func nonBlockingSend(errors chan<- error, err error) {
select {
case errors <- err:
default:
}
}
func routeToString(route *netlink.Route) string {
if route == nil || route.Dst == nil {
return ""
}
src := "-"
if route.Src != nil {
src = route.Src.String()
}
gw := "-"
if route.Gw != nil {
gw = route.Gw.String()
}
return fmt.Sprintf("dst: %s, via: %s, src: %s, dev: %d", route.Dst.String(), gw, src, route.LinkIndex)
}

262
pkg/route/route_test.go Normal file
View File

@@ -0,0 +1,262 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package route
import (
"errors"
"net"
"testing"
"github.com/vishvananda/netlink"
)
func TestSet(t *testing.T) {
_, c1, err := net.ParseCIDR("10.2.0.0/24")
if err != nil {
t.Fatalf("failed to parse CIDR: %v", err)
}
_, c2, err := net.ParseCIDR("10.1.0.0/24")
if err != nil {
t.Fatalf("failed to parse CIDR: %v", err)
}
add := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
return func(r *netlink.Route) error {
backend[routeToString(r)] = r
return nil
}
}
del := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
return func(r *netlink.Route) error {
delete(backend, routeToString(r))
return nil
}
}
adderr := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
return func(r *netlink.Route) error {
return errors.New(routeToString(r))
}
}
for _, tc := range []struct {
name string
routes []*netlink.Route
err bool
add func(map[string]*netlink.Route) func(*netlink.Route) error
del func(map[string]*netlink.Route) func(*netlink.Route) error
}{
{
name: "empty",
routes: nil,
err: false,
add: add,
del: del,
},
{
name: "single",
routes: []*netlink.Route{
{
Dst: c1,
Gw: net.ParseIP("10.1.0.1"),
},
},
err: false,
add: add,
del: del,
},
{
name: "multiple",
routes: []*netlink.Route{
{
Dst: c1,
Gw: net.ParseIP("10.1.0.1"),
},
{
Dst: c2,
Gw: net.ParseIP("127.0.0.1"),
},
},
err: false,
add: add,
del: del,
},
{
name: "err empty",
routes: nil,
err: false,
add: adderr,
del: del,
},
{
name: "err",
routes: []*netlink.Route{
{
Dst: c1,
Gw: net.ParseIP("10.1.0.1"),
},
{
Dst: c2,
Gw: net.ParseIP("127.0.0.1"),
},
},
err: true,
add: adderr,
del: del,
},
} {
backend := make(map[string]*netlink.Route)
a := tc.add(backend)
d := tc.del(backend)
table := NewTable()
table.add = a
table.del = d
if err := table.Set(tc.routes); (err != nil) != tc.err {
no := "no"
if tc.err {
no = "an"
}
t.Errorf("test case %q: got unexpected result: expected %s error, got %v", tc.name, no, err)
}
// If no error was expected, then compare the backend to the input.
if !tc.err {
for _, r := range tc.routes {
r1 := backend[routeToString(r)]
r2 := table.routes[routeToString(r)]
if r != r1 || r != r2 {
t.Errorf("test case %q: expected all routes to be equal: expected %v, got %v and %v", tc.name, r, r1, r2)
}
}
}
}
}
func TestCleanUp(t *testing.T) {
_, c1, err := net.ParseCIDR("10.2.0.0/24")
if err != nil {
t.Fatalf("failed to parse CIDR: %v", err)
}
_, c2, err := net.ParseCIDR("10.1.0.0/24")
if err != nil {
t.Fatalf("failed to parse CIDR: %v", err)
}
add := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
return func(r *netlink.Route) error {
backend[routeToString(r)] = r
return nil
}
}
del := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
return func(r *netlink.Route) error {
delete(backend, routeToString(r))
return nil
}
}
delerr := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
return func(r *netlink.Route) error {
return errors.New(routeToString(r))
}
}
for _, tc := range []struct {
name string
routes []*netlink.Route
err bool
add func(map[string]*netlink.Route) func(*netlink.Route) error
del func(map[string]*netlink.Route) func(*netlink.Route) error
}{
{
name: "empty",
routes: nil,
err: false,
add: add,
del: del,
},
{
name: "single",
routes: []*netlink.Route{
{
Dst: c1,
Gw: net.ParseIP("10.1.0.1"),
},
},
err: false,
add: add,
del: del,
},
{
name: "multiple",
routes: []*netlink.Route{
{
Dst: c1,
Gw: net.ParseIP("10.1.0.1"),
},
{
Dst: c2,
Gw: net.ParseIP("127.0.0.1"),
},
},
err: false,
add: add,
del: del,
},
{
name: "err empty",
routes: nil,
err: false,
add: add,
del: delerr,
},
{
name: "err",
routes: []*netlink.Route{
{
Dst: c1,
Gw: net.ParseIP("10.1.0.1"),
},
{
Dst: c2,
Gw: net.ParseIP("127.0.0.1"),
},
},
err: true,
add: add,
del: delerr,
},
} {
backend := make(map[string]*netlink.Route)
a := tc.add(backend)
d := tc.del(backend)
table := NewTable()
table.add = a
table.del = d
if err := table.Set(tc.routes); err != nil {
t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
}
if err := table.CleanUp(); (err != nil) != tc.err {
no := "no"
if tc.err {
no = "an"
}
t.Errorf("test case %q: got unexpected result: expected %s error, got %v", tc.name, no, err)
}
// If no error was expected, then compare the backend to the input.
if !tc.err {
for _, r := range tc.routes {
r1 := backend[routeToString(r)]
r2 := table.routes[routeToString(r)]
if r1 != nil || r2 != nil {
t.Errorf("test case %q: expected all routes to be nil: expected got %v and %v", tc.name, r1, r2)
}
}
}
}
}

18
pkg/version/version.go Normal file
View File

@@ -0,0 +1,18 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package version
// Version is the version of Kilo.
var Version = "was not built properly"

183
pkg/wireguard/wireguard.go Normal file
View File

@@ -0,0 +1,183 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wireguard
import (
"bytes"
"fmt"
"os/exec"
"regexp"
"sort"
"strconv"
"github.com/vishvananda/netlink"
"gopkg.in/ini.v1"
)
type wgLink struct {
a netlink.LinkAttrs
t string
}
func (w wgLink) Attrs() *netlink.LinkAttrs {
return &w.a
}
func (w wgLink) Type() string {
return w.t
}
// New creates a new WireGuard interface.
func New(prefix string) (int, error) {
links, err := netlink.LinkList()
if err != nil {
return 0, fmt.Errorf("failed to list links: %v", err)
}
max := 0
re := regexp.MustCompile(fmt.Sprintf("^%s([0-9]+)$", prefix))
for _, link := range links {
if matches := re.FindStringSubmatch(link.Attrs().Name); len(matches) == 2 {
i, err := strconv.Atoi(matches[1])
if err != nil {
// This should never happen.
return 0, fmt.Errorf("failed to parse digits as an integer: %v", err)
}
if i >= max {
max = i + 1
}
}
}
name := fmt.Sprintf("%s%d", prefix, max)
wl := wgLink{a: netlink.NewLinkAttrs(), t: "wireguard"}
wl.a.Name = name
if err := netlink.LinkAdd(wl); err != nil {
return 0, fmt.Errorf("failed to create interface %s: %v", name, err)
}
link, err := netlink.LinkByName(name)
if err != nil {
return 0, fmt.Errorf("failed to get interface index: %v", err)
}
return link.Attrs().Index, nil
}
// Keys generates a WireGuard private and public key-pair.
func Keys() ([]byte, []byte, error) {
private, err := GenKey()
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private key: %v", err)
}
public, err := PubKey(private)
return private, public, err
}
// GenKey generates a WireGuard private key.
func GenKey() ([]byte, error) {
return exec.Command("wg", "genkey").Output()
}
// PubKey generates a WireGuard public key for a given private key.
func PubKey(key []byte) ([]byte, error) {
cmd := exec.Command("wg", "pubkey")
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("failed to open pipe to stdin: %v", err)
}
go func() {
defer stdin.Close()
stdin.Write(key)
}()
public, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("failed to generate public key: %v", err)
}
return public, nil
}
// SetConf applies a WireGuard configuration file to the given interface.
func SetConf(iface string, path string) error {
cmd := exec.Command("wg", "setconf", iface, path)
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to apply the WireGuard configuration: %s", stderr.String())
}
return nil
}
// ShowConf gets the WireGuard configuration for the given interface.
func ShowConf(iface string) ([]byte, error) {
cmd := exec.Command("wg", "showconf", iface)
var stderr, stdout bytes.Buffer
cmd.Stderr = &stderr
cmd.Stdout = &stdout
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("failed to read the WireGuard configuration: %s", stderr.String())
}
return stdout.Bytes(), nil
}
// CompareConf compares two WireGuard configurations.
// It returns true if they are equal, false if they are not,
// and any error that was encountered.
// Note: CompareConf only goes one level deep, as WireGuard
// configurations are not nested further than that.
func CompareConf(a, b []byte) (bool, error) {
iniA, err := ini.Load(a)
if err != nil {
return false, fmt.Errorf("failed to parse configuration: %v", err)
}
iniB, err := ini.Load(b)
if err != nil {
return false, fmt.Errorf("failed to parse configuration: %v", err)
}
secsA, secsB := iniA.SectionStrings(), iniB.SectionStrings()
if len(secsA) != len(secsB) {
return false, nil
}
sort.Strings(secsA)
sort.Strings(secsB)
var keysA, keysB []string
var valsA, valsB []string
for i := range secsA {
if secsA[i] != secsB[i] {
return false, nil
}
keysA, keysB = iniA.Section(secsA[i]).KeyStrings(), iniB.Section(secsB[i]).KeyStrings()
if len(keysA) != len(keysB) {
return false, nil
}
sort.Strings(keysA)
sort.Strings(keysB)
for j := range keysA {
if keysA[j] != keysB[j] {
return false, nil
}
valsA, valsB = iniA.Section(secsA[i]).Key(keysA[j]).Strings(","), iniB.Section(secsB[i]).Key(keysB[j]).Strings(",")
if len(valsA) != len(valsB) {
return false, nil
}
sort.Strings(valsA)
sort.Strings(valsB)
for k := range valsA {
if valsA[k] != valsB[k] {
return false, nil
}
}
}
}
return true, nil
}

View File

@@ -0,0 +1,143 @@
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wireguard
import (
"testing"
)
func TestCompareConf(t *testing.T) {
for _, tc := range []struct {
name string
a []byte
b []byte
out bool
}{
{
name: "empty",
a: []byte{},
b: []byte{},
out: true,
},
{
name: "key and value order",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
ListenPort = 51820
PrivateKey = private
[Peer]
PublicKey = key
AllowedIPs = 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32, 10.2.2.0/24
Endpoint = 10.1.0.2:51820
`),
out: true,
},
{
name: "whitespace",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
PrivateKey=private
ListenPort=51820
[Peer]
Endpoint=10.1.0.2:51820
PublicKey=key
AllowedIPs=10.2.2.0/24,192.168.0.1/32,10.2.3.0/24,192.168.0.2/32,10.4.0.2/32
`),
out: true,
},
{
name: "missing key",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
out: false,
},
{
name: "section order",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
[Interface]
PrivateKey = private
ListenPort = 51820
`),
out: true,
},
{
name: "one empty",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(``),
out: false,
},
} {
equal, err := CompareConf(tc.a, tc.b)
if err != nil {
t.Errorf("test case %q: got unexpected error: %v", tc.name, err)
}
if equal != tc.out {
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
}
}
}