migrate to golang.zx2c4.com/wireguard/wgctrl (#239)
* migrate to golang.zx2c4.com/wireguard/wgctrl This commit introduces the usage of wgctrl. It avoids the usage of exec calls of the wg command and parsing the output of `wg show`. Signed-off-by: leonnicolas <leonloechner@gmx.de> * vendor wgctrl Signed-off-by: leonnicolas <leonloechner@gmx.de> * apply suggestions from code review Remove wireguard.Enpoint struct and use net.UDPAddr for the resolved endpoint and addr string (dnsanme:port) if a DN was supplied. Signed-off-by: leonnicolas <leonloechner@gmx.de> * pkg/*: use wireguard.Enpoint This commit introduces the wireguard.Enpoint struct. It encapsulates a DN name with port and a net.UPDAddr. The fields are private and only accessible over exported Methods to avoid accidental modification. Also iptables.GetProtocol is improved to avoid ipv4 rules being applied by `ip6tables`. Signed-off-by: leonnicolas <leonloechner@gmx.de> * pkg/wireguard/conf_test.go: add tests for Endpoint Signed-off-by: leonnicolas <leonloechner@gmx.de> * cmd/kg/main.go: validate port range Signed-off-by: leonnicolas <leonloechner@gmx.de> * add suggestions from review Signed-off-by: leonnicolas <leonloechner@gmx.de> * pkg/mesh/mesh.go: use Equal func Implement an Equal func for Enpoint and use it instead of comparing strings. Signed-off-by: leonnicolas <leonloechner@gmx.de> * cmd/kgctl/main.go: check port range Signed-off-by: leonnicolas <leonloechner@gmx.de> * vendor Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
@@ -74,7 +74,7 @@ func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule {
|
||||
rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
|
||||
for _, n := range nodes {
|
||||
// Accept encapsulated traffic from peers.
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(n.IP)), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
|
||||
}
|
||||
// Drop all other IPIP traffic.
|
||||
rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
|
||||
|
@@ -53,11 +53,11 @@ const (
|
||||
)
|
||||
|
||||
// GetProtocol will return a protocol from the length of an IP address.
|
||||
func GetProtocol(length int) Protocol {
|
||||
if length == net.IPv6len {
|
||||
return ProtocolIPv6
|
||||
func GetProtocol(ip net.IP) Protocol {
|
||||
if len(ip) == net.IPv4len || ip.To4() != nil {
|
||||
return ProtocolIPv4
|
||||
}
|
||||
return ProtocolIPv4
|
||||
return ProtocolIPv6
|
||||
}
|
||||
|
||||
// Client represents any type that can administer iptables rules.
|
||||
|
@@ -25,13 +25,15 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
v1 "k8s.io/api/core/v1"
|
||||
apiextensions "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/labels"
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
"k8s.io/apimachinery/pkg/util/strategicpatch"
|
||||
"k8s.io/apimachinery/pkg/util/validation"
|
||||
v1informers "k8s.io/client-go/informers/core/v1"
|
||||
"k8s.io/client-go/kubernetes"
|
||||
v1listers "k8s.io/client-go/listers/core/v1"
|
||||
@@ -67,6 +69,8 @@ const (
|
||||
jsonRemovePatch = `{"op": "remove", "path": "%s"}`
|
||||
)
|
||||
|
||||
var logger = log.NewNopLogger()
|
||||
|
||||
type backend struct {
|
||||
nodes *nodeBackend
|
||||
peers *peerBackend
|
||||
@@ -99,10 +103,12 @@ type peerBackend struct {
|
||||
}
|
||||
|
||||
// New creates a new instance of a mesh.Backend.
|
||||
func New(c kubernetes.Interface, kc kiloclient.Interface, ec apiextensions.Interface, topologyLabel string) mesh.Backend {
|
||||
func New(c kubernetes.Interface, kc kiloclient.Interface, ec apiextensions.Interface, topologyLabel string, l log.Logger) mesh.Backend {
|
||||
ni := v1informers.NewNodeInformer(c, 5*time.Minute, nil)
|
||||
pi := v1alpha1informers.NewPeerInformer(kc, 5*time.Minute, nil)
|
||||
|
||||
logger = l
|
||||
|
||||
return &backend{
|
||||
&nodeBackend{
|
||||
client: c,
|
||||
@@ -218,7 +224,7 @@ func (nb *nodeBackend) Set(name string, node *mesh.Node) error {
|
||||
} else {
|
||||
n.ObjectMeta.Annotations[internalIPAnnotationKey] = node.InternalIP.String()
|
||||
}
|
||||
n.ObjectMeta.Annotations[keyAnnotationKey] = string(node.Key)
|
||||
n.ObjectMeta.Annotations[keyAnnotationKey] = node.Key.String()
|
||||
n.ObjectMeta.Annotations[lastSeenAnnotationKey] = strconv.FormatInt(node.LastSeen, 10)
|
||||
if node.WireGuardIP == nil {
|
||||
n.ObjectMeta.Annotations[wireGuardIPAnnotationKey] = ""
|
||||
@@ -276,9 +282,9 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
|
||||
location = node.ObjectMeta.Labels[topologyLabel]
|
||||
}
|
||||
// Allow the endpoint to be overridden.
|
||||
endpoint := parseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey])
|
||||
endpoint := wireguard.ParseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey])
|
||||
if endpoint == nil {
|
||||
endpoint = parseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey])
|
||||
endpoint = wireguard.ParseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey])
|
||||
}
|
||||
// Allow the internal IP to be overridden.
|
||||
internalIP := normalizeIP(node.ObjectMeta.Annotations[forceInternalIPAnnotationKey])
|
||||
@@ -292,13 +298,11 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
|
||||
internalIP = nil
|
||||
}
|
||||
// Set Wireguard PersistentKeepalive setting for the node.
|
||||
var persistentKeepalive int64
|
||||
if keepAlive, ok := node.ObjectMeta.Annotations[persistentKeepaliveKey]; !ok {
|
||||
persistentKeepalive = 0
|
||||
} else {
|
||||
if persistentKeepalive, err = strconv.ParseInt(keepAlive, 10, 64); err != nil {
|
||||
persistentKeepalive = 0
|
||||
}
|
||||
var persistentKeepalive = time.Duration(0)
|
||||
if keepAlive, ok := node.ObjectMeta.Annotations[persistentKeepaliveKey]; ok {
|
||||
// We can ignore the error, because p will be set to 0 if an error occures.
|
||||
p, _ := strconv.ParseInt(keepAlive, 10, 64)
|
||||
persistentKeepalive = time.Duration(p) * time.Second
|
||||
}
|
||||
var lastSeen int64
|
||||
if ls, ok := node.ObjectMeta.Annotations[lastSeenAnnotationKey]; !ok {
|
||||
@@ -308,7 +312,7 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
|
||||
lastSeen = 0
|
||||
}
|
||||
}
|
||||
var discoveredEndpoints map[string]*wireguard.Endpoint
|
||||
var discoveredEndpoints map[string]*net.UDPAddr
|
||||
if de, ok := node.ObjectMeta.Annotations[discoveredEndpointsKey]; ok {
|
||||
err := json.Unmarshal([]byte(de), &discoveredEndpoints)
|
||||
if err != nil {
|
||||
@@ -316,11 +320,11 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
|
||||
}
|
||||
}
|
||||
// Set allowed IPs for a location.
|
||||
var allowedLocationIPs []*net.IPNet
|
||||
var allowedLocationIPs []net.IPNet
|
||||
if str, ok := node.ObjectMeta.Annotations[allowedLocationIPsKey]; ok {
|
||||
for _, ip := range strings.Split(str, ",") {
|
||||
if ipnet := normalizeIP(ip); ipnet != nil {
|
||||
allowedLocationIPs = append(allowedLocationIPs, ipnet)
|
||||
allowedLocationIPs = append(allowedLocationIPs, *ipnet)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -335,6 +339,9 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO log some error or warning.
|
||||
key, _ := wgtypes.ParseKey(node.ObjectMeta.Annotations[keyAnnotationKey])
|
||||
|
||||
return &mesh.Node{
|
||||
// Endpoint and InternalIP should only ever fail to parse if the
|
||||
// remote node's agent has not yet set its IP address;
|
||||
@@ -345,12 +352,12 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
|
||||
Endpoint: endpoint,
|
||||
NoInternalIP: noInternalIP,
|
||||
InternalIP: internalIP,
|
||||
Key: []byte(node.ObjectMeta.Annotations[keyAnnotationKey]),
|
||||
Key: key,
|
||||
LastSeen: lastSeen,
|
||||
Leader: leader,
|
||||
Location: location,
|
||||
Name: node.Name,
|
||||
PersistentKeepalive: int(persistentKeepalive),
|
||||
PersistentKeepalive: persistentKeepalive,
|
||||
Subnet: subnet,
|
||||
// WireGuardIP can fail to parse if the node is not a leader or if
|
||||
// the node's agent has not yet reconciled. In either case, the IP
|
||||
@@ -367,14 +374,14 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer {
|
||||
if peer == nil {
|
||||
return nil
|
||||
}
|
||||
var aips []*net.IPNet
|
||||
var aips []net.IPNet
|
||||
for _, aip := range peer.Spec.AllowedIPs {
|
||||
aip := normalizeIP(aip)
|
||||
// Skip any invalid IPs.
|
||||
if aip == nil {
|
||||
continue
|
||||
}
|
||||
aips = append(aips, aip)
|
||||
aips = append(aips, *aip)
|
||||
}
|
||||
var endpoint *wireguard.Endpoint
|
||||
if peer.Spec.Endpoint != nil {
|
||||
@@ -384,36 +391,41 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer {
|
||||
} else {
|
||||
ip = ip.To16()
|
||||
}
|
||||
if peer.Spec.Endpoint.Port > 0 && (ip != nil || peer.Spec.Endpoint.DNS != "") {
|
||||
endpoint = &wireguard.Endpoint{
|
||||
DNSOrIP: wireguard.DNSOrIP{
|
||||
DNS: peer.Spec.Endpoint.DNS,
|
||||
IP: ip,
|
||||
},
|
||||
Port: peer.Spec.Endpoint.Port,
|
||||
if peer.Spec.Endpoint.Port > 0 {
|
||||
if ip != nil {
|
||||
endpoint = wireguard.NewEndpoint(ip, int(peer.Spec.Endpoint.Port))
|
||||
}
|
||||
if peer.Spec.Endpoint.DNS != "" {
|
||||
endpoint = wireguard.ParseEndpoint(fmt.Sprintf("%s:%d", peer.Spec.Endpoint.DNS, peer.Spec.Endpoint.Port))
|
||||
}
|
||||
}
|
||||
}
|
||||
var key []byte
|
||||
if len(peer.Spec.PublicKey) > 0 {
|
||||
key = []byte(peer.Spec.PublicKey)
|
||||
|
||||
key, err := wgtypes.ParseKey(peer.Spec.PublicKey)
|
||||
if err != nil {
|
||||
level.Error(logger).Log("msg", "failed to parse public key", "peer", peer.Name, "err", err.Error())
|
||||
}
|
||||
var psk []byte
|
||||
if len(peer.Spec.PresharedKey) > 0 {
|
||||
psk = []byte(peer.Spec.PresharedKey)
|
||||
var psk *wgtypes.Key
|
||||
if k, err := wgtypes.ParseKey(peer.Spec.PresharedKey); err != nil {
|
||||
// Set key to nil to avoid setting a key to the zero value wgtypes.Key{}
|
||||
psk = nil
|
||||
} else {
|
||||
psk = &k
|
||||
}
|
||||
var pka int
|
||||
var pka time.Duration
|
||||
if peer.Spec.PersistentKeepalive > 0 {
|
||||
pka = peer.Spec.PersistentKeepalive
|
||||
pka = time.Duration(peer.Spec.PersistentKeepalive)
|
||||
}
|
||||
return &mesh.Peer{
|
||||
Name: peer.Name,
|
||||
Peer: wireguard.Peer{
|
||||
AllowedIPs: aips,
|
||||
Endpoint: endpoint,
|
||||
PersistentKeepalive: pka,
|
||||
PresharedKey: psk,
|
||||
PublicKey: key,
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: aips,
|
||||
PersistentKeepaliveInterval: &pka,
|
||||
PresharedKey: psk,
|
||||
PublicKey: key,
|
||||
},
|
||||
Endpoint: endpoint,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -511,21 +523,25 @@ func (pb *peerBackend) Set(name string, peer *mesh.Peer) error {
|
||||
p.Spec.AllowedIPs[i] = peer.AllowedIPs[i].String()
|
||||
}
|
||||
if peer.Endpoint != nil {
|
||||
var ip string
|
||||
if peer.Endpoint.IP != nil {
|
||||
ip = peer.Endpoint.IP.String()
|
||||
}
|
||||
p.Spec.Endpoint = &v1alpha1.PeerEndpoint{
|
||||
DNSOrIP: v1alpha1.DNSOrIP{
|
||||
IP: ip,
|
||||
DNS: peer.Endpoint.DNS,
|
||||
IP: peer.Endpoint.IP().String(),
|
||||
DNS: peer.Endpoint.DNS(),
|
||||
},
|
||||
Port: peer.Endpoint.Port,
|
||||
Port: uint32(peer.Endpoint.Port()),
|
||||
}
|
||||
}
|
||||
p.Spec.PersistentKeepalive = peer.PersistentKeepalive
|
||||
p.Spec.PresharedKey = string(peer.PresharedKey)
|
||||
p.Spec.PublicKey = string(peer.PublicKey)
|
||||
if peer.PersistentKeepaliveInterval == nil {
|
||||
p.Spec.PersistentKeepalive = 0
|
||||
} else {
|
||||
p.Spec.PersistentKeepalive = int(*peer.PersistentKeepaliveInterval)
|
||||
}
|
||||
if peer.PresharedKey == nil {
|
||||
p.Spec.PresharedKey = ""
|
||||
} else {
|
||||
p.Spec.PresharedKey = peer.PresharedKey.String()
|
||||
}
|
||||
p.Spec.PublicKey = peer.PublicKey.String()
|
||||
if _, err = pb.client.KiloV1alpha1().Peers().Update(context.TODO(), p, metav1.UpdateOptions{}); err != nil {
|
||||
return fmt.Errorf("failed to update peer: %v", err)
|
||||
}
|
||||
@@ -549,35 +565,3 @@ func normalizeIP(ip string) *net.IPNet {
|
||||
ipNet.IP = i.To16()
|
||||
return ipNet
|
||||
}
|
||||
|
||||
func parseEndpoint(endpoint string) *wireguard.Endpoint {
|
||||
if len(endpoint) == 0 {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(endpoint, ":")
|
||||
if len(parts) < 2 {
|
||||
return nil
|
||||
}
|
||||
portRaw := parts[len(parts)-1]
|
||||
hostRaw := strings.Trim(strings.Join(parts[:len(parts)-1], ":"), "[]")
|
||||
port, err := strconv.ParseUint(portRaw, 10, 32)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(validation.IsValidPortNum(int(port))) != 0 {
|
||||
return nil
|
||||
}
|
||||
ip := net.ParseIP(hostRaw)
|
||||
if ip == nil {
|
||||
if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 {
|
||||
return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: hostRaw}, Port: uint32(port)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
ip = ip4
|
||||
} else {
|
||||
ip = ip.To16()
|
||||
}
|
||||
return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: ip}, Port: uint32(port)}
|
||||
}
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
// Copyright 2021 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -17,8 +17,10 @@ package k8s
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
v1 "k8s.io/api/core/v1"
|
||||
|
||||
"github.com/squat/kilo/pkg/k8s/apis/kilo/v1alpha1"
|
||||
@@ -26,6 +28,30 @@ import (
|
||||
"github.com/squat/kilo/pkg/wireguard"
|
||||
)
|
||||
|
||||
func mustKey() (k wgtypes.Key) {
|
||||
var err error
|
||||
if k, err = wgtypes.GeneratePrivateKey(); err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func mustPSKKey() (key *wgtypes.Key) {
|
||||
if k, err := wgtypes.GenerateKey(); err != nil {
|
||||
panic(err.Error())
|
||||
} else {
|
||||
key = &k
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
fooKey = mustKey()
|
||||
pskKey = mustPSKKey()
|
||||
second = time.Second
|
||||
zero = time.Duration(0)
|
||||
)
|
||||
|
||||
func TestTranslateNode(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
@@ -54,8 +80,19 @@ func TestTranslateNode(t *testing.T) {
|
||||
internalIPAnnotationKey: "10.0.0.2/32",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort},
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort),
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2").To4(), Mask: net.CIDRMask(32, 32)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid ips with ipv6",
|
||||
annotations: map[string]string{
|
||||
endpointAnnotationKey: "[ff10::10]:51820",
|
||||
internalIPAnnotationKey: "ff60::10/64",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("ff10::10").To16(), mesh.DefaultKiloPort),
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("ff60::10").To16(), Mask: net.CIDRMask(64, 128)},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -68,7 +105,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
name: "normalize subnet",
|
||||
annotations: map[string]string{},
|
||||
out: &mesh.Node{
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(24, 32)},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0").To4(), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
subnet: "10.2.0.1/24",
|
||||
},
|
||||
@@ -76,7 +113,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
name: "valid subnet",
|
||||
annotations: map[string]string{},
|
||||
out: &mesh.Node{
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0").To4(), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
subnet: "10.2.1.0/24",
|
||||
},
|
||||
@@ -108,7 +145,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
forceEndpointAnnotationKey: "-10.0.0.2:51821",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -118,7 +155,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
forceEndpointAnnotationKey: "10.0.0.2:51821",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.2").To4(), 51821),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -127,7 +164,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
persistentKeepaliveKey: "25",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
PersistentKeepalive: 25,
|
||||
PersistentKeepalive: 25 * time.Second,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -137,7 +174,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
forceInternalIPAnnotationKey: "-10.1.0.2/24",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.1").To4(), Mask: net.CIDRMask(24, 32)},
|
||||
NoInternalIP: false,
|
||||
},
|
||||
},
|
||||
@@ -148,7 +185,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
forceInternalIPAnnotationKey: "10.1.0.2/24",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(24, 32)},
|
||||
NoInternalIP: false,
|
||||
},
|
||||
},
|
||||
@@ -166,7 +203,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
forceEndpointAnnotationKey: "10.0.0.2:51821",
|
||||
forceInternalIPAnnotationKey: "10.1.0.2/32",
|
||||
internalIPAnnotationKey: "10.1.0.1/32",
|
||||
keyAnnotationKey: "foo",
|
||||
keyAnnotationKey: fooKey.String(),
|
||||
lastSeenAnnotationKey: "1000000000",
|
||||
leaderAnnotationKey: "",
|
||||
locationAnnotationKey: "b",
|
||||
@@ -177,14 +214,45 @@ func TestTranslateNode(t *testing.T) {
|
||||
RegionLabelKey: "a",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.2").To4(), 51821),
|
||||
NoInternalIP: false,
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)},
|
||||
Key: []byte("foo"),
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(32, 32)},
|
||||
Key: fooKey,
|
||||
LastSeen: 1000000000,
|
||||
Leader: true,
|
||||
Location: "b",
|
||||
PersistentKeepalive: 25,
|
||||
PersistentKeepalive: 25 * time.Second,
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0").To4(), Mask: net.CIDRMask(24, 32)},
|
||||
WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1").To4(), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
subnet: "10.2.1.0/24",
|
||||
},
|
||||
{
|
||||
name: "complete with ipv6",
|
||||
annotations: map[string]string{
|
||||
endpointAnnotationKey: "10.0.0.1:51820",
|
||||
forceEndpointAnnotationKey: "[1100::10]:51821",
|
||||
forceInternalIPAnnotationKey: "10.1.0.2/32",
|
||||
internalIPAnnotationKey: "10.1.0.1/32",
|
||||
keyAnnotationKey: fooKey.String(),
|
||||
lastSeenAnnotationKey: "1000000000",
|
||||
leaderAnnotationKey: "",
|
||||
locationAnnotationKey: "b",
|
||||
persistentKeepaliveKey: "25",
|
||||
wireGuardIPAnnotationKey: "10.4.0.1/16",
|
||||
},
|
||||
labels: map[string]string{
|
||||
RegionLabelKey: "a",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("1100::10"), 51821),
|
||||
NoInternalIP: false,
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)},
|
||||
Key: fooKey,
|
||||
LastSeen: 1000000000,
|
||||
Leader: true,
|
||||
Location: "b",
|
||||
PersistentKeepalive: 25 * time.Second,
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
|
||||
WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
@@ -195,7 +263,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
annotations: map[string]string{
|
||||
endpointAnnotationKey: "10.0.0.1:51820",
|
||||
internalIPAnnotationKey: "",
|
||||
keyAnnotationKey: "foo",
|
||||
keyAnnotationKey: fooKey.String(),
|
||||
lastSeenAnnotationKey: "1000000000",
|
||||
locationAnnotationKey: "b",
|
||||
persistentKeepaliveKey: "25",
|
||||
@@ -205,13 +273,13 @@ func TestTranslateNode(t *testing.T) {
|
||||
RegionLabelKey: "a",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1"), 51820),
|
||||
InternalIP: nil,
|
||||
Key: []byte("foo"),
|
||||
Key: fooKey,
|
||||
LastSeen: 1000000000,
|
||||
Leader: false,
|
||||
Location: "b",
|
||||
PersistentKeepalive: 25,
|
||||
PersistentKeepalive: 25 * time.Second,
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
|
||||
WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
@@ -223,7 +291,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
endpointAnnotationKey: "10.0.0.1:51820",
|
||||
internalIPAnnotationKey: "10.1.0.1/32",
|
||||
forceInternalIPAnnotationKey: "",
|
||||
keyAnnotationKey: "foo",
|
||||
keyAnnotationKey: fooKey.String(),
|
||||
lastSeenAnnotationKey: "1000000000",
|
||||
locationAnnotationKey: "b",
|
||||
persistentKeepaliveKey: "25",
|
||||
@@ -233,14 +301,14 @@ func TestTranslateNode(t *testing.T) {
|
||||
RegionLabelKey: "a",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1"), 51820),
|
||||
NoInternalIP: true,
|
||||
InternalIP: nil,
|
||||
Key: []byte("foo"),
|
||||
Key: fooKey,
|
||||
LastSeen: 1000000000,
|
||||
Leader: false,
|
||||
Location: "b",
|
||||
PersistentKeepalive: 25,
|
||||
PersistentKeepalive: 25 * time.Second,
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
|
||||
WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
@@ -266,7 +334,13 @@ func TestTranslatePeer(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
out: &mesh.Peer{},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid ips",
|
||||
@@ -276,7 +350,13 @@ func TestTranslatePeer(t *testing.T) {
|
||||
"foo",
|
||||
},
|
||||
},
|
||||
out: &mesh.Peer{},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid ips",
|
||||
@@ -288,9 +368,12 @@ func TestTranslatePeer(t *testing.T) {
|
||||
},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
AllowedIPs: []*net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: []net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
|
||||
},
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -305,7 +388,13 @@ func TestTranslatePeer(t *testing.T) {
|
||||
Port: mesh.DefaultKiloPort,
|
||||
},
|
||||
},
|
||||
out: &mesh.Peer{},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "only endpoint port",
|
||||
@@ -314,7 +403,13 @@ func TestTranslatePeer(t *testing.T) {
|
||||
Port: mesh.DefaultKiloPort,
|
||||
},
|
||||
},
|
||||
out: &mesh.Peer{},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid endpoint ip",
|
||||
@@ -328,10 +423,29 @@ func TestTranslatePeer(t *testing.T) {
|
||||
},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
Endpoint: &wireguard.Endpoint{
|
||||
DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")},
|
||||
Port: mesh.DefaultKiloPort,
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid endpoint ipv6",
|
||||
spec: v1alpha1.PeerSpec{
|
||||
Endpoint: &v1alpha1.PeerEndpoint{
|
||||
DNSOrIP: v1alpha1.DNSOrIP{
|
||||
IP: "ff60::2",
|
||||
},
|
||||
Port: mesh.DefaultKiloPort,
|
||||
},
|
||||
},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("ff60::2").To16(), mesh.DefaultKiloPort),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -347,9 +461,9 @@ func TestTranslatePeer(t *testing.T) {
|
||||
},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
Endpoint: &wireguard.Endpoint{
|
||||
DNSOrIP: wireguard.DNSOrIP{DNS: "example.com"},
|
||||
Port: mesh.DefaultKiloPort,
|
||||
Endpoint: wireguard.ParseEndpoint("example.com:51820"),
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -359,16 +473,25 @@ func TestTranslatePeer(t *testing.T) {
|
||||
spec: v1alpha1.PeerSpec{
|
||||
PublicKey: "",
|
||||
},
|
||||
out: &mesh.Peer{},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid key",
|
||||
spec: v1alpha1.PeerSpec{
|
||||
PublicKey: "foo",
|
||||
PublicKey: fooKey.String(),
|
||||
},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("foo"),
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PublicKey: fooKey,
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -377,27 +500,38 @@ func TestTranslatePeer(t *testing.T) {
|
||||
spec: v1alpha1.PeerSpec{
|
||||
PersistentKeepalive: -1,
|
||||
},
|
||||
out: &mesh.Peer{},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid keepalive",
|
||||
spec: v1alpha1.PeerSpec{
|
||||
PersistentKeepalive: 1,
|
||||
PersistentKeepalive: 1 * int(time.Second),
|
||||
},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PersistentKeepalive: 1,
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &second,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid preshared key",
|
||||
spec: v1alpha1.PeerSpec{
|
||||
PresharedKey: "psk",
|
||||
PresharedKey: pskKey.String(),
|
||||
},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PresharedKey: []byte("psk"),
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
PresharedKey: pskKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -410,52 +544,3 @@ func TestTranslatePeer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEndpoint(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
endpoint string
|
||||
out *wireguard.Endpoint
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
endpoint: "",
|
||||
out: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid IP",
|
||||
endpoint: "10.0.0.:51820",
|
||||
out: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid hostname",
|
||||
endpoint: "foo-:51820",
|
||||
out: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid port",
|
||||
endpoint: "10.0.0.1:100000000",
|
||||
out: nil,
|
||||
},
|
||||
{
|
||||
name: "valid IP",
|
||||
endpoint: "10.0.0.1:51820",
|
||||
out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort},
|
||||
},
|
||||
{
|
||||
name: "valid IPv6",
|
||||
endpoint: "[ff02::114]:51820",
|
||||
out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("ff02::114")}, Port: mesh.DefaultKiloPort},
|
||||
},
|
||||
{
|
||||
name: "valid hostname",
|
||||
endpoint: "foo:51821",
|
||||
out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: "foo"}, Port: 51821},
|
||||
},
|
||||
} {
|
||||
endpoint := parseEndpoint(tc.endpoint)
|
||||
if diff := pretty.Compare(endpoint, tc.out); diff != "" {
|
||||
t.Errorf("test case %q: got diff: %v", tc.name, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -18,6 +18,8 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/squat/kilo/pkg/wireguard"
|
||||
)
|
||||
|
||||
@@ -55,7 +57,7 @@ const (
|
||||
// Node represents a node in the network.
|
||||
type Node struct {
|
||||
Endpoint *wireguard.Endpoint
|
||||
Key []byte
|
||||
Key wgtypes.Key
|
||||
NoInternalIP bool
|
||||
InternalIP *net.IPNet
|
||||
// LastSeen is a Unix time for the last time
|
||||
@@ -66,18 +68,23 @@ type Node struct {
|
||||
Leader bool
|
||||
Location string
|
||||
Name string
|
||||
PersistentKeepalive int
|
||||
PersistentKeepalive time.Duration
|
||||
Subnet *net.IPNet
|
||||
WireGuardIP *net.IPNet
|
||||
DiscoveredEndpoints map[string]*wireguard.Endpoint
|
||||
AllowedLocationIPs []*net.IPNet
|
||||
// DiscoveredEndpoints cannot be DNS endpoints, only net.UDPAddr.
|
||||
DiscoveredEndpoints map[string]*net.UDPAddr
|
||||
AllowedLocationIPs []net.IPNet
|
||||
Granularity Granularity
|
||||
}
|
||||
|
||||
// Ready indicates whether or not the node is ready.
|
||||
func (n *Node) Ready() bool {
|
||||
// Nodes that are not leaders will not have WireGuardIPs, so it is not required.
|
||||
return n != nil && n.Endpoint != nil && !(n.Endpoint.IP == nil && n.Endpoint.DNS == "") && n.Endpoint.Port != 0 && n.Key != nil && n.Subnet != nil && time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second)
|
||||
return n != nil &&
|
||||
n.Endpoint.Ready() &&
|
||||
n.Key != wgtypes.Key{} &&
|
||||
n.Subnet != nil &&
|
||||
time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second)
|
||||
}
|
||||
|
||||
// Peer represents a peer in the network.
|
||||
@@ -92,7 +99,10 @@ type Peer struct {
|
||||
// will not declare their endpoint and instead allow it to be
|
||||
// discovered.
|
||||
func (p *Peer) Ready() bool {
|
||||
return p != nil && p.AllowedIPs != nil && len(p.AllowedIPs) != 0 && p.PublicKey != nil
|
||||
return p != nil &&
|
||||
p.AllowedIPs != nil &&
|
||||
len(p.AllowedIPs) != 0 &&
|
||||
p.PublicKey != wgtypes.Key{} // If Key was not set, it will be wgtypes.Key{}.
|
||||
}
|
||||
|
||||
// EventType describes what kind of an action an event represents.
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
// Copyright 2021 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/awalterschulze/gographviz"
|
||||
|
||||
"github.com/squat/kilo/pkg/wireguard"
|
||||
)
|
||||
|
||||
@@ -166,8 +167,9 @@ func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoi
|
||||
if wgIP != nil {
|
||||
label = append(label, wgIP.String())
|
||||
}
|
||||
if endpoint != nil {
|
||||
label = append(label, endpoint.String())
|
||||
str := endpoint.String()
|
||||
if str != "" {
|
||||
label = append(label, str)
|
||||
}
|
||||
return graphEscape(strings.Join(label, "\\n"))
|
||||
}
|
||||
|
160
pkg/mesh/mesh.go
160
pkg/mesh/mesh.go
@@ -1,4 +1,4 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
// Copyright 2021 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -30,6 +30,8 @@ import (
|
||||
"github.com/go-kit/kit/log/level"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/squat/kilo/pkg/encapsulation"
|
||||
"github.com/squat/kilo/pkg/iproute"
|
||||
@@ -43,8 +45,6 @@ const (
|
||||
kiloPath = "/var/lib/kilo"
|
||||
// privateKeyPath is the filepath where the WireGuard private key is stored.
|
||||
privateKeyPath = kiloPath + "/key"
|
||||
// confPath is the filepath where the WireGuard configuration is stored.
|
||||
confPath = kiloPath + "/conf"
|
||||
)
|
||||
|
||||
// Mesh is able to create Kilo network meshes.
|
||||
@@ -60,12 +60,13 @@ type Mesh struct {
|
||||
internalIP *net.IPNet
|
||||
ipTables *iptables.Controller
|
||||
kiloIface int
|
||||
kiloIfaceName string
|
||||
key []byte
|
||||
local bool
|
||||
port uint32
|
||||
priv []byte
|
||||
port int
|
||||
priv wgtypes.Key
|
||||
privIface int
|
||||
pub []byte
|
||||
pub wgtypes.Key
|
||||
resyncPeriod time.Duration
|
||||
iptablesForwardRule bool
|
||||
stop chan struct{}
|
||||
@@ -88,23 +89,24 @@ type Mesh struct {
|
||||
}
|
||||
|
||||
// New returns a new Mesh instance.
|
||||
func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port uint32, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanUpIface bool, createIface bool, mtu uint, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, logger log.Logger) (*Mesh, error) {
|
||||
func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port int, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanUpIface bool, createIface bool, mtu uint, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, logger log.Logger) (*Mesh, error) {
|
||||
if err := os.MkdirAll(kiloPath, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
|
||||
}
|
||||
private, err := ioutil.ReadFile(privateKeyPath)
|
||||
private = bytes.Trim(private, "\n")
|
||||
privateB, err := ioutil.ReadFile(privateKeyPath)
|
||||
privateB = bytes.Trim(privateB, "\n")
|
||||
private, err := wgtypes.ParseKey(string(privateB))
|
||||
if err != nil {
|
||||
level.Warn(logger).Log("msg", "no private key found on disk; generating one now")
|
||||
if private, err = wireguard.GenKey(); err != nil {
|
||||
if private, err = wgtypes.GenerateKey(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
public, err := wireguard.PubKey(private)
|
||||
public := private.PublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ioutil.WriteFile(privateKeyPath, private, 0600); err != nil {
|
||||
if err := ioutil.WriteFile(privateKeyPath, []byte(private.String()), 0600); err != nil {
|
||||
return nil, fmt.Errorf("failed to write private key to disk: %v", err)
|
||||
}
|
||||
cniIndex, err := cniDeviceIndex()
|
||||
@@ -168,6 +170,7 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit
|
||||
internalIP: privateIP,
|
||||
ipTables: ipTables,
|
||||
kiloIface: kiloIface,
|
||||
kiloIfaceName: iface,
|
||||
nodes: make(map[string]*Node),
|
||||
peers: make(map[string]*Peer),
|
||||
port: port,
|
||||
@@ -314,7 +317,7 @@ func (m *Mesh) syncPeers(e *PeerEvent) {
|
||||
var diff bool
|
||||
m.mu.Lock()
|
||||
// Peers are indexed by public key.
|
||||
key := string(e.Peer.PublicKey)
|
||||
key := e.Peer.PublicKey.String()
|
||||
if !e.Peer.Ready() {
|
||||
// Trace non ready peer with their presence in the mesh.
|
||||
_, ok := m.peers[key]
|
||||
@@ -324,8 +327,8 @@ func (m *Mesh) syncPeers(e *PeerEvent) {
|
||||
case AddEvent:
|
||||
fallthrough
|
||||
case UpdateEvent:
|
||||
if e.Old != nil && key != string(e.Old.PublicKey) {
|
||||
delete(m.peers, string(e.Old.PublicKey))
|
||||
if e.Old != nil && key != e.Old.PublicKey.String() {
|
||||
delete(m.peers, e.Old.PublicKey.String())
|
||||
diff = true
|
||||
}
|
||||
if !peersAreEqual(m.peers[key], e.Peer) {
|
||||
@@ -367,8 +370,10 @@ func (m *Mesh) checkIn() {
|
||||
|
||||
func (m *Mesh) handleLocal(n *Node) {
|
||||
// Allow the IPs to be overridden.
|
||||
if n.Endpoint == nil || (n.Endpoint.DNS == "" && n.Endpoint.IP == nil) {
|
||||
n.Endpoint = &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: m.externalIP.IP}, Port: m.port}
|
||||
if !n.Endpoint.Ready() {
|
||||
e := wireguard.NewEndpoint(m.externalIP.IP, m.port)
|
||||
level.Info(m.logger).Log("msg", "overriding endpoint", "node", m.hostname, "old endpoint", n.Endpoint.String(), "new endpoint", e.String())
|
||||
n.Endpoint = e
|
||||
}
|
||||
if n.InternalIP == nil && !n.NoInternalIP {
|
||||
n.InternalIP = m.internalIP
|
||||
@@ -462,22 +467,26 @@ func (m *Mesh) applyTopology() {
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
// Find the old configuration.
|
||||
oldConfDump, err := wireguard.ShowDump(link.Attrs().Name)
|
||||
|
||||
wgClient, err := wgctrl.New()
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
oldConf, err := wireguard.ParseDump(oldConfDump)
|
||||
defer wgClient.Close()
|
||||
|
||||
// wgDevice is the current configuration of the wg interface.
|
||||
wgDevice, err := wgClient.Device(m.kiloIfaceName)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
|
||||
|
||||
natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger)
|
||||
nodes[m.hostname].DiscoveredEndpoints = natEndpoints
|
||||
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
|
||||
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port(), m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
@@ -489,19 +498,8 @@ func (m *Mesh) applyTopology() {
|
||||
} else {
|
||||
m.wireGuardIP = nil
|
||||
}
|
||||
conf := t.Conf()
|
||||
buf, err := conf.Bytes()
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
if err := ioutil.WriteFile(confPath, buf, 0600); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
ipRules := t.Rules(m.cni, m.iptablesForwardRule)
|
||||
|
||||
// If we are handling local routes, ensure the local
|
||||
// tunnel has an IP address and IPIP traffic is allowed.
|
||||
if m.enc.Strategy() != encapsulation.Never && m.local {
|
||||
@@ -540,10 +538,12 @@ func (m *Mesh) applyTopology() {
|
||||
}
|
||||
// Setting the WireGuard configuration interrupts existing connections
|
||||
// so only set the configuration if it has changed.
|
||||
equal := conf.Equal(oldConf)
|
||||
conf := t.Conf()
|
||||
equal, diff := conf.Equal(wgDevice)
|
||||
if !equal {
|
||||
level.Info(m.logger).Log("msg", "WireGuard configurations are different")
|
||||
if err := wireguard.SetConf(link.Attrs().Name, confPath); err != nil {
|
||||
level.Info(m.logger).Log("msg", "WireGuard configurations are different", "diff", diff)
|
||||
level.Debug(m.logger).Log("msg", "changing wg config", "config", conf.WGConfig())
|
||||
if err := wgClient.ConfigureDevice(m.kiloIfaceName, conf.WGConfig()); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
@@ -598,10 +598,6 @@ func (m *Mesh) cleanUp() {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up routes: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := os.Remove(confPath); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to delete configuration file: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if m.cleanUpIface {
|
||||
if err := iproute.RemoveInterface(m.kiloIface); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err))
|
||||
@@ -629,12 +625,8 @@ func (m *Mesh) resolveEndpoints() error {
|
||||
if !m.nodes[k].Ready() {
|
||||
continue
|
||||
}
|
||||
// If the node is ready, then the endpoint is not nil
|
||||
// but it may not have a DNS name.
|
||||
if m.nodes[k].Endpoint.DNS == "" {
|
||||
continue
|
||||
}
|
||||
if err := resolveEndpoint(m.nodes[k].Endpoint); err != nil {
|
||||
// Resolve the Endpoint
|
||||
if _, err := m.nodes[k].Endpoint.UDPAddr(true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -645,33 +637,16 @@ func (m *Mesh) resolveEndpoints() error {
|
||||
continue
|
||||
}
|
||||
// Peers may have nil endpoints.
|
||||
if m.peers[k].Endpoint == nil || m.peers[k].Endpoint.DNS == "" {
|
||||
if !m.peers[k].Endpoint.Ready() {
|
||||
continue
|
||||
}
|
||||
if err := resolveEndpoint(m.peers[k].Endpoint); err != nil {
|
||||
if _, err := m.peers[k].Endpoint.UDPAddr(true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveEndpoint(endpoint *wireguard.Endpoint) error {
|
||||
ips, err := net.LookupIP(endpoint.DNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to look up DNS name %q: %v", endpoint.DNS, err)
|
||||
}
|
||||
nets := make([]*net.IPNet, len(ips), len(ips))
|
||||
for i := range ips {
|
||||
nets[i] = oneAddressCIDR(ips[i])
|
||||
}
|
||||
sortIPs(nets)
|
||||
if len(nets) == 0 {
|
||||
return fmt.Errorf("did not find any addresses for DNS name %q", endpoint.DNS)
|
||||
}
|
||||
endpoint.IP = nets[0].IP
|
||||
return nil
|
||||
}
|
||||
|
||||
func isSelf(hostname string, node *Node) bool {
|
||||
return node != nil && node.Name == hostname
|
||||
}
|
||||
@@ -691,7 +666,18 @@ func nodesAreEqual(a, b *Node) bool {
|
||||
// Ignore LastSeen when comparing equality we want to check if the nodes are
|
||||
// equivalent. However, we do want to check if LastSeen has transitioned
|
||||
// between valid and invalid.
|
||||
return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive && discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints) && ipNetSlicesEqual(a.AllowedLocationIPs, b.AllowedLocationIPs) && a.Granularity == b.Granularity
|
||||
return a.Key.String() == b.Key.String() &&
|
||||
ipNetsEqual(a.WireGuardIP, b.WireGuardIP) &&
|
||||
ipNetsEqual(a.InternalIP, b.InternalIP) &&
|
||||
a.Leader == b.Leader &&
|
||||
a.Location == b.Location &&
|
||||
a.Name == b.Name &&
|
||||
subnetsEqual(a.Subnet, b.Subnet) &&
|
||||
a.Ready() == b.Ready() &&
|
||||
a.PersistentKeepalive == b.PersistentKeepalive &&
|
||||
discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints) &&
|
||||
ipNetSlicesEqual(a.AllowedLocationIPs, b.AllowedLocationIPs) &&
|
||||
a.Granularity == b.Granularity
|
||||
}
|
||||
|
||||
func peersAreEqual(a, b *Peer) bool {
|
||||
@@ -710,11 +696,15 @@ func peersAreEqual(a, b *Peer) bool {
|
||||
return false
|
||||
}
|
||||
for i := range a.AllowedIPs {
|
||||
if !ipNetsEqual(a.AllowedIPs[i], b.AllowedIPs[i]) {
|
||||
if !ipNetsEqual(&a.AllowedIPs[i], &b.AllowedIPs[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return string(a.PublicKey) == string(b.PublicKey) && string(a.PresharedKey) == string(b.PresharedKey) && a.PersistentKeepalive == b.PersistentKeepalive
|
||||
return a.PublicKey.String() == b.PublicKey.String() &&
|
||||
(a.PresharedKey == nil) == (b.PresharedKey == nil) &&
|
||||
(a.PresharedKey == nil || a.PresharedKey.String() == b.PresharedKey.String()) &&
|
||||
(a.PersistentKeepaliveInterval == nil) == (b.PersistentKeepaliveInterval == nil) &&
|
||||
(a.PersistentKeepaliveInterval == nil || a.PersistentKeepaliveInterval == b.PersistentKeepaliveInterval)
|
||||
}
|
||||
|
||||
func ipNetsEqual(a, b *net.IPNet) bool {
|
||||
@@ -730,12 +720,12 @@ func ipNetsEqual(a, b *net.IPNet) bool {
|
||||
return a.IP.Equal(b.IP)
|
||||
}
|
||||
|
||||
func ipNetSlicesEqual(a, b []*net.IPNet) bool {
|
||||
func ipNetSlicesEqual(a, b []net.IPNet) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if !ipNetsEqual(a[i], b[i]) {
|
||||
if !ipNetsEqual(&a[i], &b[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -761,7 +751,7 @@ func subnetsEqual(a, b *net.IPNet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
|
||||
func discoveredEndpointsAreEqual(a, b map[string]*net.UDPAddr) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
@@ -772,7 +762,7 @@ func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
|
||||
return false
|
||||
}
|
||||
for k := range a {
|
||||
if !a[k].Equal(b[k], false) {
|
||||
if a[k] != b[k] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -788,24 +778,26 @@ func linkByIndex(index int) (netlink.Link, error) {
|
||||
}
|
||||
|
||||
// discoverNATEndpoints uses the node's WireGuard configuration to returns a list of the most recently discovered endpoints for all nodes and peers behind NAT so that they can roam.
|
||||
func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf, logger log.Logger) map[string]*wireguard.Endpoint {
|
||||
natEndpoints := make(map[string]*wireguard.Endpoint)
|
||||
keys := make(map[string]*wireguard.Peer)
|
||||
// Discovered endpionts will never be DNS names, because WireGuard will always resolve them to net.UDPAddr.
|
||||
func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wgtypes.Device, logger log.Logger) map[string]*net.UDPAddr {
|
||||
natEndpoints := make(map[string]*net.UDPAddr)
|
||||
keys := make(map[string]wgtypes.Peer)
|
||||
for i := range conf.Peers {
|
||||
keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
|
||||
keys[conf.Peers[i].PublicKey.String()] = conf.Peers[i]
|
||||
}
|
||||
for _, n := range nodes {
|
||||
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
|
||||
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false), "latest-handshake", peer.LatestHandshake)
|
||||
if (peer.LatestHandshake != time.Time{}) {
|
||||
natEndpoints[string(n.Key)] = peer.Endpoint
|
||||
if peer, ok := keys[n.Key.String()]; ok && n.PersistentKeepalive != time.Duration(0) {
|
||||
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", peer.Endpoint.String() == n.Endpoint.String(), "latest-handshake", peer.LastHandshakeTime)
|
||||
// Don't update the endpoint, if there was never any handshake.
|
||||
if !peer.LastHandshakeTime.Equal(time.Time{}) {
|
||||
natEndpoints[n.Key.String()] = peer.Endpoint
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, p := range peers {
|
||||
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
|
||||
if (peer.LatestHandshake != time.Time{}) {
|
||||
natEndpoints[string(p.PublicKey)] = peer.Endpoint
|
||||
if peer, ok := keys[p.PublicKey.String()]; ok && p.PersistentKeepaliveInterval != nil {
|
||||
if !peer.LastHandshakeTime.Equal(time.Time{}) {
|
||||
natEndpoints[p.PublicKey.String()] = peer.Endpoint
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -19,9 +19,21 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/squat/kilo/pkg/wireguard"
|
||||
)
|
||||
|
||||
func mustKey() wgtypes.Key {
|
||||
if k, err := wgtypes.GeneratePrivateKey(); err != nil {
|
||||
panic(err.Error())
|
||||
} else {
|
||||
return k
|
||||
}
|
||||
}
|
||||
|
||||
var key = mustKey()
|
||||
|
||||
func TestReady(t *testing.T) {
|
||||
internalIP := oneAddressCIDR(net.ParseIP("1.1.1.1"))
|
||||
externalIP := oneAddressCIDR(net.ParseIP("2.2.2.2"))
|
||||
@@ -44,7 +56,7 @@ func TestReady(t *testing.T) {
|
||||
name: "empty endpoint",
|
||||
node: &Node{
|
||||
InternalIP: internalIP,
|
||||
Key: []byte{},
|
||||
Key: key,
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
ready: false,
|
||||
@@ -52,9 +64,9 @@ func TestReady(t *testing.T) {
|
||||
{
|
||||
name: "empty endpoint IP",
|
||||
node: &Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(nil, DefaultKiloPort),
|
||||
InternalIP: internalIP,
|
||||
Key: []byte{},
|
||||
Key: wgtypes.Key{},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
ready: false,
|
||||
@@ -62,9 +74,9 @@ func TestReady(t *testing.T) {
|
||||
{
|
||||
name: "empty endpoint port",
|
||||
node: &Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}},
|
||||
Endpoint: wireguard.NewEndpoint(externalIP.IP, 0),
|
||||
InternalIP: internalIP,
|
||||
Key: []byte{},
|
||||
Key: wgtypes.Key{},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
ready: false,
|
||||
@@ -72,8 +84,8 @@ func TestReady(t *testing.T) {
|
||||
{
|
||||
name: "empty internal IP",
|
||||
node: &Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
|
||||
Key: []byte{},
|
||||
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
|
||||
Key: wgtypes.Key{},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
ready: false,
|
||||
@@ -81,7 +93,7 @@ func TestReady(t *testing.T) {
|
||||
{
|
||||
name: "empty key",
|
||||
node: &Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
|
||||
InternalIP: internalIP,
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
@@ -90,18 +102,18 @@ func TestReady(t *testing.T) {
|
||||
{
|
||||
name: "empty subnet",
|
||||
node: &Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
|
||||
InternalIP: internalIP,
|
||||
Key: []byte{},
|
||||
Key: wgtypes.Key{},
|
||||
},
|
||||
ready: false,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
node: &Node{
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
|
||||
InternalIP: internalIP,
|
||||
Key: []byte{},
|
||||
Key: key,
|
||||
LastSeen: time.Now().Unix(),
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
|
@@ -40,7 +40,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
|
||||
var gw net.IP
|
||||
for _, segment := range t.segments {
|
||||
if segment.location == t.location {
|
||||
gw = enc.Gw(segment.endpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
|
||||
gw = enc.Gw(segment.endpoint.IP(), segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
|
||||
// we need to set routes for allowed location IPs over the leader in the current location.
|
||||
for i := range segment.allowedLocationIPs {
|
||||
routes = append(routes, encapsulateRoute(&netlink.Route{
|
||||
Dst: segment.allowedLocationIPs[i],
|
||||
Dst: &segment.allowedLocationIPs[i],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: gw,
|
||||
LinkIndex: privIface,
|
||||
@@ -125,7 +125,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
|
||||
for _, peer := range t.peers {
|
||||
for i := range peer.AllowedIPs {
|
||||
routes = append(routes, encapsulateRoute(&netlink.Route{
|
||||
Dst: peer.AllowedIPs[i],
|
||||
Dst: &peer.AllowedIPs[i],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: gw,
|
||||
LinkIndex: privIface,
|
||||
@@ -196,7 +196,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
|
||||
// equals the external IP. This means that the node
|
||||
// is only accessible through an external IP and we
|
||||
// cannot encapsulate traffic to an IP through the IP.
|
||||
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP) {
|
||||
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP()) {
|
||||
continue
|
||||
}
|
||||
// Add routes to the private IPs of nodes in other segments.
|
||||
@@ -214,7 +214,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
|
||||
// we need to set routes for allowed location IPs over the wg interface.
|
||||
for i := range segment.allowedLocationIPs {
|
||||
routes = append(routes, &netlink.Route{
|
||||
Dst: segment.allowedLocationIPs[i],
|
||||
Dst: &segment.allowedLocationIPs[i],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: segment.wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
@@ -226,7 +226,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
|
||||
for _, peer := range t.peers {
|
||||
for i := range peer.AllowedIPs {
|
||||
routes = append(routes, &netlink.Route{
|
||||
Dst: peer.AllowedIPs[i],
|
||||
Dst: &peer.AllowedIPs[i],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
})
|
||||
@@ -248,7 +248,7 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
|
||||
rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
|
||||
rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
|
||||
if cni {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
|
||||
// Some linux distros or docker will set forward DROP in the filter table.
|
||||
// To still be able to have pod to pod communication we need to ALLOW packets from and to pod CIDRs within a location.
|
||||
// Leader nodes will forward packets from all nodes within a location because they act as a gateway for them.
|
||||
@@ -258,30 +258,30 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
|
||||
if s.location == t.location {
|
||||
// Make sure packets to and from pod cidrs are not dropped in the forward chain.
|
||||
for _, c := range s.cidrs {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
|
||||
}
|
||||
// Make sure packets to and from allowed location IPs are not dropped in the forward chain.
|
||||
for _, c := range s.allowedLocationIPs {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
|
||||
}
|
||||
// Make sure packets to and from private IPs are not dropped in the forward chain.
|
||||
for _, c := range s.privateIPs {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if iptablesForwardRule {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
|
||||
}
|
||||
}
|
||||
for _, s := range t.segments {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(s.wireGuardIP)), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
|
||||
for _, aip := range s.allowedIPs {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
|
||||
}
|
||||
// Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd,
|
||||
// Otherwise packets to these destinations will reach the destination, but never find their way back.
|
||||
@@ -289,7 +289,7 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
|
||||
if t.location == s.location {
|
||||
for _, alip := range s.allowedLocationIPs {
|
||||
rules = append(rules,
|
||||
iptables.NewRule(iptables.GetProtocol(len(alip.IP)), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
|
||||
iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -297,8 +297,8 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
|
||||
for _, p := range t.peers {
|
||||
for _, aip := range p.AllowedIPs {
|
||||
rules = append(rules,
|
||||
iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
|
||||
iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
|
||||
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
|
||||
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@@ -75,7 +75,7 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["b"].AllowedLocationIPs[0],
|
||||
Dst: &nodes["b"].AllowedLocationIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
@@ -89,17 +89,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -132,17 +132,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -196,21 +196,21 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
@@ -266,24 +266,24 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["b"].AllowedLocationIPs[0],
|
||||
Dst: &nodes["b"].AllowedLocationIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["d"].Name).segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -309,7 +309,7 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["b"].AllowedLocationIPs[0],
|
||||
Dst: &nodes["b"].AllowedLocationIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
@@ -337,17 +337,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -394,17 +394,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -444,7 +444,7 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["b"].AllowedLocationIPs[0],
|
||||
Dst: &nodes["b"].AllowedLocationIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
@@ -458,17 +458,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -509,7 +509,7 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["b"].AllowedLocationIPs[0],
|
||||
Dst: &nodes["b"].AllowedLocationIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
@@ -523,17 +523,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -574,7 +574,7 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["b"].AllowedLocationIPs[0],
|
||||
Dst: &nodes["b"].AllowedLocationIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
@@ -588,17 +588,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -639,17 +639,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -698,17 +698,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -782,21 +782,21 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: privIface,
|
||||
@@ -868,21 +868,21 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: tunlIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: tunlIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nodes["b"].InternalIP.IP,
|
||||
LinkIndex: tunlIface,
|
||||
@@ -918,7 +918,7 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["b"].AllowedLocationIPs[0],
|
||||
Dst: &nodes["b"].AllowedLocationIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
@@ -946,17 +946,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -1004,17 +1004,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
@@ -1055,7 +1055,7 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: nodes["b"].AllowedLocationIPs[0],
|
||||
Dst: &nodes["b"].AllowedLocationIPs[0],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[1].wireGuardIP,
|
||||
LinkIndex: kiloIface,
|
||||
@@ -1069,17 +1069,17 @@ func TestRoutes(t *testing.T) {
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[0],
|
||||
Dst: &peers["a"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["a"].AllowedIPs[1],
|
||||
Dst: &peers["a"].AllowedIPs[1],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
{
|
||||
Dst: peers["b"].AllowedIPs[0],
|
||||
Dst: &peers["b"].AllowedIPs[0],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
},
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
// Copyright 2021 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -18,9 +18,11 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/squat/kilo/pkg/wireguard"
|
||||
)
|
||||
@@ -33,8 +35,8 @@ const (
|
||||
// Topology represents the logical structure of the overlay network.
|
||||
type Topology struct {
|
||||
// key is the private key of the node creating the topology.
|
||||
key []byte
|
||||
port uint32
|
||||
key wgtypes.Key
|
||||
port int
|
||||
// Location is the logical location of the local host.
|
||||
location string
|
||||
segments []*segment
|
||||
@@ -47,7 +49,7 @@ type Topology struct {
|
||||
leader bool
|
||||
// persistentKeepalive is the interval in seconds of the emission
|
||||
// of keepalive packets by the local node to its peers.
|
||||
persistentKeepalive int
|
||||
persistentKeepalive time.Duration
|
||||
// privateIP is the private IP address of the local node.
|
||||
privateIP *net.IPNet
|
||||
// subnet is the Pod subnet of the local node.
|
||||
@@ -59,15 +61,15 @@ type Topology struct {
|
||||
// is equal to the Kilo subnet.
|
||||
wireGuardCIDR *net.IPNet
|
||||
// discoveredEndpoints is the updated map of valid discovered Endpoints
|
||||
discoveredEndpoints map[string]*wireguard.Endpoint
|
||||
discoveredEndpoints map[string]*net.UDPAddr
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
type segment struct {
|
||||
allowedIPs []*net.IPNet
|
||||
allowedIPs []net.IPNet
|
||||
endpoint *wireguard.Endpoint
|
||||
key []byte
|
||||
persistentKeepalive int
|
||||
key wgtypes.Key
|
||||
persistentKeepalive time.Duration
|
||||
// Location is the logical location of this segment.
|
||||
location string
|
||||
|
||||
@@ -85,11 +87,11 @@ type segment struct {
|
||||
// allowedLocationIPs are not part of the cluster and are not peers.
|
||||
// They are directly routable from nodes within the segment.
|
||||
// A classic example is a printer that ought to be routable from other locations.
|
||||
allowedLocationIPs []*net.IPNet
|
||||
allowedLocationIPs []net.IPNet
|
||||
}
|
||||
|
||||
// NewTopology creates a new Topology struct from a given set of nodes and peers.
|
||||
func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int, logger log.Logger) (*Topology, error) {
|
||||
func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration, logger log.Logger) (*Topology, error) {
|
||||
if logger == nil {
|
||||
logger = log.NewNopLogger()
|
||||
}
|
||||
@@ -120,7 +122,18 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
|
||||
localLocation = nodeLocationPrefix + hostname
|
||||
}
|
||||
|
||||
t := Topology{key: key, port: port, hostname: hostname, location: localLocation, persistentKeepalive: persistentKeepalive, privateIP: nodes[hostname].InternalIP, subnet: nodes[hostname].Subnet, wireGuardCIDR: subnet, discoveredEndpoints: make(map[string]*wireguard.Endpoint), logger: logger}
|
||||
t := Topology{
|
||||
key: key,
|
||||
port: port,
|
||||
hostname: hostname,
|
||||
location: localLocation,
|
||||
persistentKeepalive: persistentKeepalive,
|
||||
privateIP: nodes[hostname].InternalIP,
|
||||
subnet: nodes[hostname].Subnet,
|
||||
wireGuardCIDR: subnet,
|
||||
discoveredEndpoints: make(map[string]*net.UDPAddr),
|
||||
logger: logger,
|
||||
}
|
||||
for location := range topoMap {
|
||||
// Sort the location so the result is stable.
|
||||
sort.Slice(topoMap[location], func(i, j int) bool {
|
||||
@@ -130,9 +143,9 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
|
||||
if location == localLocation && topoMap[location][leader].Name == hostname {
|
||||
t.leader = true
|
||||
}
|
||||
var allowedIPs []*net.IPNet
|
||||
var allowedIPs []net.IPNet
|
||||
allowedLocationIPsMap := make(map[string]struct{})
|
||||
var allowedLocationIPs []*net.IPNet
|
||||
var allowedLocationIPs []net.IPNet
|
||||
var cidrs []*net.IPNet
|
||||
var hostnames []string
|
||||
var privateIPs []net.IP
|
||||
@@ -142,7 +155,9 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
|
||||
// - the node's WireGuard IP
|
||||
// - the node's internal IP
|
||||
// - IPs that were specified by the allowed-location-ips annotation
|
||||
allowedIPs = append(allowedIPs, node.Subnet)
|
||||
if node.Subnet != nil {
|
||||
allowedIPs = append(allowedIPs, *node.Subnet)
|
||||
}
|
||||
for _, ip := range node.AllowedLocationIPs {
|
||||
if _, ok := allowedLocationIPsMap[ip.String()]; !ok {
|
||||
allowedLocationIPs = append(allowedLocationIPs, ip)
|
||||
@@ -150,7 +165,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
|
||||
}
|
||||
}
|
||||
if node.InternalIP != nil {
|
||||
allowedIPs = append(allowedIPs, oneAddressCIDR(node.InternalIP.IP))
|
||||
allowedIPs = append(allowedIPs, *oneAddressCIDR(node.InternalIP.IP))
|
||||
privateIPs = append(privateIPs, node.InternalIP.IP)
|
||||
}
|
||||
cidrs = append(cidrs, node.Subnet)
|
||||
@@ -202,7 +217,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
|
||||
return nil, errors.New("failed to allocate an IP address; ran out of IP addresses")
|
||||
}
|
||||
segment.wireGuardIP = ipNet.IP
|
||||
segment.allowedIPs = append(segment.allowedIPs, oneAddressCIDR(ipNet.IP))
|
||||
segment.allowedIPs = append(segment.allowedIPs, *oneAddressCIDR(ipNet.IP))
|
||||
if t.leader && segment.location == t.location {
|
||||
t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: subnet.Mask}
|
||||
}
|
||||
@@ -224,11 +239,11 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func intersect(n1, n2 *net.IPNet) bool {
|
||||
func intersect(n1, n2 net.IPNet) bool {
|
||||
return n1.Contains(n2.IP) || n2.Contains(n1.IP)
|
||||
}
|
||||
|
||||
func (t *Topology) filterAllowedLocationIPs(ips []*net.IPNet, location string) (ret []*net.IPNet) {
|
||||
func (t *Topology) filterAllowedLocationIPs(ips []net.IPNet, location string) (ret []net.IPNet) {
|
||||
CheckIPs:
|
||||
for _, ip := range ips {
|
||||
for _, s := range t.segments {
|
||||
@@ -270,14 +285,14 @@ CheckIPs:
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key []byte, persistentKeepalive int) *wireguard.Endpoint {
|
||||
func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *wireguard.Endpoint {
|
||||
// Do not update non-nat peers
|
||||
if persistentKeepalive == 0 {
|
||||
if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) {
|
||||
return endpoint
|
||||
}
|
||||
e, ok := t.discoveredEndpoints[string(key)]
|
||||
e, ok := t.discoveredEndpoints[key.String()]
|
||||
if ok {
|
||||
return e
|
||||
return wireguard.NewEndpointFromUDPAddr(e)
|
||||
}
|
||||
return endpoint
|
||||
}
|
||||
@@ -285,30 +300,37 @@ func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key []byte, pers
|
||||
// Conf generates a WireGuard configuration file for a given Topology.
|
||||
func (t *Topology) Conf() *wireguard.Conf {
|
||||
c := &wireguard.Conf{
|
||||
Interface: &wireguard.Interface{
|
||||
PrivateKey: t.key,
|
||||
ListenPort: t.port,
|
||||
Config: wgtypes.Config{
|
||||
PrivateKey: &t.key,
|
||||
ListenPort: &t.port,
|
||||
ReplacePeers: true,
|
||||
},
|
||||
}
|
||||
for _, s := range t.segments {
|
||||
if s.location == t.location {
|
||||
continue
|
||||
}
|
||||
peer := &wireguard.Peer{
|
||||
AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
|
||||
Endpoint: t.updateEndpoint(s.endpoint, s.key, s.persistentKeepalive),
|
||||
PersistentKeepalive: t.persistentKeepalive,
|
||||
PublicKey: s.key,
|
||||
peer := wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
|
||||
PersistentKeepaliveInterval: &t.persistentKeepalive,
|
||||
PublicKey: s.key,
|
||||
ReplaceAllowedIPs: true,
|
||||
},
|
||||
Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive),
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
for _, p := range t.peers {
|
||||
peer := &wireguard.Peer{
|
||||
AllowedIPs: p.AllowedIPs,
|
||||
Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepalive),
|
||||
PersistentKeepalive: t.persistentKeepalive,
|
||||
PresharedKey: p.PresharedKey,
|
||||
PublicKey: p.PublicKey,
|
||||
peer := wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: p.AllowedIPs,
|
||||
PersistentKeepaliveInterval: &t.persistentKeepalive,
|
||||
PresharedKey: p.PresharedKey,
|
||||
PublicKey: p.PublicKey,
|
||||
ReplaceAllowedIPs: true,
|
||||
},
|
||||
Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval),
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
@@ -322,34 +344,39 @@ func (t *Topology) AsPeer() *wireguard.Peer {
|
||||
if s.location != t.location {
|
||||
continue
|
||||
}
|
||||
return &wireguard.Peer{
|
||||
AllowedIPs: s.allowedIPs,
|
||||
Endpoint: s.endpoint,
|
||||
PublicKey: s.key,
|
||||
p := &wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: s.allowedIPs,
|
||||
PublicKey: s.key,
|
||||
},
|
||||
Endpoint: s.endpoint,
|
||||
}
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PeerConf generates a WireGuard configuration file for a given peer in a Topology.
|
||||
func (t *Topology) PeerConf(name string) *wireguard.Conf {
|
||||
var pka int
|
||||
var psk []byte
|
||||
var pka *time.Duration
|
||||
var psk *wgtypes.Key
|
||||
for i := range t.peers {
|
||||
if t.peers[i].Name == name {
|
||||
pka = t.peers[i].PersistentKeepalive
|
||||
pka = t.peers[i].PersistentKeepaliveInterval
|
||||
psk = t.peers[i].PresharedKey
|
||||
break
|
||||
}
|
||||
}
|
||||
c := &wireguard.Conf{}
|
||||
for _, s := range t.segments {
|
||||
peer := &wireguard.Peer{
|
||||
AllowedIPs: s.allowedIPs,
|
||||
Endpoint: s.endpoint,
|
||||
PersistentKeepalive: pka,
|
||||
PresharedKey: psk,
|
||||
PublicKey: s.key,
|
||||
peer := wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: s.allowedIPs,
|
||||
PersistentKeepaliveInterval: pka,
|
||||
PresharedKey: psk,
|
||||
PublicKey: s.key,
|
||||
},
|
||||
Endpoint: s.endpoint,
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
@@ -357,11 +384,13 @@ func (t *Topology) PeerConf(name string) *wireguard.Conf {
|
||||
if t.peers[i].Name == name {
|
||||
continue
|
||||
}
|
||||
peer := &wireguard.Peer{
|
||||
AllowedIPs: t.peers[i].AllowedIPs,
|
||||
PersistentKeepalive: pka,
|
||||
PublicKey: t.peers[i].PublicKey,
|
||||
Endpoint: t.peers[i].Endpoint,
|
||||
peer := wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: t.peers[i].AllowedIPs,
|
||||
PersistentKeepaliveInterval: pka,
|
||||
PublicKey: t.peers[i].PublicKey,
|
||||
},
|
||||
Endpoint: t.peers[i].Endpoint,
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
@@ -382,13 +411,13 @@ func findLeader(nodes []*Node) int {
|
||||
var leaders, public []int
|
||||
for i := range nodes {
|
||||
if nodes[i].Leader {
|
||||
if isPublic(nodes[i].Endpoint.IP) {
|
||||
if isPublic(nodes[i].Endpoint.IP()) {
|
||||
return i
|
||||
}
|
||||
leaders = append(leaders, i)
|
||||
|
||||
}
|
||||
if isPublic(nodes[i].Endpoint.IP) {
|
||||
if nodes[i].Endpoint.IP() != nil && isPublic(nodes[i].Endpoint.IP()) {
|
||||
public = append(public, i)
|
||||
}
|
||||
}
|
||||
@@ -408,10 +437,12 @@ func deduplicatePeerIPs(peers []*Peer) []*Peer {
|
||||
p := Peer{
|
||||
Name: peer.Name,
|
||||
Peer: wireguard.Peer{
|
||||
Endpoint: peer.Endpoint,
|
||||
PersistentKeepalive: peer.PersistentKeepalive,
|
||||
PresharedKey: peer.PresharedKey,
|
||||
PublicKey: peer.PublicKey,
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: peer.PersistentKeepaliveInterval,
|
||||
PresharedKey: peer.PresharedKey,
|
||||
PublicKey: peer.PublicKey,
|
||||
},
|
||||
Endpoint: peer.Endpoint,
|
||||
},
|
||||
}
|
||||
for _, ip := range peer.AllowedIPs {
|
||||
|
@@ -18,9 +18,11 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/squat/kilo/pkg/wireguard"
|
||||
)
|
||||
@@ -29,17 +31,25 @@ func allowedIPs(ips ...string) string {
|
||||
return strings.Join(ips, ", ")
|
||||
}
|
||||
|
||||
func mustParseCIDR(s string) (r *net.IPNet) {
|
||||
func mustParseCIDR(s string) (r net.IPNet) {
|
||||
if _, ip, err := net.ParseCIDR(s); err != nil {
|
||||
panic("failed to parse CIDR")
|
||||
} else {
|
||||
r = ip
|
||||
r = *ip
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) {
|
||||
key := []byte("private")
|
||||
var (
|
||||
key1 = wgtypes.Key{'k', 'e', 'y', '1'}
|
||||
key2 = wgtypes.Key{'k', 'e', 'y', '2'}
|
||||
key3 = wgtypes.Key{'k', 'e', 'y', '3'}
|
||||
key4 = wgtypes.Key{'k', 'e', 'y', '4'}
|
||||
key5 = wgtypes.Key{'k', 'e', 'y', '5'}
|
||||
)
|
||||
|
||||
func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) {
|
||||
key := wgtypes.Key{'p', 'r', 'i', 'v'}
|
||||
e1 := &net.IPNet{IP: net.ParseIP("10.1.0.1").To4(), Mask: net.CIDRMask(16, 32)}
|
||||
e2 := &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(16, 32)}
|
||||
e3 := &net.IPNet{IP: net.ParseIP("10.1.0.3").To4(), Mask: net.CIDRMask(16, 32)}
|
||||
@@ -50,62 +60,63 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) {
|
||||
nodes := map[string]*Node{
|
||||
"a": {
|
||||
Name: "a",
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
|
||||
InternalIP: i1,
|
||||
Location: "1",
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
|
||||
Key: []byte("key1"),
|
||||
Key: key1,
|
||||
PersistentKeepalive: 25,
|
||||
},
|
||||
"b": {
|
||||
Name: "b",
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
|
||||
InternalIP: i1,
|
||||
Location: "2",
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)},
|
||||
Key: []byte("key2"),
|
||||
AllowedLocationIPs: []*net.IPNet{i3},
|
||||
Key: key2,
|
||||
AllowedLocationIPs: []net.IPNet{*i3},
|
||||
},
|
||||
"c": {
|
||||
Name: "c",
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e3.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e3.IP, DefaultKiloPort),
|
||||
InternalIP: i2,
|
||||
// Same location as node b.
|
||||
Location: "2",
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.3.0"), Mask: net.CIDRMask(24, 32)},
|
||||
Key: []byte("key3"),
|
||||
Key: key3,
|
||||
},
|
||||
"d": {
|
||||
Name: "d",
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e4.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e4.IP, DefaultKiloPort),
|
||||
// Same location as node a, but without private IP
|
||||
Location: "1",
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.4.0"), Mask: net.CIDRMask(24, 32)},
|
||||
Key: []byte("key4"),
|
||||
Key: key4,
|
||||
},
|
||||
}
|
||||
peers := map[string]*Peer{
|
||||
"a": {
|
||||
Name: "a",
|
||||
Peer: wireguard.Peer{
|
||||
AllowedIPs: []*net.IPNet{
|
||||
{IP: net.ParseIP("10.5.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.5.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: []net.IPNet{
|
||||
{IP: net.ParseIP("10.5.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.5.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
PublicKey: key4,
|
||||
},
|
||||
PublicKey: []byte("key4"),
|
||||
},
|
||||
},
|
||||
"b": {
|
||||
Name: "b",
|
||||
Peer: wireguard.Peer{
|
||||
AllowedIPs: []*net.IPNet{
|
||||
{IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: []net.IPNet{
|
||||
{IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
PublicKey: key5,
|
||||
},
|
||||
Endpoint: &wireguard.Endpoint{
|
||||
DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("192.168.0.1")},
|
||||
Port: DefaultKiloPort,
|
||||
},
|
||||
PublicKey: []byte("key5"),
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("192.168.0.1"), DefaultKiloPort),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -138,7 +149,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
|
||||
segments: []*segment{
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["a"].Endpoint,
|
||||
key: nodes["a"].Key,
|
||||
persistentKeepalive: nodes["a"].PersistentKeepalive,
|
||||
@@ -149,7 +160,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["b"].Endpoint,
|
||||
key: nodes["b"].Key,
|
||||
persistentKeepalive: nodes["b"].PersistentKeepalive,
|
||||
@@ -161,7 +172,7 @@ func TestNewTopology(t *testing.T) {
|
||||
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["d"].Endpoint,
|
||||
key: nodes["d"].Key,
|
||||
persistentKeepalive: nodes["d"].PersistentKeepalive,
|
||||
@@ -189,7 +200,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
|
||||
segments: []*segment{
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["a"].Endpoint,
|
||||
key: nodes["a"].Key,
|
||||
persistentKeepalive: nodes["a"].PersistentKeepalive,
|
||||
@@ -200,7 +211,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["b"].Endpoint,
|
||||
key: nodes["b"].Key,
|
||||
persistentKeepalive: nodes["b"].PersistentKeepalive,
|
||||
@@ -212,7 +223,7 @@ func TestNewTopology(t *testing.T) {
|
||||
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["d"].Endpoint,
|
||||
key: nodes["d"].Key,
|
||||
persistentKeepalive: nodes["d"].PersistentKeepalive,
|
||||
@@ -240,7 +251,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardCIDR: DefaultKiloSubnet,
|
||||
segments: []*segment{
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["a"].Endpoint,
|
||||
key: nodes["a"].Key,
|
||||
persistentKeepalive: nodes["a"].PersistentKeepalive,
|
||||
@@ -251,7 +262,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["b"].Endpoint,
|
||||
key: nodes["b"].Key,
|
||||
persistentKeepalive: nodes["b"].PersistentKeepalive,
|
||||
@@ -263,7 +274,7 @@ func TestNewTopology(t *testing.T) {
|
||||
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["d"].Endpoint,
|
||||
key: nodes["d"].Key,
|
||||
persistentKeepalive: nodes["d"].PersistentKeepalive,
|
||||
@@ -291,7 +302,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
|
||||
segments: []*segment{
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["a"].Endpoint,
|
||||
key: nodes["a"].Key,
|
||||
persistentKeepalive: nodes["a"].PersistentKeepalive,
|
||||
@@ -302,7 +313,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["b"].Endpoint,
|
||||
key: nodes["b"].Key,
|
||||
persistentKeepalive: nodes["b"].PersistentKeepalive,
|
||||
@@ -314,7 +325,7 @@ func TestNewTopology(t *testing.T) {
|
||||
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["c"].Endpoint,
|
||||
key: nodes["c"].Key,
|
||||
persistentKeepalive: nodes["c"].PersistentKeepalive,
|
||||
@@ -325,7 +336,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardIP: w3,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["d"].Endpoint,
|
||||
key: nodes["d"].Key,
|
||||
persistentKeepalive: nodes["d"].PersistentKeepalive,
|
||||
@@ -353,7 +364,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
|
||||
segments: []*segment{
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["a"].Endpoint,
|
||||
key: nodes["a"].Key,
|
||||
persistentKeepalive: nodes["a"].PersistentKeepalive,
|
||||
@@ -364,7 +375,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["b"].Endpoint,
|
||||
key: nodes["b"].Key,
|
||||
persistentKeepalive: nodes["b"].PersistentKeepalive,
|
||||
@@ -376,7 +387,7 @@ func TestNewTopology(t *testing.T) {
|
||||
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["c"].Endpoint,
|
||||
key: nodes["c"].Key,
|
||||
persistentKeepalive: nodes["c"].PersistentKeepalive,
|
||||
@@ -387,7 +398,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardIP: w3,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["d"].Endpoint,
|
||||
key: nodes["d"].Key,
|
||||
persistentKeepalive: nodes["d"].PersistentKeepalive,
|
||||
@@ -415,7 +426,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardCIDR: &net.IPNet{IP: w3, Mask: net.CIDRMask(16, 32)},
|
||||
segments: []*segment{
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["a"].Endpoint,
|
||||
key: nodes["a"].Key,
|
||||
persistentKeepalive: nodes["a"].PersistentKeepalive,
|
||||
@@ -426,7 +437,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["b"].Endpoint,
|
||||
key: nodes["b"].Key,
|
||||
persistentKeepalive: nodes["b"].PersistentKeepalive,
|
||||
@@ -438,7 +449,7 @@ func TestNewTopology(t *testing.T) {
|
||||
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["c"].Endpoint,
|
||||
key: nodes["c"].Key,
|
||||
persistentKeepalive: nodes["c"].PersistentKeepalive,
|
||||
@@ -449,7 +460,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardIP: w3,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["d"].Endpoint,
|
||||
key: nodes["d"].Key,
|
||||
persistentKeepalive: nodes["d"].PersistentKeepalive,
|
||||
@@ -477,7 +488,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardCIDR: &net.IPNet{IP: w4, Mask: net.CIDRMask(16, 32)},
|
||||
segments: []*segment{
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["a"].Endpoint,
|
||||
key: nodes["a"].Key,
|
||||
persistentKeepalive: nodes["a"].PersistentKeepalive,
|
||||
@@ -488,7 +499,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardIP: w1,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["b"].Endpoint,
|
||||
key: nodes["b"].Key,
|
||||
persistentKeepalive: nodes["b"].PersistentKeepalive,
|
||||
@@ -500,7 +511,7 @@ func TestNewTopology(t *testing.T) {
|
||||
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["c"].Endpoint,
|
||||
key: nodes["c"].Key,
|
||||
persistentKeepalive: nodes["c"].PersistentKeepalive,
|
||||
@@ -511,7 +522,7 @@ func TestNewTopology(t *testing.T) {
|
||||
wireGuardIP: w3,
|
||||
},
|
||||
{
|
||||
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
|
||||
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
|
||||
endpoint: nodes["d"].Endpoint,
|
||||
key: nodes["d"].Key,
|
||||
persistentKeepalive: nodes["d"].PersistentKeepalive,
|
||||
@@ -539,7 +550,7 @@ func TestNewTopology(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int) *Topology {
|
||||
func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration) *Topology {
|
||||
topo, err := NewTopology(nodes, peers, granularity, hostname, port, key, subnet, persistentKeepalive, nil)
|
||||
if err != nil {
|
||||
t.Errorf("failed to generate Topology: %v", err)
|
||||
@@ -547,211 +558,6 @@ func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, gran
|
||||
return topo
|
||||
}
|
||||
|
||||
func TestConf(t *testing.T) {
|
||||
nodes, peers, key, port := setup(t)
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
topology *Topology
|
||||
result string
|
||||
}{
|
||||
{
|
||||
name: "logical from a",
|
||||
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key2
|
||||
Endpoint = 10.1.0.2:51820
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32, 192.168.178.3/32
|
||||
PersistentKeepalive = 25
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
Endpoint = 10.1.0.4:51820
|
||||
AllowedIPs = 10.2.4.0/24, 10.4.0.3/32
|
||||
PersistentKeepalive = 25
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
|
||||
PersistentKeepalive = 25
|
||||
|
||||
[Peer]
|
||||
PublicKey = key5
|
||||
Endpoint = 192.168.0.1:51820
|
||||
AllowedIPs = 10.5.0.3/24
|
||||
PersistentKeepalive = 25
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "logical from b",
|
||||
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet, nodes["b"].PersistentKeepalive),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key1
|
||||
Endpoint = 10.1.0.1:51820
|
||||
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
Endpoint = 10.1.0.4:51820
|
||||
AllowedIPs = 10.2.4.0/24, 10.4.0.3/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
|
||||
|
||||
[Peer]
|
||||
PublicKey = key5
|
||||
Endpoint = 192.168.0.1:51820
|
||||
AllowedIPs = 10.5.0.3/24
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "logical from c",
|
||||
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet, nodes["c"].PersistentKeepalive),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key1
|
||||
Endpoint = 10.1.0.1:51820
|
||||
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
Endpoint = 10.1.0.4:51820
|
||||
AllowedIPs = 10.2.4.0/24, 10.4.0.3/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
|
||||
|
||||
[Peer]
|
||||
PublicKey = key5
|
||||
Endpoint = 192.168.0.1:51820
|
||||
AllowedIPs = 10.5.0.3/24
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "full from a",
|
||||
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key2
|
||||
Endpoint = 10.1.0.2:51820
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32, 192.168.178.3/32
|
||||
PersistentKeepalive = 25
|
||||
|
||||
[Peer]
|
||||
PublicKey = key3
|
||||
Endpoint = 10.1.0.3:51820
|
||||
AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32
|
||||
PersistentKeepalive = 25
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
Endpoint = 10.1.0.4:51820
|
||||
AllowedIPs = 10.2.4.0/24, 10.4.0.4/32
|
||||
PersistentKeepalive = 25
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
|
||||
PersistentKeepalive = 25
|
||||
|
||||
[Peer]
|
||||
PublicKey = key5
|
||||
Endpoint = 192.168.0.1:51820
|
||||
AllowedIPs = 10.5.0.3/24
|
||||
PersistentKeepalive = 25
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "full from b",
|
||||
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet, nodes["b"].PersistentKeepalive),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key1
|
||||
Endpoint = 10.1.0.1:51820
|
||||
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key3
|
||||
Endpoint = 10.1.0.3:51820
|
||||
AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
Endpoint = 10.1.0.4:51820
|
||||
AllowedIPs = 10.2.4.0/24, 10.4.0.4/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
|
||||
|
||||
[Peer]
|
||||
PublicKey = key5
|
||||
Endpoint = 192.168.0.1:51820
|
||||
AllowedIPs = 10.5.0.3/24
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "full from c",
|
||||
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet, nodes["c"].PersistentKeepalive),
|
||||
result: `[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key1
|
||||
Endpoint = 10.1.0.1:51820
|
||||
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key2
|
||||
Endpoint = 10.1.0.2:51820
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32, 192.168.178.3/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
Endpoint = 10.1.0.4:51820
|
||||
AllowedIPs = 10.2.4.0/24, 10.4.0.4/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = key4
|
||||
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
|
||||
|
||||
[Peer]
|
||||
PublicKey = key5
|
||||
Endpoint = 192.168.0.1:51820
|
||||
AllowedIPs = 10.5.0.3/24
|
||||
`,
|
||||
},
|
||||
} {
|
||||
conf := tc.topology.Conf()
|
||||
if !conf.Equal(wireguard.Parse([]byte(tc.result))) {
|
||||
buf, err := conf.Bytes()
|
||||
if err != nil {
|
||||
t.Errorf("test case %q: failed to render conf: %v", tc.name, err)
|
||||
}
|
||||
t.Errorf("test case %q: expected %s got %s", tc.name, tc.result, string(buf))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindLeader(t *testing.T) {
|
||||
ip, e1, err := net.ParseCIDR("10.0.0.1/32")
|
||||
if err != nil {
|
||||
@@ -767,24 +573,24 @@ func TestFindLeader(t *testing.T) {
|
||||
nodes := []*Node{
|
||||
{
|
||||
Name: "a",
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
|
||||
},
|
||||
{
|
||||
Name: "b",
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
|
||||
},
|
||||
{
|
||||
Name: "c",
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
|
||||
},
|
||||
{
|
||||
Name: "d",
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
|
||||
Leader: true,
|
||||
},
|
||||
{
|
||||
Name: "2",
|
||||
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
|
||||
Leader: true,
|
||||
},
|
||||
}
|
||||
@@ -840,31 +646,38 @@ func TestDeduplicatePeerIPs(t *testing.T) {
|
||||
p1 := &Peer{
|
||||
Name: "1",
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("key1"),
|
||||
AllowedIPs: []*net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
|
||||
PublicKey: key1,
|
||||
AllowedIPs: []net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
p2 := &Peer{
|
||||
Name: "2",
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("key2"),
|
||||
AllowedIPs: []*net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PublicKey: key2,
|
||||
AllowedIPs: []net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
p3 := &Peer{
|
||||
Name: "3",
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("key3"),
|
||||
AllowedIPs: []*net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PublicKey: key3,
|
||||
AllowedIPs: []net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -872,10 +685,12 @@ func TestDeduplicatePeerIPs(t *testing.T) {
|
||||
p4 := &Peer{
|
||||
Name: "4",
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("key4"),
|
||||
AllowedIPs: []*net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PublicKey: key4,
|
||||
AllowedIPs: []net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -898,9 +713,11 @@ func TestDeduplicatePeerIPs(t *testing.T) {
|
||||
{
|
||||
Name: "2",
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("key2"),
|
||||
AllowedIPs: []*net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PublicKey: key2,
|
||||
AllowedIPs: []net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -914,9 +731,11 @@ func TestDeduplicatePeerIPs(t *testing.T) {
|
||||
{
|
||||
Name: "1",
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("key1"),
|
||||
AllowedIPs: []*net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PublicKey: key1,
|
||||
AllowedIPs: []net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -930,19 +749,25 @@ func TestDeduplicatePeerIPs(t *testing.T) {
|
||||
{
|
||||
Name: "2",
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("key2"),
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PublicKey: key2,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "1",
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("key1"),
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PublicKey: key1,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "4",
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("key4"),
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PublicKey: key4,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -954,19 +779,23 @@ func TestDeduplicatePeerIPs(t *testing.T) {
|
||||
{
|
||||
Name: "4",
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("key4"),
|
||||
AllowedIPs: []*net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PublicKey: key4,
|
||||
AllowedIPs: []net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "1",
|
||||
Peer: wireguard.Peer{
|
||||
PublicKey: []byte("key1"),
|
||||
AllowedIPs: []*net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PublicKey: key1,
|
||||
AllowedIPs: []net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -985,12 +814,12 @@ func TestFilterAllowedIPs(t *testing.T) {
|
||||
topo := mustTopo(t, nodes, peers, LogicalGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive)
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
allowedLocationIPs map[int][]*net.IPNet
|
||||
result map[int][]*net.IPNet
|
||||
allowedLocationIPs map[int][]net.IPNet
|
||||
result map[int][]net.IPNet
|
||||
}{
|
||||
{
|
||||
name: "nothing to filter",
|
||||
allowedLocationIPs: map[int][]*net.IPNet{
|
||||
allowedLocationIPs: map[int][]net.IPNet{
|
||||
0: {
|
||||
mustParseCIDR("192.168.178.4/32"),
|
||||
},
|
||||
@@ -1002,7 +831,7 @@ func TestFilterAllowedIPs(t *testing.T) {
|
||||
mustParseCIDR("192.168.178.7/32"),
|
||||
},
|
||||
},
|
||||
result: map[int][]*net.IPNet{
|
||||
result: map[int][]net.IPNet{
|
||||
0: {
|
||||
mustParseCIDR("192.168.178.4/32"),
|
||||
},
|
||||
@@ -1017,7 +846,7 @@ func TestFilterAllowedIPs(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "intersections between segments",
|
||||
allowedLocationIPs: map[int][]*net.IPNet{
|
||||
allowedLocationIPs: map[int][]net.IPNet{
|
||||
0: {
|
||||
mustParseCIDR("192.168.178.4/32"),
|
||||
mustParseCIDR("192.168.178.8/32"),
|
||||
@@ -1031,7 +860,7 @@ func TestFilterAllowedIPs(t *testing.T) {
|
||||
mustParseCIDR("192.168.178.4/32"),
|
||||
},
|
||||
},
|
||||
result: map[int][]*net.IPNet{
|
||||
result: map[int][]net.IPNet{
|
||||
0: {
|
||||
mustParseCIDR("192.168.178.8/32"),
|
||||
},
|
||||
@@ -1047,7 +876,7 @@ func TestFilterAllowedIPs(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "intersections with wireGuardCIDR",
|
||||
allowedLocationIPs: map[int][]*net.IPNet{
|
||||
allowedLocationIPs: map[int][]net.IPNet{
|
||||
0: {
|
||||
mustParseCIDR("10.4.0.1/32"),
|
||||
mustParseCIDR("192.168.178.8/32"),
|
||||
@@ -1060,7 +889,7 @@ func TestFilterAllowedIPs(t *testing.T) {
|
||||
mustParseCIDR("192.168.178.7/32"),
|
||||
},
|
||||
},
|
||||
result: map[int][]*net.IPNet{
|
||||
result: map[int][]net.IPNet{
|
||||
0: {
|
||||
mustParseCIDR("192.168.178.8/32"),
|
||||
},
|
||||
@@ -1075,7 +904,7 @@ func TestFilterAllowedIPs(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "intersections with more than one allowedLocationIPs",
|
||||
allowedLocationIPs: map[int][]*net.IPNet{
|
||||
allowedLocationIPs: map[int][]net.IPNet{
|
||||
0: {
|
||||
mustParseCIDR("192.168.178.8/32"),
|
||||
},
|
||||
@@ -1086,7 +915,7 @@ func TestFilterAllowedIPs(t *testing.T) {
|
||||
mustParseCIDR("192.168.178.7/24"),
|
||||
},
|
||||
},
|
||||
result: map[int][]*net.IPNet{
|
||||
result: map[int][]net.IPNet{
|
||||
0: {},
|
||||
1: {},
|
||||
2: {
|
||||
|
@@ -15,16 +15,15 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"k8s.io/apimachinery/pkg/util/validation"
|
||||
)
|
||||
|
||||
@@ -32,10 +31,6 @@ type section string
|
||||
type key string
|
||||
|
||||
const (
|
||||
separator = "="
|
||||
dumpSeparator = "\t"
|
||||
dumpNone = "(none)"
|
||||
dumpOff = "off"
|
||||
interfaceSection section = "Interface"
|
||||
peerSection section = "Peer"
|
||||
listenPortKey key = "ListenPort"
|
||||
@@ -47,56 +42,209 @@ const (
|
||||
publicKeyKey key = "PublicKey"
|
||||
)
|
||||
|
||||
type dumpInterfaceIndex int
|
||||
|
||||
const (
|
||||
dumpInterfacePrivateKeyIndex = iota
|
||||
dumpInterfacePublicKeyIndex
|
||||
dumpInterfaceListenPortIndex
|
||||
dumpInterfaceFWMarkIndex
|
||||
dumpInterfaceLen
|
||||
)
|
||||
|
||||
type dumpPeerIndex int
|
||||
|
||||
const (
|
||||
dumpPeerPublicKeyIndex = iota
|
||||
dumpPeerPresharedKeyIndex
|
||||
dumpPeerEndpointIndex
|
||||
dumpPeerAllowedIPsIndex
|
||||
dumpPeerLatestHandshakeIndex
|
||||
dumpPeerTransferRXIndex
|
||||
dumpPeerTransferTXIndex
|
||||
dumpPeerPersistentKeepaliveIndex
|
||||
dumpPeerLen
|
||||
)
|
||||
|
||||
// Conf represents a WireGuard configuration file.
|
||||
type Conf struct {
|
||||
Interface *Interface
|
||||
Peers []*Peer
|
||||
wgtypes.Config
|
||||
// The Peers field is shadowed because every Peer needs the Endpoint field that contains a DNS endpoint.
|
||||
Peers []Peer
|
||||
}
|
||||
|
||||
// Interface represents the `interface` section of a WireGuard configuration.
|
||||
type Interface struct {
|
||||
ListenPort uint32
|
||||
PrivateKey []byte
|
||||
// WGConfig returns a wgytpes.Config from a Conf.
|
||||
func (c *Conf) WGConfig() wgtypes.Config {
|
||||
if c == nil {
|
||||
// The empty Config will do nothing, when applied.
|
||||
return wgtypes.Config{}
|
||||
}
|
||||
r := c.Config
|
||||
wgPs := make([]wgtypes.PeerConfig, len(c.Peers))
|
||||
for i, p := range c.Peers {
|
||||
wgPs[i] = p.PeerConfig
|
||||
if p.Endpoint.Resolved() {
|
||||
// We can ingore the error because we already checked if the Endpoint was resolved in the above line.
|
||||
wgPs[i].Endpoint, _ = p.Endpoint.UDPAddr(false)
|
||||
}
|
||||
wgPs[i].ReplaceAllowedIPs = true
|
||||
}
|
||||
r.Peers = wgPs
|
||||
r.ReplacePeers = true
|
||||
return r
|
||||
}
|
||||
|
||||
// Endpoint represents a WireGuard endpoint.
|
||||
type Endpoint struct {
|
||||
udpAddr *net.UDPAddr
|
||||
addr string
|
||||
}
|
||||
|
||||
// ParseEndpoint returns an Endpoint from a string.
|
||||
// The input should look like "10.0.0.0:100", "[ff10::10]:100"
|
||||
// or "example.com:100".
|
||||
func ParseEndpoint(endpoint string) *Endpoint {
|
||||
if len(endpoint) == 0 {
|
||||
return nil
|
||||
}
|
||||
hostRaw, portRaw, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
port, err := strconv.ParseUint(portRaw, 10, 32)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(validation.IsValidPortNum(int(port))) != 0 {
|
||||
return nil
|
||||
}
|
||||
ip := net.ParseIP(hostRaw)
|
||||
if ip == nil {
|
||||
if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 {
|
||||
return &Endpoint{
|
||||
addr: endpoint,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// ResolveUDPAddr will not resolve the endpoint as long as a valid IP and port is given.
|
||||
// This should be the case here.
|
||||
u, err := net.ResolveUDPAddr("udp", endpoint)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
u.IP = cutIP(u.IP)
|
||||
return &Endpoint{
|
||||
udpAddr: u,
|
||||
}
|
||||
}
|
||||
|
||||
// NewEndpointFromUDPAddr returns an Endpoint from a net.UDPAddr.
|
||||
func NewEndpointFromUDPAddr(u *net.UDPAddr) *Endpoint {
|
||||
if u != nil {
|
||||
u.IP = cutIP(u.IP)
|
||||
}
|
||||
return &Endpoint{
|
||||
udpAddr: u,
|
||||
}
|
||||
}
|
||||
|
||||
// NewEndpoint returns an Endpoint from a net.IP and port.
|
||||
func NewEndpoint(ip net.IP, port int) *Endpoint {
|
||||
return &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: cutIP(ip),
|
||||
Port: port,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Ready return true, if the Enpoint is ready.
|
||||
// Ready means that an IP or DN and port exists.
|
||||
func (e *Endpoint) Ready() bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
return (e.udpAddr != nil && e.udpAddr.IP != nil && e.udpAddr.Port > 0) || len(e.addr) > 0
|
||||
}
|
||||
|
||||
// Port returns the port of the Endpoint.
|
||||
func (e *Endpoint) Port() int {
|
||||
if !e.Ready() {
|
||||
return 0
|
||||
}
|
||||
if e.udpAddr != nil {
|
||||
return e.udpAddr.Port
|
||||
}
|
||||
// We can ignore the errors here bacause the returned port will be "".
|
||||
// This will result to Port 0 after the conversion to and int.
|
||||
_, p, _ := net.SplitHostPort(e.addr)
|
||||
port, _ := strconv.ParseUint(p, 10, 32)
|
||||
return int(port)
|
||||
}
|
||||
|
||||
// HasDNS returns true if the endpoint has a DN.
|
||||
func (e *Endpoint) HasDNS() bool {
|
||||
return e != nil && e.addr != ""
|
||||
}
|
||||
|
||||
// DNS returns the DN of the Endpoint.
|
||||
func (e *Endpoint) DNS() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
_, s, _ := net.SplitHostPort(e.addr)
|
||||
return s
|
||||
}
|
||||
|
||||
// Resolved returns true, if the DN of the Endpoint was resolved
|
||||
// or if the Endpoint has a resolved endpoint.
|
||||
func (e *Endpoint) Resolved() bool {
|
||||
return e != nil && e.udpAddr != nil
|
||||
}
|
||||
|
||||
// UDPAddr returns the UDPAddr of the Endpoint. If resolve is false,
|
||||
// UDPAddr() will not try to resolve a DN name, if the Endpoint is not yet resolved.
|
||||
func (e *Endpoint) UDPAddr(resolve bool) (*net.UDPAddr, error) {
|
||||
if !e.Ready() {
|
||||
return nil, errors.New("Enpoint is not ready")
|
||||
}
|
||||
if e.udpAddr != nil {
|
||||
// Make a copy of the UDPAddr to protect it from modification outside this package.
|
||||
h := *e.udpAddr
|
||||
return &h, nil
|
||||
}
|
||||
if !resolve {
|
||||
return nil, errors.New("Endpoint is not resolved")
|
||||
}
|
||||
var err error
|
||||
if e.udpAddr, err = net.ResolveUDPAddr("udp", e.addr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Make a copy of the UDPAddr to protect it from modification outside this package.
|
||||
h := *e.udpAddr
|
||||
return &h, nil
|
||||
}
|
||||
|
||||
// IP returns the IP address of the Enpoint or nil.
|
||||
func (e *Endpoint) IP() net.IP {
|
||||
if !e.Resolved() {
|
||||
return nil
|
||||
}
|
||||
return e.udpAddr.IP
|
||||
}
|
||||
|
||||
// String will return the endpoint as a string.
|
||||
// If a DN exists, it will take prcedence over the resolved endpoint.
|
||||
func (e *Endpoint) String() string {
|
||||
return e.StringOpt(true)
|
||||
}
|
||||
|
||||
// StringOpt will return the string of the Endpoint.
|
||||
// If dnsFirst is false, the resolved Endpoint will
|
||||
// take precedence over the DN.
|
||||
func (e *Endpoint) StringOpt(dnsFirst bool) string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.udpAddr != nil && (!dnsFirst || e.addr == "") {
|
||||
return e.udpAddr.String()
|
||||
}
|
||||
return e.addr
|
||||
}
|
||||
|
||||
// Equal will return true, if the Enpoints are equal.
|
||||
// If dnsFirst is false, the DN will only be compared if
|
||||
// the IPs are nil.
|
||||
func (e *Endpoint) Equal(b *Endpoint, dnsFirst bool) bool {
|
||||
return e.StringOpt(dnsFirst) == b.StringOpt(dnsFirst)
|
||||
}
|
||||
|
||||
// Peer represents a `peer` section of a WireGuard configuration.
|
||||
type Peer struct {
|
||||
AllowedIPs []*net.IPNet
|
||||
Endpoint *Endpoint
|
||||
PersistentKeepalive int
|
||||
PresharedKey []byte
|
||||
PublicKey []byte
|
||||
// The following fields are part of the runtime information, not the configuration.
|
||||
LatestHandshake time.Time
|
||||
wgtypes.PeerConfig
|
||||
Endpoint *Endpoint
|
||||
}
|
||||
|
||||
// DeduplicateIPs eliminates duplicate allowed IPs.
|
||||
func (p *Peer) DeduplicateIPs() {
|
||||
var ips []*net.IPNet
|
||||
var ips []net.IPNet
|
||||
seen := make(map[string]struct{})
|
||||
for _, ip := range p.AllowedIPs {
|
||||
if _, ok := seen[ip.String()]; ok {
|
||||
@@ -108,181 +256,27 @@ func (p *Peer) DeduplicateIPs() {
|
||||
p.AllowedIPs = ips
|
||||
}
|
||||
|
||||
// Endpoint represents an `endpoint` key of a `peer` section.
|
||||
type Endpoint struct {
|
||||
DNSOrIP
|
||||
Port uint32
|
||||
}
|
||||
|
||||
// String prints the string representation of the endpoint.
|
||||
func (e *Endpoint) String() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
dnsOrIP := e.DNSOrIP.String()
|
||||
if e.IP != nil && len(e.IP) == net.IPv6len {
|
||||
dnsOrIP = "[" + dnsOrIP + "]"
|
||||
}
|
||||
return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
|
||||
}
|
||||
|
||||
// Equal compares two endpoints.
|
||||
func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool {
|
||||
if (e == nil) != (b == nil) {
|
||||
return false
|
||||
}
|
||||
if e != nil {
|
||||
if e.Port != b.Port {
|
||||
return false
|
||||
}
|
||||
if DNSFirst {
|
||||
// Check the DNS name first if it was resolved.
|
||||
if e.DNS != b.DNS {
|
||||
return false
|
||||
}
|
||||
if e.DNS == "" && !e.IP.Equal(b.IP) {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
// IPs take priority, so check them first.
|
||||
if !e.IP.Equal(b.IP) {
|
||||
return false
|
||||
}
|
||||
// Only check the DNS name if the IP is empty.
|
||||
if e.IP == nil && e.DNS != b.DNS {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// DNSOrIP represents either a DNS name or an IP address.
|
||||
// IPs, as they are more specific, are preferred.
|
||||
type DNSOrIP struct {
|
||||
DNS string
|
||||
IP net.IP
|
||||
}
|
||||
|
||||
// String prints the string representation of the struct.
|
||||
func (d DNSOrIP) String() string {
|
||||
if d.IP != nil {
|
||||
return d.IP.String()
|
||||
}
|
||||
return d.DNS
|
||||
}
|
||||
|
||||
// Parse parses a given WireGuard configuration file and produces a Conf struct.
|
||||
func Parse(buf []byte) *Conf {
|
||||
var (
|
||||
active section
|
||||
kv []string
|
||||
c Conf
|
||||
err error
|
||||
iface *Interface
|
||||
i int
|
||||
k key
|
||||
line, v string
|
||||
peer *Peer
|
||||
port uint64
|
||||
)
|
||||
s := bufio.NewScanner(bytes.NewBuffer(buf))
|
||||
for s.Scan() {
|
||||
line = strings.TrimSpace(s.Text())
|
||||
// Skip comments.
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
// Line is a section title.
|
||||
if strings.HasPrefix(line, "[") {
|
||||
if peer != nil {
|
||||
c.Peers = append(c.Peers, peer)
|
||||
peer = nil
|
||||
}
|
||||
if iface != nil {
|
||||
c.Interface = iface
|
||||
iface = nil
|
||||
}
|
||||
active = section(strings.TrimSpace(strings.Trim(line, "[]")))
|
||||
switch active {
|
||||
case interfaceSection:
|
||||
iface = new(Interface)
|
||||
case peerSection:
|
||||
peer = new(Peer)
|
||||
}
|
||||
continue
|
||||
}
|
||||
kv = strings.SplitN(line, separator, 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
k = key(strings.TrimSpace(kv[0]))
|
||||
v = strings.TrimSpace(kv[1])
|
||||
switch active {
|
||||
case interfaceSection:
|
||||
switch k {
|
||||
case listenPortKey:
|
||||
port, err = strconv.ParseUint(v, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
iface.ListenPort = uint32(port)
|
||||
case privateKeyKey:
|
||||
iface.PrivateKey = []byte(v)
|
||||
}
|
||||
case peerSection:
|
||||
switch k {
|
||||
case allowedIPsKey:
|
||||
err = peer.parseAllowedIPs(v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
case endpointKey:
|
||||
err = peer.parseEndpoint(v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
case persistentKeepaliveKey:
|
||||
i, err = strconv.Atoi(v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
peer.PersistentKeepalive = i
|
||||
case presharedKeyKey:
|
||||
peer.PresharedKey = []byte(v)
|
||||
case publicKeyKey:
|
||||
peer.PublicKey = []byte(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if peer != nil {
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
if iface != nil {
|
||||
c.Interface = iface
|
||||
}
|
||||
return &c
|
||||
}
|
||||
|
||||
// Bytes renders a WireGuard configuration to bytes.
|
||||
func (c *Conf) Bytes() ([]byte, error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
var err error
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 512))
|
||||
if c.Interface != nil {
|
||||
if c.PrivateKey != nil {
|
||||
if err = writeSection(buf, interfaceSection); err != nil {
|
||||
return nil, fmt.Errorf("failed to write interface: %v", err)
|
||||
}
|
||||
if err = writePKey(buf, privateKeyKey, c.Interface.PrivateKey); err != nil {
|
||||
if err = writePKey(buf, privateKeyKey, c.PrivateKey); err != nil {
|
||||
return nil, fmt.Errorf("failed to write private key: %v", err)
|
||||
}
|
||||
if err = writeValue(buf, listenPortKey, strconv.FormatUint(uint64(c.Interface.ListenPort), 10)); err != nil {
|
||||
if err = writeValue(buf, listenPortKey, strconv.Itoa(*c.ListenPort)); err != nil {
|
||||
return nil, fmt.Errorf("failed to write listen port: %v", err)
|
||||
}
|
||||
}
|
||||
for i, p := range c.Peers {
|
||||
// Add newlines to make the formatting nicer.
|
||||
if i == 0 && c.Interface != nil || i != 0 {
|
||||
if i == 0 && c.PrivateKey != nil || i != 0 {
|
||||
if err = buf.WriteByte('\n'); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -297,71 +291,103 @@ func (c *Conf) Bytes() ([]byte, error) {
|
||||
if err = writeEndpoint(buf, p.Endpoint); err != nil {
|
||||
return nil, fmt.Errorf("failed to write endpoint: %v", err)
|
||||
}
|
||||
if err = writeValue(buf, persistentKeepaliveKey, strconv.Itoa(p.PersistentKeepalive)); err != nil {
|
||||
if p.PersistentKeepaliveInterval == nil {
|
||||
p.PersistentKeepaliveInterval = new(time.Duration)
|
||||
}
|
||||
if err = writeValue(buf, persistentKeepaliveKey, strconv.FormatUint(uint64(*p.PersistentKeepaliveInterval/time.Second), 10)); err != nil {
|
||||
return nil, fmt.Errorf("failed to write persistent keepalive: %v", err)
|
||||
}
|
||||
if err = writePKey(buf, presharedKeyKey, p.PresharedKey); err != nil {
|
||||
return nil, fmt.Errorf("failed to write preshared key: %v", err)
|
||||
}
|
||||
if err = writePKey(buf, publicKeyKey, p.PublicKey); err != nil {
|
||||
if err = writePKey(buf, publicKeyKey, &p.PublicKey); err != nil {
|
||||
return nil, fmt.Errorf("failed to write public key: %v", err)
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Equal checks if two WireGuard configurations are equivalent.
|
||||
func (c *Conf) Equal(b *Conf) bool {
|
||||
if (c.Interface == nil) != (b.Interface == nil) {
|
||||
return false
|
||||
// Equal returns true if the Conf and wgtypes.Device are equal.
|
||||
func (c *Conf) Equal(d *wgtypes.Device) (bool, string) {
|
||||
if c == nil || d == nil {
|
||||
return c == nil && d == nil, "nil values"
|
||||
}
|
||||
if c.Interface != nil {
|
||||
if c.Interface.ListenPort != b.Interface.ListenPort || !bytes.Equal(c.Interface.PrivateKey, b.Interface.PrivateKey) {
|
||||
return false
|
||||
}
|
||||
if c.ListenPort == nil || *c.ListenPort != d.ListenPort {
|
||||
return false, fmt.Sprintf("port: old=%q, new=\"%v\"", d.ListenPort, c.ListenPort)
|
||||
}
|
||||
if len(c.Peers) != len(b.Peers) {
|
||||
return false
|
||||
if c.PrivateKey == nil || *c.PrivateKey != d.PrivateKey {
|
||||
return false, fmt.Sprintf("private key: old=\"%s...\", new=\"%s\"", d.PrivateKey.String()[0:5], c.PrivateKey.String()[0:5])
|
||||
}
|
||||
if len(c.Peers) != len(d.Peers) {
|
||||
return false, fmt.Sprintf("number of peers: old=%d, new=%d", len(d.Peers), len(c.Peers))
|
||||
}
|
||||
sortPeerConfigs(d.Peers)
|
||||
sortPeers(c.Peers)
|
||||
sortPeers(b.Peers)
|
||||
for i := range c.Peers {
|
||||
if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) {
|
||||
return false
|
||||
if len(c.Peers[i].AllowedIPs) != len(d.Peers[i].AllowedIPs) {
|
||||
return false, fmt.Sprintf("Peer %d allowed IP length: old=%d, new=%d", i, len(d.Peers[i].AllowedIPs), len(c.Peers[i].AllowedIPs))
|
||||
}
|
||||
sortCIDRs(c.Peers[i].AllowedIPs)
|
||||
sortCIDRs(b.Peers[i].AllowedIPs)
|
||||
sortCIDRs(d.Peers[i].AllowedIPs)
|
||||
for j := range c.Peers[i].AllowedIPs {
|
||||
if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() {
|
||||
return false
|
||||
if c.Peers[i].AllowedIPs[j].String() != d.Peers[i].AllowedIPs[j].String() {
|
||||
return false, fmt.Sprintf("Peer %d allowed IP: old=%q, new=%q", i, d.Peers[i].AllowedIPs[j].String(), c.Peers[i].AllowedIPs[j].String())
|
||||
}
|
||||
}
|
||||
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint, false) {
|
||||
return false
|
||||
if c.Peers[i].Endpoint == nil || d.Peers[i].Endpoint == nil {
|
||||
return c.Peers[i].Endpoint == nil && d.Peers[i].Endpoint == nil, "peer endpoints: nil value"
|
||||
}
|
||||
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
|
||||
return false
|
||||
if c.Peers[i].Endpoint.StringOpt(false) != d.Peers[i].Endpoint.String() {
|
||||
return false, fmt.Sprintf("Peer %d endpoint: old=%q, new=%q", i, d.Peers[i].Endpoint.String(), c.Peers[i].Endpoint.StringOpt(false))
|
||||
}
|
||||
|
||||
pki := time.Duration(0)
|
||||
if p := c.Peers[i].PersistentKeepaliveInterval; p != nil {
|
||||
pki = *p
|
||||
}
|
||||
psk := wgtypes.Key{}
|
||||
if p := c.Peers[i].PresharedKey; p != nil {
|
||||
psk = *p
|
||||
}
|
||||
if pki != d.Peers[i].PersistentKeepaliveInterval || psk != d.Peers[i].PresharedKey || c.Peers[i].PublicKey != d.Peers[i].PublicKey {
|
||||
return false, "persistent keepalive or pershared key"
|
||||
}
|
||||
}
|
||||
return true
|
||||
return true, ""
|
||||
}
|
||||
|
||||
func sortPeers(peers []*Peer) {
|
||||
func sortPeerConfigs(peers []wgtypes.Peer) {
|
||||
sort.Slice(peers, func(i, j int) bool {
|
||||
if bytes.Compare(peers[i].PublicKey, peers[j].PublicKey) < 0 {
|
||||
if peers[i].PublicKey.String() < peers[j].PublicKey.String() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func sortCIDRs(cidrs []*net.IPNet) {
|
||||
func sortPeers(peers []Peer) {
|
||||
sort.Slice(peers, func(i, j int) bool {
|
||||
if peers[i].PublicKey.String() < peers[j].PublicKey.String() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func sortCIDRs(cidrs []net.IPNet) {
|
||||
sort.Slice(cidrs, func(i, j int) bool {
|
||||
return cidrs[i].String() < cidrs[j].String()
|
||||
})
|
||||
}
|
||||
|
||||
func writeAllowedIPs(buf *bytes.Buffer, ais []*net.IPNet) error {
|
||||
func cutIP(ip net.IP) net.IP {
|
||||
if i4 := ip.To4(); i4 != nil {
|
||||
return i4
|
||||
}
|
||||
return ip.To16()
|
||||
}
|
||||
|
||||
func writeAllowedIPs(buf *bytes.Buffer, ais []net.IPNet) error {
|
||||
if len(ais) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -382,15 +408,16 @@ func writeAllowedIPs(buf *bytes.Buffer, ais []*net.IPNet) error {
|
||||
return buf.WriteByte('\n')
|
||||
}
|
||||
|
||||
func writePKey(buf *bytes.Buffer, k key, b []byte) error {
|
||||
if len(b) == 0 {
|
||||
func writePKey(buf *bytes.Buffer, k key, b *wgtypes.Key) error {
|
||||
// Print nothing if the public key was never initialized.
|
||||
if b == nil || (wgtypes.Key{}) == *b {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
if err = writeKey(buf, k); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = buf.Write(b); err != nil {
|
||||
if _, err = buf.Write([]byte(b.String())); err != nil {
|
||||
return err
|
||||
}
|
||||
return buf.WriteByte('\n')
|
||||
@@ -408,14 +435,15 @@ func writeValue(buf *bytes.Buffer, k key, v string) error {
|
||||
}
|
||||
|
||||
func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error {
|
||||
if e == nil {
|
||||
str := e.String()
|
||||
if str == "" {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
if err = writeKey(buf, endpointKey); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = buf.WriteString(e.String()); err != nil {
|
||||
if _, err = buf.WriteString(str); err != nil {
|
||||
return err
|
||||
}
|
||||
return buf.WriteByte('\n')
|
||||
@@ -443,177 +471,3 @@ func writeKey(buf *bytes.Buffer, k key) error {
|
||||
_, err = buf.WriteString(" = ")
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
errParseEndpoint = errors.New("could not parse Endpoint")
|
||||
)
|
||||
|
||||
func (p *Peer) parseEndpoint(v string) error {
|
||||
var (
|
||||
kv []string
|
||||
err error
|
||||
ip, ip4 net.IP
|
||||
port uint64
|
||||
)
|
||||
kv = strings.Split(v, ":")
|
||||
if len(kv) < 2 {
|
||||
return errParseEndpoint
|
||||
}
|
||||
port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d := DNSOrIP{}
|
||||
ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]"))
|
||||
if ip == nil {
|
||||
if len(validation.IsDNS1123Subdomain(kv[0])) != 0 {
|
||||
return errParseEndpoint
|
||||
}
|
||||
d.DNS = kv[0]
|
||||
} else {
|
||||
if ip4 = ip.To4(); ip4 != nil {
|
||||
d.IP = ip4
|
||||
} else {
|
||||
d.IP = ip.To16()
|
||||
}
|
||||
}
|
||||
|
||||
p.Endpoint = &Endpoint{
|
||||
DNSOrIP: d,
|
||||
Port: uint32(port),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Peer) parseAllowedIPs(v string) error {
|
||||
var (
|
||||
ai *net.IPNet
|
||||
kv []string
|
||||
err error
|
||||
i int
|
||||
ip, ip4 net.IP
|
||||
)
|
||||
|
||||
kv = strings.Split(v, ",")
|
||||
for i = range kv {
|
||||
ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ip4 = ip.To4(); ip4 != nil {
|
||||
ip = ip4
|
||||
} else {
|
||||
ip = ip.To16()
|
||||
}
|
||||
ai.IP = ip
|
||||
p.AllowedIPs = append(p.AllowedIPs, ai)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseDump parses a given WireGuard dump and produces a Conf struct.
|
||||
func ParseDump(buf []byte) (*Conf, error) {
|
||||
// from man wg, show section:
|
||||
// If dump is specified, then several lines are printed;
|
||||
// the first contains in order separated by tab: private-key, public-key, listen-port, fw‐mark.
|
||||
// Subsequent lines are printed for each peer and contain in order separated by tab:
|
||||
// public-key, preshared-key, endpoint, allowed-ips, latest-handshake, transfer-rx, transfer-tx, persistent-keepalive.
|
||||
var (
|
||||
active section
|
||||
values []string
|
||||
c Conf
|
||||
err error
|
||||
iface *Interface
|
||||
peer *Peer
|
||||
port uint64
|
||||
sec int64
|
||||
pka int
|
||||
line int
|
||||
)
|
||||
// First line is Interface
|
||||
active = interfaceSection
|
||||
s := bufio.NewScanner(bytes.NewBuffer(buf))
|
||||
for s.Scan() {
|
||||
values = strings.Split(s.Text(), dumpSeparator)
|
||||
|
||||
switch active {
|
||||
case interfaceSection:
|
||||
if len(values) < dumpInterfaceLen {
|
||||
return nil, fmt.Errorf("invalid interface line: missing fields (%d < %d)", len(values), dumpInterfaceLen)
|
||||
}
|
||||
iface = new(Interface)
|
||||
for i := range values {
|
||||
switch i {
|
||||
case dumpInterfacePrivateKeyIndex:
|
||||
iface.PrivateKey = []byte(values[i])
|
||||
case dumpInterfaceListenPortIndex:
|
||||
port, err = strconv.ParseUint(values[i], 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid interface line: error parsing listen-port: %w", err)
|
||||
}
|
||||
iface.ListenPort = uint32(port)
|
||||
}
|
||||
}
|
||||
c.Interface = iface
|
||||
// Next lines are Peers
|
||||
active = peerSection
|
||||
case peerSection:
|
||||
if len(values) < dumpPeerLen {
|
||||
return nil, fmt.Errorf("invalid peer line %d: missing fields (%d < %d)", line, len(values), dumpPeerLen)
|
||||
}
|
||||
peer = new(Peer)
|
||||
|
||||
for i := range values {
|
||||
switch i {
|
||||
case dumpPeerPublicKeyIndex:
|
||||
peer.PublicKey = []byte(values[i])
|
||||
case dumpPeerPresharedKeyIndex:
|
||||
if values[i] == dumpNone {
|
||||
continue
|
||||
}
|
||||
peer.PresharedKey = []byte(values[i])
|
||||
case dumpPeerEndpointIndex:
|
||||
if values[i] == dumpNone {
|
||||
continue
|
||||
}
|
||||
err = peer.parseEndpoint(values[i])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid peer line %d: error parsing endpoint: %w", line, err)
|
||||
}
|
||||
case dumpPeerAllowedIPsIndex:
|
||||
if values[i] == dumpNone {
|
||||
continue
|
||||
}
|
||||
err = peer.parseAllowedIPs(values[i])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid peer line %d: error parsing allowed-ips: %w", line, err)
|
||||
}
|
||||
case dumpPeerLatestHandshakeIndex:
|
||||
if values[i] == "0" {
|
||||
// Use go zero value, not unix 0 timestamp.
|
||||
peer.LatestHandshake = time.Time{}
|
||||
continue
|
||||
}
|
||||
sec, err = strconv.ParseInt(values[i], 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid peer line %d: error parsing latest-handshake: %w", line, err)
|
||||
}
|
||||
peer.LatestHandshake = time.Unix(sec, 0)
|
||||
case dumpPeerPersistentKeepaliveIndex:
|
||||
if values[i] == dumpOff {
|
||||
continue
|
||||
}
|
||||
pka, err = strconv.Atoi(values[i])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid peer line %d: error parsing persistent-keepalive: %w", line, err)
|
||||
}
|
||||
peer.PersistentKeepalive = pka
|
||||
}
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
peer = nil
|
||||
}
|
||||
line++
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
// Copyright 2021 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -21,336 +21,431 @@ import (
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
)
|
||||
|
||||
func TestCompareConf(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
func TestNewEndpoint(t *testing.T) {
|
||||
for i, tc := range []struct {
|
||||
name string
|
||||
a []byte
|
||||
b []byte
|
||||
out bool
|
||||
ip net.IP
|
||||
port int
|
||||
out *Endpoint
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
a: []byte{},
|
||||
b: []byte{},
|
||||
out: true,
|
||||
name: "no ip, no port",
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "key and value order",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PresharedKey = psk
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(`[Interface]
|
||||
ListenPort = 51820
|
||||
PrivateKey = private
|
||||
|
||||
[Peer]
|
||||
PublicKey = key
|
||||
AllowedIPs = 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32, 10.2.2.0/24
|
||||
PresharedKey = psk
|
||||
Endpoint = 10.1.0.2:51820
|
||||
`),
|
||||
out: true,
|
||||
name: "only port",
|
||||
ip: nil,
|
||||
port: 99,
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 99,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "whitespace",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PresharedKey = psk
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(`[Interface]
|
||||
PrivateKey=private
|
||||
ListenPort=51820
|
||||
[Peer]
|
||||
Endpoint=10.1.0.2:51820
|
||||
PresharedKey = psk
|
||||
PublicKey=key
|
||||
AllowedIPs=10.2.2.0/24,192.168.0.1/32,10.2.3.0/24,192.168.0.2/32,10.4.0.2/32
|
||||
`),
|
||||
out: true,
|
||||
name: "only ipv4",
|
||||
ip: net.ParseIP("10.0.0.0"),
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("10.0.0.0").To4(),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
out: false,
|
||||
name: "only ipv6",
|
||||
ip: net.ParseIP("ff50::10"),
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("ff50::10").To16(),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different value",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PublicKey = key2
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
out: false,
|
||||
name: "ipv4",
|
||||
ip: net.ParseIP("10.0.0.0"),
|
||||
port: 1000,
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("10.0.0.0").To4(),
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "section order",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PresharedKey = psk
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(`[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PresharedKey = psk
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
|
||||
[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
`),
|
||||
out: true,
|
||||
name: "ipv6",
|
||||
ip: net.ParseIP("ff50::10"),
|
||||
port: 1000,
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("ff50::10").To16(),
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "out of order peers",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PresharedKey = psk2
|
||||
PublicKey = key2
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PresharedKey = psk1
|
||||
PublicKey = key1
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PresharedKey = psk1
|
||||
PublicKey = key1
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PresharedKey = psk2
|
||||
PublicKey = key2
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "one empty",
|
||||
a: []byte(`[Interface]
|
||||
PrivateKey = private
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
Endpoint = 10.1.0.2:51820
|
||||
PresharedKey = psk
|
||||
PublicKey = key
|
||||
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
|
||||
`),
|
||||
b: []byte(``),
|
||||
out: false,
|
||||
name: "ipv6",
|
||||
ip: net.ParseIP("fc00:f853:ccd:e793::3"),
|
||||
port: 51820,
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("fc00:f853:ccd:e793::3").To16(),
|
||||
Port: 51820,
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
equal := Parse(tc.a).Equal(Parse(tc.b))
|
||||
if equal != tc.out {
|
||||
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
|
||||
out := NewEndpoint(tc.ip, tc.port)
|
||||
if diff := pretty.Compare(out, tc.out); diff != "" {
|
||||
t.Errorf("%d %s: got diff:\n%s\n", i, tc.name, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareEndpoint(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
a *Endpoint
|
||||
b *Endpoint
|
||||
dnsFirst bool
|
||||
out bool
|
||||
}{
|
||||
{
|
||||
name: "both nil",
|
||||
a: nil,
|
||||
b: nil,
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "a nil",
|
||||
a: nil,
|
||||
b: &Endpoint{},
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "b nil",
|
||||
a: &Endpoint{},
|
||||
b: nil,
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "zero",
|
||||
a: &Endpoint{},
|
||||
b: &Endpoint{},
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "diff port",
|
||||
a: &Endpoint{Port: 1234},
|
||||
b: &Endpoint{Port: 5678},
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "same IP",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "diff IP",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2")}},
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "same IP ignore DNS",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "b"}},
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "no IP check DNS",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "no IP check DNS (same)",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "DNS first, ignore IP",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}},
|
||||
dnsFirst: true,
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "DNS first",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
|
||||
dnsFirst: true,
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "DNS first, no DNS compare IP",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}},
|
||||
dnsFirst: true,
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "DNS first, no DNS compare IP (same)",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||
dnsFirst: true,
|
||||
out: true,
|
||||
},
|
||||
} {
|
||||
equal := tc.a.Equal(tc.b, tc.dnsFirst)
|
||||
if equal != tc.out {
|
||||
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareDumpConf(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
func TestParseEndpoint(t *testing.T) {
|
||||
for i, tc := range []struct {
|
||||
name string
|
||||
d []byte
|
||||
c []byte
|
||||
str string
|
||||
out *Endpoint
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
d: []byte{},
|
||||
c: []byte{},
|
||||
name: "no ip, no port",
|
||||
},
|
||||
{
|
||||
name: "redacted copy from wg output",
|
||||
d: []byte(`private B7qk8EMlob0nfado0ABM6HulUV607r4yqtBKjhap7S4= 51820 off
|
||||
key1 (none) 10.254.1.1:51820 100.64.1.0/24,192.168.0.125/32,10.4.0.1/32 1619012801 67048 34952 10
|
||||
key2 (none) 10.254.2.1:51820 100.64.4.0/24,10.69.76.55/32,100.64.3.0/24,10.66.25.131/32,10.4.0.2/32 1619013058 1134456 10077852 10`),
|
||||
c: []byte(`[Interface]
|
||||
ListenPort = 51820
|
||||
PrivateKey = private
|
||||
|
||||
[Peer]
|
||||
PublicKey = key1
|
||||
AllowedIPs = 100.64.1.0/24, 192.168.0.125/32, 10.4.0.1/32
|
||||
Endpoint = 10.254.1.1:51820
|
||||
PersistentKeepalive = 10
|
||||
|
||||
[Peer]
|
||||
PublicKey = key2
|
||||
AllowedIPs = 100.64.4.0/24, 10.69.76.55/32, 100.64.3.0/24, 10.66.25.131/32, 10.4.0.2/32
|
||||
Endpoint = 10.254.2.1:51820
|
||||
PersistentKeepalive = 10`),
|
||||
name: "only port",
|
||||
str: ":1000",
|
||||
},
|
||||
{
|
||||
name: "only ipv4",
|
||||
str: "10.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "only ipv6",
|
||||
str: "ff50::10",
|
||||
},
|
||||
{
|
||||
name: "ipv4",
|
||||
str: "10.0.0.0:1000",
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("10.0.0.0").To4(),
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
str: "[ff50::10]:1000",
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("ff50::10").To16(),
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
|
||||
dumpConf, _ := ParseDump(tc.d)
|
||||
conf := Parse(tc.c)
|
||||
// Equal will ignore runtime fields and only compare configuration fields.
|
||||
if !dumpConf.Equal(conf) {
|
||||
diff := pretty.Compare(dumpConf, conf)
|
||||
t.Errorf("test case %q: got diff: %v", tc.name, diff)
|
||||
out := ParseEndpoint(tc.str)
|
||||
if diff := pretty.Compare(out, tc.out); diff != "" {
|
||||
t.Errorf("ParseEndpoint %s(%d): got diff:\n%s\n", tc.name, i, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEndpointFromUDPAddr(t *testing.T) {
|
||||
for i, tc := range []struct {
|
||||
name string
|
||||
u *net.UDPAddr
|
||||
out *Endpoint
|
||||
}{
|
||||
{
|
||||
name: "no ip, no port",
|
||||
out: &Endpoint{
|
||||
addr: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "only port",
|
||||
u: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
},
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
},
|
||||
addr: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "only ipv4",
|
||||
u: &net.UDPAddr{
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
},
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("10.0.0.0").To4(),
|
||||
},
|
||||
addr: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "only ipv6",
|
||||
u: &net.UDPAddr{
|
||||
IP: net.ParseIP("ff60::10"),
|
||||
},
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("ff60::10").To16(),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ipv4",
|
||||
u: &net.UDPAddr{
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
Port: 1000,
|
||||
},
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("10.0.0.0").To4(),
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
u: &net.UDPAddr{
|
||||
IP: net.ParseIP("ff50::10"),
|
||||
Port: 1000,
|
||||
},
|
||||
out: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("ff50::10").To16(),
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
out := NewEndpointFromUDPAddr(tc.u)
|
||||
if diff := pretty.Compare(out, tc.out); diff != "" {
|
||||
t.Errorf("ParseEndpoint %s(%d): got diff:\n%s\n", tc.name, i, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReady(t *testing.T) {
|
||||
for i, tc := range []struct {
|
||||
name string
|
||||
in *Endpoint
|
||||
r bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
r: false,
|
||||
},
|
||||
{
|
||||
name: "no ip, no port",
|
||||
in: &Endpoint{
|
||||
addr: "",
|
||||
udpAddr: &net.UDPAddr{},
|
||||
},
|
||||
r: false,
|
||||
},
|
||||
{
|
||||
name: "only port",
|
||||
in: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
r: false,
|
||||
},
|
||||
{
|
||||
name: "only ipv4",
|
||||
in: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
},
|
||||
},
|
||||
r: false,
|
||||
},
|
||||
{
|
||||
name: "only ipv6",
|
||||
in: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("ff60::10"),
|
||||
},
|
||||
},
|
||||
r: false,
|
||||
},
|
||||
{
|
||||
name: "ipv4",
|
||||
in: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
r: true,
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
in: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: net.ParseIP("ff50::10"),
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
r: true,
|
||||
},
|
||||
} {
|
||||
if tc.r != tc.in.Ready() {
|
||||
t.Errorf("Endpoint.Ready() %s(%d): expected=%v\tgot=%v\n", tc.name, i, tc.r, tc.in.Ready())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEqual(t *testing.T) {
|
||||
for i, tc := range []struct {
|
||||
name string
|
||||
a *Endpoint
|
||||
b *Endpoint
|
||||
df bool
|
||||
r bool
|
||||
}{
|
||||
{
|
||||
name: "nil dns last",
|
||||
r: true,
|
||||
},
|
||||
{
|
||||
name: "nil dns first",
|
||||
df: true,
|
||||
r: true,
|
||||
},
|
||||
{
|
||||
name: "equal: only port",
|
||||
a: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
b: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
r: true,
|
||||
},
|
||||
{
|
||||
name: "not equal: only port",
|
||||
a: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
},
|
||||
},
|
||||
b: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1001,
|
||||
},
|
||||
},
|
||||
r: false,
|
||||
},
|
||||
{
|
||||
name: "equal dns first",
|
||||
a: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
},
|
||||
addr: "example.com:1000",
|
||||
},
|
||||
b: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
},
|
||||
addr: "example.com:1000",
|
||||
},
|
||||
r: true,
|
||||
},
|
||||
{
|
||||
name: "equal dns last",
|
||||
a: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
},
|
||||
addr: "example.com:1000",
|
||||
},
|
||||
b: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
},
|
||||
addr: "foo",
|
||||
},
|
||||
r: true,
|
||||
},
|
||||
{
|
||||
name: "unequal dns first",
|
||||
a: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
},
|
||||
addr: "example.com:1000",
|
||||
},
|
||||
b: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
},
|
||||
addr: "foo",
|
||||
},
|
||||
df: true,
|
||||
r: false,
|
||||
},
|
||||
{
|
||||
name: "unequal dns last",
|
||||
a: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
},
|
||||
addr: "foo",
|
||||
},
|
||||
b: &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
Port: 1000,
|
||||
IP: net.ParseIP("11.0.0.0"),
|
||||
},
|
||||
addr: "foo",
|
||||
},
|
||||
r: false,
|
||||
},
|
||||
{
|
||||
name: "unequal dns last empty IP",
|
||||
a: &Endpoint{
|
||||
addr: "foo",
|
||||
},
|
||||
b: &Endpoint{
|
||||
addr: "bar",
|
||||
},
|
||||
r: false,
|
||||
},
|
||||
{
|
||||
name: "equal dns last empty IP",
|
||||
a: &Endpoint{
|
||||
addr: "foo",
|
||||
},
|
||||
b: &Endpoint{
|
||||
addr: "foo",
|
||||
},
|
||||
r: true,
|
||||
},
|
||||
} {
|
||||
if out := tc.a.Equal(tc.b, tc.df); out != tc.r {
|
||||
t.Errorf("ParseEndpoint %s(%d): expected: %v\tgot: %v\n", tc.name, i, tc.r, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -18,9 +18,7 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
@@ -65,74 +63,3 @@ func New(name string, mtu uint) (int, bool, error) {
|
||||
}
|
||||
return link.Attrs().Index, true, nil
|
||||
}
|
||||
|
||||
// Keys generates a WireGuard private and public key-pair.
|
||||
func Keys() ([]byte, []byte, error) {
|
||||
private, err := GenKey()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
public, err := PubKey(private)
|
||||
return private, public, err
|
||||
}
|
||||
|
||||
// GenKey generates a WireGuard private key.
|
||||
func GenKey() ([]byte, error) {
|
||||
key, err := exec.Command("wg", "genkey").Output()
|
||||
return bytes.Trim(key, "\n"), err
|
||||
}
|
||||
|
||||
// PubKey generates a WireGuard public key for a given private key.
|
||||
func PubKey(key []byte) ([]byte, error) {
|
||||
cmd := exec.Command("wg", "pubkey")
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open pipe to stdin: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stdin.Close()
|
||||
stdin.Write(key)
|
||||
}()
|
||||
|
||||
public, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate public key: %v", err)
|
||||
}
|
||||
return bytes.Trim(public, "\n"), nil
|
||||
}
|
||||
|
||||
// SetConf applies a WireGuard configuration file to the given interface.
|
||||
func SetConf(iface string, path string) error {
|
||||
cmd := exec.Command("wg", "setconf", iface, path)
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to apply the WireGuard configuration: %s", stderr.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ShowConf gets the WireGuard configuration for the given interface.
|
||||
func ShowConf(iface string) ([]byte, error) {
|
||||
cmd := exec.Command("wg", "showconf", iface)
|
||||
var stderr, stdout bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
cmd.Stdout = &stdout
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read the WireGuard configuration: %s", stderr.String())
|
||||
}
|
||||
return stdout.Bytes(), nil
|
||||
}
|
||||
|
||||
// ShowDump gets the WireGuard configuration and runtime information for the given interface.
|
||||
func ShowDump(iface string) ([]byte, error) {
|
||||
cmd := exec.Command("wg", "show", iface, "dump")
|
||||
var stderr, stdout bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
cmd.Stdout = &stdout
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read the WireGuard dump output: %s", stderr.String())
|
||||
}
|
||||
return stdout.Bytes(), nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user