From b370ed351159b3cf8c8b2afc3fb57abea69c645d Mon Sep 17 00:00:00 2001 From: leonnicolas Date: Tue, 21 Sep 2021 12:08:35 +0200 Subject: [PATCH] apply suggestions from code review Remove wireguard.Enpoint struct and use net.UDPAddr for the resolved endpoint and addr string (dnsanme:port) if a DN was supplied. Signed-off-by: leonnicolas --- cmd/kgctl/showconf.go | 14 +++-- pkg/k8s/backend.go | 66 +++++++++++--------- pkg/k8s/backend_test.go | 117 ++++++++++++++++++++++++++-------- pkg/mesh/backend.go | 6 +- pkg/mesh/graph.go | 19 +++--- pkg/mesh/mesh.go | 23 ++++--- pkg/mesh/mesh_test.go | 45 +++++++------- pkg/mesh/routes.go | 4 +- pkg/mesh/topology.go | 28 ++++----- pkg/mesh/topology_test.go | 96 ++++++++++++++-------------- pkg/wireguard/conf.go | 98 ++++------------------------- pkg/wireguard/conf_test.go | 124 ------------------------------------- 12 files changed, 260 insertions(+), 380 deletions(-) delete mode 100644 pkg/wireguard/conf_test.go diff --git a/cmd/kgctl/showconf.go b/cmd/kgctl/showconf.go index 9e877d5..1cd76d4 100644 --- a/cmd/kgctl/showconf.go +++ b/cmd/kgctl/showconf.go @@ -303,17 +303,21 @@ func translatePeer(peer *wireguard.Peer) *v1alpha1.Peer { aips = append(aips, aip.String()) } var endpoint *v1alpha1.PeerEndpoint - if peer.KiloEndpoint != nil && peer.KiloEndpoint.Port > 0 && (peer.KiloEndpoint.IP != nil || peer.KiloEndpoint.DNS != "") { + if (peer.Endpoint != nil && peer.Endpoint.Port > 0) || peer.Addr != "" { var ip string - if peer.KiloEndpoint.IP != nil { - ip = peer.KiloEndpoint.IP.String() + if peer.Endpoint.IP != nil { + ip = peer.Endpoint.IP.String() + } + var dns string + if strs := strings.Split(peer.Addr, ":"); len(strs) == 2 && strs[0] != "" { + dns = strs[0] } endpoint = &v1alpha1.PeerEndpoint{ DNSOrIP: v1alpha1.DNSOrIP{ - DNS: peer.KiloEndpoint.DNS, + DNS: dns, IP: ip, }, - Port: uint32(peer.KiloEndpoint.Port), + Port: uint32(peer.Endpoint.Port), } } var key string diff --git a/pkg/k8s/backend.go b/pkg/k8s/backend.go index 07bc06a..4c9bbc5 100644 --- a/pkg/k8s/backend.go +++ b/pkg/k8s/backend.go @@ -213,7 +213,7 @@ func (nb *nodeBackend) Set(name string, node *mesh.Node) error { return fmt.Errorf("failed to find node: %v", err) } n := old.DeepCopy() - n.ObjectMeta.Annotations[endpointAnnotationKey] = node.KiloEndpoint.String() + n.ObjectMeta.Annotations[endpointAnnotationKey] = node.Endpoint.String() if node.InternalIP == nil { n.ObjectMeta.Annotations[internalIPAnnotationKey] = "" } else { @@ -277,9 +277,9 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node { location = node.ObjectMeta.Labels[topologyLabel] } // Allow the endpoint to be overridden. - endpoint := parseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey]) - if endpoint == nil { - endpoint = parseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey]) + endpoint, addr := parseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey]) + if endpoint == nil && addr == "" { + endpoint, addr = parseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey]) } // Allow the internal IP to be overridden. internalIP := normalizeIP(node.ObjectMeta.Annotations[forceInternalIPAnnotationKey]) @@ -344,7 +344,8 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node { // the mesh can wait for the node to be updated. // It is valid for the InternalIP to be nil, // if the given node only has public IP addresses. - KiloEndpoint: endpoint, + Endpoint: endpoint, + Addr: addr, NoInternalIP: noInternalIP, InternalIP: internalIP, Key: key, @@ -378,7 +379,8 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer { } aips = append(aips, *aip) } - var endpoint *wireguard.Endpoint + var endpoint *net.UDPAddr + var addr string if peer.Spec.Endpoint != nil { ip := net.ParseIP(peer.Spec.Endpoint.IP) if ip4 := ip.To4(); ip4 != nil { @@ -386,13 +388,12 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer { } else { ip = ip.To16() } - if peer.Spec.Endpoint.Port > 0 && (ip != nil || peer.Spec.Endpoint.DNS != "") { - endpoint = &wireguard.Endpoint{ - DNSOrIP: wireguard.DNSOrIP{ - DNS: peer.Spec.Endpoint.DNS, - IP: ip, - }, - Port: int(peer.Spec.Endpoint.Port), + if peer.Spec.Endpoint.Port > 0 { + if ip != nil { + endpoint = &net.UDPAddr{IP: ip, Port: int(peer.Spec.Endpoint.Port)} + } + if peer.Spec.Endpoint.DNS != "" { + addr = fmt.Sprintf("%s:%d", peer.Spec.Endpoint.DNS, peer.Spec.Endpoint.Port) } } } @@ -414,12 +415,12 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer { Peer: wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: aips, - Endpoint: nil, // applyTopology will resolve this endpoint from the KiloEndpoint. + Endpoint: endpoint, // applyTopology will resolve this endpoint from the KiloEndpoint. PersistentKeepaliveInterval: &pka, PresharedKey: psk, PublicKey: key, }, - KiloEndpoint: endpoint, + Addr: addr, }, } } @@ -519,14 +520,20 @@ func (pb *peerBackend) Set(name string, peer *mesh.Peer) error { if peer.Endpoint != nil { var ip string if peer.Endpoint.IP != nil { - ip = peer.KiloEndpoint.IP.String() + ip = peer.Endpoint.IP.String() + } + var dns string + if peer.Addr != "" { + if strs := strings.Split(peer.Addr, ":"); len(strs) == 2 && strs[0] != "" { + dns = strs[0] + } } p.Spec.Endpoint = &v1alpha1.PeerEndpoint{ DNSOrIP: v1alpha1.DNSOrIP{ IP: ip, - DNS: peer.KiloEndpoint.DNS, + DNS: dns, }, - Port: uint32(peer.KiloEndpoint.Port), + Port: uint32(peer.Endpoint.Port), } } if peer.PersistentKeepaliveInterval == nil { @@ -564,34 +571,33 @@ func normalizeIP(ip string) *net.IPNet { return ipNet } -func parseEndpoint(endpoint string) *wireguard.Endpoint { +func parseEndpoint(endpoint string) (*net.UDPAddr, string) { if len(endpoint) == 0 { - return nil + return nil, "" } parts := strings.Split(endpoint, ":") if len(parts) < 2 { - return nil + return nil, "" } portRaw := parts[len(parts)-1] hostRaw := strings.Trim(strings.Join(parts[:len(parts)-1], ":"), "[]") port, err := strconv.ParseUint(portRaw, 10, 32) if err != nil { - return nil + return nil, "" } if len(validation.IsValidPortNum(int(port))) != 0 { - return nil + return nil, "" } ip := net.ParseIP(hostRaw) if ip == nil { if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 { - return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: hostRaw}, Port: int(port)} + return nil, endpoint } - return nil + return nil, "" } - if ip4 := ip.To4(); ip4 != nil { - ip = ip4 - } else { - ip = ip.To16() + u, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return nil, "" } - return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: ip}, Port: int(port)} + return u, "" } diff --git a/pkg/k8s/backend_test.go b/pkg/k8s/backend_test.go index ad18306..1cc5898 100644 --- a/pkg/k8s/backend_test.go +++ b/pkg/k8s/backend_test.go @@ -80,8 +80,19 @@ func TestTranslateNode(t *testing.T) { internalIPAnnotationKey: "10.0.0.2/32", }, out: &mesh.Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort}, - InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)}, + Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort}, + InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)}, + }, + }, + { + name: "valid ips with ipv6", + annotations: map[string]string{ + endpointAnnotationKey: "[ff10::10]:51820", + internalIPAnnotationKey: "ff60::10/64", + }, + out: &mesh.Node{ + Endpoint: &net.UDPAddr{IP: net.ParseIP("ff10::10"), Port: mesh.DefaultKiloPort}, + InternalIP: &net.IPNet{IP: net.ParseIP("ff60::10"), Mask: net.CIDRMask(64, 128)}, }, }, { @@ -134,7 +145,7 @@ func TestTranslateNode(t *testing.T) { forceEndpointAnnotationKey: "-10.0.0.2:51821", }, out: &mesh.Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort}, + Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort}, }, }, { @@ -144,7 +155,7 @@ func TestTranslateNode(t *testing.T) { forceEndpointAnnotationKey: "10.0.0.2:51821", }, out: &mesh.Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821}, + Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.2"), Port: 51821}, }, }, { @@ -203,7 +214,38 @@ func TestTranslateNode(t *testing.T) { RegionLabelKey: "a", }, out: &mesh.Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821}, + Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.2"), Port: 51821}, + NoInternalIP: false, + InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)}, + Key: fooKey, + LastSeen: 1000000000, + Leader: true, + Location: "b", + PersistentKeepalive: 25 * time.Second, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)}, + WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)}, + }, + subnet: "10.2.1.0/24", + }, + { + name: "complete with ipv6", + annotations: map[string]string{ + endpointAnnotationKey: "10.0.0.1:51820", + forceEndpointAnnotationKey: "[1100::10]:51821", + forceInternalIPAnnotationKey: "10.1.0.2/32", + internalIPAnnotationKey: "10.1.0.1/32", + keyAnnotationKey: fooKey.String(), + lastSeenAnnotationKey: "1000000000", + leaderAnnotationKey: "", + locationAnnotationKey: "b", + persistentKeepaliveKey: "25", + wireGuardIPAnnotationKey: "10.4.0.1/16", + }, + labels: map[string]string{ + RegionLabelKey: "a", + }, + out: &mesh.Node{ + Endpoint: &net.UDPAddr{IP: net.ParseIP("1100::10"), Port: 51821}, NoInternalIP: false, InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)}, Key: fooKey, @@ -231,7 +273,7 @@ func TestTranslateNode(t *testing.T) { RegionLabelKey: "a", }, out: &mesh.Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820}, + Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 51820}, InternalIP: nil, Key: fooKey, LastSeen: 1000000000, @@ -259,7 +301,7 @@ func TestTranslateNode(t *testing.T) { RegionLabelKey: "a", }, out: &mesh.Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820}, + Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 51820}, NoInternalIP: true, InternalIP: nil, Key: fooKey, @@ -381,11 +423,27 @@ func TestTranslatePeer(t *testing.T) { }, out: &mesh.Peer{ Peer: wireguard.Peer{ - KiloEndpoint: &wireguard.Endpoint{ - DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, - Port: mesh.DefaultKiloPort, - }, PeerConfig: wgtypes.PeerConfig{ + Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort}, + PersistentKeepaliveInterval: &zero, + }, + }, + }, + }, + { + name: "valid endpoint ipv6", + spec: v1alpha1.PeerSpec{ + Endpoint: &v1alpha1.PeerEndpoint{ + DNSOrIP: v1alpha1.DNSOrIP{ + IP: "ff60::2", + }, + Port: mesh.DefaultKiloPort, + }, + }, + out: &mesh.Peer{ + Peer: wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + Endpoint: &net.UDPAddr{IP: net.ParseIP("ff60::2"), Port: mesh.DefaultKiloPort}, PersistentKeepaliveInterval: &zero, }, }, @@ -403,10 +461,7 @@ func TestTranslatePeer(t *testing.T) { }, out: &mesh.Peer{ Peer: wireguard.Peer{ - KiloEndpoint: &wireguard.Endpoint{ - DNSOrIP: wireguard.DNSOrIP{DNS: "example.com"}, - Port: mesh.DefaultKiloPort, - }, + Addr: "example.com:51820", PeerConfig: wgtypes.PeerConfig{ PersistentKeepaliveInterval: &zero, }, @@ -494,47 +549,59 @@ func TestParseEndpoint(t *testing.T) { for _, tc := range []struct { name string endpoint string - out *wireguard.Endpoint + udp *net.UDPAddr + addr string }{ { name: "empty", endpoint: "", - out: nil, + udp: nil, + addr: "", }, { name: "invalid IP", endpoint: "10.0.0.:51820", - out: nil, + udp: nil, + addr: "", }, { name: "invalid hostname", endpoint: "foo-:51820", - out: nil, + udp: nil, + addr: "", }, { name: "invalid port", endpoint: "10.0.0.1:100000000", - out: nil, + udp: nil, + addr: "", }, { name: "valid IP", endpoint: "10.0.0.1:51820", - out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort}, + udp: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort}, + addr: "", }, { name: "valid IPv6", endpoint: "[ff02::114]:51820", - out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("ff02::114")}, Port: mesh.DefaultKiloPort}, + udp: &net.UDPAddr{IP: net.ParseIP("ff02::114"), Port: mesh.DefaultKiloPort}, + addr: "", }, { name: "valid hostname", endpoint: "foo:51821", - out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: "foo"}, Port: 51821}, + udp: nil, + addr: "foo:51821", }, } { - endpoint := parseEndpoint(tc.endpoint) - if diff := pretty.Compare(endpoint, tc.out); diff != "" { + udp, addr := parseEndpoint(tc.endpoint) + if diff := pretty.Compare(udp, tc.udp); diff != "" { t.Errorf("test case %q: got diff: %v", tc.name, diff) } + if addr != tc.addr { + t.Errorf("test case %q: got: %q, wants: %q", tc.name, addr, tc.addr) + } + } } diff --git a/pkg/mesh/backend.go b/pkg/mesh/backend.go index d30eb88..f4bc152 100644 --- a/pkg/mesh/backend.go +++ b/pkg/mesh/backend.go @@ -56,8 +56,8 @@ const ( // Node represents a node in the network. type Node struct { - KiloEndpoint *wireguard.Endpoint Endpoint *net.UDPAddr + Addr string // eg. dnsname:port Key wgtypes.Key NoInternalIP bool InternalIP *net.IPNet @@ -82,9 +82,7 @@ type Node struct { func (n *Node) Ready() bool { // Nodes that are not leaders will not have WireGuardIPs, so it is not required. return n != nil && - n.KiloEndpoint != nil && - !(n.KiloEndpoint.IP == nil && n.KiloEndpoint.DNS == "") && - n.KiloEndpoint.Port != 0 && + (n.Endpoint != nil || n.Addr != "") && n.Key != wgtypes.Key{} && n.Subnet != nil && time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second) diff --git a/pkg/mesh/graph.go b/pkg/mesh/graph.go index 84b50a1..1b6133f 100644 --- a/pkg/mesh/graph.go +++ b/pkg/mesh/graph.go @@ -20,7 +20,6 @@ import ( "strings" "github.com/awalterschulze/gographviz" - "github.com/squat/kilo/pkg/wireguard" ) // Dot generates a Graphviz graph of the Topology in DOT fomat. @@ -62,10 +61,10 @@ func (t *Topology) Dot() (string, error) { return "", fmt.Errorf("failed to add node to subgraph") } var wg net.IP - var endpoint *wireguard.Endpoint + var endpoint *net.UDPAddr if j == s.leader { wg = s.wireGuardIP - endpoint = s.kiloEndpoint + endpoint = s.endpoint if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Rank), "1"); err != nil { return "", fmt.Errorf("failed to add rank to node") } @@ -74,7 +73,7 @@ func (t *Topology) Dot() (string, error) { if s.privateIPs != nil { priv = s.privateIPs[j] } - if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], priv, wg, endpoint)); err != nil { + if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], priv, wg, endpoint, s.addr)); err != nil { return "", fmt.Errorf("failed to add label to node") } } @@ -154,7 +153,7 @@ func subGraphName(name string) string { return graphEscape(fmt.Sprintf("cluster_location_%s", name)) } -func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoint *wireguard.Endpoint) string { +func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoint *net.UDPAddr, addr string) string { label := []string{ location, name, @@ -166,8 +165,14 @@ func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoi if wgIP != nil { label = append(label, wgIP.String()) } - if endpoint != nil { - label = append(label, endpoint.String()) + var str string + if addr != "" { + str = addr + } else if endpoint != nil { + str = endpoint.String() + } + if str != "" { + label = append(label, str) } return graphEscape(strings.Join(label, "\\n")) } diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index 34ffe2a..cbb3697 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -370,8 +370,8 @@ func (m *Mesh) checkIn() { func (m *Mesh) handleLocal(n *Node) { // Allow the IPs to be overridden. - if n.KiloEndpoint == nil || (n.KiloEndpoint.DNS == "" && n.KiloEndpoint.IP == nil) { - n.KiloEndpoint = &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: m.externalIP.IP}, Port: m.port} + if n.Endpoint == nil || n.Addr == "" { + n.Endpoint = &net.UDPAddr{IP: m.externalIP.IP, Port: m.port} } if n.InternalIP == nil && !n.NoInternalIP { n.InternalIP = m.internalIP @@ -380,7 +380,7 @@ func (m *Mesh) handleLocal(n *Node) { // Take leader, location, and subnet from the argument, as these // are not determined by kilo. local := &Node{ - KiloEndpoint: n.KiloEndpoint, + Endpoint: n.Endpoint, Key: m.pub, NoInternalIP: n.NoInternalIP, InternalIP: n.InternalIP, @@ -484,7 +484,7 @@ func (m *Mesh) applyTopology() { natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger) nodes[m.hostname].DiscoveredEndpoints = natEndpoints - t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].KiloEndpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger) + t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger) if err != nil { level.Error(m.logger).Log("error", err) m.errorCounter.WithLabelValues("apply").Inc() @@ -625,12 +625,12 @@ func (m *Mesh) resolveEndpoints() error { } // If the node is ready, then the endpoint is not nil // but it may not have a DNS name. - if m.nodes[k].KiloEndpoint.DNS == "" { + if m.nodes[k].Addr == "" { continue } - if u, err := net.ResolveUDPAddr("udp", m.nodes[k].KiloEndpoint.String()); err == nil { + if u, err := net.ResolveUDPAddr("udp", m.nodes[k].Addr); err == nil { m.nodes[k].Endpoint = u - m.nodes[k].KiloEndpoint.IP = u.IP + m.nodes[k].Endpoint.IP = u.IP } else { return err } @@ -642,12 +642,11 @@ func (m *Mesh) resolveEndpoints() error { continue } // Peers may have nil endpoints. - if m.peers[k].KiloEndpoint == nil || m.peers[k].KiloEndpoint.DNS == "" { + if m.peers[k].Addr == "" { continue } - if u, err := net.ResolveUDPAddr("udp", m.peers[k].KiloEndpoint.String()); err == nil { + if u, err := net.ResolveUDPAddr("udp", m.peers[k].Addr); err == nil { m.peers[k].Endpoint = u - m.peers[k].KiloEndpoint.IP = u.IP } else { return err } @@ -668,7 +667,7 @@ func nodesAreEqual(a, b *Node) bool { } // Check the DNS name first since this package // is doing the DNS resolution. - if !a.KiloEndpoint.Equal(b.KiloEndpoint, true) { + if a.Addr != b.Addr || a.Endpoint.String() != b.Endpoint.String() { return false } // Ignore LastSeen when comparing equality we want to check if the nodes are @@ -697,7 +696,7 @@ func peersAreEqual(a, b *Peer) bool { } // Check the DNS name first since this package // is doing the DNS resolution. - if !a.KiloEndpoint.Equal(b.KiloEndpoint, true) { + if a.Addr != b.Addr || a.Endpoint.String() != b.Endpoint.String() { return false } if len(a.AllowedIPs) != len(b.AllowedIPs) { diff --git a/pkg/mesh/mesh_test.go b/pkg/mesh/mesh_test.go index ea5c396..d4d80d7 100644 --- a/pkg/mesh/mesh_test.go +++ b/pkg/mesh/mesh_test.go @@ -19,7 +19,6 @@ import ( "testing" "time" - "github.com/squat/kilo/pkg/wireguard" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -63,58 +62,58 @@ func TestReady(t *testing.T) { { name: "empty endpoint IP", node: &Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{}, Port: DefaultKiloPort}, - InternalIP: internalIP, - Key: wgtypes.Key{}, - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + Endpoint: &net.UDPAddr{Port: DefaultKiloPort}, + InternalIP: internalIP, + Key: wgtypes.Key{}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: false, }, { name: "empty endpoint port", node: &Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}}, - InternalIP: internalIP, - Key: wgtypes.Key{}, - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + Endpoint: &net.UDPAddr{IP: externalIP.IP}, + InternalIP: internalIP, + Key: wgtypes.Key{}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: false, }, { name: "empty internal IP", node: &Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, - Key: wgtypes.Key{}, - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort}, + Key: wgtypes.Key{}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: false, }, { name: "empty key", node: &Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, - InternalIP: internalIP, - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort}, + InternalIP: internalIP, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: false, }, { name: "empty subnet", node: &Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, - InternalIP: internalIP, - Key: wgtypes.Key{}, + Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort}, + InternalIP: internalIP, + Key: wgtypes.Key{}, }, ready: false, }, { name: "valid", node: &Node{ - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, - InternalIP: internalIP, - Key: key, - LastSeen: time.Now().Unix(), - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort}, + InternalIP: internalIP, + Key: key, + LastSeen: time.Now().Unix(), + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: true, }, diff --git a/pkg/mesh/routes.go b/pkg/mesh/routes.go index e3bc4b0..950da0c 100644 --- a/pkg/mesh/routes.go +++ b/pkg/mesh/routes.go @@ -40,7 +40,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface var gw net.IP for _, segment := range t.segments { if segment.location == t.location { - gw = enc.Gw(segment.kiloEndpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader]) + gw = enc.Gw(segment.endpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader]) break } } @@ -196,7 +196,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface // equals the external IP. This means that the node // is only accessible through an external IP and we // cannot encapsulate traffic to an IP through the IP. - if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.kiloEndpoint.IP) { + if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP) { continue } // Add routes to the private IPs of nodes in other segments. diff --git a/pkg/mesh/topology.go b/pkg/mesh/topology.go index 628df81..24b56c4 100644 --- a/pkg/mesh/topology.go +++ b/pkg/mesh/topology.go @@ -67,7 +67,7 @@ type Topology struct { type segment struct { allowedIPs []net.IPNet - kiloEndpoint *wireguard.Endpoint + addr string endpoint *net.UDPAddr key wgtypes.Key persistentKeepalive time.Duration @@ -178,7 +178,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra }) t.segments = append(t.segments, &segment{ allowedIPs: allowedIPs, - kiloEndpoint: topoMap[location][leader].KiloEndpoint, + addr: topoMap[location][leader].Addr, endpoint: topoMap[location][leader].Endpoint, key: topoMap[location][leader].Key, persistentKeepalive: topoMap[location][leader].PersistentKeepalive, @@ -287,10 +287,10 @@ CheckIPs: return } -func (t *Topology) updateEndpoint(kiloEndpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *net.UDPAddr { +func (t *Topology) updateEndpoint(endpoint *net.UDPAddr, key wgtypes.Key, persistentKeepalive *time.Duration) *net.UDPAddr { // Do not update non-nat peers if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) { - return kiloEndpoint.UDPAddr() + return endpoint } e, ok := t.discoveredEndpoints[key.String()] if ok { @@ -315,12 +315,12 @@ func (t *Topology) Conf() *wireguard.Conf { peer := wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...), - Endpoint: t.updateEndpoint(s.kiloEndpoint, s.key, &s.persistentKeepalive), + Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive), PersistentKeepaliveInterval: &t.persistentKeepalive, PublicKey: s.key, ReplaceAllowedIPs: true, }, - KiloEndpoint: s.kiloEndpoint, + Addr: s.addr, } c.Peers = append(c.Peers, peer) } @@ -328,13 +328,13 @@ func (t *Topology) Conf() *wireguard.Conf { peer := wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: p.AllowedIPs, - Endpoint: t.updateEndpoint(p.KiloEndpoint, p.PublicKey, p.PersistentKeepaliveInterval), + Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval), PersistentKeepaliveInterval: &t.persistentKeepalive, PresharedKey: p.PresharedKey, PublicKey: p.PublicKey, ReplaceAllowedIPs: true, }, - KiloEndpoint: p.KiloEndpoint, + Addr: p.Addr, } c.Peers = append(c.Peers, peer) } @@ -354,7 +354,7 @@ func (t *Topology) AsPeer() wireguard.Peer { PublicKey: s.key, Endpoint: s.endpoint, }, - KiloEndpoint: s.kiloEndpoint, + Addr: s.addr, } return p } @@ -377,12 +377,12 @@ func (t *Topology) PeerConf(name string) wireguard.Conf { peer := wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: s.allowedIPs, - Endpoint: s.kiloEndpoint.UDPAddr(), + Endpoint: s.endpoint, PersistentKeepaliveInterval: pka, PresharedKey: psk, PublicKey: s.key, }, - KiloEndpoint: s.kiloEndpoint, + Addr: s.addr, } c.Peers = append(c.Peers, peer) } @@ -417,13 +417,13 @@ func findLeader(nodes []*Node) int { var leaders, public []int for i := range nodes { if nodes[i].Leader { - if isPublic(nodes[i].KiloEndpoint.IP) { + if isPublic(nodes[i].Endpoint.IP) { return i } leaders = append(leaders, i) } - if isPublic(nodes[i].KiloEndpoint.IP) { + if nodes[i].Endpoint != nil && isPublic(nodes[i].Endpoint.IP) { public = append(public, i) } } @@ -449,7 +449,7 @@ func deduplicatePeerIPs(peers []*Peer) []*Peer { PresharedKey: peer.PresharedKey, PublicKey: peer.PublicKey, }, - KiloEndpoint: peer.KiloEndpoint, + Addr: peer.Addr, }, } for _, ip := range peer.AllowedIPs { diff --git a/pkg/mesh/topology_test.go b/pkg/mesh/topology_test.go index c6c5ec0..d05fea5 100644 --- a/pkg/mesh/topology_test.go +++ b/pkg/mesh/topology_test.go @@ -60,7 +60,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) nodes := map[string]*Node{ "a": { Name: "a", - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, + Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort}, InternalIP: i1, Location: "1", Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)}, @@ -69,7 +69,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) }, "b": { Name: "b", - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, + Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort}, InternalIP: i1, Location: "2", Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)}, @@ -77,17 +77,17 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) AllowedLocationIPs: []net.IPNet{*i3}, }, "c": { - Name: "c", - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e3.IP}, Port: DefaultKiloPort}, - InternalIP: i2, + Name: "c", + Endpoint: &net.UDPAddr{IP: e3.IP, Port: DefaultKiloPort}, + InternalIP: i2, // Same location as node b. Location: "2", Subnet: &net.IPNet{IP: net.ParseIP("10.2.3.0"), Mask: net.CIDRMask(24, 32)}, Key: key3, }, "d": { - Name: "d", - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e4.IP}, Port: DefaultKiloPort}, + Name: "d", + Endpoint: &net.UDPAddr{IP: e4.IP, Port: DefaultKiloPort}, // Same location as node a, but without private IP Location: "1", Subnet: &net.IPNet{IP: net.ParseIP("10.2.4.0"), Mask: net.CIDRMask(24, 32)}, @@ -115,10 +115,10 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) {IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)}, }, PublicKey: key5, - }, - KiloEndpoint: &wireguard.Endpoint{ - DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("192.168.0.1")}, - Port: DefaultKiloPort, + Endpoint: &net.UDPAddr{ + IP: net.ParseIP("192.168.0.1"), + Port: DefaultKiloPort, + }, }, }, }, @@ -153,7 +153,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["a"].KiloEndpoint, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: logicalLocationPrefix + nodes["a"].Location, @@ -164,7 +164,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["b"].KiloEndpoint, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: logicalLocationPrefix + nodes["b"].Location, @@ -176,7 +176,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["d"].KiloEndpoint, + endpoint: nodes["d"].Endpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -204,7 +204,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["a"].KiloEndpoint, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: logicalLocationPrefix + nodes["a"].Location, @@ -215,7 +215,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["b"].KiloEndpoint, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: logicalLocationPrefix + nodes["b"].Location, @@ -227,7 +227,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["d"].KiloEndpoint, + endpoint: nodes["d"].Endpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -255,7 +255,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["a"].KiloEndpoint, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: logicalLocationPrefix + nodes["a"].Location, @@ -266,7 +266,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["b"].KiloEndpoint, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: logicalLocationPrefix + nodes["b"].Location, @@ -278,7 +278,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["d"].KiloEndpoint, + endpoint: nodes["d"].Endpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -306,7 +306,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["a"].KiloEndpoint, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: nodeLocationPrefix + nodes["a"].Name, @@ -317,7 +317,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["b"].KiloEndpoint, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: nodeLocationPrefix + nodes["b"].Name, @@ -329,7 +329,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["c"].KiloEndpoint, + endpoint: nodes["c"].Endpoint, key: nodes["c"].Key, persistentKeepalive: nodes["c"].PersistentKeepalive, location: nodeLocationPrefix + nodes["c"].Name, @@ -340,7 +340,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["d"].KiloEndpoint, + endpoint: nodes["d"].Endpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -368,7 +368,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["a"].KiloEndpoint, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: nodeLocationPrefix + nodes["a"].Name, @@ -379,7 +379,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["b"].KiloEndpoint, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: nodeLocationPrefix + nodes["b"].Name, @@ -391,7 +391,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["c"].KiloEndpoint, + endpoint: nodes["c"].Endpoint, key: nodes["c"].Key, persistentKeepalive: nodes["c"].PersistentKeepalive, location: nodeLocationPrefix + nodes["c"].Name, @@ -402,7 +402,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["d"].KiloEndpoint, + endpoint: nodes["d"].Endpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -430,7 +430,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["a"].KiloEndpoint, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: nodeLocationPrefix + nodes["a"].Name, @@ -441,7 +441,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["b"].KiloEndpoint, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: nodeLocationPrefix + nodes["b"].Name, @@ -453,7 +453,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["c"].KiloEndpoint, + endpoint: nodes["c"].Endpoint, key: nodes["c"].Key, persistentKeepalive: nodes["c"].PersistentKeepalive, location: nodeLocationPrefix + nodes["c"].Name, @@ -464,7 +464,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["d"].KiloEndpoint, + endpoint: nodes["d"].Endpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -492,7 +492,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["a"].KiloEndpoint, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: nodeLocationPrefix + nodes["a"].Name, @@ -503,7 +503,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["b"].KiloEndpoint, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: nodeLocationPrefix + nodes["b"].Name, @@ -515,7 +515,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["c"].KiloEndpoint, + endpoint: nodes["c"].Endpoint, key: nodes["c"].Key, persistentKeepalive: nodes["c"].PersistentKeepalive, location: nodeLocationPrefix + nodes["c"].Name, @@ -526,7 +526,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, - kiloEndpoint: nodes["d"].KiloEndpoint, + endpoint: nodes["d"].Endpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -575,26 +575,26 @@ func TestFindLeader(t *testing.T) { nodes := []*Node{ { - Name: "a", - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, + Name: "a", + Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort}, }, { - Name: "b", - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, + Name: "b", + Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort}, }, { - Name: "c", - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, + Name: "c", + Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort}, }, { - Name: "d", - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, - Leader: true, + Name: "d", + Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort}, + Leader: true, }, { - Name: "2", - KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, - Leader: true, + Name: "2", + Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort}, + Leader: true, }, } for _, tc := range []struct { diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index 3f624a1..2fc6f66 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -43,7 +43,7 @@ const ( // Conf represents a WireGuard configuration file. type Conf struct { wgtypes.Config - // The Peers field is shadowed because every Peer needs the KiloEndpoint field that contains a DNS endpoint. + // The Peers field is shadowed because every Peer needs the Endpoint field that contains a DNS endpoint. Peers []Peer } @@ -60,16 +60,10 @@ func (c Conf) WGConfig() wgtypes.Config { return r } -// Interface represents the `interface` section of a WireGuard configuration. -type Interface struct { - ListenPort uint32 - PrivateKey []byte -} - // Peer represents a `peer` section of a WireGuard configuration. type Peer struct { wgtypes.PeerConfig - KiloEndpoint *Endpoint + Addr string // eg: dnsname:port } // DeduplicateIPs eliminates duplicate allowed IPs. @@ -86,79 +80,6 @@ func (p *Peer) DeduplicateIPs() { p.AllowedIPs = ips } -// Endpoint represents an `endpoint` key of a `peer` section. -type Endpoint struct { - DNSOrIP - Port int -} - -// String prints the string representation of the endpoint. -func (e *Endpoint) String() string { - if e == nil { - return "" - } - dnsOrIP := e.DNSOrIP.String() - if e.IP != nil && len(e.IP) == net.IPv6len { - dnsOrIP = "[" + dnsOrIP + "]" - } - return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10) -} - -// UDPAddr returns the corresponding net.UDPAddr of the Endpoint or nil. -func (e *Endpoint) UDPAddr() (u *net.UDPAddr) { - if a, err := net.ResolveUDPAddr("udp", e.String()); err == nil { - u = a - } - return -} - -// Equal compares two endpoints. -func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool { - if (e == nil) != (b == nil) { - return false - } - if e != nil { - if e.Port != b.Port { - return false - } - if DNSFirst { - // Check the DNS name first if it was resolved. - if e.DNS != b.DNS { - return false - } - if e.DNS == "" && !e.IP.Equal(b.IP) { - return false - } - } else { - // IPs take priority, so check them first. - if !e.IP.Equal(b.IP) { - return false - } - // Only check the DNS name if the IP is empty. - if e.IP == nil && e.DNS != b.DNS { - return false - } - } - } - - return true -} - -// DNSOrIP represents either a DNS name or an IP address. -// IPs, as they are more specific, are preferred. -type DNSOrIP struct { - DNS string - IP net.IP -} - -// String prints the string representation of the struct. -func (d DNSOrIP) String() string { - if d.IP != nil { - return d.IP.String() - } - return d.DNS -} - // Bytes renders a WireGuard configuration to bytes. func (c Conf) Bytes() ([]byte, error) { var err error @@ -188,13 +109,13 @@ func (c Conf) Bytes() ([]byte, error) { if err = writeAllowedIPs(buf, p.AllowedIPs); err != nil { return nil, fmt.Errorf("failed to write allowed IPs: %v", err) } - if err = writeEndpoint(buf, p.KiloEndpoint); err != nil { + if err = writeEndpoint(buf, p.Endpoint, p.Addr); err != nil { return nil, fmt.Errorf("failed to write endpoint: %v", err) } if p.PersistentKeepaliveInterval == nil { p.PersistentKeepaliveInterval = new(time.Duration) } - if err = writeValue(buf, persistentKeepaliveKey, strconv.FormatUint(uint64(*p.PersistentKeepaliveInterval), 10)); err != nil { + if err = writeValue(buf, persistentKeepaliveKey, strconv.FormatUint(uint64(*p.PersistentKeepaliveInterval/time.Second), 10)); err != nil { return nil, fmt.Errorf("failed to write persistent keepalive: %v", err) } if err = writePKey(buf, presharedKeyKey, p.PresharedKey); err != nil { @@ -327,15 +248,20 @@ func writeValue(buf *bytes.Buffer, k key, v string) error { return buf.WriteByte('\n') } -func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error { - if e == nil { +func writeEndpoint(buf *bytes.Buffer, e *net.UDPAddr, d string) error { + str := "" + if d != "" { + str = d + } else if e != nil { + str = e.String() + } else { return nil } var err error if err = writeKey(buf, endpointKey); err != nil { return err } - if _, err = buf.WriteString(e.String()); err != nil { + if _, err = buf.WriteString(str); err != nil { return err } return buf.WriteByte('\n') diff --git a/pkg/wireguard/conf_test.go b/pkg/wireguard/conf_test.go deleted file mode 100644 index 629c2fd..0000000 --- a/pkg/wireguard/conf_test.go +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2021 the Kilo authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wireguard - -import ( - "net" - "testing" -) - -func TestCompareEndpoint(t *testing.T) { - for _, tc := range []struct { - name string - a *Endpoint - b *Endpoint - dnsFirst bool - out bool - }{ - { - name: "both nil", - a: nil, - b: nil, - out: true, - }, - { - name: "a nil", - a: nil, - b: &Endpoint{}, - out: false, - }, - { - name: "b nil", - a: &Endpoint{}, - b: nil, - out: false, - }, - { - name: "zero", - a: &Endpoint{}, - b: &Endpoint{}, - out: true, - }, - { - name: "diff port", - a: &Endpoint{Port: 1234}, - b: &Endpoint{Port: 5678}, - out: false, - }, - { - name: "same IP", - a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}}, - b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}}, - out: true, - }, - { - name: "diff IP", - a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}}, - b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2")}}, - out: false, - }, - { - name: "same IP ignore DNS", - a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}}, - b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "b"}}, - out: true, - }, - { - name: "no IP check DNS", - a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}}, - b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}}, - out: false, - }, - { - name: "no IP check DNS (same)", - a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}}, - b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}}, - out: true, - }, - { - name: "DNS first, ignore IP", - a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}}, - b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}}, - dnsFirst: true, - out: true, - }, - { - name: "DNS first", - a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}}, - b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}}, - dnsFirst: true, - out: false, - }, - { - name: "DNS first, no DNS compare IP", - a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}}, - b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}}, - dnsFirst: true, - out: false, - }, - { - name: "DNS first, no DNS compare IP (same)", - a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}}, - b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}}, - dnsFirst: true, - out: true, - }, - } { - equal := tc.a.Equal(tc.b, tc.dnsFirst) - if equal != tc.out { - t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal) - } - } -}