init
This commit is contained in:
59
pkg/iproute/ipip.go
Normal file
59
pkg/iproute/ipip.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package iproute
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
const (
|
||||
ipipHeaderSize = 20
|
||||
tunnelName = "tunl0"
|
||||
)
|
||||
|
||||
// NewIPIP creates an IPIP interface using the base interface
|
||||
// to derive the tunnel's MTU.
|
||||
func NewIPIP(baseIndex int) (int, error) {
|
||||
link, err := netlink.LinkByName(tunnelName)
|
||||
if err != nil {
|
||||
// If we failed to find the tunnel, then it probably simply does not exist.
|
||||
cmd := exec.Command("ip", "tunnel", "add", tunnelName, "mode", "ipip")
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return 0, fmt.Errorf("failed to create IPIP tunnel: %s", stderr.String())
|
||||
}
|
||||
link, err = netlink.LinkByName(tunnelName)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get tunnel device: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
base, err := netlink.LinkByIndex(baseIndex)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get base device: %v", err)
|
||||
}
|
||||
|
||||
mtu := base.Attrs().MTU - ipipHeaderSize
|
||||
if err = netlink.LinkSetMTU(link, mtu); err != nil {
|
||||
return 0, fmt.Errorf("failed to set tunnel MTU: %v", err)
|
||||
}
|
||||
|
||||
return link.Attrs().Index, nil
|
||||
}
|
70
pkg/iproute/iproute.go
Normal file
70
pkg/iproute/iproute.go
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package iproute
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// RemoveInterface removes an interface.
|
||||
func RemoveInterface(index int) error {
|
||||
link, err := netlink.LinkByIndex(index)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get link: %s", err)
|
||||
}
|
||||
return netlink.LinkDel(link)
|
||||
}
|
||||
|
||||
// Set sets the interface up or down.
|
||||
func Set(index int, up bool) error {
|
||||
link, err := netlink.LinkByIndex(index)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get link: %s", err)
|
||||
}
|
||||
if up {
|
||||
return netlink.LinkSetUp(link)
|
||||
}
|
||||
return netlink.LinkSetDown(link)
|
||||
}
|
||||
|
||||
// SetAddress sets the IP address of an interface.
|
||||
func SetAddress(index int, cidr *net.IPNet) error {
|
||||
link, err := netlink.LinkByIndex(index)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get link: %s", err)
|
||||
}
|
||||
addrs, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l := len(addrs)
|
||||
for _, addr := range addrs {
|
||||
if addr.IP.Equal(cidr.IP) && addr.Mask.String() == cidr.Mask.String() {
|
||||
continue
|
||||
}
|
||||
if err := netlink.AddrDel(link, &addr); err != nil {
|
||||
return fmt.Errorf("failed to delete address: %s", err)
|
||||
}
|
||||
l--
|
||||
}
|
||||
// The only address left is the desired address, so quit.
|
||||
if l == 1 {
|
||||
return nil
|
||||
}
|
||||
return netlink.AddrReplace(link, &netlink.Addr{IPNet: cidr})
|
||||
}
|
199
pkg/ipset/ipset.go
Normal file
199
pkg/ipset/ipset.go
Normal file
@@ -0,0 +1,199 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ipset
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Set represents an ipset.
|
||||
// Set can safely be used concurrently.
|
||||
type Set struct {
|
||||
errors chan error
|
||||
hosts map[string]struct{}
|
||||
mu sync.Mutex
|
||||
name string
|
||||
subscribed bool
|
||||
|
||||
// Make these functions fields to allow
|
||||
// for testing.
|
||||
add func(string) error
|
||||
del func(string) error
|
||||
}
|
||||
|
||||
func setExists(name string) (bool, error) {
|
||||
cmd := exec.Command("ipset", "list", "-n")
|
||||
var stderr, stdout bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
cmd.Stdout = &stdout
|
||||
if err := cmd.Run(); err != nil {
|
||||
return false, fmt.Errorf("failed to check for set %s: %s", name, stderr.String())
|
||||
}
|
||||
return bytes.Contains(stdout.Bytes(), []byte(name)), nil
|
||||
}
|
||||
|
||||
func hostInSet(set, name string) (bool, error) {
|
||||
cmd := exec.Command("ipset", "list", set)
|
||||
var stderr, stdout bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
cmd.Stdout = &stdout
|
||||
if err := cmd.Run(); err != nil {
|
||||
return false, fmt.Errorf("failed to check for host %s: %s", name, stderr.String())
|
||||
}
|
||||
return bytes.Contains(stdout.Bytes(), []byte(name)), nil
|
||||
}
|
||||
|
||||
// New generates a new ipset.
|
||||
func New(name string) *Set {
|
||||
return &Set{
|
||||
errors: make(chan error),
|
||||
hosts: make(map[string]struct{}),
|
||||
name: name,
|
||||
|
||||
add: func(ip string) error {
|
||||
ok, err := hostInSet(name, ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
cmd := exec.Command("ipset", "add", name, ip)
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to add host %s to set %s: %s", ip, name, stderr.String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
del: func(ip string) error {
|
||||
ok, err := hostInSet(name, ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
cmd := exec.Command("ipset", "del", name, ip)
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to remove host %s from set %s: %s", ip, name, stderr.String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Run watches for changes to the ipset and reconciles
|
||||
// the ipset against the desired state.
|
||||
func (s *Set) Run(stop <-chan struct{}) (<-chan error, error) {
|
||||
s.mu.Lock()
|
||||
if s.subscribed {
|
||||
s.mu.Unlock()
|
||||
return s.errors, nil
|
||||
}
|
||||
// Ensure a given instance only subscribes once.
|
||||
s.subscribed = true
|
||||
s.mu.Unlock()
|
||||
go func() {
|
||||
defer close(s.errors)
|
||||
for {
|
||||
select {
|
||||
case <-time.After(2 * time.Second):
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
ok, err := setExists(s.name)
|
||||
if err != nil {
|
||||
nonBlockingSend(s.errors, err)
|
||||
}
|
||||
// The set does not exist so wait and try again later.
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
s.mu.Lock()
|
||||
for h := range s.hosts {
|
||||
if err := s.add(h); err != nil {
|
||||
nonBlockingSend(s.errors, err)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
return s.errors, nil
|
||||
}
|
||||
|
||||
// CleanUp will clean up any hosts added to the set.
|
||||
func (s *Set) CleanUp() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for h := range s.hosts {
|
||||
if err := s.del(h); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(s.hosts, h)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set idempotently overwrites any hosts previously defined
|
||||
// for the ipset with the given hosts.
|
||||
func (s *Set) Set(hosts []net.IP) error {
|
||||
h := make(map[string]struct{})
|
||||
for _, host := range hosts {
|
||||
if host == nil {
|
||||
continue
|
||||
}
|
||||
h[host.String()] = struct{}{}
|
||||
}
|
||||
exists, err := setExists(s.name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k := range s.hosts {
|
||||
if _, ok := h[k]; !ok {
|
||||
if exists {
|
||||
if err := s.del(k); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
delete(s.hosts, k)
|
||||
}
|
||||
}
|
||||
for k := range h {
|
||||
if _, ok := s.hosts[k]; !ok {
|
||||
if exists {
|
||||
if err := s.add(k); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.hosts[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nonBlockingSend(errors chan<- error, err error) {
|
||||
select {
|
||||
case errors <- err:
|
||||
default:
|
||||
}
|
||||
}
|
92
pkg/iptables/fake.go
Normal file
92
pkg/iptables/fake.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
)
|
||||
|
||||
type statusExiter interface {
|
||||
ExitStatus() int
|
||||
}
|
||||
|
||||
var _ statusExiter = (*iptables.Error)(nil)
|
||||
var _ statusExiter = statusError(0)
|
||||
|
||||
type statusError int
|
||||
|
||||
func (s statusError) Error() string {
|
||||
return fmt.Sprintf("%d", s)
|
||||
}
|
||||
|
||||
func (s statusError) ExitStatus() int {
|
||||
return int(s)
|
||||
}
|
||||
|
||||
type fakeClient map[string]Rule
|
||||
|
||||
var _ iptablesClient = fakeClient(nil)
|
||||
|
||||
func (f fakeClient) AppendUnique(table, chain string, spec ...string) error {
|
||||
r := &rule{table, chain, spec, nil}
|
||||
f[r.String()] = r
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeClient) Delete(table, chain string, spec ...string) error {
|
||||
r := &rule{table, chain, spec, nil}
|
||||
delete(f, r.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeClient) Exists(table, chain string, spec ...string) (bool, error) {
|
||||
r := &rule{table, chain, spec, nil}
|
||||
_, ok := f[r.String()]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (f fakeClient) ClearChain(table, name string) error {
|
||||
c := &chain{table, name, nil}
|
||||
for k := range f {
|
||||
if strings.HasPrefix(k, c.String()) {
|
||||
delete(f, k)
|
||||
}
|
||||
}
|
||||
f[c.String()] = c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeClient) DeleteChain(table, name string) error {
|
||||
c := &chain{table, name, nil}
|
||||
for k := range f {
|
||||
if strings.HasPrefix(k, c.String()) {
|
||||
return fmt.Errorf("cannot delete chain %s; rules exist", name)
|
||||
}
|
||||
}
|
||||
delete(f, c.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeClient) NewChain(table, name string) error {
|
||||
c := &chain{table, name, nil}
|
||||
if _, ok := f[c.String()]; ok {
|
||||
return statusError(1)
|
||||
}
|
||||
f[c.String()] = c
|
||||
return nil
|
||||
}
|
289
pkg/iptables/iptables.go
Normal file
289
pkg/iptables/iptables.go
Normal file
@@ -0,0 +1,289 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
)
|
||||
|
||||
type iptablesClient interface {
|
||||
AppendUnique(string, string, ...string) error
|
||||
Delete(string, string, ...string) error
|
||||
Exists(string, string, ...string) (bool, error)
|
||||
ClearChain(string, string) error
|
||||
DeleteChain(string, string) error
|
||||
NewChain(string, string) error
|
||||
}
|
||||
|
||||
// rule represents an iptables rule.
|
||||
type rule struct {
|
||||
table string
|
||||
chain string
|
||||
spec []string
|
||||
client iptablesClient
|
||||
}
|
||||
|
||||
func (r *rule) Add() error {
|
||||
if err := r.client.AppendUnique(r.table, r.chain, r.spec...); err != nil {
|
||||
return fmt.Errorf("failed to add iptables rule: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rule) Delete() error {
|
||||
// Ignore the returned error as an error likely means
|
||||
// that the rule doesn't exist, which is fine.
|
||||
r.client.Delete(r.table, r.chain, r.spec...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rule) Exists() (bool, error) {
|
||||
return r.client.Exists(r.table, r.chain, r.spec...)
|
||||
}
|
||||
|
||||
func (r *rule) String() string {
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s_%s_%s", r.table, r.chain, strings.Join(r.spec, "_"))
|
||||
}
|
||||
|
||||
// chain represents an iptables chain.
|
||||
type chain struct {
|
||||
table string
|
||||
chain string
|
||||
client iptablesClient
|
||||
}
|
||||
|
||||
func (c *chain) Add() error {
|
||||
if err := c.client.ClearChain(c.table, c.chain); err != nil {
|
||||
return fmt.Errorf("failed to add iptables chain: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *chain) Delete() error {
|
||||
// The chain must be empty before it can be deleted.
|
||||
if err := c.client.ClearChain(c.table, c.chain); err != nil {
|
||||
return fmt.Errorf("failed to clear iptables chain: %v", err)
|
||||
}
|
||||
// Ignore the returned error as an error likely means
|
||||
// that the chain doesn't exist, which is fine.
|
||||
c.client.DeleteChain(c.table, c.chain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *chain) Exists() (bool, error) {
|
||||
// The code for "chain already exists".
|
||||
existsErr := 1
|
||||
err := c.client.NewChain(c.table, c.chain)
|
||||
se, ok := err.(statusExiter)
|
||||
switch {
|
||||
case err == nil:
|
||||
// If there was no error adding a new chain, then it did not exist.
|
||||
// Delete it and return false.
|
||||
c.client.DeleteChain(c.table, c.chain)
|
||||
return false, nil
|
||||
case ok && se.ExitStatus() == existsErr:
|
||||
return true, nil
|
||||
default:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *chain) String() string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s_%s", c.table, c.chain)
|
||||
}
|
||||
|
||||
// Rule is an interface for interacting with iptables objects.
|
||||
type Rule interface {
|
||||
Add() error
|
||||
Delete() error
|
||||
Exists() (bool, error)
|
||||
String() string
|
||||
}
|
||||
|
||||
// Controller is able to reconcile a given set of iptables rules.
|
||||
type Controller struct {
|
||||
client iptablesClient
|
||||
errors chan error
|
||||
rules map[string]Rule
|
||||
mu sync.Mutex
|
||||
subscribed bool
|
||||
}
|
||||
|
||||
// New generates a new iptables rules controller.
|
||||
// It expects an IP address length to determine
|
||||
// whether to operate in IPv4 or IPv6 mode.
|
||||
func New(ipLength int) (*Controller, error) {
|
||||
p := iptables.ProtocolIPv4
|
||||
if ipLength == net.IPv6len {
|
||||
p = iptables.ProtocolIPv6
|
||||
}
|
||||
client, err := iptables.NewWithProtocol(p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create iptables client: %v", err)
|
||||
}
|
||||
return &Controller{
|
||||
client: client,
|
||||
errors: make(chan error),
|
||||
rules: make(map[string]Rule),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run watches for changes to iptables rules and reconciles
|
||||
// the rules against the desired state.
|
||||
func (c *Controller) Run(stop <-chan struct{}) (<-chan error, error) {
|
||||
c.mu.Lock()
|
||||
if c.subscribed {
|
||||
c.mu.Unlock()
|
||||
return c.errors, nil
|
||||
}
|
||||
// Ensure a given instance only subscribes once.
|
||||
c.subscribed = true
|
||||
c.mu.Unlock()
|
||||
go func() {
|
||||
defer close(c.errors)
|
||||
for {
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
for _, r := range c.rules {
|
||||
ok, err := r.Exists()
|
||||
if err != nil {
|
||||
nonBlockingSend(c.errors, fmt.Errorf("failed to check if rule exists: %v", err))
|
||||
}
|
||||
if !ok {
|
||||
if err := r.Add(); err != nil {
|
||||
nonBlockingSend(c.errors, fmt.Errorf("failed to add rule: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
return c.errors, nil
|
||||
}
|
||||
|
||||
// Set idempotently overwrites any iptables rules previously defined
|
||||
// for the controller with the given set of rules.
|
||||
func (c *Controller) Set(rules []Rule) error {
|
||||
r := make(map[string]struct{})
|
||||
for i := range rules {
|
||||
if rules[i] == nil {
|
||||
continue
|
||||
}
|
||||
switch v := rules[i].(type) {
|
||||
case *rule:
|
||||
v.client = c.client
|
||||
case *chain:
|
||||
v.client = c.client
|
||||
}
|
||||
r[rules[i].String()] = struct{}{}
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for k, rule := range c.rules {
|
||||
if _, ok := r[k]; !ok {
|
||||
if err := rule.Delete(); err != nil {
|
||||
return fmt.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
delete(c.rules, k)
|
||||
}
|
||||
}
|
||||
// Iterate over the slice rather than the map
|
||||
// to ensure the rules are added in order.
|
||||
for _, rule := range rules {
|
||||
if _, ok := c.rules[rule.String()]; !ok {
|
||||
if err := rule.Add(); err != nil {
|
||||
return fmt.Errorf("failed to add rule: %v", err)
|
||||
}
|
||||
c.rules[rule.String()] = rule
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanUp will clean up any rules created by the controller.
|
||||
func (c *Controller) CleanUp() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for k, rule := range c.rules {
|
||||
if err := rule.Delete(); err != nil {
|
||||
return fmt.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
delete(c.rules, k)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncapsulateRules returns a set of iptables rules that are necessary
|
||||
// when traffic between nodes must be encapsulated.
|
||||
func EncapsulateRules(nodes []*net.IPNet) []Rule {
|
||||
var rules []Rule
|
||||
for _, n := range nodes {
|
||||
// Accept encapsulated traffic from peers.
|
||||
rules = append(rules, &rule{"filter", "INPUT", []string{"-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-s", n.IP.String(), "-p", "4", "-j", "ACCEPT"}, nil})
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
// ForwardRules returns a set of iptables rules that are necessary
|
||||
// when traffic must be forwarded for the overlay.
|
||||
func ForwardRules(subnet *net.IPNet) []Rule {
|
||||
s := subnet.String()
|
||||
return []Rule{
|
||||
// Forward traffic to and from the overlay.
|
||||
&rule{"filter", "FORWARD", []string{"-s", s, "-j", "ACCEPT"}, nil},
|
||||
&rule{"filter", "FORWARD", []string{"-d", s, "-j", "ACCEPT"}, nil},
|
||||
}
|
||||
}
|
||||
|
||||
// MasqueradeRules returns a set of iptables rules that are necessary
|
||||
// when traffic must be masqueraded for Kilo.
|
||||
func MasqueradeRules(subnet, localPodSubnet *net.IPNet, remotePodSubnet []*net.IPNet) []Rule {
|
||||
var rules []Rule
|
||||
rules = append(rules, &chain{"mangle", "KILO-MARK", nil})
|
||||
rules = append(rules, &rule{"mangle", "PREROUTING", []string{"-m", "comment", "--comment", "Kilo: jump to mark chain", "-i", "kilo+", "-j", "KILO-MARK"}, nil})
|
||||
rules = append(rules, &rule{"mangle", "KILO-MARK", []string{"-m", "comment", "--comment", "Kilo: do not mark packets destined for the local Pod subnet", "-d", localPodSubnet.String(), "-j", "RETURN"}, nil})
|
||||
if subnet != nil {
|
||||
rules = append(rules, &rule{"mangle", "KILO-MARK", []string{"-m", "comment", "--comment", "Kilo: do not mark packets destined for the local private subnet", "-d", subnet.String(), "-j", "RETURN"}, nil})
|
||||
}
|
||||
rules = append(rules, &rule{"mangle", "KILO-MARK", []string{"-m", "comment", "--comment", "Kilo: remaining packets should be marked for NAT", "-j", "MARK", "--set-xmark", "0x1107/0x1107"}, nil})
|
||||
rules = append(rules, &rule{"nat", "POSTROUTING", []string{"-m", "comment", "--comment", "Kilo: NAT packets from Kilo interface", "-m", "mark", "--mark", "0x1107/0x1107", "-j", "MASQUERADE"}, nil})
|
||||
for _, r := range remotePodSubnet {
|
||||
rules = append(rules, &rule{"nat", "POSTROUTING", []string{"-m", "comment", "--comment", "Kilo: NAT packets from local pod subnet to remote pod subnets", "-s", localPodSubnet.String(), "-d", r.String(), "-j", "MASQUERADE"}, nil})
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
func nonBlockingSend(errors chan<- error, err error) {
|
||||
select {
|
||||
case errors <- err:
|
||||
default:
|
||||
}
|
||||
}
|
101
pkg/iptables/iptables_test.go
Normal file
101
pkg/iptables/iptables_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
var rules = []Rule{
|
||||
&rule{"filter", "FORWARD", []string{"-s", "10.4.0.0/16", "-j", "ACCEPT"}, nil},
|
||||
&rule{"filter", "FORWARD", []string{"-d", "10.4.0.0/16", "-j", "ACCEPT"}, nil},
|
||||
}
|
||||
|
||||
func newController() *Controller {
|
||||
return &Controller{
|
||||
rules: make(map[string]Rule),
|
||||
}
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
rules []Rule
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
rules: nil,
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
rules: []Rule{rules[0]},
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
rules: []Rule{rules[0], rules[1]},
|
||||
},
|
||||
} {
|
||||
backend := make(map[string]Rule)
|
||||
controller := newController()
|
||||
controller.client = fakeClient(backend)
|
||||
if err := controller.Set(tc.rules); err != nil {
|
||||
t.Fatalf("test case %q: got unexpected error: %v", tc.name, err)
|
||||
}
|
||||
for _, r := range tc.rules {
|
||||
r1 := backend[r.String()]
|
||||
r2 := controller.rules[r.String()]
|
||||
if r.String() != r1.String() || r.String() != r2.String() {
|
||||
t.Errorf("test case %q: expected all rules to be equal: expected %v, got %v and %v", tc.name, r, r1, r2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanUp(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
rules []Rule
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
rules: nil,
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
rules: []Rule{rules[0]},
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
rules: []Rule{rules[0], rules[1]},
|
||||
},
|
||||
} {
|
||||
backend := make(map[string]Rule)
|
||||
controller := newController()
|
||||
controller.client = fakeClient(backend)
|
||||
if err := controller.Set(tc.rules); err != nil {
|
||||
t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
|
||||
}
|
||||
if err := controller.CleanUp(); err != nil {
|
||||
t.Errorf("test case %q: got unexpected error: %v", tc.name, err)
|
||||
}
|
||||
for _, r := range tc.rules {
|
||||
r1 := backend[r.String()]
|
||||
r2 := controller.rules[r.String()]
|
||||
if r1 != nil || r2 != nil {
|
||||
t.Errorf("test case %q: expected all rules to be nil: expected got %v and %v", tc.name, r1, r2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
229
pkg/k8s/backend.go
Normal file
229
pkg/k8s/backend.go
Normal file
@@ -0,0 +1,229 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package k8s
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"k8s.io/api/core/v1"
|
||||
"k8s.io/apimachinery/pkg/labels"
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
"k8s.io/apimachinery/pkg/util/strategicpatch"
|
||||
v1informers "k8s.io/client-go/informers/core/v1"
|
||||
"k8s.io/client-go/kubernetes"
|
||||
v1listers "k8s.io/client-go/listers/core/v1"
|
||||
"k8s.io/client-go/tools/cache"
|
||||
|
||||
"github.com/squat/kilo/pkg/mesh"
|
||||
)
|
||||
|
||||
const (
|
||||
// Backend is the name of this mesh backend.
|
||||
Backend = "kubernetes"
|
||||
externalIPAnnotationKey = "kilo.squat.ai/external-ip"
|
||||
forceExternalIPAnnotationKey = "kilo.squat.ai/force-external-ip"
|
||||
internalIPAnnotationKey = "kilo.squat.ai/internal-ip"
|
||||
keyAnnotationKey = "kilo.squat.ai/key"
|
||||
leaderAnnotationKey = "kilo.squat.ai/leader"
|
||||
locationAnnotationKey = "kilo.squat.ai/location"
|
||||
regionLabelKey = "failure-domain.beta.kubernetes.io/region"
|
||||
jsonPatchSlash = "~1"
|
||||
jsonRemovePatch = `{"op": "remove", "path": "%s"}`
|
||||
)
|
||||
|
||||
type backend struct {
|
||||
client kubernetes.Interface
|
||||
events chan *mesh.Event
|
||||
informer cache.SharedIndexInformer
|
||||
lister v1listers.NodeLister
|
||||
}
|
||||
|
||||
// New creates a new instance of a mesh.Backend.
|
||||
func New(client kubernetes.Interface) mesh.Backend {
|
||||
informer := v1informers.NewNodeInformer(client, 5*time.Minute, nil)
|
||||
|
||||
b := &backend{
|
||||
client: client,
|
||||
events: make(chan *mesh.Event),
|
||||
informer: informer,
|
||||
lister: v1listers.NewNodeLister(informer.GetIndexer()),
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// CleanUp removes configuration applied to the backend.
|
||||
func (b *backend) CleanUp(name string) error {
|
||||
patch := []byte("[" + strings.Join([]string{
|
||||
fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(externalIPAnnotationKey, "/", jsonPatchSlash, 1))),
|
||||
fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(internalIPAnnotationKey, "/", jsonPatchSlash, 1))),
|
||||
fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(keyAnnotationKey, "/", jsonPatchSlash, 1))),
|
||||
}, ",") + "]")
|
||||
if _, err := b.client.CoreV1().Nodes().Patch(name, types.JSONPatchType, patch); err != nil {
|
||||
return fmt.Errorf("failed to patch node: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get gets a single Node by name.
|
||||
func (b *backend) Get(name string) (*mesh.Node, error) {
|
||||
n, err := b.lister.Get(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return translateNode(n), nil
|
||||
}
|
||||
|
||||
// Init initializes the backend; for this backend that means
|
||||
// syncing the informer cache.
|
||||
func (b *backend) Init(stop <-chan struct{}) error {
|
||||
go b.informer.Run(stop)
|
||||
if ok := cache.WaitForCacheSync(stop, func() bool {
|
||||
return b.informer.HasSynced()
|
||||
}); !ok {
|
||||
return errors.New("failed to start sync node cache")
|
||||
}
|
||||
b.informer.AddEventHandler(
|
||||
cache.ResourceEventHandlerFuncs{
|
||||
AddFunc: func(obj interface{}) {
|
||||
n, ok := obj.(*v1.Node)
|
||||
if !ok {
|
||||
// Failed to decode Node; ignoring...
|
||||
return
|
||||
}
|
||||
b.events <- &mesh.Event{Type: mesh.AddEvent, Node: translateNode(n)}
|
||||
},
|
||||
UpdateFunc: func(_, obj interface{}) {
|
||||
n, ok := obj.(*v1.Node)
|
||||
if !ok {
|
||||
// Failed to decode Node; ignoring...
|
||||
return
|
||||
}
|
||||
b.events <- &mesh.Event{Type: mesh.UpdateEvent, Node: translateNode(n)}
|
||||
},
|
||||
DeleteFunc: func(obj interface{}) {
|
||||
n, ok := obj.(*v1.Node)
|
||||
if !ok {
|
||||
// Failed to decode Node; ignoring...
|
||||
return
|
||||
}
|
||||
b.events <- &mesh.Event{Type: mesh.DeleteEvent, Node: translateNode(n)}
|
||||
},
|
||||
},
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// List gets all the Nodes in the cluster.
|
||||
func (b *backend) List() ([]*mesh.Node, error) {
|
||||
ns, err := b.lister.List(labels.Everything())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodes := make([]*mesh.Node, len(ns))
|
||||
for i := range ns {
|
||||
nodes[i] = translateNode(ns[i])
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// Set sets the fields of a node.
|
||||
func (b *backend) Set(name string, node *mesh.Node) error {
|
||||
old, err := b.lister.Get(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find node: %v", err)
|
||||
}
|
||||
n := old.DeepCopy()
|
||||
n.ObjectMeta.Annotations[externalIPAnnotationKey] = node.ExternalIP.String()
|
||||
n.ObjectMeta.Annotations[internalIPAnnotationKey] = node.InternalIP.String()
|
||||
n.ObjectMeta.Annotations[keyAnnotationKey] = string(node.Key)
|
||||
oldData, err := json.Marshal(old)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newData, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
patch, err := strategicpatch.CreateTwoWayMergePatch(oldData, newData, v1.Node{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create patch for node %q: %v", n.Name, err)
|
||||
}
|
||||
if _, err = b.client.CoreV1().Nodes().Patch(name, types.StrategicMergePatchType, patch); err != nil {
|
||||
return fmt.Errorf("failed to patch node: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Watch returns a chan of node events.
|
||||
func (b *backend) Watch() <-chan *mesh.Event {
|
||||
return b.events
|
||||
}
|
||||
|
||||
// translateNode translates a Kubernetes Node to a mesh.Node.
|
||||
func translateNode(node *v1.Node) *mesh.Node {
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
_, subnet, err := net.ParseCIDR(node.Spec.PodCIDR)
|
||||
// The subnet should only ever fail to parse if the pod CIDR has not been set,
|
||||
// so in this case set the subnet to nil and let the node be updated.
|
||||
if err != nil {
|
||||
subnet = nil
|
||||
}
|
||||
_, leader := node.ObjectMeta.Annotations[leaderAnnotationKey]
|
||||
// Allow the region to be overridden by an explicit location.
|
||||
location, ok := node.ObjectMeta.Annotations[locationAnnotationKey]
|
||||
if !ok {
|
||||
location = node.ObjectMeta.Labels[regionLabelKey]
|
||||
}
|
||||
// Allow the external IP to be overridden.
|
||||
externalIP, ok := node.ObjectMeta.Annotations[forceExternalIPAnnotationKey]
|
||||
if !ok {
|
||||
externalIP = node.ObjectMeta.Annotations[externalIPAnnotationKey]
|
||||
}
|
||||
return &mesh.Node{
|
||||
// ExternalIP and InternalIP should only ever fail to parse if the
|
||||
// remote node's mesh has not yet set its IP address;
|
||||
// in this case the IP will be nil and
|
||||
// the mesh can wait for the node to be updated.
|
||||
ExternalIP: normalizeIP(externalIP),
|
||||
InternalIP: normalizeIP(node.ObjectMeta.Annotations[internalIPAnnotationKey]),
|
||||
Key: []byte(node.ObjectMeta.Annotations[keyAnnotationKey]),
|
||||
Leader: leader,
|
||||
Location: location,
|
||||
Name: node.Name,
|
||||
Subnet: subnet,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeIP(ip string) *net.IPNet {
|
||||
i, ipNet, _ := net.ParseCIDR(ip)
|
||||
if ipNet == nil {
|
||||
return ipNet
|
||||
}
|
||||
if ip4 := i.To4(); ip4 != nil {
|
||||
ipNet.IP = ip4
|
||||
return ipNet
|
||||
}
|
||||
ipNet.IP = i.To16()
|
||||
return ipNet
|
||||
}
|
145
pkg/k8s/backend_test.go
Normal file
145
pkg/k8s/backend_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package k8s
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
"k8s.io/api/core/v1"
|
||||
|
||||
"github.com/squat/kilo/pkg/mesh"
|
||||
)
|
||||
|
||||
func TestTranslateNode(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
annotations map[string]string
|
||||
labels map[string]string
|
||||
out *mesh.Node
|
||||
subnet string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
annotations: nil,
|
||||
out: &mesh.Node{},
|
||||
},
|
||||
{
|
||||
name: "invalid ip",
|
||||
annotations: map[string]string{
|
||||
externalIPAnnotationKey: "10.0.0.1",
|
||||
internalIPAnnotationKey: "10.0.0.1",
|
||||
},
|
||||
out: &mesh.Node{},
|
||||
},
|
||||
{
|
||||
name: "valid ip",
|
||||
annotations: map[string]string{
|
||||
externalIPAnnotationKey: "10.0.0.1/24",
|
||||
internalIPAnnotationKey: "10.0.0.2/32",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
ExternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid subnet",
|
||||
annotations: map[string]string{},
|
||||
out: &mesh.Node{},
|
||||
subnet: "foo",
|
||||
},
|
||||
{
|
||||
name: "normalize subnet",
|
||||
annotations: map[string]string{},
|
||||
out: &mesh.Node{
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
subnet: "10.2.0.1/24",
|
||||
},
|
||||
{
|
||||
name: "valid subnet",
|
||||
annotations: map[string]string{},
|
||||
out: &mesh.Node{
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
subnet: "10.2.1.0/24",
|
||||
},
|
||||
{
|
||||
name: "region",
|
||||
labels: map[string]string{
|
||||
regionLabelKey: "a",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Location: "a",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "region override",
|
||||
annotations: map[string]string{
|
||||
locationAnnotationKey: "b",
|
||||
},
|
||||
labels: map[string]string{
|
||||
regionLabelKey: "a",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Location: "b",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "external IP override",
|
||||
annotations: map[string]string{
|
||||
externalIPAnnotationKey: "10.0.0.1/24",
|
||||
forceExternalIPAnnotationKey: "10.0.0.2/24",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
ExternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complete",
|
||||
annotations: map[string]string{
|
||||
externalIPAnnotationKey: "10.0.0.1/24",
|
||||
forceExternalIPAnnotationKey: "10.0.0.2/24",
|
||||
internalIPAnnotationKey: "10.0.0.2/32",
|
||||
keyAnnotationKey: "foo",
|
||||
leaderAnnotationKey: "",
|
||||
locationAnnotationKey: "b",
|
||||
},
|
||||
labels: map[string]string{
|
||||
regionLabelKey: "a",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
ExternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
|
||||
Key: []byte("foo"),
|
||||
Leader: true,
|
||||
Location: "b",
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
subnet: "10.2.1.0/24",
|
||||
},
|
||||
} {
|
||||
n := &v1.Node{}
|
||||
n.ObjectMeta.Annotations = tc.annotations
|
||||
n.ObjectMeta.Labels = tc.labels
|
||||
n.Spec.PodCIDR = tc.subnet
|
||||
node := translateNode(n)
|
||||
if diff := pretty.Compare(node, tc.out); diff != "" {
|
||||
t.Errorf("test case %q: got diff: %v", tc.name, diff)
|
||||
}
|
||||
}
|
||||
}
|
101
pkg/mesh/graph.go
Normal file
101
pkg/mesh/graph.go
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/awalterschulze/gographviz"
|
||||
)
|
||||
|
||||
// Dot generates a Graphviz graph of the Topology in DOT fomat.
|
||||
func (t *Topology) Dot() (string, error) {
|
||||
g := gographviz.NewGraph()
|
||||
g.Name = "kilo"
|
||||
if err := g.AddAttr("kilo", string(gographviz.Label), graphEscape(t.subnet.String())); err != nil {
|
||||
return "", fmt.Errorf("failed to add label to graph")
|
||||
}
|
||||
if err := g.AddAttr("kilo", string(gographviz.LabelLOC), "t"); err != nil {
|
||||
return "", fmt.Errorf("failed to add label location to graph")
|
||||
}
|
||||
if err := g.AddAttr("kilo", string(gographviz.Overlap), "false"); err != nil {
|
||||
return "", fmt.Errorf("failed to disable graph overlap")
|
||||
}
|
||||
if err := g.SetDir(true); err != nil {
|
||||
return "", fmt.Errorf("failed to set direction")
|
||||
}
|
||||
leaders := make([]string, len(t.Segments))
|
||||
nodeAttrs := map[string]string{
|
||||
string(gographviz.Shape): "ellipse",
|
||||
}
|
||||
for i, s := range t.Segments {
|
||||
if err := g.AddSubGraph("kilo", subGraphName(s.Location), nil); err != nil {
|
||||
return "", fmt.Errorf("failed to add subgraph")
|
||||
}
|
||||
if err := g.AddAttr(subGraphName(s.Location), string(gographviz.Label), graphEscape(s.Location)); err != nil {
|
||||
return "", fmt.Errorf("failed to add label to subgraph")
|
||||
}
|
||||
if err := g.AddAttr(subGraphName(s.Location), string(gographviz.Style), `"dashed,rounded"`); err != nil {
|
||||
return "", fmt.Errorf("failed to add style to subgraph")
|
||||
}
|
||||
for j := range s.cidrs {
|
||||
if err := g.AddNode(subGraphName(s.Location), graphEscape(s.hostnames[j]), nodeAttrs); err != nil {
|
||||
return "", fmt.Errorf("failed to add node to subgraph")
|
||||
}
|
||||
var wg net.IP
|
||||
if j == s.leader {
|
||||
wg = s.wireGuardIP
|
||||
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Rank), "1"); err != nil {
|
||||
return "", fmt.Errorf("failed to add rank to node")
|
||||
}
|
||||
}
|
||||
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.Location, s.hostnames[j], s.cidrs[j], s.privateIPs[j], wg)); err != nil {
|
||||
return "", fmt.Errorf("failed to add label to node")
|
||||
}
|
||||
}
|
||||
meshSubGraph(g, g.Relations.SortedChildren(subGraphName(s.Location)), s.leader)
|
||||
leaders[i] = graphEscape(s.hostnames[s.leader])
|
||||
}
|
||||
meshSubGraph(g, leaders, 0)
|
||||
return g.String(), nil
|
||||
}
|
||||
|
||||
func meshSubGraph(g *gographviz.Graph, nodes []string, leader int) {
|
||||
for i := range nodes {
|
||||
if i == leader {
|
||||
continue
|
||||
}
|
||||
a := make(gographviz.Attrs)
|
||||
a[gographviz.Dir] = "both"
|
||||
g.Edges.Add(&gographviz.Edge{Src: nodes[leader], Dst: nodes[i], Dir: true, Attrs: a})
|
||||
}
|
||||
}
|
||||
|
||||
func graphEscape(s string) string {
|
||||
return fmt.Sprintf("\"%s\"", s)
|
||||
}
|
||||
|
||||
func subGraphName(name string) string {
|
||||
return graphEscape(fmt.Sprintf("cluster_%s", name))
|
||||
}
|
||||
|
||||
func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP) string {
|
||||
var wg string
|
||||
if wgIP != nil {
|
||||
wg = wgIP.String()
|
||||
}
|
||||
return graphEscape(fmt.Sprintf("%s\n%s\n%s\n%s\n%s", location, name, cidr.String(), priv.String(), wg))
|
||||
}
|
348
pkg/mesh/ip.go
Normal file
348
pkg/mesh/ip.go
Normal file
@@ -0,0 +1,348 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// getIP returns a private and public IP address for the local node.
|
||||
// It selects the private IP address in the following order:
|
||||
// - private IP to which hostname resolves
|
||||
// - private IP assigned to interface of default route
|
||||
// - private IP assigned to local interface
|
||||
// - public IP to which hostname resolves
|
||||
// - public IP assigned to interface of default route
|
||||
// - public IP assigned to local interface
|
||||
// It selects the public IP address in the following order:
|
||||
// - public IP to which hostname resolves
|
||||
// - public IP assigned to interface of default route
|
||||
// - public IP assigned to local interface
|
||||
// - private IP to which hostname resolves
|
||||
// - private IP assigned to interface of default route
|
||||
// - private IP assigned to local interface
|
||||
// - if no IP was found, return nil and an error.
|
||||
func getIP(hostname string) (*net.IPNet, *net.IPNet, error) {
|
||||
var hostPriv, hostPub []*net.IPNet
|
||||
{
|
||||
// Check IPs to which hostname resolves first.
|
||||
ips, err := ipsForHostname(hostname)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, ip := range ips {
|
||||
ok, mask, err := assignedToInterface(ip)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to search locally assigned addresses: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ip.Mask = mask
|
||||
if isPublic(ip) {
|
||||
hostPub = append(hostPub, ip)
|
||||
continue
|
||||
}
|
||||
hostPriv = append(hostPriv, ip)
|
||||
}
|
||||
sortIPs(hostPriv)
|
||||
sortIPs(hostPub)
|
||||
}
|
||||
|
||||
var defaultPriv, defaultPub []*net.IPNet
|
||||
{
|
||||
// Check IPs on interface for default route next.
|
||||
iface, err := defaultInterface()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
ips, err := ipsForInterface(iface)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if isLocal(ip.IP) {
|
||||
continue
|
||||
}
|
||||
if isPublic(ip) {
|
||||
defaultPub = append(defaultPub, ip)
|
||||
continue
|
||||
}
|
||||
defaultPriv = append(defaultPriv, ip)
|
||||
}
|
||||
sortIPs(defaultPriv)
|
||||
sortIPs(defaultPub)
|
||||
}
|
||||
|
||||
var interfacePriv, interfacePub []*net.IPNet
|
||||
{
|
||||
// Finally look for IPs on all interfaces.
|
||||
ips, err := ipsForAllInterfaces()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if isLocal(ip.IP) {
|
||||
continue
|
||||
}
|
||||
if isPublic(ip) {
|
||||
interfacePub = append(interfacePub, ip)
|
||||
continue
|
||||
}
|
||||
interfacePriv = append(interfacePriv, ip)
|
||||
}
|
||||
sortIPs(interfacePriv)
|
||||
sortIPs(interfacePub)
|
||||
}
|
||||
|
||||
var priv, pub []*net.IPNet
|
||||
priv = append(priv, hostPriv...)
|
||||
priv = append(priv, defaultPriv...)
|
||||
priv = append(priv, interfacePriv...)
|
||||
pub = append(pub, hostPub...)
|
||||
pub = append(pub, defaultPub...)
|
||||
pub = append(pub, interfacePub...)
|
||||
if len(priv) == 0 && len(pub) == 0 {
|
||||
return nil, nil, errors.New("no valid IP was found")
|
||||
}
|
||||
if len(priv) == 0 {
|
||||
priv = pub
|
||||
}
|
||||
if len(pub) == 0 {
|
||||
pub = priv
|
||||
}
|
||||
return priv[0], pub[0], nil
|
||||
}
|
||||
|
||||
// sortIPs sorts IPs so the result is stable.
|
||||
// It will first sort IPs by type, to prefer selecting
|
||||
// IPs of the same type, and then by value.
|
||||
func sortIPs(ips []*net.IPNet) {
|
||||
sort.Slice(ips, func(i, j int) bool {
|
||||
i4, j4 := ips[i].IP.To4(), ips[j].IP.To4()
|
||||
if i4 != nil && j4 == nil {
|
||||
return true
|
||||
}
|
||||
if j4 != nil && i4 == nil {
|
||||
return false
|
||||
}
|
||||
return ips[i].String() < ips[j].String()
|
||||
})
|
||||
}
|
||||
|
||||
func assignedToInterface(ip *net.IPNet) (bool, net.IPMask, error) {
|
||||
links, err := netlink.LinkList()
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("failed to list interfaces: %v", err)
|
||||
}
|
||||
// Sort the links for stability.
|
||||
sort.Slice(links, func(i, j int) bool {
|
||||
return links[i].Attrs().Name < links[j].Attrs().Name
|
||||
})
|
||||
for _, link := range links {
|
||||
addrs, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("failed to list addresses for %s: %v", link.Attrs().Name, err)
|
||||
}
|
||||
// Sort the IPs for stability.
|
||||
sort.Slice(addrs, func(i, j int) bool {
|
||||
return addrs[i].String() < addrs[j].String()
|
||||
})
|
||||
for i := range addrs {
|
||||
if ip.IP.Equal(addrs[i].IP) {
|
||||
return true, addrs[i].Mask, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
func isLocal(ip net.IP) bool {
|
||||
return ip.IsLoopback() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast()
|
||||
}
|
||||
|
||||
func isPublic(ip *net.IPNet) bool {
|
||||
// Check RFC 1918 addresses.
|
||||
if ip4 := ip.IP.To4(); ip4 != nil {
|
||||
switch true {
|
||||
// Check for 10.0.0.0/8.
|
||||
case ip4[0] == 10:
|
||||
return false
|
||||
// Check for 172.16.0.0/12.
|
||||
case ip4[0] == 172 && ip4[1]&0xf0 == 0x01:
|
||||
return false
|
||||
// Check for 192.168.0.0/16.
|
||||
case ip4[0] == 192 && ip4[1] == 168:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Check RFC 4193 addresses.
|
||||
if len(ip.IP) == net.IPv6len {
|
||||
switch true {
|
||||
// Check for fd00::/8.
|
||||
case ip.IP[0] == 0xfd && ip.IP[1] == 0x00:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ipsForHostname returns a slice of IPs to which the
|
||||
// given hostname resolves.
|
||||
func ipsForHostname(hostname string) ([]*net.IPNet, error) {
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
return []*net.IPNet{oneAddressCIDR(ip)}, nil
|
||||
}
|
||||
ips, err := net.LookupIP(hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to lookip IPs of hostname: %v", err)
|
||||
}
|
||||
nets := make([]*net.IPNet, len(ips))
|
||||
for i := range ips {
|
||||
nets[i] = oneAddressCIDR(ips[i])
|
||||
}
|
||||
return nets, nil
|
||||
}
|
||||
|
||||
// ipsForAllInterfaces returns a slice of IPs assigned to all the
|
||||
// interfaces on the host.
|
||||
func ipsForAllInterfaces() ([]*net.IPNet, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list interfaces: %v", err)
|
||||
}
|
||||
var nets []*net.IPNet
|
||||
for _, iface := range ifaces {
|
||||
ips, err := ipsForInterface(&iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list addresses for %s: %v", iface.Name, err)
|
||||
}
|
||||
nets = append(nets, ips...)
|
||||
}
|
||||
return nets, nil
|
||||
}
|
||||
|
||||
// ipsForInterface returns a slice of IPs assigned to the given interface.
|
||||
func ipsForInterface(iface *net.Interface) ([]*net.IPNet, error) {
|
||||
link, err := netlink.LinkByIndex(iface.Index)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get link: %s", err)
|
||||
}
|
||||
addrs, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list addresses for %s: %v", iface.Name, err)
|
||||
}
|
||||
var ips []*net.IPNet
|
||||
for _, a := range addrs {
|
||||
if a.IPNet != nil {
|
||||
ips = append(ips, a.IPNet)
|
||||
}
|
||||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
// interfacesForIP returns a slice of interfaces withthe given IP.
|
||||
func interfacesForIP(ip *net.IPNet) ([]net.Interface, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list interfaces: %v", err)
|
||||
}
|
||||
var interfaces []net.Interface
|
||||
for _, iface := range ifaces {
|
||||
ips, err := ipsForInterface(&iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list addresses for %s: %v", iface.Name, err)
|
||||
}
|
||||
for i := range ips {
|
||||
if ip.IP.Equal(ips[i].IP) {
|
||||
interfaces = append(interfaces, iface)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(interfaces) == 0 {
|
||||
return nil, fmt.Errorf("no interface has %s assigned", ip.String())
|
||||
}
|
||||
return interfaces, nil
|
||||
}
|
||||
|
||||
// defaultInterface returns the interface for the default route of the host.
|
||||
func defaultInterface() (*net.Interface, error) {
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Dst == nil || route.Dst.String() == "0.0.0.0/0" || route.Dst.String() == "::/0" {
|
||||
if route.LinkIndex <= 0 {
|
||||
return nil, errors.New("failed to determine interface of route")
|
||||
}
|
||||
return net.InterfaceByIndex(route.LinkIndex)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("failed to find default route")
|
||||
}
|
||||
|
||||
type allocator struct {
|
||||
bits int
|
||||
cidr *net.IPNet
|
||||
current net.IP
|
||||
}
|
||||
|
||||
func newAllocator(cidr net.IPNet) *allocator {
|
||||
_, bits := cidr.Mask.Size()
|
||||
current := make(net.IP, len(cidr.IP))
|
||||
copy(current, cidr.IP)
|
||||
if ip4 := current.To4(); ip4 != nil {
|
||||
current = ip4
|
||||
}
|
||||
|
||||
return &allocator{
|
||||
bits: bits,
|
||||
cidr: &cidr,
|
||||
current: current,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *allocator) next() *net.IPNet {
|
||||
if a.current == nil {
|
||||
return nil
|
||||
}
|
||||
for i := len(a.current) - 1; i >= 0; i-- {
|
||||
a.current[i]++
|
||||
// if we haven't overflowed, then we can exit.
|
||||
if a.current[i] != 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !a.cidr.Contains(a.current) {
|
||||
a.current = nil
|
||||
}
|
||||
ip := make(net.IP, len(a.current))
|
||||
copy(ip, a.current)
|
||||
|
||||
return &net.IPNet{IP: ip, Mask: net.CIDRMask(a.bits, a.bits)}
|
||||
}
|
75
pkg/mesh/ip_test.go
Normal file
75
pkg/mesh/ip_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSortIPs(t *testing.T) {
|
||||
ip1 := oneAddressCIDR(net.ParseIP("10.0.0.1"))
|
||||
ip2 := oneAddressCIDR(net.ParseIP("10.0.0.2"))
|
||||
ip3 := oneAddressCIDR(net.ParseIP("192.168.0.1"))
|
||||
ip4 := oneAddressCIDR(net.ParseIP("2001::7"))
|
||||
ip5 := oneAddressCIDR(net.ParseIP("fd68:da49:09da:b27f::"))
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
ips []*net.IPNet
|
||||
out []*net.IPNet
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
ips: []*net.IPNet{ip1},
|
||||
out: []*net.IPNet{ip1},
|
||||
},
|
||||
{
|
||||
name: "IPv4s",
|
||||
ips: []*net.IPNet{ip2, ip3, ip1},
|
||||
out: []*net.IPNet{ip1, ip2, ip3},
|
||||
},
|
||||
{
|
||||
name: "IPv4 and IPv6",
|
||||
ips: []*net.IPNet{ip4, ip1},
|
||||
out: []*net.IPNet{ip1, ip4},
|
||||
},
|
||||
{
|
||||
name: "IPv6s",
|
||||
ips: []*net.IPNet{ip5, ip4},
|
||||
out: []*net.IPNet{ip4, ip5},
|
||||
},
|
||||
{
|
||||
name: "all",
|
||||
ips: []*net.IPNet{ip3, ip4, ip2, ip5, ip1},
|
||||
out: []*net.IPNet{ip1, ip2, ip3, ip4, ip5},
|
||||
},
|
||||
} {
|
||||
sortIPs(tc.ips)
|
||||
equal := true
|
||||
if len(tc.ips) != len(tc.out) {
|
||||
equal = false
|
||||
} else {
|
||||
for i := range tc.ips {
|
||||
if !ipNetsEqual(tc.ips[i], tc.out[i]) {
|
||||
equal = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !equal {
|
||||
t.Errorf("test case %q: expected %s, got %s", tc.name, tc.out, tc.ips)
|
||||
}
|
||||
}
|
||||
}
|
581
pkg/mesh/mesh.go
Normal file
581
pkg/mesh/mesh.go
Normal file
@@ -0,0 +1,581 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/squat/kilo/pkg/iproute"
|
||||
"github.com/squat/kilo/pkg/ipset"
|
||||
"github.com/squat/kilo/pkg/iptables"
|
||||
"github.com/squat/kilo/pkg/route"
|
||||
"github.com/squat/kilo/pkg/wireguard"
|
||||
)
|
||||
|
||||
const resyncPeriod = 30 * time.Second
|
||||
|
||||
const (
|
||||
// KiloPath is the directory where Kilo stores its configuration.
|
||||
KiloPath = "/var/lib/kilo"
|
||||
// PrivateKeyPath is the filepath where the WireGuard private key is stored.
|
||||
PrivateKeyPath = KiloPath + "/key"
|
||||
// ConfPath is the filepath where the WireGuard configuration is stored.
|
||||
ConfPath = KiloPath + "/conf"
|
||||
)
|
||||
|
||||
// Granularity represents the abstraction level at which the network
|
||||
// should be meshed.
|
||||
type Granularity string
|
||||
|
||||
// Encapsulate identifies what packets within a location should
|
||||
// be encapsulated.
|
||||
type Encapsulate string
|
||||
|
||||
const (
|
||||
// DataCenterGranularity indicates that the network should create
|
||||
// a mesh between data-centers but not between nodes within a
|
||||
// single data-center.
|
||||
DataCenterGranularity Granularity = "data-center"
|
||||
// NodeGranularity indicates that the network should create
|
||||
// a mesh between every node.
|
||||
NodeGranularity Granularity = "node"
|
||||
// NeverEncapsulate indicates that no packets within a location
|
||||
// should be encapsulated.
|
||||
NeverEncapsulate Encapsulate = "never"
|
||||
// CrossSubnetEncapsulate indicates that only packets that
|
||||
// traverse subnets within a location should be encapsulated.
|
||||
CrossSubnetEncapsulate Encapsulate = "crosssubnet"
|
||||
// AlwaysEncapsulate indicates that all packets within a location
|
||||
// should be encapsulated.
|
||||
AlwaysEncapsulate Encapsulate = "always"
|
||||
)
|
||||
|
||||
// Node represents a node in the network.
|
||||
type Node struct {
|
||||
ExternalIP *net.IPNet
|
||||
Key []byte
|
||||
InternalIP *net.IPNet
|
||||
// Leader is a suggestion to Kilo that
|
||||
// the node wants to lead its segment.
|
||||
Leader bool
|
||||
Location string
|
||||
Name string
|
||||
Subnet *net.IPNet
|
||||
}
|
||||
|
||||
// Ready indicates whether or not the node is ready.
|
||||
func (n *Node) Ready() bool {
|
||||
return n != nil && n.ExternalIP != nil && n.Key != nil && n.InternalIP != nil && n.Subnet != nil
|
||||
}
|
||||
|
||||
// EventType describes what kind of an action an event represents.
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
// AddEvent represents an action where an item was added.
|
||||
AddEvent EventType = "add"
|
||||
// DeleteEvent represents an action where an item was removed.
|
||||
DeleteEvent EventType = "delete"
|
||||
// UpdateEvent represents an action where an item was updated.
|
||||
UpdateEvent EventType = "update"
|
||||
)
|
||||
|
||||
// Event represents an update event concerning a node in the cluster.
|
||||
type Event struct {
|
||||
Type EventType
|
||||
Node *Node
|
||||
}
|
||||
|
||||
// Backend can get nodes by name, init itself,
|
||||
// list the nodes that should be meshed,
|
||||
// set Kilo properties for a node,
|
||||
// clean up any changes applied to the backend,
|
||||
// and watch for changes to nodes.
|
||||
type Backend interface {
|
||||
CleanUp(string) error
|
||||
Get(string) (*Node, error)
|
||||
Init(<-chan struct{}) error
|
||||
List() ([]*Node, error)
|
||||
Set(string, *Node) error
|
||||
Watch() <-chan *Event
|
||||
}
|
||||
|
||||
// Mesh is able to create Kilo network meshes.
|
||||
type Mesh struct {
|
||||
Backend
|
||||
encapsulate Encapsulate
|
||||
externalIP *net.IPNet
|
||||
granularity Granularity
|
||||
hostname string
|
||||
internalIP *net.IPNet
|
||||
ipset *ipset.Set
|
||||
ipTables *iptables.Controller
|
||||
kiloIface int
|
||||
key []byte
|
||||
local bool
|
||||
port int
|
||||
priv []byte
|
||||
privIface int
|
||||
pub []byte
|
||||
pubIface int
|
||||
stop chan struct{}
|
||||
subnet *net.IPNet
|
||||
table *route.Table
|
||||
tunlIface int
|
||||
|
||||
// nodes is a mutable field in the struct
|
||||
// and needs to be guarded.
|
||||
nodes map[string]*Node
|
||||
mu sync.Mutex
|
||||
|
||||
errorCounter *prometheus.CounterVec
|
||||
nodesGuage prometheus.Gauge
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// New returns a new Mesh instance.
|
||||
func New(backend Backend, encapsulate Encapsulate, granularity Granularity, hostname string, port int, subnet *net.IPNet, local bool, logger log.Logger) (*Mesh, error) {
|
||||
if err := os.MkdirAll(KiloPath, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
|
||||
}
|
||||
private, err := ioutil.ReadFile(PrivateKeyPath)
|
||||
if err != nil {
|
||||
level.Warn(logger).Log("msg", "no private key found on disk; generating one now")
|
||||
if private, err = wireguard.GenKey(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
public, err := wireguard.PubKey(private)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ioutil.WriteFile(PrivateKeyPath, private, 0600); err != nil {
|
||||
return nil, fmt.Errorf("failed to write private key to disk: %v", err)
|
||||
}
|
||||
privateIP, publicIP, err := getIP(hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find public IP: %v", err)
|
||||
}
|
||||
ifaces, err := interfacesForIP(privateIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find interface for private IP: %v", err)
|
||||
}
|
||||
privIface := ifaces[0].Index
|
||||
ifaces, err = interfacesForIP(publicIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find interface for public IP: %v", err)
|
||||
}
|
||||
pubIface := ifaces[0].Index
|
||||
kiloIface, err := wireguard.New("kilo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create WireGuard interface: %v", err)
|
||||
}
|
||||
var tunlIface int
|
||||
if encapsulate != NeverEncapsulate {
|
||||
if tunlIface, err = iproute.NewIPIP(privIface); err != nil {
|
||||
return nil, fmt.Errorf("failed to create tunnel interface: %v", err)
|
||||
}
|
||||
if err := iproute.Set(tunlIface, true); err != nil {
|
||||
return nil, fmt.Errorf("failed to set tunnel interface up: %v", err)
|
||||
}
|
||||
}
|
||||
level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the private IP address", privateIP.String()))
|
||||
level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the public IP address", publicIP.String()))
|
||||
ipTables, err := iptables.New(len(subnet.IP))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to IP tables controller: %v", err)
|
||||
}
|
||||
return &Mesh{
|
||||
Backend: backend,
|
||||
encapsulate: encapsulate,
|
||||
externalIP: publicIP,
|
||||
granularity: granularity,
|
||||
hostname: hostname,
|
||||
internalIP: privateIP,
|
||||
// This is a patch until Calico supports
|
||||
// other hosts adding IPIP iptables rules.
|
||||
ipset: ipset.New("cali40all-hosts-net"),
|
||||
ipTables: ipTables,
|
||||
kiloIface: kiloIface,
|
||||
nodes: make(map[string]*Node),
|
||||
port: port,
|
||||
priv: private,
|
||||
privIface: privIface,
|
||||
pub: public,
|
||||
pubIface: pubIface,
|
||||
local: local,
|
||||
stop: make(chan struct{}),
|
||||
subnet: subnet,
|
||||
table: route.NewTable(),
|
||||
tunlIface: tunlIface,
|
||||
errorCounter: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "kilo_errors_total",
|
||||
Help: "Number of errors that occurred while administering the mesh.",
|
||||
}, []string{"event"}),
|
||||
nodesGuage: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "kilo_nodes",
|
||||
Help: "Number of in the mesh.",
|
||||
}),
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run starts the mesh.
|
||||
func (m *Mesh) Run() error {
|
||||
if err := m.Init(m.stop); err != nil {
|
||||
return fmt.Errorf("failed to initialize backend: %v", err)
|
||||
}
|
||||
ipsetErrors, err := m.ipset.Run(m.stop)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to watch for ipset updates: %v", err)
|
||||
}
|
||||
ipTablesErrors, err := m.ipTables.Run(m.stop)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to watch for IP tables updates: %v", err)
|
||||
}
|
||||
routeErrors, err := m.table.Run(m.stop)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to watch for route table updates: %v", err)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
var err error
|
||||
select {
|
||||
case err = <-ipsetErrors:
|
||||
case err = <-ipTablesErrors:
|
||||
case err = <-routeErrors:
|
||||
case <-m.stop:
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("run").Inc()
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer m.cleanUp()
|
||||
t := time.NewTimer(resyncPeriod)
|
||||
w := m.Watch()
|
||||
for {
|
||||
var e *Event
|
||||
select {
|
||||
case e = <-w:
|
||||
m.sync(e)
|
||||
case <-t.C:
|
||||
m.applyTopology()
|
||||
t.Reset(resyncPeriod)
|
||||
case <-m.stop:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mesh) sync(e *Event) {
|
||||
logger := log.With(m.logger, "event", e.Type)
|
||||
level.Debug(logger).Log("msg", "syncing", "event", e.Type)
|
||||
if isSelf(m.hostname, e.Node) {
|
||||
level.Debug(logger).Log("msg", "processing local node", "node", e.Node)
|
||||
m.handleLocal(e.Node)
|
||||
return
|
||||
}
|
||||
var diff bool
|
||||
m.mu.Lock()
|
||||
if !e.Node.Ready() {
|
||||
level.Debug(logger).Log("msg", "received incomplete node", "node", e.Node)
|
||||
// An existing node is no longer valid
|
||||
// so remove it from the mesh.
|
||||
if _, ok := m.nodes[e.Node.Name]; ok {
|
||||
level.Info(logger).Log("msg", "node is no longer in the mesh", "node", e.Node)
|
||||
delete(m.nodes, e.Node.Name)
|
||||
diff = true
|
||||
}
|
||||
} else {
|
||||
switch e.Type {
|
||||
case AddEvent:
|
||||
fallthrough
|
||||
case UpdateEvent:
|
||||
if !nodesAreEqual(m.nodes[e.Node.Name], e.Node) {
|
||||
m.nodes[e.Node.Name] = e.Node
|
||||
diff = true
|
||||
}
|
||||
case DeleteEvent:
|
||||
delete(m.nodes, e.Node.Name)
|
||||
diff = true
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if diff {
|
||||
level.Info(logger).Log("node", e.Node)
|
||||
m.applyTopology()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mesh) handleLocal(n *Node) {
|
||||
// Allow the external IP to be overridden.
|
||||
if n.ExternalIP == nil {
|
||||
n.ExternalIP = m.externalIP
|
||||
}
|
||||
// Compare the given node to the calculated local node.
|
||||
// Take leader, location, and subnet from the argument, as these
|
||||
// are not determined by kilo.
|
||||
local := &Node{ExternalIP: n.ExternalIP, Key: m.pub, InternalIP: m.internalIP, Leader: n.Leader, Location: n.Location, Name: m.hostname, Subnet: n.Subnet}
|
||||
if !nodesAreEqual(n, local) {
|
||||
level.Debug(m.logger).Log("msg", "local node differs from backend")
|
||||
if err := m.Set(m.hostname, local); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", local)
|
||||
m.errorCounter.WithLabelValues("local").Inc()
|
||||
return
|
||||
}
|
||||
level.Debug(m.logger).Log("msg", "successfully reconciled local node against backend")
|
||||
}
|
||||
m.mu.Lock()
|
||||
n = m.nodes[m.hostname]
|
||||
if n == nil {
|
||||
n = &Node{}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if !nodesAreEqual(n, local) {
|
||||
m.mu.Lock()
|
||||
m.nodes[local.Name] = local
|
||||
m.mu.Unlock()
|
||||
m.applyTopology()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mesh) applyTopology() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Ensure all unready nodes are removed.
|
||||
var ready float64
|
||||
for n := range m.nodes {
|
||||
if !m.nodes[n].Ready() {
|
||||
delete(m.nodes, n)
|
||||
continue
|
||||
}
|
||||
ready++
|
||||
}
|
||||
m.nodesGuage.Set(ready)
|
||||
// We cannot do anything with the topology until the local node is available.
|
||||
if m.nodes[m.hostname] == nil {
|
||||
return
|
||||
}
|
||||
t, err := NewTopology(m.nodes, m.granularity, m.hostname, m.port, m.priv, m.subnet)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
conf, err := t.Conf()
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
}
|
||||
if err := ioutil.WriteFile(ConfPath, conf, 0600); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
var private *net.IPNet
|
||||
// If we are not encapsulating packets to the local private network,
|
||||
// then pass the private IP to add an exception to the NAT rule.
|
||||
if m.encapsulate != AlwaysEncapsulate {
|
||||
private = t.privateIP
|
||||
}
|
||||
rules := iptables.MasqueradeRules(private, m.nodes[m.hostname].Subnet, t.RemoteSubnets())
|
||||
rules = append(rules, iptables.ForwardRules(m.subnet)...)
|
||||
if err := m.ipTables.Set(rules); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
if m.encapsulate != NeverEncapsulate {
|
||||
var peers []net.IP
|
||||
for _, s := range t.Segments {
|
||||
if s.Location == m.nodes[m.hostname].Location {
|
||||
peers = s.privateIPs
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := m.ipset.Set(peers); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
if m.local {
|
||||
if err := iproute.SetAddress(m.tunlIface, oneAddressCIDR(newAllocator(*m.nodes[m.hostname].Subnet).next().IP)); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.leader {
|
||||
if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
link, err := linkByIndex(m.kiloIface)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
oldConf, err := wireguard.ShowConf(link.Attrs().Name)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
// Setting the WireGuard configuration interrupts existing connections
|
||||
// so only set the configuration if it has changed.
|
||||
equal, err := wireguard.CompareConf(conf, oldConf)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
// Don't return here, simply overwrite the old configuration.
|
||||
equal = false
|
||||
}
|
||||
if !equal {
|
||||
if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := iproute.Set(m.kiloIface, true); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
} else {
|
||||
level.Debug(m.logger).Log("msg", "local node is not the leader")
|
||||
if err := iproute.Set(m.kiloIface, false); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
}
|
||||
// We need to add routes last since they may depend
|
||||
// on the WireGuard interface.
|
||||
routes := t.Routes(m.kiloIface, m.privIface, m.tunlIface, m.local, m.encapsulate)
|
||||
if err := m.table.Set(routes); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterMetrics registers Prometheus metrics on the given Prometheus
|
||||
// registerer.
|
||||
func (m *Mesh) RegisterMetrics(r prometheus.Registerer) {
|
||||
r.MustRegister(
|
||||
m.errorCounter,
|
||||
m.nodesGuage,
|
||||
)
|
||||
}
|
||||
|
||||
// Stop stops the mesh.
|
||||
func (m *Mesh) Stop() {
|
||||
close(m.stop)
|
||||
}
|
||||
|
||||
func (m *Mesh) cleanUp() {
|
||||
if err := m.ipTables.CleanUp(); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up IP tables: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := m.table.CleanUp(); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up routes: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := os.Remove(PrivateKeyPath); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to delete private key: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := os.Remove(ConfPath); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to delete configuration file: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := iproute.RemoveInterface(m.kiloIface); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove wireguard interface: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := m.CleanUp(m.hostname); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up backend: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := m.ipset.CleanUp(); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up ipset: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
}
|
||||
|
||||
func isSelf(hostname string, node *Node) bool {
|
||||
return node != nil && node.Name == hostname
|
||||
}
|
||||
|
||||
func nodesAreEqual(a, b *Node) bool {
|
||||
if !(a != nil) == (b != nil) {
|
||||
return false
|
||||
}
|
||||
if a == b {
|
||||
return true
|
||||
}
|
||||
return ipNetsEqual(a.ExternalIP, b.ExternalIP) && string(a.Key) == string(b.Key) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet)
|
||||
}
|
||||
|
||||
func ipNetsEqual(a, b *net.IPNet) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if (a != nil) != (b != nil) {
|
||||
return false
|
||||
}
|
||||
if a.Mask.String() != b.Mask.String() {
|
||||
return false
|
||||
}
|
||||
return a.IP.Equal(b.IP)
|
||||
}
|
||||
|
||||
func subnetsEqual(a, b *net.IPNet) bool {
|
||||
if a.Mask.String() != b.Mask.String() {
|
||||
return false
|
||||
}
|
||||
if !a.Contains(b.IP) {
|
||||
return false
|
||||
}
|
||||
if !b.Contains(a.IP) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func linkByIndex(index int) (netlink.Link, error) {
|
||||
link, err := netlink.LinkByIndex(index)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
return link, nil
|
||||
}
|
146
pkg/mesh/mesh_test.go
Normal file
146
pkg/mesh/mesh_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewAllocator(t *testing.T) {
|
||||
_, c1, err := net.ParseCIDR("10.1.0.0/16")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse CIDR: %v", err)
|
||||
}
|
||||
a1 := newAllocator(*c1)
|
||||
_, c2, err := net.ParseCIDR("10.1.0.0/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse CIDR: %v", err)
|
||||
}
|
||||
a2 := newAllocator(*c2)
|
||||
_, c3, err := net.ParseCIDR("10.1.0.0/31")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse CIDR: %v", err)
|
||||
}
|
||||
a3 := newAllocator(*c3)
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
a *allocator
|
||||
next string
|
||||
}{
|
||||
{
|
||||
name: "10.1.0.0/16 first",
|
||||
a: a1,
|
||||
next: "10.1.0.1/32",
|
||||
},
|
||||
{
|
||||
name: "10.1.0.0/16 second",
|
||||
a: a1,
|
||||
next: "10.1.0.2/32",
|
||||
},
|
||||
{
|
||||
name: "10.1.0.0/32",
|
||||
a: a2,
|
||||
next: "<nil>",
|
||||
},
|
||||
{
|
||||
name: "10.1.0.0/31 first",
|
||||
a: a3,
|
||||
next: "10.1.0.1/32",
|
||||
},
|
||||
{
|
||||
name: "10.1.0.0/31 second",
|
||||
a: a3,
|
||||
next: "<nil>",
|
||||
},
|
||||
} {
|
||||
next := tc.a.next()
|
||||
if next.String() != tc.next {
|
||||
t.Errorf("test case %q: expected %s, got %s", tc.name, tc.next, next.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReady(t *testing.T) {
|
||||
internalIP := oneAddressCIDR(net.ParseIP("1.1.1.1"))
|
||||
externalIP := oneAddressCIDR(net.ParseIP("2.2.2.2"))
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
node *Node
|
||||
ready bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
node: nil,
|
||||
ready: false,
|
||||
},
|
||||
{
|
||||
name: "empty fields",
|
||||
node: &Node{},
|
||||
ready: false,
|
||||
},
|
||||
{
|
||||
name: "empty external IP",
|
||||
node: &Node{
|
||||
InternalIP: internalIP,
|
||||
Key: []byte{},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
ready: false,
|
||||
},
|
||||
{
|
||||
name: "empty internal IP",
|
||||
node: &Node{
|
||||
ExternalIP: externalIP,
|
||||
Key: []byte{},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
ready: false,
|
||||
},
|
||||
{
|
||||
name: "empty key",
|
||||
node: &Node{
|
||||
ExternalIP: externalIP,
|
||||
InternalIP: internalIP,
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
ready: false,
|
||||
},
|
||||
{
|
||||
name: "empty subnet",
|
||||
node: &Node{
|
||||
ExternalIP: externalIP,
|
||||
InternalIP: internalIP,
|
||||
Key: []byte{},
|
||||
},
|
||||
ready: false,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
node: &Node{
|
||||
ExternalIP: externalIP,
|
||||
InternalIP: internalIP,
|
||||
Key: []byte{},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
ready: true,
|
||||
},
|
||||
} {
|
||||
ready := tc.node.Ready()
|
||||
if ready != tc.ready {
|
||||
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.ready, ready)
|
||||
}
|
||||
}
|
||||
}
|
334
pkg/mesh/topology.go
Normal file
334
pkg/mesh/topology.go
Normal file
@@ -0,0 +1,334 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
confTemplate = template.Must(template.New("").Parse(`[Interface]
|
||||
PrivateKey = {{.Key}}
|
||||
ListenPort = {{.Port}}
|
||||
{{range .Segments -}}
|
||||
{{if ne .Location $.Location}}
|
||||
[Peer]
|
||||
PublicKey = {{.Key}}
|
||||
Endpoint = {{.Endpoint}}:{{$.Port}}
|
||||
AllowedIPs = {{.AllowedIPs}}
|
||||
{{end}}
|
||||
{{- end -}}
|
||||
`))
|
||||
)
|
||||
|
||||
// Topology represents the logical structure of the overlay network.
|
||||
type Topology struct {
|
||||
// Some fields need to be exported so that the template can read them.
|
||||
Key string
|
||||
Port int
|
||||
// Location is the logical location of the local host.
|
||||
Location string
|
||||
Segments []*segment
|
||||
|
||||
// hostname is the hostname of the local host.
|
||||
hostname string
|
||||
// leader represents whether or not the local host
|
||||
// is the segment leader.
|
||||
leader bool
|
||||
// subnet is the entire subnet from which IPs
|
||||
// for the WireGuard interfaces will be allocated.
|
||||
subnet *net.IPNet
|
||||
// privateIP is the private IP address of the local node.
|
||||
privateIP *net.IPNet
|
||||
// wireGuardCIDR is the allocated CIDR of the WireGuard
|
||||
// interface of the local node. If the local node is not
|
||||
// the leader, then it is nil.
|
||||
wireGuardCIDR *net.IPNet
|
||||
}
|
||||
|
||||
type segment struct {
|
||||
// Some fields need to be exported so that the template can read them.
|
||||
AllowedIPs string
|
||||
Endpoint string
|
||||
Key string
|
||||
// Location is the logical location of this segment.
|
||||
Location string
|
||||
|
||||
// cidrs is a slice of subnets of all peers in the segment.
|
||||
cidrs []*net.IPNet
|
||||
// hostnames is a slice of the hostnames of the peers in the segment.
|
||||
hostnames []string
|
||||
// leader is the index of the leader of the segment.
|
||||
leader int
|
||||
// privateIPs is a slice of private IPs of all peers in the segment.
|
||||
privateIPs []net.IP
|
||||
// wireGuardIP is the allocated IP address of the WireGuard
|
||||
// interface on the leader of the segment.
|
||||
wireGuardIP net.IP
|
||||
}
|
||||
|
||||
// NewTopology creates a new Topology struct from a given set of nodes.
|
||||
func NewTopology(nodes map[string]*Node, granularity Granularity, hostname string, port int, key []byte, subnet *net.IPNet) (*Topology, error) {
|
||||
topoMap := make(map[string][]*Node)
|
||||
for _, node := range nodes {
|
||||
var location string
|
||||
switch granularity {
|
||||
case DataCenterGranularity:
|
||||
location = node.Location
|
||||
case NodeGranularity:
|
||||
location = node.Name
|
||||
}
|
||||
topoMap[location] = append(topoMap[location], node)
|
||||
}
|
||||
var localLocation string
|
||||
switch granularity {
|
||||
case DataCenterGranularity:
|
||||
localLocation = nodes[hostname].Location
|
||||
case NodeGranularity:
|
||||
localLocation = hostname
|
||||
}
|
||||
|
||||
t := Topology{Key: strings.TrimSpace(string(key)), Port: port, hostname: hostname, Location: localLocation, subnet: subnet, privateIP: nodes[hostname].InternalIP}
|
||||
for location := range topoMap {
|
||||
// Sort the location so the result is stable.
|
||||
sort.Slice(topoMap[location], func(i, j int) bool {
|
||||
return topoMap[location][i].Name < topoMap[location][j].Name
|
||||
})
|
||||
leader := findLeader(topoMap[location])
|
||||
if location == localLocation && topoMap[location][leader].Name == hostname {
|
||||
t.leader = true
|
||||
}
|
||||
var allowedIPs []string
|
||||
var cidrs []*net.IPNet
|
||||
var hostnames []string
|
||||
var privateIPs []net.IP
|
||||
for _, node := range topoMap[location] {
|
||||
// Allowed IPs should include:
|
||||
// - the node's allocated subnet
|
||||
// - the node's WireGuard IP
|
||||
// - the node's internal IP
|
||||
allowedIPs = append(allowedIPs, node.Subnet.String(), oneAddressCIDR(node.InternalIP.IP).String())
|
||||
cidrs = append(cidrs, node.Subnet)
|
||||
hostnames = append(hostnames, node.Name)
|
||||
privateIPs = append(privateIPs, node.InternalIP.IP)
|
||||
}
|
||||
t.Segments = append(t.Segments, &segment{
|
||||
AllowedIPs: strings.Join(allowedIPs, ", "),
|
||||
Endpoint: topoMap[location][leader].ExternalIP.IP.String(),
|
||||
Key: strings.TrimSpace(string(topoMap[location][leader].Key)),
|
||||
Location: location,
|
||||
cidrs: cidrs,
|
||||
hostnames: hostnames,
|
||||
leader: leader,
|
||||
privateIPs: privateIPs,
|
||||
})
|
||||
}
|
||||
// Sort the Topology so the result is stable.
|
||||
sort.Slice(t.Segments, func(i, j int) bool {
|
||||
return t.Segments[i].Location < t.Segments[j].Location
|
||||
})
|
||||
|
||||
// Allocate IPs to the segment leaders in a stable, coordination-free manner.
|
||||
a := newAllocator(*subnet)
|
||||
for _, segment := range t.Segments {
|
||||
ipNet := a.next()
|
||||
if ipNet == nil {
|
||||
return nil, errors.New("failed to allocate an IP address; ran out of IP addresses")
|
||||
}
|
||||
segment.wireGuardIP = ipNet.IP
|
||||
segment.AllowedIPs = fmt.Sprintf("%s, %s", segment.AllowedIPs, ipNet.String())
|
||||
if t.leader && segment.Location == t.Location {
|
||||
t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: t.subnet.Mask}
|
||||
}
|
||||
}
|
||||
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
// RemoteSubnets identifies the subnets of the hosts in segments different than the host's.
|
||||
func (t *Topology) RemoteSubnets() []*net.IPNet {
|
||||
var remote []*net.IPNet
|
||||
for _, s := range t.Segments {
|
||||
if s == nil || s.Location == t.Location {
|
||||
continue
|
||||
}
|
||||
remote = append(remote, s.cidrs...)
|
||||
}
|
||||
return remote
|
||||
}
|
||||
|
||||
// Routes generates a slice of routes for a given Topology.
|
||||
func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, encapsulate Encapsulate) []*netlink.Route {
|
||||
var routes []*netlink.Route
|
||||
if !t.leader {
|
||||
// Find the leader for this segment.
|
||||
var leader net.IP
|
||||
for _, segment := range t.Segments {
|
||||
if segment.Location == t.Location {
|
||||
leader = segment.privateIPs[segment.leader]
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, segment := range t.Segments {
|
||||
// First, add a route to the WireGuard IP of the segment.
|
||||
routes = append(routes, encapsulateRoute(&netlink.Route{
|
||||
Dst: oneAddressCIDR(segment.wireGuardIP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: leader,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
}, encapsulate, t.privateIP, tunlIface))
|
||||
// Add routes for the current segment if local is true.
|
||||
if segment.Location == t.Location {
|
||||
if local {
|
||||
for i := range segment.cidrs {
|
||||
// Don't add routes for the local node.
|
||||
if segment.privateIPs[i].Equal(t.privateIP.IP) {
|
||||
continue
|
||||
}
|
||||
routes = append(routes, encapsulateRoute(&netlink.Route{
|
||||
Dst: segment.cidrs[i],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: segment.privateIPs[i],
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
}, encapsulate, t.privateIP, tunlIface))
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
for i := range segment.cidrs {
|
||||
// Add routes to the Pod CIDRs of nodes in other segments.
|
||||
routes = append(routes, encapsulateRoute(&netlink.Route{
|
||||
Dst: segment.cidrs[i],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: leader,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
}, encapsulate, t.privateIP, tunlIface))
|
||||
// Add routes to the private IPs of nodes in other segments.
|
||||
// Number of CIDRs and private IPs always match so
|
||||
// we can reuse the loop.
|
||||
routes = append(routes, encapsulateRoute(&netlink.Route{
|
||||
Dst: oneAddressCIDR(segment.privateIPs[i]),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: leader,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
}, encapsulate, t.privateIP, tunlIface))
|
||||
}
|
||||
}
|
||||
return routes
|
||||
}
|
||||
for _, segment := range t.Segments {
|
||||
// Add routes for the current segment if local is true.
|
||||
if segment.Location == t.Location {
|
||||
if local {
|
||||
for i := range segment.cidrs {
|
||||
// Don't add routes for the local node.
|
||||
if segment.privateIPs[i].Equal(t.privateIP.IP) {
|
||||
continue
|
||||
}
|
||||
routes = append(routes, encapsulateRoute(&netlink.Route{
|
||||
Dst: segment.cidrs[i],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: segment.privateIPs[i],
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
}, encapsulate, t.privateIP, tunlIface))
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
for i := range segment.cidrs {
|
||||
// Add routes to the Pod CIDRs of nodes in other segments.
|
||||
routes = append(routes, &netlink.Route{
|
||||
Dst: segment.cidrs[i],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: segment.wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
})
|
||||
// Add routes to the private IPs of nodes in other segments.
|
||||
// Number of CIDRs and private IPs always match so
|
||||
// we can reuse the loop.
|
||||
routes = append(routes, &netlink.Route{
|
||||
Dst: oneAddressCIDR(segment.privateIPs[i]),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: segment.wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
})
|
||||
}
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
func encapsulateRoute(route *netlink.Route, encapsulate Encapsulate, subnet *net.IPNet, tunlIface int) *netlink.Route {
|
||||
if encapsulate == AlwaysEncapsulate || (encapsulate == CrossSubnetEncapsulate && !subnet.Contains(route.Gw)) {
|
||||
route.LinkIndex = tunlIface
|
||||
}
|
||||
return route
|
||||
}
|
||||
|
||||
// Conf generates a WireGuard configuration file for a given Topology.
|
||||
func (t *Topology) Conf() ([]byte, error) {
|
||||
conf := new(bytes.Buffer)
|
||||
if err := confTemplate.Execute(conf, t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conf.Bytes(), nil
|
||||
}
|
||||
|
||||
// oneAddressCIDR takes an IP address and returns a CIDR
|
||||
// that contains only that address.
|
||||
func oneAddressCIDR(ip net.IP) *net.IPNet {
|
||||
return &net.IPNet{IP: ip, Mask: net.CIDRMask(len(ip)*8, len(ip)*8)}
|
||||
}
|
||||
|
||||
// findLeader selects a leader for the nodes in a segment;
|
||||
// it will select the first node that says it should lead
|
||||
// or the first node in the segment if none have volunteered,
|
||||
// always preferring those with a public external IP address,
|
||||
func findLeader(nodes []*Node) int {
|
||||
var leaders, public []int
|
||||
for i := range nodes {
|
||||
if nodes[i].Leader {
|
||||
if isPublic(nodes[i].ExternalIP) {
|
||||
return i
|
||||
}
|
||||
leaders = append(leaders, i)
|
||||
}
|
||||
if isPublic(nodes[i].ExternalIP) {
|
||||
public = append(public, i)
|
||||
}
|
||||
}
|
||||
if len(leaders) != 0 {
|
||||
return leaders[0]
|
||||
}
|
||||
if len(public) != 0 {
|
||||
return public[0]
|
||||
}
|
||||
return 0
|
||||
}
|
982
pkg/mesh/topology_test.go
Normal file
982
pkg/mesh/topology_test.go
Normal file
@@ -0,0 +1,982 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func allowedIPs(ips ...string) string {
|
||||
return strings.Join(ips, ", ")
|
||||
}
|
||||
|
||||
func setup(t *testing.T) (map[string]*Node, []byte, int, *net.IPNet) {
|
||||
key := []byte("private")
|
||||
port := 51820
|
||||
_, kiloNet, err := net.ParseCIDR("10.4.0.0/16")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse Kilo subnet CIDR: %v", err)
|
||||
}
|
||||
ip, e1, err := net.ParseCIDR("10.1.0.1/16")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse external IP CIDR: %v", err)
|
||||
}
|
||||
e1.IP = ip
|
||||
ip, e2, err := net.ParseCIDR("10.1.0.2/16")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse external IP CIDR: %v", err)
|
||||
}
|
||||
e2.IP = ip
|
||||
ip, e3, err := net.ParseCIDR("10.1.0.3/16")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse external IP CIDR: %v", err)
|
||||
}
|
||||
e3.IP = ip
|
||||
ip, i1, err := net.ParseCIDR("192.168.0.1/24")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse internal IP CIDR: %v", err)
|
||||
}
|
||||
i1.IP = ip
|
||||
ip, i2, err := net.ParseCIDR("192.168.0.2/24")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse internal IP CIDR: %v", err)
|
||||
}
|
||||
i2.IP = ip
|
||||
nodes := map[string]*Node{
|
||||
"a": {
|
||||
Name: "a",
|
||||
ExternalIP: e1,
|
||||
InternalIP: i1,
|
||||
Location: "1",
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
|
||||
Key: []byte("key1"),
|
||||
},
|
||||
"b": {
|
||||
Name: "b",
|
||||
ExternalIP: e2,
|
||||
InternalIP: i1,
|
||||
Location: "2",
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)},
|
||||
Key: []byte("key2"),
|
||||
},
|
||||
"c": {
|
||||
Name: "c",
|
||||
ExternalIP: e3,
|
||||
InternalIP: i2,
|
||||
// Same location a node b.
|
||||
Location: "2",
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.3.0"), Mask: net.CIDRMask(24, 32)},
|
||||
Key: []byte("key3"),
|
||||
},
|
||||
}
|
||||
return nodes, key, port, kiloNet
|
||||
}
|
||||
|
||||
func TestNewTopology(t *testing.T) {
|
||||
nodes, key, port, kiloNet := setup(t)
|
||||
|
||||
w1 := net.ParseIP("10.4.0.1").To4()
|
||||
w2 := net.ParseIP("10.4.0.2").To4()
|
||||
w3 := net.ParseIP("10.4.0.3").To4()
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
granularity Granularity
|
||||
hostname string
|
||||
result *Topology
|
||||
}{
|
||||
{
|
||||
name: "datacenter from a",
|
||||
granularity: DataCenterGranularity,
|
||||
hostname: nodes["a"].Name,
|
||||
result: &Topology{
|
||||
hostname: nodes["a"].Name,
|
||||
leader: true,
|
||||
Location: nodes["a"].Location,
|
||||
subnet: kiloNet,
|
||||
privateIP: nodes["a"].InternalIP,
|
||||
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
|
||||
Segments: []*segment{
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
|
||||
Endpoint: nodes["a"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["a"].Key),
|
||||
Location: nodes["a"].Location,
|
||||
cidrs: []*net.IPNet{nodes["a"].Subnet},
|
||||
hostnames: []string{"a"},
|
||||
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.2/32"),
|
||||
Endpoint: nodes["b"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["b"].Key),
|
||||
Location: nodes["b"].Location,
|
||||
cidrs: []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
|
||||
hostnames: []string{"b", "c"},
|
||||
privateIPs: []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
|
||||
wireGuardIP: w2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "datacenter from b",
|
||||
granularity: DataCenterGranularity,
|
||||
hostname: nodes["b"].Name,
|
||||
result: &Topology{
|
||||
hostname: nodes["b"].Name,
|
||||
leader: true,
|
||||
Location: nodes["b"].Location,
|
||||
subnet: kiloNet,
|
||||
privateIP: nodes["b"].InternalIP,
|
||||
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
|
||||
Segments: []*segment{
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
|
||||
Endpoint: nodes["a"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["a"].Key),
|
||||
Location: nodes["a"].Location,
|
||||
cidrs: []*net.IPNet{nodes["a"].Subnet},
|
||||
hostnames: []string{"a"},
|
||||
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.2/32"),
|
||||
Endpoint: nodes["b"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["b"].Key),
|
||||
Location: nodes["b"].Location,
|
||||
cidrs: []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
|
||||
hostnames: []string{"b", "c"},
|
||||
privateIPs: []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
|
||||
wireGuardIP: w2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "datacenter from c",
|
||||
granularity: DataCenterGranularity,
|
||||
hostname: nodes["c"].Name,
|
||||
result: &Topology{
|
||||
hostname: nodes["c"].Name,
|
||||
leader: false,
|
||||
Location: nodes["b"].Location,
|
||||
subnet: kiloNet,
|
||||
privateIP: nodes["c"].InternalIP,
|
||||
wireGuardCIDR: nil,
|
||||
Segments: []*segment{
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
|
||||
Endpoint: nodes["a"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["a"].Key),
|
||||
Location: nodes["a"].Location,
|
||||
cidrs: []*net.IPNet{nodes["a"].Subnet},
|
||||
hostnames: []string{"a"},
|
||||
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.2/32"),
|
||||
Endpoint: nodes["b"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["b"].Key),
|
||||
Location: nodes["b"].Location,
|
||||
cidrs: []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
|
||||
hostnames: []string{"b", "c"},
|
||||
privateIPs: []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
|
||||
wireGuardIP: w2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "node from a",
|
||||
granularity: NodeGranularity,
|
||||
hostname: nodes["a"].Name,
|
||||
result: &Topology{
|
||||
hostname: nodes["a"].Name,
|
||||
leader: true,
|
||||
Location: nodes["a"].Name,
|
||||
subnet: kiloNet,
|
||||
privateIP: nodes["a"].InternalIP,
|
||||
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
|
||||
Segments: []*segment{
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
|
||||
Endpoint: nodes["a"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["a"].Key),
|
||||
Location: nodes["a"].Name,
|
||||
cidrs: []*net.IPNet{nodes["a"].Subnet},
|
||||
hostnames: []string{"a"},
|
||||
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", "10.4.0.2/32"),
|
||||
Endpoint: nodes["b"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["b"].Key),
|
||||
Location: nodes["b"].Name,
|
||||
cidrs: []*net.IPNet{nodes["b"].Subnet},
|
||||
hostnames: []string{"b"},
|
||||
privateIPs: []net.IP{nodes["b"].InternalIP.IP},
|
||||
wireGuardIP: w2,
|
||||
},
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.3/32"),
|
||||
Endpoint: nodes["c"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["c"].Key),
|
||||
Location: nodes["c"].Name,
|
||||
cidrs: []*net.IPNet{nodes["c"].Subnet},
|
||||
hostnames: []string{"c"},
|
||||
privateIPs: []net.IP{nodes["c"].InternalIP.IP},
|
||||
wireGuardIP: w3,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "node from b",
|
||||
granularity: NodeGranularity,
|
||||
hostname: nodes["b"].Name,
|
||||
result: &Topology{
|
||||
hostname: nodes["b"].Name,
|
||||
leader: true,
|
||||
Location: nodes["b"].Name,
|
||||
subnet: kiloNet,
|
||||
privateIP: nodes["b"].InternalIP,
|
||||
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
|
||||
Segments: []*segment{
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
|
||||
Endpoint: nodes["a"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["a"].Key),
|
||||
Location: nodes["a"].Name,
|
||||
cidrs: []*net.IPNet{nodes["a"].Subnet},
|
||||
hostnames: []string{"a"},
|
||||
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", "10.4.0.2/32"),
|
||||
Endpoint: nodes["b"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["b"].Key),
|
||||
Location: nodes["b"].Name,
|
||||
cidrs: []*net.IPNet{nodes["b"].Subnet},
|
||||
hostnames: []string{"b"},
|
||||
privateIPs: []net.IP{nodes["b"].InternalIP.IP},
|
||||
wireGuardIP: w2,
|
||||
},
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.3/32"),
|
||||
Endpoint: nodes["c"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["c"].Key),
|
||||
Location: nodes["c"].Name,
|
||||
cidrs: []*net.IPNet{nodes["c"].Subnet},
|
||||
hostnames: []string{"c"},
|
||||
privateIPs: []net.IP{nodes["c"].InternalIP.IP},
|
||||
wireGuardIP: w3,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "node from c",
|
||||
granularity: NodeGranularity,
|
||||
hostname: nodes["c"].Name,
|
||||
result: &Topology{
|
||||
hostname: nodes["c"].Name,
|
||||
leader: true,
|
||||
Location: nodes["c"].Name,
|
||||
subnet: kiloNet,
|
||||
privateIP: nodes["c"].InternalIP,
|
||||
wireGuardCIDR: &net.IPNet{IP: w3, Mask: net.CIDRMask(16, 32)},
|
||||
Segments: []*segment{
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["a"].Subnet.String(), "192.168.0.1/32", "10.4.0.1/32"),
|
||||
Endpoint: nodes["a"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["a"].Key),
|
||||
Location: nodes["a"].Name,
|
||||
cidrs: []*net.IPNet{nodes["a"].Subnet},
|
||||
hostnames: []string{"a"},
|
||||
privateIPs: []net.IP{nodes["a"].InternalIP.IP},
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["b"].Subnet.String(), "192.168.0.1/32", "10.4.0.2/32"),
|
||||
Endpoint: nodes["b"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["b"].Key),
|
||||
Location: nodes["b"].Name,
|
||||
cidrs: []*net.IPNet{nodes["b"].Subnet},
|
||||
hostnames: []string{"b"},
|
||||
privateIPs: []net.IP{nodes["b"].InternalIP.IP},
|
||||
wireGuardIP: w2,
|
||||
},
|
||||
{
|
||||
AllowedIPs: allowedIPs(nodes["c"].Subnet.String(), "192.168.0.2/32", "10.4.0.3/32"),
|
||||
Endpoint: nodes["c"].ExternalIP.IP.String(),
|
||||
Key: string(nodes["c"].Key),
|
||||
Location: nodes["c"].Name,
|
||||
cidrs: []*net.IPNet{nodes["c"].Subnet},
|
||||
hostnames: []string{"c"},
|
||||
privateIPs: []net.IP{nodes["c"].InternalIP.IP},
|
||||
wireGuardIP: w3,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
tc.result.Key = string(key)
|
||||
tc.result.Port = port
|
||||
topo, err := NewTopology(nodes, tc.granularity, tc.hostname, port, key, kiloNet)
|
||||
if err != nil {
|
||||
t.Errorf("test case %q: failed to generate Topology: %v", tc.name, err)
|
||||
}
|
||||
if diff := pretty.Compare(topo, tc.result); diff != "" {
|
||||
t.Errorf("test case %q: got diff: %v", tc.name, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mustTopo(t *testing.T, nodes map[string]*Node, granularity Granularity, hostname string, port int, key []byte, subnet *net.IPNet) *Topology {
|
||||
topo, err := NewTopology(nodes, granularity, hostname, port, key, subnet)
|
||||
if err != nil {
|
||||
t.Errorf("failed to generate Topology: %v", err)
|
||||
}
|
||||
return topo
|
||||
}
|
||||
|
||||
func TestRoutes(t *testing.T) {
|
||||
nodes, key, port, kiloNet := setup(t)
|
||||
kiloIface := 0
|
||||
privIface := 1
|
||||
pubIface := 2
|
||||
mustTopoForGranularityAndHost := func(granularity Granularity, hostname string) *Topology {
|
||||
return mustTopo(t, nodes, granularity, hostname, port, key, kiloNet)
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
local bool
|
||||
topology *Topology
|
||||
result []*netlink.Route
|
||||
}{
|
||||
{
|
||||
name: "datacenter from a",
|
||||
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].cidrs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].cidrs[1],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "datacenter from b",
|
||||
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name).Segments[0].cidrs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "datacenter from c",
|
||||
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: oneAddressCIDR(mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name).Segments[0].wireGuardIP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name).Segments[0].cidrs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name).Segments[1].wireGuardIP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "node from a",
|
||||
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[1].cidrs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[2].cidrs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[2].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[2].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "node from b",
|
||||
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[0].cidrs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[2].cidrs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[2].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[2].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "node from c",
|
||||
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[0].cidrs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[1].cidrs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "datacenter from a local",
|
||||
local: true,
|
||||
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: nodes["b"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["c"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "datacenter from b local",
|
||||
local: true,
|
||||
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: nodes["a"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["c"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["c"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "datacenter from c local",
|
||||
local: true,
|
||||
topology: mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: oneAddressCIDR(mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name).Segments[0].wireGuardIP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["a"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(mustTopoForGranularityAndHost(DataCenterGranularity, nodes["c"].Name).Segments[1].wireGuardIP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["b"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "node from a local",
|
||||
local: true,
|
||||
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: nodes["b"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["c"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[2].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["a"].Name).Segments[2].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "node from b local",
|
||||
local: true,
|
||||
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: nodes["a"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["c"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[2].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["c"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["b"].Name).Segments[2].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "node from c local",
|
||||
local: true,
|
||||
topology: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name),
|
||||
result: []*netlink.Route{
|
||||
{
|
||||
Dst: nodes["a"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["a"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[0].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["b"].Subnet,
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: oneAddressCIDR(nodes["b"].InternalIP.IP),
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(NodeGranularity, nodes["c"].Name).Segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
routes := tc.topology.Routes(kiloIface, privIface, pubIface, tc.local, NeverEncapsulate)
|
||||
if diff := pretty.Compare(routes, tc.result); diff != "" {
|
||||
t.Errorf("test case %q: got diff: %v", tc.name, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConf(t *testing.T) {
|
||||
nodes, key, port, kiloNet := setup(t)
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
topology *Topology
|
||||
result string
|
||||
}{
|
||||
{
|
||||
name: "datacenter from a",
|
||||
topology: mustTopo(t, nodes, DataCenterGranularity, nodes["a"].Name, port, key, kiloNet),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key2
|
||||
Endpoint = 10.1.0.2:51820
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "datacenter from b",
|
||||
topology: mustTopo(t, nodes, DataCenterGranularity, nodes["b"].Name, port, key, kiloNet),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key1
|
||||
Endpoint = 10.1.0.1:51820
|
||||
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "datacenter from c",
|
||||
topology: mustTopo(t, nodes, DataCenterGranularity, nodes["c"].Name, port, key, kiloNet),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key1
|
||||
Endpoint = 10.1.0.1:51820
|
||||
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "node from a",
|
||||
topology: mustTopo(t, nodes, NodeGranularity, nodes["a"].Name, port, key, kiloNet),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key2
|
||||
Endpoint = 10.1.0.2:51820
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key3
|
||||
Endpoint = 10.1.0.3:51820
|
||||
AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "node from b",
|
||||
topology: mustTopo(t, nodes, NodeGranularity, nodes["b"].Name, port, key, kiloNet),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key1
|
||||
Endpoint = 10.1.0.1:51820
|
||||
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key3
|
||||
Endpoint = 10.1.0.3:51820
|
||||
AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "node from c",
|
||||
topology: mustTopo(t, nodes, NodeGranularity, nodes["c"].Name, port, key, kiloNet),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key1
|
||||
Endpoint = 10.1.0.1:51820
|
||||
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key2
|
||||
Endpoint = 10.1.0.2:51820
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32
|
||||
`,
|
||||
},
|
||||
} {
|
||||
conf, err := tc.topology.Conf()
|
||||
if err != nil {
|
||||
t.Errorf("test case %q: failed to generate conf: %v", tc.name, err)
|
||||
}
|
||||
if string(conf) != tc.result {
|
||||
t.Errorf("test case %q: expected %s got %s", tc.name, tc.result, string(conf))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindLeader(t *testing.T) {
|
||||
ip, e1, err := net.ParseCIDR("10.0.0.1/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse external IP CIDR: %v", err)
|
||||
}
|
||||
e1.IP = ip
|
||||
ip, e2, err := net.ParseCIDR("8.8.8.8/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse external IP CIDR: %v", err)
|
||||
}
|
||||
e2.IP = ip
|
||||
|
||||
nodes := []*Node{
|
||||
{
|
||||
Name: "a",
|
||||
ExternalIP: e1,
|
||||
},
|
||||
{
|
||||
Name: "b",
|
||||
ExternalIP: e2,
|
||||
},
|
||||
{
|
||||
Name: "c",
|
||||
ExternalIP: e2,
|
||||
},
|
||||
{
|
||||
Name: "d",
|
||||
ExternalIP: e1,
|
||||
Leader: true,
|
||||
},
|
||||
{
|
||||
Name: "2",
|
||||
ExternalIP: e2,
|
||||
Leader: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
nodes []*Node
|
||||
out int
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
nodes: nil,
|
||||
out: 0,
|
||||
},
|
||||
{
|
||||
name: "one",
|
||||
nodes: []*Node{nodes[0]},
|
||||
out: 0,
|
||||
},
|
||||
{
|
||||
name: "non-leaders",
|
||||
nodes: []*Node{nodes[0], nodes[1], nodes[2]},
|
||||
out: 1,
|
||||
},
|
||||
{
|
||||
name: "leaders",
|
||||
nodes: []*Node{nodes[3], nodes[4]},
|
||||
out: 1,
|
||||
},
|
||||
{
|
||||
name: "public",
|
||||
nodes: []*Node{nodes[1], nodes[2], nodes[4]},
|
||||
out: 2,
|
||||
},
|
||||
{
|
||||
name: "private",
|
||||
nodes: []*Node{nodes[0], nodes[3]},
|
||||
out: 1,
|
||||
},
|
||||
{
|
||||
name: "all",
|
||||
nodes: nodes,
|
||||
out: 4,
|
||||
},
|
||||
} {
|
||||
l := findLeader(tc.nodes)
|
||||
if l != tc.out {
|
||||
t.Errorf("test case %q: expected %d got %d", tc.name, tc.out, l)
|
||||
}
|
||||
}
|
||||
}
|
173
pkg/route/route.go
Normal file
173
pkg/route/route.go
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Table represents a routing table.
|
||||
// Table can safely be used concurrently.
|
||||
type Table struct {
|
||||
errors chan error
|
||||
mu sync.Mutex
|
||||
routes map[string]*netlink.Route
|
||||
subscribed bool
|
||||
|
||||
// Make these functions fields to allow
|
||||
// for testing.
|
||||
add func(*netlink.Route) error
|
||||
del func(*netlink.Route) error
|
||||
}
|
||||
|
||||
// NewTable generates a new table.
|
||||
func NewTable() *Table {
|
||||
return &Table{
|
||||
errors: make(chan error),
|
||||
routes: make(map[string]*netlink.Route),
|
||||
add: netlink.RouteReplace,
|
||||
del: func(r *netlink.Route) error {
|
||||
name := routeToString(r)
|
||||
if name == "" {
|
||||
return errors.New("attempting to delete invalid route")
|
||||
}
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list routes before deletion: %v", err)
|
||||
}
|
||||
for _, route := range routes {
|
||||
if routeToString(&route) == name {
|
||||
return netlink.RouteDel(r)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Run watches for changes to routes in the table and reconciles
|
||||
// the table against the desired state.
|
||||
func (t *Table) Run(stop <-chan struct{}) (<-chan error, error) {
|
||||
t.mu.Lock()
|
||||
if t.subscribed {
|
||||
t.mu.Unlock()
|
||||
return t.errors, nil
|
||||
}
|
||||
// Ensure a given instance only subscribes once.
|
||||
t.subscribed = true
|
||||
t.mu.Unlock()
|
||||
events := make(chan netlink.RouteUpdate)
|
||||
if err := netlink.RouteSubscribe(events, stop); err != nil {
|
||||
return t.errors, fmt.Errorf("failed to subscribe to route events: %v", err)
|
||||
}
|
||||
go func() {
|
||||
defer close(t.errors)
|
||||
for {
|
||||
var e netlink.RouteUpdate
|
||||
select {
|
||||
case e = <-events:
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
switch e.Type {
|
||||
// Watch for deleted routes to reconcile this table's routes.
|
||||
case unix.RTM_DELROUTE:
|
||||
t.mu.Lock()
|
||||
for _, r := range t.routes {
|
||||
// If any deleted route's destination matches a destination
|
||||
// in the table, reset the corresponding route just in case.
|
||||
if r.Dst.IP.Equal(e.Route.Dst.IP) && r.Dst.Mask.String() == e.Route.Dst.Mask.String() {
|
||||
if err := t.add(r); err != nil {
|
||||
nonBlockingSend(t.errors, fmt.Errorf("failed add route: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
return t.errors, nil
|
||||
}
|
||||
|
||||
// CleanUp will clean up any routes created by the instance.
|
||||
func (t *Table) CleanUp() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
for k, route := range t.routes {
|
||||
if err := t.del(route); err != nil {
|
||||
return fmt.Errorf("failed to delete route: %v", err)
|
||||
}
|
||||
delete(t.routes, k)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set idempotently overwrites any routes previously defined
|
||||
// for the table with the given set of routes.
|
||||
func (t *Table) Set(routes []*netlink.Route) error {
|
||||
r := make(map[string]*netlink.Route)
|
||||
for _, route := range routes {
|
||||
if route == nil {
|
||||
continue
|
||||
}
|
||||
r[routeToString(route)] = route
|
||||
}
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
for k := range t.routes {
|
||||
if _, ok := r[k]; !ok {
|
||||
if err := t.del(t.routes[k]); err != nil {
|
||||
return fmt.Errorf("failed to delete route: %v", err)
|
||||
}
|
||||
delete(t.routes, k)
|
||||
}
|
||||
}
|
||||
for k := range r {
|
||||
if _, ok := t.routes[k]; !ok {
|
||||
if err := t.add(r[k]); err != nil {
|
||||
return fmt.Errorf("failed to add route %q: %v", routeToString(r[k]), err)
|
||||
}
|
||||
t.routes[k] = r[k]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nonBlockingSend(errors chan<- error, err error) {
|
||||
select {
|
||||
case errors <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func routeToString(route *netlink.Route) string {
|
||||
if route == nil || route.Dst == nil {
|
||||
return ""
|
||||
}
|
||||
src := "-"
|
||||
if route.Src != nil {
|
||||
src = route.Src.String()
|
||||
}
|
||||
gw := "-"
|
||||
if route.Gw != nil {
|
||||
gw = route.Gw.String()
|
||||
}
|
||||
return fmt.Sprintf("dst: %s, via: %s, src: %s, dev: %d", route.Dst.String(), gw, src, route.LinkIndex)
|
||||
}
|
262
pkg/route/route_test.go
Normal file
262
pkg/route/route_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
_, c1, err := net.ParseCIDR("10.2.0.0/24")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse CIDR: %v", err)
|
||||
}
|
||||
_, c2, err := net.ParseCIDR("10.1.0.0/24")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse CIDR: %v", err)
|
||||
}
|
||||
add := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
|
||||
return func(r *netlink.Route) error {
|
||||
backend[routeToString(r)] = r
|
||||
return nil
|
||||
}
|
||||
}
|
||||
del := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
|
||||
return func(r *netlink.Route) error {
|
||||
delete(backend, routeToString(r))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
adderr := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
|
||||
return func(r *netlink.Route) error {
|
||||
return errors.New(routeToString(r))
|
||||
}
|
||||
}
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
routes []*netlink.Route
|
||||
err bool
|
||||
add func(map[string]*netlink.Route) func(*netlink.Route) error
|
||||
del func(map[string]*netlink.Route) func(*netlink.Route) error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
routes: nil,
|
||||
err: false,
|
||||
add: add,
|
||||
del: del,
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
routes: []*netlink.Route{
|
||||
{
|
||||
Dst: c1,
|
||||
Gw: net.ParseIP("10.1.0.1"),
|
||||
},
|
||||
},
|
||||
err: false,
|
||||
add: add,
|
||||
del: del,
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
routes: []*netlink.Route{
|
||||
{
|
||||
Dst: c1,
|
||||
Gw: net.ParseIP("10.1.0.1"),
|
||||
},
|
||||
{
|
||||
Dst: c2,
|
||||
Gw: net.ParseIP("127.0.0.1"),
|
||||
},
|
||||
},
|
||||
err: false,
|
||||
add: add,
|
||||
del: del,
|
||||
},
|
||||
{
|
||||
name: "err empty",
|
||||
routes: nil,
|
||||
err: false,
|
||||
add: adderr,
|
||||
del: del,
|
||||
},
|
||||
{
|
||||
name: "err",
|
||||
routes: []*netlink.Route{
|
||||
{
|
||||
Dst: c1,
|
||||
Gw: net.ParseIP("10.1.0.1"),
|
||||
},
|
||||
{
|
||||
Dst: c2,
|
||||
Gw: net.ParseIP("127.0.0.1"),
|
||||
},
|
||||
},
|
||||
err: true,
|
||||
add: adderr,
|
||||
del: del,
|
||||
},
|
||||
} {
|
||||
backend := make(map[string]*netlink.Route)
|
||||
a := tc.add(backend)
|
||||
d := tc.del(backend)
|
||||
table := NewTable()
|
||||
table.add = a
|
||||
table.del = d
|
||||
if err := table.Set(tc.routes); (err != nil) != tc.err {
|
||||
no := "no"
|
||||
if tc.err {
|
||||
no = "an"
|
||||
}
|
||||
t.Errorf("test case %q: got unexpected result: expected %s error, got %v", tc.name, no, err)
|
||||
}
|
||||
// If no error was expected, then compare the backend to the input.
|
||||
if !tc.err {
|
||||
for _, r := range tc.routes {
|
||||
r1 := backend[routeToString(r)]
|
||||
r2 := table.routes[routeToString(r)]
|
||||
if r != r1 || r != r2 {
|
||||
t.Errorf("test case %q: expected all routes to be equal: expected %v, got %v and %v", tc.name, r, r1, r2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanUp(t *testing.T) {
|
||||
_, c1, err := net.ParseCIDR("10.2.0.0/24")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse CIDR: %v", err)
|
||||
}
|
||||
_, c2, err := net.ParseCIDR("10.1.0.0/24")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse CIDR: %v", err)
|
||||
}
|
||||
add := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
|
||||
return func(r *netlink.Route) error {
|
||||
backend[routeToString(r)] = r
|
||||
return nil
|
||||
}
|
||||
}
|
||||
del := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
|
||||
return func(r *netlink.Route) error {
|
||||
delete(backend, routeToString(r))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
delerr := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
|
||||
return func(r *netlink.Route) error {
|
||||
return errors.New(routeToString(r))
|
||||
}
|
||||
}
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
routes []*netlink.Route
|
||||
err bool
|
||||
add func(map[string]*netlink.Route) func(*netlink.Route) error
|
||||
del func(map[string]*netlink.Route) func(*netlink.Route) error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
routes: nil,
|
||||
err: false,
|
||||
add: add,
|
||||
del: del,
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
routes: []*netlink.Route{
|
||||
{
|
||||
Dst: c1,
|
||||
Gw: net.ParseIP("10.1.0.1"),
|
||||
},
|
||||
},
|
||||
err: false,
|
||||
add: add,
|
||||
del: del,
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
routes: []*netlink.Route{
|
||||
{
|
||||
Dst: c1,
|
||||
Gw: net.ParseIP("10.1.0.1"),
|
||||
},
|
||||
{
|
||||
Dst: c2,
|
||||
Gw: net.ParseIP("127.0.0.1"),
|
||||
},
|
||||
},
|
||||
err: false,
|
||||
add: add,
|
||||
del: del,
|
||||
},
|
||||
{
|
||||
name: "err empty",
|
||||
routes: nil,
|
||||
err: false,
|
||||
add: add,
|
||||
del: delerr,
|
||||
},
|
||||
{
|
||||
name: "err",
|
||||
routes: []*netlink.Route{
|
||||
{
|
||||
Dst: c1,
|
||||
Gw: net.ParseIP("10.1.0.1"),
|
||||
},
|
||||
{
|
||||
Dst: c2,
|
||||
Gw: net.ParseIP("127.0.0.1"),
|
||||
},
|
||||
},
|
||||
err: true,
|
||||
add: add,
|
||||
del: delerr,
|
||||
},
|
||||
} {
|
||||
backend := make(map[string]*netlink.Route)
|
||||
a := tc.add(backend)
|
||||
d := tc.del(backend)
|
||||
table := NewTable()
|
||||
table.add = a
|
||||
table.del = d
|
||||
if err := table.Set(tc.routes); err != nil {
|
||||
t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
|
||||
}
|
||||
if err := table.CleanUp(); (err != nil) != tc.err {
|
||||
no := "no"
|
||||
if tc.err {
|
||||
no = "an"
|
||||
}
|
||||
t.Errorf("test case %q: got unexpected result: expected %s error, got %v", tc.name, no, err)
|
||||
}
|
||||
// If no error was expected, then compare the backend to the input.
|
||||
if !tc.err {
|
||||
for _, r := range tc.routes {
|
||||
r1 := backend[routeToString(r)]
|
||||
r2 := table.routes[routeToString(r)]
|
||||
if r1 != nil || r2 != nil {
|
||||
t.Errorf("test case %q: expected all routes to be nil: expected got %v and %v", tc.name, r1, r2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
18
pkg/version/version.go
Normal file
18
pkg/version/version.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package version
|
||||
|
||||
// Version is the version of Kilo.
|
||||
var Version = "was not built properly"
|
183
pkg/wireguard/wireguard.go
Normal file
183
pkg/wireguard/wireguard.go
Normal file
@@ -0,0 +1,183 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
type wgLink struct {
|
||||
a netlink.LinkAttrs
|
||||
t string
|
||||
}
|
||||
|
||||
func (w wgLink) Attrs() *netlink.LinkAttrs {
|
||||
return &w.a
|
||||
}
|
||||
|
||||
func (w wgLink) Type() string {
|
||||
return w.t
|
||||
}
|
||||
|
||||
// New creates a new WireGuard interface.
|
||||
func New(prefix string) (int, error) {
|
||||
links, err := netlink.LinkList()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to list links: %v", err)
|
||||
}
|
||||
max := 0
|
||||
re := regexp.MustCompile(fmt.Sprintf("^%s([0-9]+)$", prefix))
|
||||
for _, link := range links {
|
||||
if matches := re.FindStringSubmatch(link.Attrs().Name); len(matches) == 2 {
|
||||
i, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
// This should never happen.
|
||||
return 0, fmt.Errorf("failed to parse digits as an integer: %v", err)
|
||||
}
|
||||
if i >= max {
|
||||
max = i + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
name := fmt.Sprintf("%s%d", prefix, max)
|
||||
wl := wgLink{a: netlink.NewLinkAttrs(), t: "wireguard"}
|
||||
wl.a.Name = name
|
||||
if err := netlink.LinkAdd(wl); err != nil {
|
||||
return 0, fmt.Errorf("failed to create interface %s: %v", name, err)
|
||||
}
|
||||
link, err := netlink.LinkByName(name)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get interface index: %v", err)
|
||||
}
|
||||
return link.Attrs().Index, nil
|
||||
}
|
||||
|
||||
// Keys generates a WireGuard private and public key-pair.
|
||||
func Keys() ([]byte, []byte, error) {
|
||||
private, err := GenKey()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
public, err := PubKey(private)
|
||||
return private, public, err
|
||||
}
|
||||
|
||||
// GenKey generates a WireGuard private key.
|
||||
func GenKey() ([]byte, error) {
|
||||
return exec.Command("wg", "genkey").Output()
|
||||
}
|
||||
|
||||
// PubKey generates a WireGuard public key for a given private key.
|
||||
func PubKey(key []byte) ([]byte, error) {
|
||||
cmd := exec.Command("wg", "pubkey")
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open pipe to stdin: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stdin.Close()
|
||||
stdin.Write(key)
|
||||
}()
|
||||
|
||||
public, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate public key: %v", err)
|
||||
}
|
||||
return public, nil
|
||||
}
|
||||
|
||||
// SetConf applies a WireGuard configuration file to the given interface.
|
||||
func SetConf(iface string, path string) error {
|
||||
cmd := exec.Command("wg", "setconf", iface, path)
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to apply the WireGuard configuration: %s", stderr.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ShowConf gets the WireGuard configuration for the given interface.
|
||||
func ShowConf(iface string) ([]byte, error) {
|
||||
cmd := exec.Command("wg", "showconf", iface)
|
||||
var stderr, stdout bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
cmd.Stdout = &stdout
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read the WireGuard configuration: %s", stderr.String())
|
||||
}
|
||||
return stdout.Bytes(), nil
|
||||
}
|
||||
|
||||
// CompareConf compares two WireGuard configurations.
|
||||
// It returns true if they are equal, false if they are not,
|
||||
// and any error that was encountered.
|
||||
// Note: CompareConf only goes one level deep, as WireGuard
|
||||
// configurations are not nested further than that.
|
||||
func CompareConf(a, b []byte) (bool, error) {
|
||||
iniA, err := ini.Load(a)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to parse configuration: %v", err)
|
||||
}
|
||||
iniB, err := ini.Load(b)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to parse configuration: %v", err)
|
||||
}
|
||||
secsA, secsB := iniA.SectionStrings(), iniB.SectionStrings()
|
||||
if len(secsA) != len(secsB) {
|
||||
return false, nil
|
||||
}
|
||||
sort.Strings(secsA)
|
||||
sort.Strings(secsB)
|
||||
var keysA, keysB []string
|
||||
var valsA, valsB []string
|
||||
for i := range secsA {
|
||||
if secsA[i] != secsB[i] {
|
||||
return false, nil
|
||||
}
|
||||
keysA, keysB = iniA.Section(secsA[i]).KeyStrings(), iniB.Section(secsB[i]).KeyStrings()
|
||||
if len(keysA) != len(keysB) {
|
||||
return false, nil
|
||||
}
|
||||
sort.Strings(keysA)
|
||||
sort.Strings(keysB)
|
||||
for j := range keysA {
|
||||
if keysA[j] != keysB[j] {
|
||||
return false, nil
|
||||
}
|
||||
valsA, valsB = iniA.Section(secsA[i]).Key(keysA[j]).Strings(","), iniB.Section(secsB[i]).Key(keysB[j]).Strings(",")
|
||||
if len(valsA) != len(valsB) {
|
||||
return false, nil
|
||||
}
|
||||
sort.Strings(valsA)
|
||||
sort.Strings(valsB)
|
||||
for k := range valsA {
|
||||
if valsA[k] != valsB[k] {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
143
pkg/wireguard/wireguard_test.go
Normal file
143
pkg/wireguard/wireguard_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompareConf(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
a []byte
|
||||
b []byte
|
||||
out bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
a: []byte{},
|
||||
b: []byte{},
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "key and value order",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(`[Interface]
|
||||
ListenPort = 51820
|
||||
PrivateKey = private
|
||||
|
||||
[Peer]
|
||||
PublicKey = key
|
||||
AllowedIPs = 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32, 10.2.2.0/24
|
||||
Endpoint = 10.1.0.2:51820
|
||||
`),
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "whitespace",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(`[Interface]
|
||||
PrivateKey=private
|
||||
ListenPort=51820
|
||||
[Peer]
|
||||
Endpoint=10.1.0.2:51820
|
||||
PublicKey=key
|
||||
AllowedIPs=10.2.2.0/24,192.168.0.1/32,10.2.3.0/24,192.168.0.2/32,10.4.0.2/32
|
||||
`),
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "section order",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(`[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
|
||||
[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
`),
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "one empty",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(``),
|
||||
out: false,
|
||||
},
|
||||
} {
|
||||
equal, err := CompareConf(tc.a, tc.b)
|
||||
if err != nil {
|
||||
t.Errorf("test case %q: got unexpected error: %v", tc.name, err)
|
||||
}
|
||||
if equal != tc.out {
|
||||
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user