migrate to golang.zx2c4.com/wireguard/wgctrl

This commit introduces the usage of wgctrl.
It avoids the usage of exec calls of the wg command
and parsing the output of `wg show`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
leonnicolas
2021-09-20 15:47:47 +02:00
parent 797133f272
commit b1927990c2
21 changed files with 1266 additions and 1335 deletions

View File

@@ -18,6 +18,8 @@ import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/wireguard"
)
@@ -54,8 +56,9 @@ const (
// Node represents a node in the network.
type Node struct {
Endpoint *wireguard.Endpoint
Key []byte
KiloEndpoint *wireguard.Endpoint
Endpoint *net.UDPAddr
Key wgtypes.Key
NoInternalIP bool
InternalIP *net.IPNet
// LastSeen is a Unix time for the last time
@@ -66,18 +69,25 @@ type Node struct {
Leader bool
Location string
Name string
PersistentKeepalive int
PersistentKeepalive time.Duration
Subnet *net.IPNet
WireGuardIP *net.IPNet
DiscoveredEndpoints map[string]*wireguard.Endpoint
AllowedLocationIPs []*net.IPNet
// DiscoveredEndpoints cannot be DNS endpoints, only net.UDPAddr.
DiscoveredEndpoints map[string]*net.UDPAddr
AllowedLocationIPs []net.IPNet
Granularity Granularity
}
// Ready indicates whether or not the node is ready.
func (n *Node) Ready() bool {
// Nodes that are not leaders will not have WireGuardIPs, so it is not required.
return n != nil && n.Endpoint != nil && !(n.Endpoint.IP == nil && n.Endpoint.DNS == "") && n.Endpoint.Port != 0 && n.Key != nil && n.Subnet != nil && time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second)
return n != nil &&
n.KiloEndpoint != nil &&
!(n.KiloEndpoint.IP == nil && n.KiloEndpoint.DNS == "") &&
n.KiloEndpoint.Port != 0 &&
n.Key != wgtypes.Key{} &&
n.Subnet != nil &&
time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second)
}
// Peer represents a peer in the network.
@@ -92,7 +102,10 @@ type Peer struct {
// will not declare their endpoint and instead allow it to be
// discovered.
func (p *Peer) Ready() bool {
return p != nil && p.AllowedIPs != nil && len(p.AllowedIPs) != 0 && p.PublicKey != nil
return p != nil &&
p.AllowedIPs != nil &&
len(p.AllowedIPs) != 0 &&
p.PublicKey != wgtypes.Key{} // If Key was not set, it will be wgtypes.Key{}.
}
// EventType describes what kind of an action an event represents.

View File

@@ -65,7 +65,7 @@ func (t *Topology) Dot() (string, error) {
var endpoint *wireguard.Endpoint
if j == s.leader {
wg = s.wireGuardIP
endpoint = s.endpoint
endpoint = s.kiloEndpoint
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Rank), "1"); err != nil {
return "", fmt.Errorf("failed to add rank to node")
}

View File

@@ -1,4 +1,4 @@
// Copyright 2019 the Kilo authors
// Copyright 2021 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -30,6 +30,8 @@ import (
"github.com/go-kit/kit/log/level"
"github.com/prometheus/client_golang/prometheus"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/encapsulation"
"github.com/squat/kilo/pkg/iproute"
@@ -43,8 +45,6 @@ const (
kiloPath = "/var/lib/kilo"
// privateKeyPath is the filepath where the WireGuard private key is stored.
privateKeyPath = kiloPath + "/key"
// confPath is the filepath where the WireGuard configuration is stored.
confPath = kiloPath + "/conf"
)
// Mesh is able to create Kilo network meshes.
@@ -60,12 +60,13 @@ type Mesh struct {
internalIP *net.IPNet
ipTables *iptables.Controller
kiloIface int
kiloIfaceName string
key []byte
local bool
port uint32
priv []byte
port int
priv wgtypes.Key
privIface int
pub []byte
pub wgtypes.Key
resyncPeriod time.Duration
iptablesForwardRule bool
stop chan struct{}
@@ -88,23 +89,24 @@ type Mesh struct {
}
// New returns a new Mesh instance.
func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port uint32, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanUpIface bool, createIface bool, mtu uint, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, logger log.Logger) (*Mesh, error) {
func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port int, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanUpIface bool, createIface bool, mtu uint, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, logger log.Logger) (*Mesh, error) {
if err := os.MkdirAll(kiloPath, 0700); err != nil {
return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
}
private, err := ioutil.ReadFile(privateKeyPath)
private = bytes.Trim(private, "\n")
privateB, err := ioutil.ReadFile(privateKeyPath)
privateB = bytes.Trim(privateB, "\n")
private, err := wgtypes.ParseKey(string(privateB))
if err != nil {
level.Warn(logger).Log("msg", "no private key found on disk; generating one now")
if private, err = wireguard.GenKey(); err != nil {
if private, err = wgtypes.GenerateKey(); err != nil {
return nil, err
}
}
public, err := wireguard.PubKey(private)
public := private.PublicKey()
if err != nil {
return nil, err
}
if err := ioutil.WriteFile(privateKeyPath, private, 0600); err != nil {
if err := ioutil.WriteFile(privateKeyPath, []byte(private.String()), 0600); err != nil {
return nil, fmt.Errorf("failed to write private key to disk: %v", err)
}
cniIndex, err := cniDeviceIndex()
@@ -168,6 +170,7 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit
internalIP: privateIP,
ipTables: ipTables,
kiloIface: kiloIface,
kiloIfaceName: iface,
nodes: make(map[string]*Node),
peers: make(map[string]*Peer),
port: port,
@@ -314,7 +317,7 @@ func (m *Mesh) syncPeers(e *PeerEvent) {
var diff bool
m.mu.Lock()
// Peers are indexed by public key.
key := string(e.Peer.PublicKey)
key := e.Peer.PublicKey.String()
if !e.Peer.Ready() {
// Trace non ready peer with their presence in the mesh.
_, ok := m.peers[key]
@@ -324,8 +327,8 @@ func (m *Mesh) syncPeers(e *PeerEvent) {
case AddEvent:
fallthrough
case UpdateEvent:
if e.Old != nil && key != string(e.Old.PublicKey) {
delete(m.peers, string(e.Old.PublicKey))
if e.Old != nil && key != e.Old.PublicKey.String() {
delete(m.peers, e.Old.PublicKey.String())
diff = true
}
if !peersAreEqual(m.peers[key], e.Peer) {
@@ -367,8 +370,8 @@ func (m *Mesh) checkIn() {
func (m *Mesh) handleLocal(n *Node) {
// Allow the IPs to be overridden.
if n.Endpoint == nil || (n.Endpoint.DNS == "" && n.Endpoint.IP == nil) {
n.Endpoint = &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: m.externalIP.IP}, Port: m.port}
if n.KiloEndpoint == nil || (n.KiloEndpoint.DNS == "" && n.KiloEndpoint.IP == nil) {
n.KiloEndpoint = &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: m.externalIP.IP}, Port: m.port}
}
if n.InternalIP == nil && !n.NoInternalIP {
n.InternalIP = m.internalIP
@@ -377,7 +380,7 @@ func (m *Mesh) handleLocal(n *Node) {
// Take leader, location, and subnet from the argument, as these
// are not determined by kilo.
local := &Node{
Endpoint: n.Endpoint,
KiloEndpoint: n.KiloEndpoint,
Key: m.pub,
NoInternalIP: n.NoInternalIP,
InternalIP: n.InternalIP,
@@ -462,22 +465,26 @@ func (m *Mesh) applyTopology() {
m.errorCounter.WithLabelValues("apply").Inc()
return
}
// Find the old configuration.
oldConfDump, err := wireguard.ShowDump(link.Attrs().Name)
wgClient, err := wgctrl.New()
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
oldConf, err := wireguard.ParseDump(oldConfDump)
defer wgClient.Close()
// wgDevice is the current configuration of the wg interface.
wgDevice, err := wgClient.Device(m.kiloIfaceName)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger)
nodes[m.hostname].DiscoveredEndpoints = natEndpoints
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].KiloEndpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
@@ -489,19 +496,8 @@ func (m *Mesh) applyTopology() {
} else {
m.wireGuardIP = nil
}
conf := t.Conf()
buf, err := conf.Bytes()
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
if err := ioutil.WriteFile(confPath, buf, 0600); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
ipRules := t.Rules(m.cni, m.iptablesForwardRule)
// If we are handling local routes, ensure the local
// tunnel has an IP address and IPIP traffic is allowed.
if m.enc.Strategy() != encapsulation.Never && m.local {
@@ -540,10 +536,12 @@ func (m *Mesh) applyTopology() {
}
// Setting the WireGuard configuration interrupts existing connections
// so only set the configuration if it has changed.
equal := conf.Equal(oldConf)
conf := t.Conf()
equal, diff := conf.Equal(wgDevice)
if !equal {
level.Info(m.logger).Log("msg", "WireGuard configurations are different")
if err := wireguard.SetConf(link.Attrs().Name, confPath); err != nil {
level.Info(m.logger).Log("msg", "WireGuard configurations are different", "diff", diff)
level.Debug(m.logger).Log("changing wg config", "config", conf.WGConfig())
if err := wgClient.ConfigureDevice(m.kiloIfaceName, conf.WGConfig()); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
@@ -598,10 +596,6 @@ func (m *Mesh) cleanUp() {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up routes: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := os.Remove(confPath); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to delete configuration file: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if m.cleanUpIface {
if err := iproute.RemoveInterface(m.kiloIface); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err))
@@ -631,10 +625,13 @@ func (m *Mesh) resolveEndpoints() error {
}
// If the node is ready, then the endpoint is not nil
// but it may not have a DNS name.
if m.nodes[k].Endpoint.DNS == "" {
if m.nodes[k].KiloEndpoint.DNS == "" {
continue
}
if err := resolveEndpoint(m.nodes[k].Endpoint); err != nil {
if u, err := net.ResolveUDPAddr("udp", m.nodes[k].KiloEndpoint.String()); err == nil {
m.nodes[k].Endpoint = u
m.nodes[k].KiloEndpoint.IP = u.IP
} else {
return err
}
}
@@ -645,33 +642,19 @@ func (m *Mesh) resolveEndpoints() error {
continue
}
// Peers may have nil endpoints.
if m.peers[k].Endpoint == nil || m.peers[k].Endpoint.DNS == "" {
if m.peers[k].KiloEndpoint == nil || m.peers[k].KiloEndpoint.DNS == "" {
continue
}
if err := resolveEndpoint(m.peers[k].Endpoint); err != nil {
if u, err := net.ResolveUDPAddr("udp", m.peers[k].KiloEndpoint.String()); err == nil {
m.peers[k].Endpoint = u
m.peers[k].KiloEndpoint.IP = u.IP
} else {
return err
}
}
return nil
}
func resolveEndpoint(endpoint *wireguard.Endpoint) error {
ips, err := net.LookupIP(endpoint.DNS)
if err != nil {
return fmt.Errorf("failed to look up DNS name %q: %v", endpoint.DNS, err)
}
nets := make([]*net.IPNet, len(ips), len(ips))
for i := range ips {
nets[i] = oneAddressCIDR(ips[i])
}
sortIPs(nets)
if len(nets) == 0 {
return fmt.Errorf("did not find any addresses for DNS name %q", endpoint.DNS)
}
endpoint.IP = nets[0].IP
return nil
}
func isSelf(hostname string, node *Node) bool {
return node != nil && node.Name == hostname
}
@@ -685,13 +668,24 @@ func nodesAreEqual(a, b *Node) bool {
}
// Check the DNS name first since this package
// is doing the DNS resolution.
if !a.Endpoint.Equal(b.Endpoint, true) {
if !a.KiloEndpoint.Equal(b.KiloEndpoint, true) {
return false
}
// Ignore LastSeen when comparing equality we want to check if the nodes are
// equivalent. However, we do want to check if LastSeen has transitioned
// between valid and invalid.
return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive && discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints) && ipNetSlicesEqual(a.AllowedLocationIPs, b.AllowedLocationIPs) && a.Granularity == b.Granularity
return a.Key.String() == b.Key.String() &&
ipNetsEqual(a.WireGuardIP, b.WireGuardIP) &&
ipNetsEqual(a.InternalIP, b.InternalIP) &&
a.Leader == b.Leader &&
a.Location == b.Location &&
a.Name == b.Name &&
subnetsEqual(a.Subnet, b.Subnet) &&
a.Ready() == b.Ready() &&
a.PersistentKeepalive == b.PersistentKeepalive &&
discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints) &&
ipNetSlicesEqual(a.AllowedLocationIPs, b.AllowedLocationIPs) &&
a.Granularity == b.Granularity
}
func peersAreEqual(a, b *Peer) bool {
@@ -703,18 +697,22 @@ func peersAreEqual(a, b *Peer) bool {
}
// Check the DNS name first since this package
// is doing the DNS resolution.
if !a.Endpoint.Equal(b.Endpoint, true) {
if !a.KiloEndpoint.Equal(b.KiloEndpoint, true) {
return false
}
if len(a.AllowedIPs) != len(b.AllowedIPs) {
return false
}
for i := range a.AllowedIPs {
if !ipNetsEqual(a.AllowedIPs[i], b.AllowedIPs[i]) {
if !ipNetsEqual(&a.AllowedIPs[i], &b.AllowedIPs[i]) {
return false
}
}
return string(a.PublicKey) == string(b.PublicKey) && string(a.PresharedKey) == string(b.PresharedKey) && a.PersistentKeepalive == b.PersistentKeepalive
return a.PublicKey.String() == b.PublicKey.String() &&
(a.PresharedKey == nil) == (b.PresharedKey == nil) &&
(a.PresharedKey == nil || a.PresharedKey.String() == b.PresharedKey.String()) &&
(a.PersistentKeepaliveInterval == nil) == (b.PersistentKeepaliveInterval == nil) &&
(a.PersistentKeepaliveInterval == nil || a.PersistentKeepaliveInterval == b.PersistentKeepaliveInterval)
}
func ipNetsEqual(a, b *net.IPNet) bool {
@@ -730,12 +728,12 @@ func ipNetsEqual(a, b *net.IPNet) bool {
return a.IP.Equal(b.IP)
}
func ipNetSlicesEqual(a, b []*net.IPNet) bool {
func ipNetSlicesEqual(a, b []net.IPNet) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !ipNetsEqual(a[i], b[i]) {
if !ipNetsEqual(&a[i], &b[i]) {
return false
}
}
@@ -761,7 +759,7 @@ func subnetsEqual(a, b *net.IPNet) bool {
return true
}
func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
func discoveredEndpointsAreEqual(a, b map[string]*net.UDPAddr) bool {
if a == nil && b == nil {
return true
}
@@ -772,7 +770,7 @@ func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
return false
}
for k := range a {
if !a[k].Equal(b[k], false) {
if a[k] != b[k] {
return false
}
}
@@ -788,24 +786,26 @@ func linkByIndex(index int) (netlink.Link, error) {
}
// discoverNATEndpoints uses the node's WireGuard configuration to returns a list of the most recently discovered endpoints for all nodes and peers behind NAT so that they can roam.
func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf, logger log.Logger) map[string]*wireguard.Endpoint {
natEndpoints := make(map[string]*wireguard.Endpoint)
keys := make(map[string]*wireguard.Peer)
// Discovered endpionts will never be DNS names, because WireGuard will always resolve them to net.UDPAddr.
func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wgtypes.Device, logger log.Logger) map[string]*net.UDPAddr {
natEndpoints := make(map[string]*net.UDPAddr)
keys := make(map[string]wgtypes.Peer)
for i := range conf.Peers {
keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
keys[conf.Peers[i].PublicKey.String()] = conf.Peers[i]
}
for _, n := range nodes {
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false), "latest-handshake", peer.LatestHandshake)
if (peer.LatestHandshake != time.Time{}) {
natEndpoints[string(n.Key)] = peer.Endpoint
if peer, ok := keys[n.Key.String()]; ok && n.PersistentKeepalive != time.Duration(0) {
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", peer.Endpoint.String() == n.Endpoint.String(), "latest-handshake", peer.LastHandshakeTime)
// Don't update the endpoint, if there was never any handshake.
if !peer.LastHandshakeTime.Equal(time.Time{}) {
natEndpoints[n.Key.String()] = peer.Endpoint
}
}
}
for _, p := range peers {
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
if (peer.LatestHandshake != time.Time{}) {
natEndpoints[string(p.PublicKey)] = peer.Endpoint
if peer, ok := keys[p.PublicKey.String()]; ok && p.PersistentKeepaliveInterval != nil {
if !peer.LastHandshakeTime.Equal(time.Time{}) {
natEndpoints[p.PublicKey.String()] = peer.Endpoint
}
}
}

View File

@@ -20,8 +20,19 @@ import (
"time"
"github.com/squat/kilo/pkg/wireguard"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func mustKey() wgtypes.Key {
if k, err := wgtypes.GeneratePrivateKey(); err != nil {
panic(err.Error())
} else {
return k
}
}
var key = mustKey()
func TestReady(t *testing.T) {
internalIP := oneAddressCIDR(net.ParseIP("1.1.1.1"))
externalIP := oneAddressCIDR(net.ParseIP("2.2.2.2"))
@@ -44,7 +55,7 @@ func TestReady(t *testing.T) {
name: "empty endpoint",
node: &Node{
InternalIP: internalIP,
Key: []byte{},
Key: key,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
@@ -52,58 +63,58 @@ func TestReady(t *testing.T) {
{
name: "empty endpoint IP",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{}, Port: DefaultKiloPort},
InternalIP: internalIP,
Key: []byte{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{}, Port: DefaultKiloPort},
InternalIP: internalIP,
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
},
{
name: "empty endpoint port",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}},
InternalIP: internalIP,
Key: []byte{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}},
InternalIP: internalIP,
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
},
{
name: "empty internal IP",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
Key: []byte{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
},
{
name: "empty key",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
InternalIP: internalIP,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
InternalIP: internalIP,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
},
{
name: "empty subnet",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
InternalIP: internalIP,
Key: []byte{},
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
InternalIP: internalIP,
Key: wgtypes.Key{},
},
ready: false,
},
{
name: "valid",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
InternalIP: internalIP,
Key: []byte{},
LastSeen: time.Now().Unix(),
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
InternalIP: internalIP,
Key: key,
LastSeen: time.Now().Unix(),
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: true,
},

View File

@@ -40,7 +40,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
var gw net.IP
for _, segment := range t.segments {
if segment.location == t.location {
gw = enc.Gw(segment.endpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
gw = enc.Gw(segment.kiloEndpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
break
}
}
@@ -113,7 +113,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
// we need to set routes for allowed location IPs over the leader in the current location.
for i := range segment.allowedLocationIPs {
routes = append(routes, encapsulateRoute(&netlink.Route{
Dst: segment.allowedLocationIPs[i],
Dst: &segment.allowedLocationIPs[i],
Flags: int(netlink.FLAG_ONLINK),
Gw: gw,
LinkIndex: privIface,
@@ -125,7 +125,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
for _, peer := range t.peers {
for i := range peer.AllowedIPs {
routes = append(routes, encapsulateRoute(&netlink.Route{
Dst: peer.AllowedIPs[i],
Dst: &peer.AllowedIPs[i],
Flags: int(netlink.FLAG_ONLINK),
Gw: gw,
LinkIndex: privIface,
@@ -196,7 +196,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
// equals the external IP. This means that the node
// is only accessible through an external IP and we
// cannot encapsulate traffic to an IP through the IP.
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP) {
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.kiloEndpoint.IP) {
continue
}
// Add routes to the private IPs of nodes in other segments.
@@ -214,7 +214,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
// we need to set routes for allowed location IPs over the wg interface.
for i := range segment.allowedLocationIPs {
routes = append(routes, &netlink.Route{
Dst: segment.allowedLocationIPs[i],
Dst: &segment.allowedLocationIPs[i],
Flags: int(netlink.FLAG_ONLINK),
Gw: segment.wireGuardIP,
LinkIndex: kiloIface,
@@ -226,7 +226,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
for _, peer := range t.peers {
for i := range peer.AllowedIPs {
routes = append(routes, &netlink.Route{
Dst: peer.AllowedIPs[i],
Dst: &peer.AllowedIPs[i],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
})

View File

@@ -75,7 +75,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@@ -89,17 +89,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -132,17 +132,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -196,21 +196,21 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
@@ -266,24 +266,24 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["d"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -309,7 +309,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@@ -337,17 +337,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -394,17 +394,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -444,7 +444,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@@ -458,17 +458,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -509,7 +509,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@@ -523,17 +523,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -574,7 +574,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@@ -588,17 +588,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -639,17 +639,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -698,17 +698,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -782,21 +782,21 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
@@ -868,21 +868,21 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: tunlIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: tunlIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: tunlIface,
@@ -918,7 +918,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@@ -946,17 +946,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -1004,17 +1004,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@@ -1055,7 +1055,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@@ -1069,17 +1069,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},

View File

@@ -1,4 +1,4 @@
// Copyright 2019 the Kilo authors
// Copyright 2021 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -18,9 +18,11 @@ import (
"errors"
"net"
"sort"
"time"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/wireguard"
)
@@ -33,8 +35,8 @@ const (
// Topology represents the logical structure of the overlay network.
type Topology struct {
// key is the private key of the node creating the topology.
key []byte
port uint32
key wgtypes.Key
port int
// Location is the logical location of the local host.
location string
segments []*segment
@@ -47,7 +49,7 @@ type Topology struct {
leader bool
// persistentKeepalive is the interval in seconds of the emission
// of keepalive packets by the local node to its peers.
persistentKeepalive int
persistentKeepalive time.Duration
// privateIP is the private IP address of the local node.
privateIP *net.IPNet
// subnet is the Pod subnet of the local node.
@@ -59,15 +61,16 @@ type Topology struct {
// is equal to the Kilo subnet.
wireGuardCIDR *net.IPNet
// discoveredEndpoints is the updated map of valid discovered Endpoints
discoveredEndpoints map[string]*wireguard.Endpoint
discoveredEndpoints map[string]*net.UDPAddr
logger log.Logger
}
type segment struct {
allowedIPs []*net.IPNet
endpoint *wireguard.Endpoint
key []byte
persistentKeepalive int
allowedIPs []net.IPNet
kiloEndpoint *wireguard.Endpoint
endpoint *net.UDPAddr
key wgtypes.Key
persistentKeepalive time.Duration
// Location is the logical location of this segment.
location string
@@ -85,11 +88,11 @@ type segment struct {
// allowedLocationIPs are not part of the cluster and are not peers.
// They are directly routable from nodes within the segment.
// A classic example is a printer that ought to be routable from other locations.
allowedLocationIPs []*net.IPNet
allowedLocationIPs []net.IPNet
}
// NewTopology creates a new Topology struct from a given set of nodes and peers.
func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int, logger log.Logger) (*Topology, error) {
func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration, logger log.Logger) (*Topology, error) {
if logger == nil {
logger = log.NewNopLogger()
}
@@ -120,7 +123,18 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
localLocation = nodeLocationPrefix + hostname
}
t := Topology{key: key, port: port, hostname: hostname, location: localLocation, persistentKeepalive: persistentKeepalive, privateIP: nodes[hostname].InternalIP, subnet: nodes[hostname].Subnet, wireGuardCIDR: subnet, discoveredEndpoints: make(map[string]*wireguard.Endpoint), logger: logger}
t := Topology{
key: key,
port: port,
hostname: hostname,
location: localLocation,
persistentKeepalive: persistentKeepalive,
privateIP: nodes[hostname].InternalIP,
subnet: nodes[hostname].Subnet,
wireGuardCIDR: subnet,
discoveredEndpoints: make(map[string]*net.UDPAddr),
logger: logger,
}
for location := range topoMap {
// Sort the location so the result is stable.
sort.Slice(topoMap[location], func(i, j int) bool {
@@ -130,9 +144,9 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
if location == localLocation && topoMap[location][leader].Name == hostname {
t.leader = true
}
var allowedIPs []*net.IPNet
var allowedIPs []net.IPNet
allowedLocationIPsMap := make(map[string]struct{})
var allowedLocationIPs []*net.IPNet
var allowedLocationIPs []net.IPNet
var cidrs []*net.IPNet
var hostnames []string
var privateIPs []net.IP
@@ -142,7 +156,9 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
// - the node's WireGuard IP
// - the node's internal IP
// - IPs that were specified by the allowed-location-ips annotation
allowedIPs = append(allowedIPs, node.Subnet)
if node.Subnet != nil {
allowedIPs = append(allowedIPs, *node.Subnet)
}
for _, ip := range node.AllowedLocationIPs {
if _, ok := allowedLocationIPsMap[ip.String()]; !ok {
allowedLocationIPs = append(allowedLocationIPs, ip)
@@ -150,7 +166,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
}
}
if node.InternalIP != nil {
allowedIPs = append(allowedIPs, oneAddressCIDR(node.InternalIP.IP))
allowedIPs = append(allowedIPs, *oneAddressCIDR(node.InternalIP.IP))
privateIPs = append(privateIPs, node.InternalIP.IP)
}
cidrs = append(cidrs, node.Subnet)
@@ -162,6 +178,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
})
t.segments = append(t.segments, &segment{
allowedIPs: allowedIPs,
kiloEndpoint: topoMap[location][leader].KiloEndpoint,
endpoint: topoMap[location][leader].Endpoint,
key: topoMap[location][leader].Key,
persistentKeepalive: topoMap[location][leader].PersistentKeepalive,
@@ -202,7 +219,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
return nil, errors.New("failed to allocate an IP address; ran out of IP addresses")
}
segment.wireGuardIP = ipNet.IP
segment.allowedIPs = append(segment.allowedIPs, oneAddressCIDR(ipNet.IP))
segment.allowedIPs = append(segment.allowedIPs, *oneAddressCIDR(ipNet.IP))
if t.leader && segment.location == t.location {
t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: subnet.Mask}
}
@@ -224,11 +241,11 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
return &t, nil
}
func intersect(n1, n2 *net.IPNet) bool {
func intersect(n1, n2 net.IPNet) bool {
return n1.Contains(n2.IP) || n2.Contains(n1.IP)
}
func (t *Topology) filterAllowedLocationIPs(ips []*net.IPNet, location string) (ret []*net.IPNet) {
func (t *Topology) filterAllowedLocationIPs(ips []net.IPNet, location string) (ret []net.IPNet) {
CheckIPs:
for _, ip := range ips {
for _, s := range t.segments {
@@ -270,45 +287,54 @@ CheckIPs:
return
}
func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key []byte, persistentKeepalive int) *wireguard.Endpoint {
func (t *Topology) updateEndpoint(kiloEndpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *net.UDPAddr {
// Do not update non-nat peers
if persistentKeepalive == 0 {
return endpoint
if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) {
return kiloEndpoint.UDPAddr()
}
e, ok := t.discoveredEndpoints[string(key)]
e, ok := t.discoveredEndpoints[key.String()]
if ok {
return e
}
return endpoint
return nil
}
// Conf generates a WireGuard configuration file for a given Topology.
func (t *Topology) Conf() *wireguard.Conf {
c := &wireguard.Conf{
Interface: &wireguard.Interface{
PrivateKey: t.key,
ListenPort: t.port,
Config: wgtypes.Config{
PrivateKey: &t.key,
ListenPort: &t.port,
ReplacePeers: true,
},
}
for _, s := range t.segments {
if s.location == t.location {
continue
}
peer := &wireguard.Peer{
AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
Endpoint: t.updateEndpoint(s.endpoint, s.key, s.persistentKeepalive),
PersistentKeepalive: t.persistentKeepalive,
PublicKey: s.key,
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
Endpoint: t.updateEndpoint(s.kiloEndpoint, s.key, &s.persistentKeepalive),
PersistentKeepaliveInterval: &t.persistentKeepalive,
PublicKey: s.key,
ReplaceAllowedIPs: true,
},
KiloEndpoint: s.kiloEndpoint,
}
c.Peers = append(c.Peers, peer)
}
for _, p := range t.peers {
peer := &wireguard.Peer{
AllowedIPs: p.AllowedIPs,
Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepalive),
PersistentKeepalive: t.persistentKeepalive,
PresharedKey: p.PresharedKey,
PublicKey: p.PublicKey,
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: p.AllowedIPs,
Endpoint: t.updateEndpoint(p.KiloEndpoint, p.PublicKey, p.PersistentKeepaliveInterval),
PersistentKeepaliveInterval: &t.persistentKeepalive,
PresharedKey: p.PresharedKey,
PublicKey: p.PublicKey,
ReplaceAllowedIPs: true,
},
KiloEndpoint: p.KiloEndpoint,
}
c.Peers = append(c.Peers, peer)
}
@@ -317,39 +343,46 @@ func (t *Topology) Conf() *wireguard.Conf {
// AsPeer generates the WireGuard peer configuration for the local location of the given Topology.
// This configuration can be used to configure this location as a peer of another WireGuard interface.
func (t *Topology) AsPeer() *wireguard.Peer {
func (t *Topology) AsPeer() wireguard.Peer {
for _, s := range t.segments {
if s.location != t.location {
continue
}
return &wireguard.Peer{
AllowedIPs: s.allowedIPs,
Endpoint: s.endpoint,
PublicKey: s.key,
p := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: s.allowedIPs,
PublicKey: s.key,
Endpoint: s.endpoint,
},
KiloEndpoint: s.kiloEndpoint,
}
return p
}
return nil
return wireguard.Peer{}
}
// PeerConf generates a WireGuard configuration file for a given peer in a Topology.
func (t *Topology) PeerConf(name string) *wireguard.Conf {
var pka int
var psk []byte
func (t *Topology) PeerConf(name string) wireguard.Conf {
var pka *time.Duration
var psk *wgtypes.Key
for i := range t.peers {
if t.peers[i].Name == name {
pka = t.peers[i].PersistentKeepalive
pka = t.peers[i].PersistentKeepaliveInterval
psk = t.peers[i].PresharedKey
break
}
}
c := &wireguard.Conf{}
c := wireguard.Conf{}
for _, s := range t.segments {
peer := &wireguard.Peer{
AllowedIPs: s.allowedIPs,
Endpoint: s.endpoint,
PersistentKeepalive: pka,
PresharedKey: psk,
PublicKey: s.key,
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: s.allowedIPs,
Endpoint: s.kiloEndpoint.UDPAddr(),
PersistentKeepaliveInterval: pka,
PresharedKey: psk,
PublicKey: s.key,
},
KiloEndpoint: s.kiloEndpoint,
}
c.Peers = append(c.Peers, peer)
}
@@ -357,11 +390,13 @@ func (t *Topology) PeerConf(name string) *wireguard.Conf {
if t.peers[i].Name == name {
continue
}
peer := &wireguard.Peer{
AllowedIPs: t.peers[i].AllowedIPs,
PersistentKeepalive: pka,
PublicKey: t.peers[i].PublicKey,
Endpoint: t.peers[i].Endpoint,
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: t.peers[i].AllowedIPs,
PersistentKeepaliveInterval: pka,
PublicKey: t.peers[i].PublicKey,
Endpoint: t.peers[i].Endpoint,
},
}
c.Peers = append(c.Peers, peer)
}
@@ -382,13 +417,13 @@ func findLeader(nodes []*Node) int {
var leaders, public []int
for i := range nodes {
if nodes[i].Leader {
if isPublic(nodes[i].Endpoint.IP) {
if isPublic(nodes[i].KiloEndpoint.IP) {
return i
}
leaders = append(leaders, i)
}
if isPublic(nodes[i].Endpoint.IP) {
if isPublic(nodes[i].KiloEndpoint.IP) {
public = append(public, i)
}
}
@@ -408,10 +443,13 @@ func deduplicatePeerIPs(peers []*Peer) []*Peer {
p := Peer{
Name: peer.Name,
Peer: wireguard.Peer{
Endpoint: peer.Endpoint,
PersistentKeepalive: peer.PersistentKeepalive,
PresharedKey: peer.PresharedKey,
PublicKey: peer.PublicKey,
PeerConfig: wgtypes.PeerConfig{
Endpoint: peer.Endpoint,
PersistentKeepaliveInterval: peer.PersistentKeepaliveInterval,
PresharedKey: peer.PresharedKey,
PublicKey: peer.PublicKey,
},
KiloEndpoint: peer.KiloEndpoint,
},
}
for _, ip := range peer.AllowedIPs {

View File

@@ -18,9 +18,11 @@ import (
"net"
"strings"
"testing"
"time"
"github.com/go-kit/kit/log"
"github.com/kylelemons/godebug/pretty"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/wireguard"
)
@@ -29,17 +31,25 @@ func allowedIPs(ips ...string) string {
return strings.Join(ips, ", ")
}
func mustParseCIDR(s string) (r *net.IPNet) {
func mustParseCIDR(s string) (r net.IPNet) {
if _, ip, err := net.ParseCIDR(s); err != nil {
panic("failed to parse CIDR")
} else {
r = ip
r = *ip
}
return
}
func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) {
key := []byte("private")
var (
key1 = wgtypes.Key{'k', 'e', 'y', '1'}
key2 = wgtypes.Key{'k', 'e', 'y', '2'}
key3 = wgtypes.Key{'k', 'e', 'y', '3'}
key4 = wgtypes.Key{'k', 'e', 'y', '4'}
key5 = wgtypes.Key{'k', 'e', 'y', '5'}
)
func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) {
key := wgtypes.Key{'p', 'r', 'i', 'v'}
e1 := &net.IPNet{IP: net.ParseIP("10.1.0.1").To4(), Mask: net.CIDRMask(16, 32)}
e2 := &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(16, 32)}
e3 := &net.IPNet{IP: net.ParseIP("10.1.0.3").To4(), Mask: net.CIDRMask(16, 32)}
@@ -50,62 +60,66 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) {
nodes := map[string]*Node{
"a": {
Name: "a",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
InternalIP: i1,
Location: "1",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
Key: []byte("key1"),
Key: key1,
PersistentKeepalive: 25,
},
"b": {
Name: "b",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
InternalIP: i1,
Location: "2",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)},
Key: []byte("key2"),
AllowedLocationIPs: []*net.IPNet{i3},
Key: key2,
AllowedLocationIPs: []net.IPNet{*i3},
},
"c": {
Name: "c",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e3.IP}, Port: DefaultKiloPort},
InternalIP: i2,
Name: "c",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e3.IP}, Port: DefaultKiloPort},
InternalIP: i2,
// Same location as node b.
Location: "2",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.3.0"), Mask: net.CIDRMask(24, 32)},
Key: []byte("key3"),
Key: key3,
},
"d": {
Name: "d",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e4.IP}, Port: DefaultKiloPort},
Name: "d",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e4.IP}, Port: DefaultKiloPort},
// Same location as node a, but without private IP
Location: "1",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.4.0"), Mask: net.CIDRMask(24, 32)},
Key: []byte("key4"),
Key: key4,
},
}
peers := map[string]*Peer{
"a": {
Name: "a",
Peer: wireguard.Peer{
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.5.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.5.0.2"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.5.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.5.0.2"), Mask: net.CIDRMask(24, 32)},
},
PublicKey: key4,
},
PublicKey: []byte("key4"),
},
},
"b": {
Name: "b",
Peer: wireguard.Peer{
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)},
},
PublicKey: key5,
},
Endpoint: &wireguard.Endpoint{
KiloEndpoint: &wireguard.Endpoint{
DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("192.168.0.1")},
Port: DefaultKiloPort,
},
PublicKey: []byte("key5"),
},
},
}
@@ -138,8 +152,8 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["a"].Location,
@@ -149,8 +163,8 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["b"].Location,
@@ -161,8 +175,8 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@@ -189,8 +203,8 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["a"].Location,
@@ -200,8 +214,8 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["b"].Location,
@@ -212,8 +226,8 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@@ -240,8 +254,8 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: DefaultKiloSubnet,
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["a"].Location,
@@ -251,8 +265,8 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["b"].Location,
@@ -263,8 +277,8 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@@ -291,8 +305,8 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["a"].Name,
@@ -302,8 +316,8 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["b"].Name,
@@ -314,8 +328,8 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["c"].Endpoint,
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["c"].KiloEndpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["c"].Name,
@@ -325,8 +339,8 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w3,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@@ -353,8 +367,8 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["a"].Name,
@@ -364,8 +378,8 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["b"].Name,
@@ -376,8 +390,8 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["c"].Endpoint,
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["c"].KiloEndpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["c"].Name,
@@ -387,8 +401,8 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w3,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@@ -415,8 +429,8 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w3, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["a"].Name,
@@ -426,8 +440,8 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["b"].Name,
@@ -438,8 +452,8 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["c"].Endpoint,
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["c"].KiloEndpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["c"].Name,
@@ -449,8 +463,8 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w3,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@@ -477,8 +491,8 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w4, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["a"].Name,
@@ -488,8 +502,8 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["b"].Name,
@@ -500,8 +514,8 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["c"].Endpoint,
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["c"].KiloEndpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["c"].Name,
@@ -511,8 +525,8 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w3,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@@ -539,7 +553,7 @@ func TestNewTopology(t *testing.T) {
}
}
func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int) *Topology {
func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration) *Topology {
topo, err := NewTopology(nodes, peers, granularity, hostname, port, key, subnet, persistentKeepalive, nil)
if err != nil {
t.Errorf("failed to generate Topology: %v", err)
@@ -547,211 +561,6 @@ func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, gran
return topo
}
func TestConf(t *testing.T) {
nodes, peers, key, port := setup(t)
for _, tc := range []struct {
name string
topology *Topology
result string
}{
{
name: "logical from a",
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key2
Endpoint = 10.1.0.2:51820
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32, 192.168.178.3/32
PersistentKeepalive = 25
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.3/32
PersistentKeepalive = 25
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
PersistentKeepalive = 25
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
PersistentKeepalive = 25
`,
},
{
name: "logical from b",
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet, nodes["b"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.3/32
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
`,
},
{
name: "logical from c",
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet, nodes["c"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.3/32
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
`,
},
{
name: "full from a",
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key2
Endpoint = 10.1.0.2:51820
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32, 192.168.178.3/32
PersistentKeepalive = 25
[Peer]
PublicKey = key3
Endpoint = 10.1.0.3:51820
AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32
PersistentKeepalive = 25
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.4/32
PersistentKeepalive = 25
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
PersistentKeepalive = 25
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
PersistentKeepalive = 25
`,
},
{
name: "full from b",
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet, nodes["b"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer]
PublicKey = key3
Endpoint = 10.1.0.3:51820
AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.4/32
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
`,
},
{
name: "full from c",
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet, nodes["c"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer]
PublicKey = key2
Endpoint = 10.1.0.2:51820
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32, 192.168.178.3/32
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.4/32
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
`,
},
} {
conf := tc.topology.Conf()
if !conf.Equal(wireguard.Parse([]byte(tc.result))) {
buf, err := conf.Bytes()
if err != nil {
t.Errorf("test case %q: failed to render conf: %v", tc.name, err)
}
t.Errorf("test case %q: expected %s got %s", tc.name, tc.result, string(buf))
}
}
}
func TestFindLeader(t *testing.T) {
ip, e1, err := net.ParseCIDR("10.0.0.1/32")
if err != nil {
@@ -766,26 +575,26 @@ func TestFindLeader(t *testing.T) {
nodes := []*Node{
{
Name: "a",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
Name: "a",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
},
{
Name: "b",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Name: "b",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
},
{
Name: "c",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Name: "c",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
},
{
Name: "d",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
Leader: true,
Name: "d",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
Leader: true,
},
{
Name: "2",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Leader: true,
Name: "2",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Leader: true,
},
}
for _, tc := range []struct {
@@ -840,31 +649,38 @@ func TestDeduplicatePeerIPs(t *testing.T) {
p1 := &Peer{
Name: "1",
Peer: wireguard.Peer{
PublicKey: []byte("key1"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key1,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
},
},
},
}
p2 := &Peer{
Name: "2",
Peer: wireguard.Peer{
PublicKey: []byte("key2"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key2,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
},
},
},
}
p3 := &Peer{
Name: "3",
Peer: wireguard.Peer{
PublicKey: []byte("key3"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key3,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
},
},
},
}
@@ -872,10 +688,12 @@ func TestDeduplicatePeerIPs(t *testing.T) {
p4 := &Peer{
Name: "4",
Peer: wireguard.Peer{
PublicKey: []byte("key4"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key4,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
},
},
},
}
@@ -898,9 +716,11 @@ func TestDeduplicatePeerIPs(t *testing.T) {
{
Name: "2",
Peer: wireguard.Peer{
PublicKey: []byte("key2"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key2,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
},
},
},
},
@@ -914,9 +734,11 @@ func TestDeduplicatePeerIPs(t *testing.T) {
{
Name: "1",
Peer: wireguard.Peer{
PublicKey: []byte("key1"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key1,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
},
},
},
},
@@ -930,19 +752,25 @@ func TestDeduplicatePeerIPs(t *testing.T) {
{
Name: "2",
Peer: wireguard.Peer{
PublicKey: []byte("key2"),
PeerConfig: wgtypes.PeerConfig{
PublicKey: key2,
},
},
},
{
Name: "1",
Peer: wireguard.Peer{
PublicKey: []byte("key1"),
PeerConfig: wgtypes.PeerConfig{
PublicKey: key1,
},
},
},
{
Name: "4",
Peer: wireguard.Peer{
PublicKey: []byte("key4"),
PeerConfig: wgtypes.PeerConfig{
PublicKey: key4,
},
},
},
},
@@ -954,19 +782,23 @@ func TestDeduplicatePeerIPs(t *testing.T) {
{
Name: "4",
Peer: wireguard.Peer{
PublicKey: []byte("key4"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key4,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
},
},
},
},
{
Name: "1",
Peer: wireguard.Peer{
PublicKey: []byte("key1"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key1,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
},
},
},
},
@@ -985,12 +817,12 @@ func TestFilterAllowedIPs(t *testing.T) {
topo := mustTopo(t, nodes, peers, LogicalGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive)
for _, tc := range []struct {
name string
allowedLocationIPs map[int][]*net.IPNet
result map[int][]*net.IPNet
allowedLocationIPs map[int][]net.IPNet
result map[int][]net.IPNet
}{
{
name: "nothing to filter",
allowedLocationIPs: map[int][]*net.IPNet{
allowedLocationIPs: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.4/32"),
},
@@ -1002,7 +834,7 @@ func TestFilterAllowedIPs(t *testing.T) {
mustParseCIDR("192.168.178.7/32"),
},
},
result: map[int][]*net.IPNet{
result: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.4/32"),
},
@@ -1017,7 +849,7 @@ func TestFilterAllowedIPs(t *testing.T) {
},
{
name: "intersections between segments",
allowedLocationIPs: map[int][]*net.IPNet{
allowedLocationIPs: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.4/32"),
mustParseCIDR("192.168.178.8/32"),
@@ -1031,7 +863,7 @@ func TestFilterAllowedIPs(t *testing.T) {
mustParseCIDR("192.168.178.4/32"),
},
},
result: map[int][]*net.IPNet{
result: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.8/32"),
},
@@ -1047,7 +879,7 @@ func TestFilterAllowedIPs(t *testing.T) {
},
{
name: "intersections with wireGuardCIDR",
allowedLocationIPs: map[int][]*net.IPNet{
allowedLocationIPs: map[int][]net.IPNet{
0: {
mustParseCIDR("10.4.0.1/32"),
mustParseCIDR("192.168.178.8/32"),
@@ -1060,7 +892,7 @@ func TestFilterAllowedIPs(t *testing.T) {
mustParseCIDR("192.168.178.7/32"),
},
},
result: map[int][]*net.IPNet{
result: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.8/32"),
},
@@ -1075,7 +907,7 @@ func TestFilterAllowedIPs(t *testing.T) {
},
{
name: "intersections with more than one allowedLocationIPs",
allowedLocationIPs: map[int][]*net.IPNet{
allowedLocationIPs: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.8/32"),
},
@@ -1086,7 +918,7 @@ func TestFilterAllowedIPs(t *testing.T) {
mustParseCIDR("192.168.178.7/24"),
},
},
result: map[int][]*net.IPNet{
result: map[int][]net.IPNet{
0: {},
1: {},
2: {