apply suggestions from code review

Remove wireguard.Enpoint struct and use net.UDPAddr for the resolved
endpoint and addr string (dnsanme:port) if a DN was supplied.

Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
leonnicolas 2021-09-21 12:08:35 +02:00
parent bde5c3e7d1
commit b370ed3511
No known key found for this signature in database
GPG Key ID: 088D0743E2B65C07
12 changed files with 260 additions and 380 deletions

View File

@ -303,17 +303,21 @@ func translatePeer(peer *wireguard.Peer) *v1alpha1.Peer {
aips = append(aips, aip.String())
}
var endpoint *v1alpha1.PeerEndpoint
if peer.KiloEndpoint != nil && peer.KiloEndpoint.Port > 0 && (peer.KiloEndpoint.IP != nil || peer.KiloEndpoint.DNS != "") {
if (peer.Endpoint != nil && peer.Endpoint.Port > 0) || peer.Addr != "" {
var ip string
if peer.KiloEndpoint.IP != nil {
ip = peer.KiloEndpoint.IP.String()
if peer.Endpoint.IP != nil {
ip = peer.Endpoint.IP.String()
}
var dns string
if strs := strings.Split(peer.Addr, ":"); len(strs) == 2 && strs[0] != "" {
dns = strs[0]
}
endpoint = &v1alpha1.PeerEndpoint{
DNSOrIP: v1alpha1.DNSOrIP{
DNS: peer.KiloEndpoint.DNS,
DNS: dns,
IP: ip,
},
Port: uint32(peer.KiloEndpoint.Port),
Port: uint32(peer.Endpoint.Port),
}
}
var key string

View File

@ -213,7 +213,7 @@ func (nb *nodeBackend) Set(name string, node *mesh.Node) error {
return fmt.Errorf("failed to find node: %v", err)
}
n := old.DeepCopy()
n.ObjectMeta.Annotations[endpointAnnotationKey] = node.KiloEndpoint.String()
n.ObjectMeta.Annotations[endpointAnnotationKey] = node.Endpoint.String()
if node.InternalIP == nil {
n.ObjectMeta.Annotations[internalIPAnnotationKey] = ""
} else {
@ -277,9 +277,9 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
location = node.ObjectMeta.Labels[topologyLabel]
}
// Allow the endpoint to be overridden.
endpoint := parseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey])
if endpoint == nil {
endpoint = parseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey])
endpoint, addr := parseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey])
if endpoint == nil && addr == "" {
endpoint, addr = parseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey])
}
// Allow the internal IP to be overridden.
internalIP := normalizeIP(node.ObjectMeta.Annotations[forceInternalIPAnnotationKey])
@ -344,7 +344,8 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
// the mesh can wait for the node to be updated.
// It is valid for the InternalIP to be nil,
// if the given node only has public IP addresses.
KiloEndpoint: endpoint,
Endpoint: endpoint,
Addr: addr,
NoInternalIP: noInternalIP,
InternalIP: internalIP,
Key: key,
@ -378,7 +379,8 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer {
}
aips = append(aips, *aip)
}
var endpoint *wireguard.Endpoint
var endpoint *net.UDPAddr
var addr string
if peer.Spec.Endpoint != nil {
ip := net.ParseIP(peer.Spec.Endpoint.IP)
if ip4 := ip.To4(); ip4 != nil {
@ -386,13 +388,12 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer {
} else {
ip = ip.To16()
}
if peer.Spec.Endpoint.Port > 0 && (ip != nil || peer.Spec.Endpoint.DNS != "") {
endpoint = &wireguard.Endpoint{
DNSOrIP: wireguard.DNSOrIP{
DNS: peer.Spec.Endpoint.DNS,
IP: ip,
},
Port: int(peer.Spec.Endpoint.Port),
if peer.Spec.Endpoint.Port > 0 {
if ip != nil {
endpoint = &net.UDPAddr{IP: ip, Port: int(peer.Spec.Endpoint.Port)}
}
if peer.Spec.Endpoint.DNS != "" {
addr = fmt.Sprintf("%s:%d", peer.Spec.Endpoint.DNS, peer.Spec.Endpoint.Port)
}
}
}
@ -414,12 +415,12 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer {
Peer: wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: aips,
Endpoint: nil, // applyTopology will resolve this endpoint from the KiloEndpoint.
Endpoint: endpoint, // applyTopology will resolve this endpoint from the KiloEndpoint.
PersistentKeepaliveInterval: &pka,
PresharedKey: psk,
PublicKey: key,
},
KiloEndpoint: endpoint,
Addr: addr,
},
}
}
@ -519,14 +520,20 @@ func (pb *peerBackend) Set(name string, peer *mesh.Peer) error {
if peer.Endpoint != nil {
var ip string
if peer.Endpoint.IP != nil {
ip = peer.KiloEndpoint.IP.String()
ip = peer.Endpoint.IP.String()
}
var dns string
if peer.Addr != "" {
if strs := strings.Split(peer.Addr, ":"); len(strs) == 2 && strs[0] != "" {
dns = strs[0]
}
}
p.Spec.Endpoint = &v1alpha1.PeerEndpoint{
DNSOrIP: v1alpha1.DNSOrIP{
IP: ip,
DNS: peer.KiloEndpoint.DNS,
DNS: dns,
},
Port: uint32(peer.KiloEndpoint.Port),
Port: uint32(peer.Endpoint.Port),
}
}
if peer.PersistentKeepaliveInterval == nil {
@ -564,34 +571,33 @@ func normalizeIP(ip string) *net.IPNet {
return ipNet
}
func parseEndpoint(endpoint string) *wireguard.Endpoint {
func parseEndpoint(endpoint string) (*net.UDPAddr, string) {
if len(endpoint) == 0 {
return nil
return nil, ""
}
parts := strings.Split(endpoint, ":")
if len(parts) < 2 {
return nil
return nil, ""
}
portRaw := parts[len(parts)-1]
hostRaw := strings.Trim(strings.Join(parts[:len(parts)-1], ":"), "[]")
port, err := strconv.ParseUint(portRaw, 10, 32)
if err != nil {
return nil
return nil, ""
}
if len(validation.IsValidPortNum(int(port))) != 0 {
return nil
return nil, ""
}
ip := net.ParseIP(hostRaw)
if ip == nil {
if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 {
return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: hostRaw}, Port: int(port)}
return nil, endpoint
}
return nil
return nil, ""
}
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
} else {
ip = ip.To16()
u, err := net.ResolveUDPAddr("udp", endpoint)
if err != nil {
return nil, ""
}
return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: ip}, Port: int(port)}
return u, ""
}

View File

@ -80,8 +80,19 @@ func TestTranslateNode(t *testing.T) {
internalIPAnnotationKey: "10.0.0.2/32",
},
out: &mesh.Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort},
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort},
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
},
},
{
name: "valid ips with ipv6",
annotations: map[string]string{
endpointAnnotationKey: "[ff10::10]:51820",
internalIPAnnotationKey: "ff60::10/64",
},
out: &mesh.Node{
Endpoint: &net.UDPAddr{IP: net.ParseIP("ff10::10"), Port: mesh.DefaultKiloPort},
InternalIP: &net.IPNet{IP: net.ParseIP("ff60::10"), Mask: net.CIDRMask(64, 128)},
},
},
{
@ -134,7 +145,7 @@ func TestTranslateNode(t *testing.T) {
forceEndpointAnnotationKey: "-10.0.0.2:51821",
},
out: &mesh.Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort},
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort},
},
},
{
@ -144,7 +155,7 @@ func TestTranslateNode(t *testing.T) {
forceEndpointAnnotationKey: "10.0.0.2:51821",
},
out: &mesh.Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821},
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.2"), Port: 51821},
},
},
{
@ -203,7 +214,38 @@ func TestTranslateNode(t *testing.T) {
RegionLabelKey: "a",
},
out: &mesh.Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821},
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.2"), Port: 51821},
NoInternalIP: false,
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)},
Key: fooKey,
LastSeen: 1000000000,
Leader: true,
Location: "b",
PersistentKeepalive: 25 * time.Second,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)},
},
subnet: "10.2.1.0/24",
},
{
name: "complete with ipv6",
annotations: map[string]string{
endpointAnnotationKey: "10.0.0.1:51820",
forceEndpointAnnotationKey: "[1100::10]:51821",
forceInternalIPAnnotationKey: "10.1.0.2/32",
internalIPAnnotationKey: "10.1.0.1/32",
keyAnnotationKey: fooKey.String(),
lastSeenAnnotationKey: "1000000000",
leaderAnnotationKey: "",
locationAnnotationKey: "b",
persistentKeepaliveKey: "25",
wireGuardIPAnnotationKey: "10.4.0.1/16",
},
labels: map[string]string{
RegionLabelKey: "a",
},
out: &mesh.Node{
Endpoint: &net.UDPAddr{IP: net.ParseIP("1100::10"), Port: 51821},
NoInternalIP: false,
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)},
Key: fooKey,
@ -231,7 +273,7 @@ func TestTranslateNode(t *testing.T) {
RegionLabelKey: "a",
},
out: &mesh.Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820},
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 51820},
InternalIP: nil,
Key: fooKey,
LastSeen: 1000000000,
@ -259,7 +301,7 @@ func TestTranslateNode(t *testing.T) {
RegionLabelKey: "a",
},
out: &mesh.Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820},
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 51820},
NoInternalIP: true,
InternalIP: nil,
Key: fooKey,
@ -381,11 +423,27 @@ func TestTranslatePeer(t *testing.T) {
},
out: &mesh.Peer{
Peer: wireguard.Peer{
KiloEndpoint: &wireguard.Endpoint{
DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")},
Port: mesh.DefaultKiloPort,
},
PeerConfig: wgtypes.PeerConfig{
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort},
PersistentKeepaliveInterval: &zero,
},
},
},
},
{
name: "valid endpoint ipv6",
spec: v1alpha1.PeerSpec{
Endpoint: &v1alpha1.PeerEndpoint{
DNSOrIP: v1alpha1.DNSOrIP{
IP: "ff60::2",
},
Port: mesh.DefaultKiloPort,
},
},
out: &mesh.Peer{
Peer: wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
Endpoint: &net.UDPAddr{IP: net.ParseIP("ff60::2"), Port: mesh.DefaultKiloPort},
PersistentKeepaliveInterval: &zero,
},
},
@ -403,10 +461,7 @@ func TestTranslatePeer(t *testing.T) {
},
out: &mesh.Peer{
Peer: wireguard.Peer{
KiloEndpoint: &wireguard.Endpoint{
DNSOrIP: wireguard.DNSOrIP{DNS: "example.com"},
Port: mesh.DefaultKiloPort,
},
Addr: "example.com:51820",
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &zero,
},
@ -494,47 +549,59 @@ func TestParseEndpoint(t *testing.T) {
for _, tc := range []struct {
name string
endpoint string
out *wireguard.Endpoint
udp *net.UDPAddr
addr string
}{
{
name: "empty",
endpoint: "",
out: nil,
udp: nil,
addr: "",
},
{
name: "invalid IP",
endpoint: "10.0.0.:51820",
out: nil,
udp: nil,
addr: "",
},
{
name: "invalid hostname",
endpoint: "foo-:51820",
out: nil,
udp: nil,
addr: "",
},
{
name: "invalid port",
endpoint: "10.0.0.1:100000000",
out: nil,
udp: nil,
addr: "",
},
{
name: "valid IP",
endpoint: "10.0.0.1:51820",
out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort},
udp: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort},
addr: "",
},
{
name: "valid IPv6",
endpoint: "[ff02::114]:51820",
out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("ff02::114")}, Port: mesh.DefaultKiloPort},
udp: &net.UDPAddr{IP: net.ParseIP("ff02::114"), Port: mesh.DefaultKiloPort},
addr: "",
},
{
name: "valid hostname",
endpoint: "foo:51821",
out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: "foo"}, Port: 51821},
udp: nil,
addr: "foo:51821",
},
} {
endpoint := parseEndpoint(tc.endpoint)
if diff := pretty.Compare(endpoint, tc.out); diff != "" {
udp, addr := parseEndpoint(tc.endpoint)
if diff := pretty.Compare(udp, tc.udp); diff != "" {
t.Errorf("test case %q: got diff: %v", tc.name, diff)
}
if addr != tc.addr {
t.Errorf("test case %q: got: %q, wants: %q", tc.name, addr, tc.addr)
}
}
}

View File

@ -56,8 +56,8 @@ const (
// Node represents a node in the network.
type Node struct {
KiloEndpoint *wireguard.Endpoint
Endpoint *net.UDPAddr
Addr string // eg. dnsname:port
Key wgtypes.Key
NoInternalIP bool
InternalIP *net.IPNet
@ -82,9 +82,7 @@ type Node struct {
func (n *Node) Ready() bool {
// Nodes that are not leaders will not have WireGuardIPs, so it is not required.
return n != nil &&
n.KiloEndpoint != nil &&
!(n.KiloEndpoint.IP == nil && n.KiloEndpoint.DNS == "") &&
n.KiloEndpoint.Port != 0 &&
(n.Endpoint != nil || n.Addr != "") &&
n.Key != wgtypes.Key{} &&
n.Subnet != nil &&
time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second)

View File

@ -20,7 +20,6 @@ import (
"strings"
"github.com/awalterschulze/gographviz"
"github.com/squat/kilo/pkg/wireguard"
)
// Dot generates a Graphviz graph of the Topology in DOT fomat.
@ -62,10 +61,10 @@ func (t *Topology) Dot() (string, error) {
return "", fmt.Errorf("failed to add node to subgraph")
}
var wg net.IP
var endpoint *wireguard.Endpoint
var endpoint *net.UDPAddr
if j == s.leader {
wg = s.wireGuardIP
endpoint = s.kiloEndpoint
endpoint = s.endpoint
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Rank), "1"); err != nil {
return "", fmt.Errorf("failed to add rank to node")
}
@ -74,7 +73,7 @@ func (t *Topology) Dot() (string, error) {
if s.privateIPs != nil {
priv = s.privateIPs[j]
}
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], priv, wg, endpoint)); err != nil {
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], priv, wg, endpoint, s.addr)); err != nil {
return "", fmt.Errorf("failed to add label to node")
}
}
@ -154,7 +153,7 @@ func subGraphName(name string) string {
return graphEscape(fmt.Sprintf("cluster_location_%s", name))
}
func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoint *wireguard.Endpoint) string {
func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoint *net.UDPAddr, addr string) string {
label := []string{
location,
name,
@ -166,8 +165,14 @@ func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoi
if wgIP != nil {
label = append(label, wgIP.String())
}
if endpoint != nil {
label = append(label, endpoint.String())
var str string
if addr != "" {
str = addr
} else if endpoint != nil {
str = endpoint.String()
}
if str != "" {
label = append(label, str)
}
return graphEscape(strings.Join(label, "\\n"))
}

View File

@ -370,8 +370,8 @@ func (m *Mesh) checkIn() {
func (m *Mesh) handleLocal(n *Node) {
// Allow the IPs to be overridden.
if n.KiloEndpoint == nil || (n.KiloEndpoint.DNS == "" && n.KiloEndpoint.IP == nil) {
n.KiloEndpoint = &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: m.externalIP.IP}, Port: m.port}
if n.Endpoint == nil || n.Addr == "" {
n.Endpoint = &net.UDPAddr{IP: m.externalIP.IP, Port: m.port}
}
if n.InternalIP == nil && !n.NoInternalIP {
n.InternalIP = m.internalIP
@ -380,7 +380,7 @@ func (m *Mesh) handleLocal(n *Node) {
// Take leader, location, and subnet from the argument, as these
// are not determined by kilo.
local := &Node{
KiloEndpoint: n.KiloEndpoint,
Endpoint: n.Endpoint,
Key: m.pub,
NoInternalIP: n.NoInternalIP,
InternalIP: n.InternalIP,
@ -484,7 +484,7 @@ func (m *Mesh) applyTopology() {
natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger)
nodes[m.hostname].DiscoveredEndpoints = natEndpoints
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].KiloEndpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
@ -625,12 +625,12 @@ func (m *Mesh) resolveEndpoints() error {
}
// If the node is ready, then the endpoint is not nil
// but it may not have a DNS name.
if m.nodes[k].KiloEndpoint.DNS == "" {
if m.nodes[k].Addr == "" {
continue
}
if u, err := net.ResolveUDPAddr("udp", m.nodes[k].KiloEndpoint.String()); err == nil {
if u, err := net.ResolveUDPAddr("udp", m.nodes[k].Addr); err == nil {
m.nodes[k].Endpoint = u
m.nodes[k].KiloEndpoint.IP = u.IP
m.nodes[k].Endpoint.IP = u.IP
} else {
return err
}
@ -642,12 +642,11 @@ func (m *Mesh) resolveEndpoints() error {
continue
}
// Peers may have nil endpoints.
if m.peers[k].KiloEndpoint == nil || m.peers[k].KiloEndpoint.DNS == "" {
if m.peers[k].Addr == "" {
continue
}
if u, err := net.ResolveUDPAddr("udp", m.peers[k].KiloEndpoint.String()); err == nil {
if u, err := net.ResolveUDPAddr("udp", m.peers[k].Addr); err == nil {
m.peers[k].Endpoint = u
m.peers[k].KiloEndpoint.IP = u.IP
} else {
return err
}
@ -668,7 +667,7 @@ func nodesAreEqual(a, b *Node) bool {
}
// Check the DNS name first since this package
// is doing the DNS resolution.
if !a.KiloEndpoint.Equal(b.KiloEndpoint, true) {
if a.Addr != b.Addr || a.Endpoint.String() != b.Endpoint.String() {
return false
}
// Ignore LastSeen when comparing equality we want to check if the nodes are
@ -697,7 +696,7 @@ func peersAreEqual(a, b *Peer) bool {
}
// Check the DNS name first since this package
// is doing the DNS resolution.
if !a.KiloEndpoint.Equal(b.KiloEndpoint, true) {
if a.Addr != b.Addr || a.Endpoint.String() != b.Endpoint.String() {
return false
}
if len(a.AllowedIPs) != len(b.AllowedIPs) {

View File

@ -19,7 +19,6 @@ import (
"testing"
"time"
"github.com/squat/kilo/pkg/wireguard"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -63,58 +62,58 @@ func TestReady(t *testing.T) {
{
name: "empty endpoint IP",
node: &Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{}, Port: DefaultKiloPort},
InternalIP: internalIP,
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
Endpoint: &net.UDPAddr{Port: DefaultKiloPort},
InternalIP: internalIP,
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
},
{
name: "empty endpoint port",
node: &Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}},
InternalIP: internalIP,
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
Endpoint: &net.UDPAddr{IP: externalIP.IP},
InternalIP: internalIP,
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
},
{
name: "empty internal IP",
node: &Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
},
{
name: "empty key",
node: &Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
InternalIP: internalIP,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
InternalIP: internalIP,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
},
{
name: "empty subnet",
node: &Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
InternalIP: internalIP,
Key: wgtypes.Key{},
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
InternalIP: internalIP,
Key: wgtypes.Key{},
},
ready: false,
},
{
name: "valid",
node: &Node{
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
InternalIP: internalIP,
Key: key,
LastSeen: time.Now().Unix(),
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
InternalIP: internalIP,
Key: key,
LastSeen: time.Now().Unix(),
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: true,
},

View File

@ -40,7 +40,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
var gw net.IP
for _, segment := range t.segments {
if segment.location == t.location {
gw = enc.Gw(segment.kiloEndpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
gw = enc.Gw(segment.endpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
break
}
}
@ -196,7 +196,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
// equals the external IP. This means that the node
// is only accessible through an external IP and we
// cannot encapsulate traffic to an IP through the IP.
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.kiloEndpoint.IP) {
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP) {
continue
}
// Add routes to the private IPs of nodes in other segments.

View File

@ -67,7 +67,7 @@ type Topology struct {
type segment struct {
allowedIPs []net.IPNet
kiloEndpoint *wireguard.Endpoint
addr string
endpoint *net.UDPAddr
key wgtypes.Key
persistentKeepalive time.Duration
@ -178,7 +178,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
})
t.segments = append(t.segments, &segment{
allowedIPs: allowedIPs,
kiloEndpoint: topoMap[location][leader].KiloEndpoint,
addr: topoMap[location][leader].Addr,
endpoint: topoMap[location][leader].Endpoint,
key: topoMap[location][leader].Key,
persistentKeepalive: topoMap[location][leader].PersistentKeepalive,
@ -287,10 +287,10 @@ CheckIPs:
return
}
func (t *Topology) updateEndpoint(kiloEndpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *net.UDPAddr {
func (t *Topology) updateEndpoint(endpoint *net.UDPAddr, key wgtypes.Key, persistentKeepalive *time.Duration) *net.UDPAddr {
// Do not update non-nat peers
if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) {
return kiloEndpoint.UDPAddr()
return endpoint
}
e, ok := t.discoveredEndpoints[key.String()]
if ok {
@ -315,12 +315,12 @@ func (t *Topology) Conf() *wireguard.Conf {
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
Endpoint: t.updateEndpoint(s.kiloEndpoint, s.key, &s.persistentKeepalive),
Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive),
PersistentKeepaliveInterval: &t.persistentKeepalive,
PublicKey: s.key,
ReplaceAllowedIPs: true,
},
KiloEndpoint: s.kiloEndpoint,
Addr: s.addr,
}
c.Peers = append(c.Peers, peer)
}
@ -328,13 +328,13 @@ func (t *Topology) Conf() *wireguard.Conf {
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: p.AllowedIPs,
Endpoint: t.updateEndpoint(p.KiloEndpoint, p.PublicKey, p.PersistentKeepaliveInterval),
Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval),
PersistentKeepaliveInterval: &t.persistentKeepalive,
PresharedKey: p.PresharedKey,
PublicKey: p.PublicKey,
ReplaceAllowedIPs: true,
},
KiloEndpoint: p.KiloEndpoint,
Addr: p.Addr,
}
c.Peers = append(c.Peers, peer)
}
@ -354,7 +354,7 @@ func (t *Topology) AsPeer() wireguard.Peer {
PublicKey: s.key,
Endpoint: s.endpoint,
},
KiloEndpoint: s.kiloEndpoint,
Addr: s.addr,
}
return p
}
@ -377,12 +377,12 @@ func (t *Topology) PeerConf(name string) wireguard.Conf {
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: s.allowedIPs,
Endpoint: s.kiloEndpoint.UDPAddr(),
Endpoint: s.endpoint,
PersistentKeepaliveInterval: pka,
PresharedKey: psk,
PublicKey: s.key,
},
KiloEndpoint: s.kiloEndpoint,
Addr: s.addr,
}
c.Peers = append(c.Peers, peer)
}
@ -417,13 +417,13 @@ func findLeader(nodes []*Node) int {
var leaders, public []int
for i := range nodes {
if nodes[i].Leader {
if isPublic(nodes[i].KiloEndpoint.IP) {
if isPublic(nodes[i].Endpoint.IP) {
return i
}
leaders = append(leaders, i)
}
if isPublic(nodes[i].KiloEndpoint.IP) {
if nodes[i].Endpoint != nil && isPublic(nodes[i].Endpoint.IP) {
public = append(public, i)
}
}
@ -449,7 +449,7 @@ func deduplicatePeerIPs(peers []*Peer) []*Peer {
PresharedKey: peer.PresharedKey,
PublicKey: peer.PublicKey,
},
KiloEndpoint: peer.KiloEndpoint,
Addr: peer.Addr,
},
}
for _, ip := range peer.AllowedIPs {

View File

@ -60,7 +60,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
nodes := map[string]*Node{
"a": {
Name: "a",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort},
InternalIP: i1,
Location: "1",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
@ -69,7 +69,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
},
"b": {
Name: "b",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
InternalIP: i1,
Location: "2",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)},
@ -77,17 +77,17 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
AllowedLocationIPs: []net.IPNet{*i3},
},
"c": {
Name: "c",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e3.IP}, Port: DefaultKiloPort},
InternalIP: i2,
Name: "c",
Endpoint: &net.UDPAddr{IP: e3.IP, Port: DefaultKiloPort},
InternalIP: i2,
// Same location as node b.
Location: "2",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.3.0"), Mask: net.CIDRMask(24, 32)},
Key: key3,
},
"d": {
Name: "d",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e4.IP}, Port: DefaultKiloPort},
Name: "d",
Endpoint: &net.UDPAddr{IP: e4.IP, Port: DefaultKiloPort},
// Same location as node a, but without private IP
Location: "1",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.4.0"), Mask: net.CIDRMask(24, 32)},
@ -115,10 +115,10 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
{IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)},
},
PublicKey: key5,
},
KiloEndpoint: &wireguard.Endpoint{
DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("192.168.0.1")},
Port: DefaultKiloPort,
Endpoint: &net.UDPAddr{
IP: net.ParseIP("192.168.0.1"),
Port: DefaultKiloPort,
},
},
},
},
@ -153,7 +153,7 @@ func TestNewTopology(t *testing.T) {
segments: []*segment{
{
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["a"].Location,
@ -164,7 +164,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["b"].Location,
@ -176,7 +176,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@ -204,7 +204,7 @@ func TestNewTopology(t *testing.T) {
segments: []*segment{
{
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["a"].Location,
@ -215,7 +215,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["b"].Location,
@ -227,7 +227,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@ -255,7 +255,7 @@ func TestNewTopology(t *testing.T) {
segments: []*segment{
{
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["a"].Location,
@ -266,7 +266,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: logicalLocationPrefix + nodes["b"].Location,
@ -278,7 +278,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@ -306,7 +306,7 @@ func TestNewTopology(t *testing.T) {
segments: []*segment{
{
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["a"].Name,
@ -317,7 +317,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["b"].Name,
@ -329,7 +329,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["c"].KiloEndpoint,
endpoint: nodes["c"].Endpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["c"].Name,
@ -340,7 +340,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@ -368,7 +368,7 @@ func TestNewTopology(t *testing.T) {
segments: []*segment{
{
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["a"].Name,
@ -379,7 +379,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["b"].Name,
@ -391,7 +391,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["c"].KiloEndpoint,
endpoint: nodes["c"].Endpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["c"].Name,
@ -402,7 +402,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@ -430,7 +430,7 @@ func TestNewTopology(t *testing.T) {
segments: []*segment{
{
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["a"].Name,
@ -441,7 +441,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["b"].Name,
@ -453,7 +453,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["c"].KiloEndpoint,
endpoint: nodes["c"].Endpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["c"].Name,
@ -464,7 +464,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@ -492,7 +492,7 @@ func TestNewTopology(t *testing.T) {
segments: []*segment{
{
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["a"].KiloEndpoint,
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["a"].Name,
@ -503,7 +503,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["b"].KiloEndpoint,
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["b"].Name,
@ -515,7 +515,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["c"].KiloEndpoint,
endpoint: nodes["c"].Endpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["c"].Name,
@ -526,7 +526,7 @@ func TestNewTopology(t *testing.T) {
},
{
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
kiloEndpoint: nodes["d"].KiloEndpoint,
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
location: nodeLocationPrefix + nodes["d"].Name,
@ -575,26 +575,26 @@ func TestFindLeader(t *testing.T) {
nodes := []*Node{
{
Name: "a",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
Name: "a",
Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort},
},
{
Name: "b",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Name: "b",
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
},
{
Name: "c",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Name: "c",
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
},
{
Name: "d",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
Leader: true,
Name: "d",
Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort},
Leader: true,
},
{
Name: "2",
KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Leader: true,
Name: "2",
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
Leader: true,
},
}
for _, tc := range []struct {

View File

@ -43,7 +43,7 @@ const (
// Conf represents a WireGuard configuration file.
type Conf struct {
wgtypes.Config
// The Peers field is shadowed because every Peer needs the KiloEndpoint field that contains a DNS endpoint.
// The Peers field is shadowed because every Peer needs the Endpoint field that contains a DNS endpoint.
Peers []Peer
}
@ -60,16 +60,10 @@ func (c Conf) WGConfig() wgtypes.Config {
return r
}
// Interface represents the `interface` section of a WireGuard configuration.
type Interface struct {
ListenPort uint32
PrivateKey []byte
}
// Peer represents a `peer` section of a WireGuard configuration.
type Peer struct {
wgtypes.PeerConfig
KiloEndpoint *Endpoint
Addr string // eg: dnsname:port
}
// DeduplicateIPs eliminates duplicate allowed IPs.
@ -86,79 +80,6 @@ func (p *Peer) DeduplicateIPs() {
p.AllowedIPs = ips
}
// Endpoint represents an `endpoint` key of a `peer` section.
type Endpoint struct {
DNSOrIP
Port int
}
// String prints the string representation of the endpoint.
func (e *Endpoint) String() string {
if e == nil {
return ""
}
dnsOrIP := e.DNSOrIP.String()
if e.IP != nil && len(e.IP) == net.IPv6len {
dnsOrIP = "[" + dnsOrIP + "]"
}
return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
}
// UDPAddr returns the corresponding net.UDPAddr of the Endpoint or nil.
func (e *Endpoint) UDPAddr() (u *net.UDPAddr) {
if a, err := net.ResolveUDPAddr("udp", e.String()); err == nil {
u = a
}
return
}
// Equal compares two endpoints.
func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool {
if (e == nil) != (b == nil) {
return false
}
if e != nil {
if e.Port != b.Port {
return false
}
if DNSFirst {
// Check the DNS name first if it was resolved.
if e.DNS != b.DNS {
return false
}
if e.DNS == "" && !e.IP.Equal(b.IP) {
return false
}
} else {
// IPs take priority, so check them first.
if !e.IP.Equal(b.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if e.IP == nil && e.DNS != b.DNS {
return false
}
}
}
return true
}
// DNSOrIP represents either a DNS name or an IP address.
// IPs, as they are more specific, are preferred.
type DNSOrIP struct {
DNS string
IP net.IP
}
// String prints the string representation of the struct.
func (d DNSOrIP) String() string {
if d.IP != nil {
return d.IP.String()
}
return d.DNS
}
// Bytes renders a WireGuard configuration to bytes.
func (c Conf) Bytes() ([]byte, error) {
var err error
@ -188,13 +109,13 @@ func (c Conf) Bytes() ([]byte, error) {
if err = writeAllowedIPs(buf, p.AllowedIPs); err != nil {
return nil, fmt.Errorf("failed to write allowed IPs: %v", err)
}
if err = writeEndpoint(buf, p.KiloEndpoint); err != nil {
if err = writeEndpoint(buf, p.Endpoint, p.Addr); err != nil {
return nil, fmt.Errorf("failed to write endpoint: %v", err)
}
if p.PersistentKeepaliveInterval == nil {
p.PersistentKeepaliveInterval = new(time.Duration)
}
if err = writeValue(buf, persistentKeepaliveKey, strconv.FormatUint(uint64(*p.PersistentKeepaliveInterval), 10)); err != nil {
if err = writeValue(buf, persistentKeepaliveKey, strconv.FormatUint(uint64(*p.PersistentKeepaliveInterval/time.Second), 10)); err != nil {
return nil, fmt.Errorf("failed to write persistent keepalive: %v", err)
}
if err = writePKey(buf, presharedKeyKey, p.PresharedKey); err != nil {
@ -327,15 +248,20 @@ func writeValue(buf *bytes.Buffer, k key, v string) error {
return buf.WriteByte('\n')
}
func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error {
if e == nil {
func writeEndpoint(buf *bytes.Buffer, e *net.UDPAddr, d string) error {
str := ""
if d != "" {
str = d
} else if e != nil {
str = e.String()
} else {
return nil
}
var err error
if err = writeKey(buf, endpointKey); err != nil {
return err
}
if _, err = buf.WriteString(e.String()); err != nil {
if _, err = buf.WriteString(str); err != nil {
return err
}
return buf.WriteByte('\n')

View File

@ -1,124 +0,0 @@
// Copyright 2021 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wireguard
import (
"net"
"testing"
)
func TestCompareEndpoint(t *testing.T) {
for _, tc := range []struct {
name string
a *Endpoint
b *Endpoint
dnsFirst bool
out bool
}{
{
name: "both nil",
a: nil,
b: nil,
out: true,
},
{
name: "a nil",
a: nil,
b: &Endpoint{},
out: false,
},
{
name: "b nil",
a: &Endpoint{},
b: nil,
out: false,
},
{
name: "zero",
a: &Endpoint{},
b: &Endpoint{},
out: true,
},
{
name: "diff port",
a: &Endpoint{Port: 1234},
b: &Endpoint{Port: 5678},
out: false,
},
{
name: "same IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
out: true,
},
{
name: "diff IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2")}},
out: false,
},
{
name: "same IP ignore DNS",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "b"}},
out: true,
},
{
name: "no IP check DNS",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
out: false,
},
{
name: "no IP check DNS (same)",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
out: true,
},
{
name: "DNS first, ignore IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}},
dnsFirst: true,
out: true,
},
{
name: "DNS first",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
dnsFirst: true,
out: false,
},
{
name: "DNS first, no DNS compare IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}},
dnsFirst: true,
out: false,
},
{
name: "DNS first, no DNS compare IP (same)",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
dnsFirst: true,
out: true,
},
} {
equal := tc.a.Equal(tc.b, tc.dnsFirst)
if equal != tc.out {
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
}
}
}