diff --git a/README.md b/README.md index e3b60e6..b5b58d7 100644 --- a/README.md +++ b/README.md @@ -57,8 +57,9 @@ Kilo allows the topology of the encrypted network to be completely customized. ### Step 4: ensure nodes have public IP -At least one node in each location must have a public IP address. -If the public IP address is not automatically configured on the node's Ethernet device, it can be manually specified using the [kilo.squat.ai/force-external-ip](./docs/annotations.md#force-external-ip) annotation. +At least one node in each location must have an IP address that is routable from the other locations. +If the locations are in different clouds or private networks, then this must be a public IP address. +If this IP address is not automatically configured on the node's Ethernet device, it can be manually specified using the [kilo.squat.ai/force-endpoint](./docs/annotations.md#force-endpoint) annotation. ### Step 5: install Kilo! diff --git a/docs/annotations.md b/docs/annotations.md index 0ae1a0f..de5de25 100644 --- a/docs/annotations.md +++ b/docs/annotations.md @@ -2,19 +2,22 @@ The following annotations can be added to any Kubernetes Node object to configure the Kilo network. -|Name|type|example| +|Name|type|examples| |----|----|-------| -|[kilo.squat.ai/force-external-ip](#force-external-ip)|CIDR|`"55.55.55.55/32"`| -|[kilo.squat.ai/force-internal-ip](#force-internal-ip)|CIDR|`"55.55.55.55/32"`| -|[kilo.squat.ai/leader](#leader)|string|`""`| -|[kilo.squat.ai/location](#location)|string|`"gcp-east"`| +|[kilo.squat.ai/force-endpoint](#force-endpoint)|host:port|`55.55.55.55:51820`, `example.com:1337| +|[kilo.squat.ai/force-internal-ip](#force-internal-ip)|CIDR|`55.55.55.55/32`| +|[kilo.squat.ai/leader](#leader)|string|`""`, `true`| +|[kilo.squat.ai/location](#location)|string|`gcp-east`, `lab`| -### force-external-ip -Kilo requires at least one node in each location to have a publicly accessible IP address in order to create links to other locations. -The Kilo agent running on each node will use heuristics to automatically detect an external IP address for the node; however, in some circumstances it may be necessary to explicitly configure the IP address, for example: +### force-endpoint +In order to create links between locations, Kilo requires at least one node in each location to have an endpoint, ie a `host:port` combination, that is routable from the other locations. +If the locations are in different cloud providers or in different private networks, then the `host` portion of the endpoint should be a publicly accessible IP address, or a DNS name that resolves to a public IP, so that the other locations can route packets to it. +The Kilo agent running on each node will use heuristics to automatically detect an external IP address for the node and correctly configure its endpoint; however, in some circumstances it may be necessary to explicitly configure the endpoint to use, for example: * _no automatic public IP on ethernet device_: on some cloud providers it is common for nodes to be allocated a public IP address but for the Ethernet devices to only be automatically configured with the private network address; in this case the allocated public IP address should be specified; * _multiple public IP addresses_: if a node has multiple public IPs but one is preferred, then the preferred IP address should be specified; - * _IPv6_: if a node has both public IPv4 and IPv6 addresses and the Kilo network should operate over IPv6, then the IPv6 address should be specified. + * _IPv6_: if a node has both public IPv4 and IPv6 addresses and the Kilo network should operate over IPv6, then the IPv6 address should be specified; + * _dynamic IP address_: if a node has a dynamically allocated public IP address, for example an IP leased from a network provider, then a dynamic DNS name can be given can be given and Kilo will periodically lookup the IP to keep the endpoint up-to-date; + * _override port_: if a node should listen on a specific port that is different from the mesh's default WireGuard port, then this annotation can be used to override the port; this can be useful, for example, to ensure that two nodes operating behind the same port-forwarded NAT gateway can each be allocated a different port. ### force-internal-ip Kilo routes packets destined for nodes inside the same logical location using the node's internal IP address. diff --git a/pkg/calico/calico b/pkg/calico/calico new file mode 100644 index 0000000..ff6962a --- /dev/null +++ b/pkg/calico/calico @@ -0,0 +1,179 @@ +// Copyright 2019 the Kilo authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package calico + +import ( + "errors" + "net" + "time" + + "k8s.io/apimachinery/pkg/labels" + v1informers "k8s.io/client-go/informers/core/v1" + "k8s.io/client-go/kubernetes" + v1listers "k8s.io/client-go/listers/core/v1" + "k8s.io/client-go/tools/cache" + + "github.com/squat/kilo/pkg/ipset" + "github.com/squat/kilo/pkg/mesh" +) + +type Compatibility interface { + Apply(*mesh.Topology, mesh.Encapsulate) error + Backend(mesh.Backend) mesh.Backend + CleanUp() error + Run(stop <-chan struct{}) (<-chan error, error) +} + +// calico is a Calico compatibility layer. +type calico struct { + client kubernetes.Interface + errors chan error + ipset *ipset.Set +} + +// New generates a new ipset. +func New(c kubernetes.Interface) Compatibility { + return &calico{ + client: c, + errors: make(chan error), + // This is a patch until Calico supports + // other hosts adding IPIP iptables rules. + ipset: ipset.New("cali40all-hosts-net"), + } +} + +// Run implements the mesh.Compatibility interface. +// It runs the ipset controller and forwards errors along. +func (c *calico) Run(stop <-chan struct{}) (<-chan error, error) { + return c.ipset.Run(stop) +} + +// CleanUp stops the compatibility layer's controllers. +func (c *calico) CleanUp() error { + return c.ipset.CleanUp() +} + +type backend struct { + backend mesh.Backend + client kubernetes.Interface + events chan *mesh.NodeEvent + informer cache.SharedIndexInformer + lister v1listers.NodeLister +} + +func (c *calico) Apply(t *mesh.Topology, encapsulate mesh.Encapsulate, location string) error { + if encapsulate == mesh.NeverEncapsulate { + return nil + } + var peers []net.IP + for _, s := range t.segments { + if s.location == location { + peers = s.privateIPs + break + } + } + return c.ipset.Set(peers) +} + +func (c *calico) Backend(b mesh.Backend) mesh.Backend { + ni := v1informers.NewNodeInformer(c.client, 5*time.Minute, nil) + return &backend{ + backend: b, + events: make(chan *mesh.NodeEvent), + informer: ni, + lister: v1listers.NewNodeLister(ni.GetIndexer()), + } +} + +// Nodes implements the mesh.Backend interface. +func (b *backend) Nodes() mesh.NodeBackend { + return b +} + +// Peers implements the mesh.Backend interface. +func (b *backend) Peers() mesh.PeerBackend { + // The Calico compatibility backend only wraps the node backend. + return b.backend.Peers() +} + +// CleanUp removes configuration applied to the backend. +func (b *backend) CleanUp(name string) error { + return b.backend.Nodes().CleanUp(name) +} + +// Get gets a single Node by name. +func (b *backend) Get(name string) (*mesh.Node, error) { + n, err := b.lister.Get(name) + if err != nil { + return nil, err + } + m, err := b.backend.Nodes().Get(name) + if err != nil { + return nil, err + } + return translateNode(n, m), nil +} + +// Init initializes the backend; for this backend that means +// syncing the informer cache and the wrapped backend. +func (b *backend) Init(stop <-chan struct{}) error { + if err := b.backend.Nodes().Init(stop); err != nil { + return err + } + go b.informer.Run(stop) + if ok := cache.WaitForCacheSync(stop, func() bool { + return b.informer.HasSynced() + }); !ok { + return errors.New("failed to sync node cache") + } + go func() { + w := b.backend.Nodes().Watch() + var ne *mesh.NodeEvent + for { + select { + case ne = <-w: + b.events <- &mesh.NodeEvent{Type: ne.Type, Node: translateNode(n, ne.Node)} + case <-stop: + return + } + } + }() + return nil +} + +// List gets all the Nodes in the cluster. +func (b *backend) List() ([]*mesh.Node, error) { + ns, err := b.lister.List(labels.Everything()) + if err != nil { + return nil, err + } + nodes := make([]*mesh.Node, len(ns)) + for i := range ns { + nodes[i] = translateNode(ns[i]) + } + return nodes, nil +} + +// Set sets the fields of a node. +func (b *backend) Set(name string, node *mesh.Node) error { + // The Calico compatibility backend is read-only. + // Proxy all writes to the underlying backend. + return b.backend.Nodes().Set(name, node) +} + +// Watch returns a chan of node events. +func (b *backend) Watch() <-chan *mesh.NodeEvent { + return b.events +} diff --git a/pkg/encapsulation/none.go b/pkg/encapsulation/none.go new file mode 100644 index 0000000..774987d --- /dev/null +++ b/pkg/encapsulation/none.go @@ -0,0 +1,63 @@ +// Copyright 2019 the Kilo authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package encapsulation + +import ( + "net" + + "github.com/squat/kilo/pkg/iptables" +) + +type none Strategy + +// NewNone returns an new encapsulator that does not encapsulate. +func NewNone(strategy Strategy) Encapsulator { + return none(strategy) +} + +// CleanUp is a no-op. +func (n none) CleanUp() error { + return nil +} + +// Gw always returns nil. +func (n none) Gw(_, _ net.IP, _ *net.IPNet) net.IP { + return nil +} + +// Index always returns 0. +func (n none) Index() int { + return 0 +} + +// Init is a no-op. +func (n none) Init(base int) error { + return nil +} + +// Rules always returns an empty list. +func (n none) Rules(_ []*net.IPNet) []iptables.Rule { + return nil +} + +// Set is a no-op. +func (n none) Set(_ *net.IPNet) error { + return nil +} + +// Strategy returns the configured strategy for encapsulation. +func (n none) Strategy() Strategy { + return Strategy(n) +} diff --git a/pkg/k8s/backend.go b/pkg/k8s/backend.go index 6620735..8f8b615 100644 --- a/pkg/k8s/backend.go +++ b/pkg/k8s/backend.go @@ -32,6 +32,7 @@ import ( "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/strategicpatch" + "k8s.io/apimachinery/pkg/util/validation" v1informers "k8s.io/client-go/informers/core/v1" "k8s.io/client-go/kubernetes" v1listers "k8s.io/client-go/listers/core/v1" @@ -48,8 +49,8 @@ import ( const ( // Backend is the name of this mesh backend. Backend = "kubernetes" - externalIPAnnotationKey = "kilo.squat.ai/external-ip" - forceExternalIPAnnotationKey = "kilo.squat.ai/force-external-ip" + endpointAnnotationKey = "kilo.squat.ai/endpoint" + forceEndpointAnnotationKey = "kilo.squat.ai/force-endpoint" forceInternalIPAnnotationKey = "kilo.squat.ai/force-internal-ip" internalIPAnnotationKey = "kilo.squat.ai/internal-ip" keyAnnotationKey = "kilo.squat.ai/key" @@ -119,7 +120,7 @@ func New(c kubernetes.Interface, kc kiloclient.Interface, ec apiextensions.Inter // CleanUp removes configuration applied to the backend. func (nb *nodeBackend) CleanUp(name string) error { patch := []byte("[" + strings.Join([]string{ - fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(externalIPAnnotationKey, "/", jsonPatchSlash, 1))), + fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(endpointAnnotationKey, "/", jsonPatchSlash, 1))), fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(internalIPAnnotationKey, "/", jsonPatchSlash, 1))), fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(keyAnnotationKey, "/", jsonPatchSlash, 1))), fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(lastSeenAnnotationKey, "/", jsonPatchSlash, 1))), @@ -205,7 +206,7 @@ func (nb *nodeBackend) Set(name string, node *mesh.Node) error { return fmt.Errorf("failed to find node: %v", err) } n := old.DeepCopy() - n.ObjectMeta.Annotations[externalIPAnnotationKey] = node.ExternalIP.String() + n.ObjectMeta.Annotations[endpointAnnotationKey] = node.Endpoint.String() n.ObjectMeta.Annotations[internalIPAnnotationKey] = node.InternalIP.String() n.ObjectMeta.Annotations[keyAnnotationKey] = string(node.Key) n.ObjectMeta.Annotations[lastSeenAnnotationKey] = strconv.FormatInt(node.LastSeen, 10) @@ -254,14 +255,15 @@ func translateNode(node *v1.Node) *mesh.Node { if !ok { location = node.ObjectMeta.Labels[regionLabelKey] } - // Allow the IPs to be overridden. - externalIP, ok := node.ObjectMeta.Annotations[forceExternalIPAnnotationKey] - if !ok { - externalIP = node.ObjectMeta.Annotations[externalIPAnnotationKey] + // Allow the endpoint to be overridden. + endpoint := parseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey]) + if endpoint == nil { + endpoint = parseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey]) } - internalIP, ok := node.ObjectMeta.Annotations[forceInternalIPAnnotationKey] - if !ok { - internalIP = node.ObjectMeta.Annotations[internalIPAnnotationKey] + // Allow the internal IP to be overridden. + internalIP := normalizeIP(node.ObjectMeta.Annotations[forceInternalIPAnnotationKey]) + if internalIP == nil { + internalIP = normalizeIP(node.ObjectMeta.Annotations[internalIPAnnotationKey]) } // Set Wireguard PersistentKeepalive setting for the node. var persistentKeepalive int64 @@ -281,12 +283,12 @@ func translateNode(node *v1.Node) *mesh.Node { } } return &mesh.Node{ - // ExternalIP and InternalIP should only ever fail to parse if the + // Endpoint and InternalIP should only ever fail to parse if the // remote node's agent has not yet set its IP address; // in this case the IP will be nil and // the mesh can wait for the node to be updated. - ExternalIP: normalizeIP(externalIP), - InternalIP: normalizeIP(internalIP), + Endpoint: endpoint, + InternalIP: internalIP, Key: []byte(node.ObjectMeta.Annotations[keyAnnotationKey]), LastSeen: lastSeen, Leader: leader, @@ -325,8 +327,8 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer { } if peer.Spec.Endpoint.Port > 0 && ip != nil { endpoint = &wireguard.Endpoint{ - IP: ip, - Port: peer.Spec.Endpoint.Port, + DNSOrIP: wireguard.DNSOrIP{IP: ip}, + Port: peer.Spec.Endpoint.Port, } } } @@ -487,3 +489,35 @@ func normalizeIP(ip string) *net.IPNet { ipNet.IP = i.To16() return ipNet } + +func parseEndpoint(endpoint string) *wireguard.Endpoint { + if len(endpoint) == 0 { + return nil + } + parts := strings.Split(endpoint, ":") + if len(parts) < 2 { + return nil + } + portRaw := parts[len(parts)-1] + hostRaw := strings.Trim(strings.Join(parts[:len(parts)-1], ":"), "[]") + port, err := strconv.ParseUint(portRaw, 10, 32) + if err != nil { + return nil + } + if len(validation.IsValidPortNum(int(port))) != 0 { + return nil + } + ip := net.ParseIP(hostRaw) + if ip == nil { + if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 { + return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: hostRaw}, Port: uint32(port)} + } + return nil + } + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } else { + ip = ip.To16() + } + return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: ip}, Port: uint32(port)} +} diff --git a/pkg/k8s/backend_test.go b/pkg/k8s/backend_test.go index a0f8afe..530a365 100644 --- a/pkg/k8s/backend_test.go +++ b/pkg/k8s/backend_test.go @@ -42,7 +42,7 @@ func TestTranslateNode(t *testing.T) { { name: "invalid ips", annotations: map[string]string{ - externalIPAnnotationKey: "10.0.0.1", + endpointAnnotationKey: "10.0.0.1", internalIPAnnotationKey: "foo", }, out: &mesh.Node{}, @@ -50,11 +50,11 @@ func TestTranslateNode(t *testing.T) { { name: "valid ips", annotations: map[string]string{ - externalIPAnnotationKey: "10.0.0.1/24", + endpointAnnotationKey: "10.0.0.1:51820", internalIPAnnotationKey: "10.0.0.2/32", }, out: &mesh.Node{ - ExternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)}, + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort}, InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)}, }, }, @@ -102,13 +102,23 @@ func TestTranslateNode(t *testing.T) { }, }, { - name: "external IP override", + name: "invalid endpoint override", annotations: map[string]string{ - externalIPAnnotationKey: "10.0.0.1/24", - forceExternalIPAnnotationKey: "10.0.0.2/24", + endpointAnnotationKey: "10.0.0.1:51820", + forceEndpointAnnotationKey: "-10.0.0.2:51821", }, out: &mesh.Node{ - ExternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)}, + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort}, + }, + }, + { + name: "endpoint override", + annotations: map[string]string{ + endpointAnnotationKey: "10.0.0.1:51820", + forceEndpointAnnotationKey: "10.0.0.2:51821", + }, + out: &mesh.Node{ + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821}, }, }, { @@ -120,6 +130,16 @@ func TestTranslateNode(t *testing.T) { PersistentKeepalive: 25, }, }, + { + name: "invalid internal IP override", + annotations: map[string]string{ + internalIPAnnotationKey: "10.1.0.1/24", + forceInternalIPAnnotationKey: "-10.1.0.2/24", + }, + out: &mesh.Node{ + InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.1"), Mask: net.CIDRMask(24, 32)}, + }, + }, { name: "internal IP override", annotations: map[string]string{ @@ -140,8 +160,8 @@ func TestTranslateNode(t *testing.T) { { name: "complete", annotations: map[string]string{ - externalIPAnnotationKey: "10.0.0.1/24", - forceExternalIPAnnotationKey: "10.0.0.2/24", + endpointAnnotationKey: "10.0.0.1:51820", + forceEndpointAnnotationKey: "10.0.0.2:51821", forceInternalIPAnnotationKey: "10.1.0.2/32", internalIPAnnotationKey: "10.1.0.1/32", keyAnnotationKey: "foo", @@ -155,7 +175,7 @@ func TestTranslateNode(t *testing.T) { regionLabelKey: "a", }, out: &mesh.Node{ - ExternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)}, + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821}, InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)}, Key: []byte("foo"), LastSeen: 1000000000, @@ -237,8 +257,8 @@ func TestTranslatePeer(t *testing.T) { out: &mesh.Peer{ Peer: wireguard.Peer{ Endpoint: &wireguard.Endpoint{ - IP: net.ParseIP("10.0.0.1"), - Port: mesh.DefaultKiloPort, + DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, + Port: mesh.DefaultKiloPort, }, }, }, @@ -288,3 +308,52 @@ func TestTranslatePeer(t *testing.T) { } } } + +func TestParseEndpoint(t *testing.T) { + for _, tc := range []struct { + name string + endpoint string + out *wireguard.Endpoint + }{ + { + name: "empty", + endpoint: "", + out: nil, + }, + { + name: "invalid IP", + endpoint: "10.0.0.:51820", + out: nil, + }, + { + name: "invalid hostname", + endpoint: "foo-:51820", + out: nil, + }, + { + name: "invalid port", + endpoint: "10.0.0.1:100000000", + out: nil, + }, + { + name: "valid IP", + endpoint: "10.0.0.1:51820", + out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort}, + }, + { + name: "valid IPv6", + endpoint: "[ff02::114]:51820", + out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("ff02::114")}, Port: mesh.DefaultKiloPort}, + }, + { + name: "valid hostname", + endpoint: "foo:51821", + out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: "foo"}, Port: 51821}, + }, + } { + endpoint := parseEndpoint(tc.endpoint) + if diff := pretty.Compare(endpoint, tc.out); diff != "" { + t.Errorf("test case %q: got diff: %v", tc.name, diff) + } + } +} diff --git a/pkg/mesh/graph.go b/pkg/mesh/graph.go index bb5f72b..4ebd220 100644 --- a/pkg/mesh/graph.go +++ b/pkg/mesh/graph.go @@ -17,8 +17,10 @@ package mesh import ( "fmt" "net" + "strings" "github.com/awalterschulze/gographviz" + "github.com/squat/kilo/pkg/wireguard" ) // Dot generates a Graphviz graph of the Topology in DOT fomat. @@ -60,13 +62,15 @@ func (t *Topology) Dot() (string, error) { return "", fmt.Errorf("failed to add node to subgraph") } var wg net.IP + var endpoint *wireguard.Endpoint if j == s.leader { wg = s.wireGuardIP + endpoint = s.endpoint if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Rank), "1"); err != nil { return "", fmt.Errorf("failed to add rank to node") } } - if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], s.privateIPs[j], wg)); err != nil { + if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], s.privateIPs[j], wg, endpoint)); err != nil { return "", fmt.Errorf("failed to add label to node") } } @@ -146,14 +150,22 @@ func subGraphName(name string) string { return graphEscape(fmt.Sprintf("cluster_location_%s", name)) } -func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP) string { - var wg string - if wgIP != nil { - wg = wgIP.String() +func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoint *wireguard.Endpoint) string { + label := []string{ + location, + name, + cidr.String(), + priv.String(), } - return graphEscape(fmt.Sprintf("%s\n%s\n%s\n%s\n%s", location, name, cidr.String(), priv.String(), wg)) + if wgIP != nil { + label = append(label, wgIP.String()) + } + if endpoint != nil { + label = append(label, endpoint.String()) + } + return graphEscape(strings.Join(label, "\n")) } func peerLabel(peer *Peer) string { - return graphEscape(fmt.Sprintf("%s\n%s\n", peer.Name, peer.Endpoint.IP.String())) + return graphEscape(fmt.Sprintf("%s\n%s\n", peer.Name, peer.Endpoint.String())) } diff --git a/pkg/mesh/ip.go b/pkg/mesh/ip.go index 760321e..92d2667 100644 --- a/pkg/mesh/ip.go +++ b/pkg/mesh/ip.go @@ -72,7 +72,7 @@ func getIP(hostname string, ignoreIfaces ...int) (*net.IPNet, *net.IPNet, error) continue } ip.Mask = mask - if isPublic(ip) { + if isPublic(ip.IP) { hostPub = append(hostPub, ip) continue } @@ -97,7 +97,7 @@ func getIP(hostname string, ignoreIfaces ...int) (*net.IPNet, *net.IPNet, error) if isLocal(ip.IP) { continue } - if isPublic(ip) { + if isPublic(ip.IP) { defaultPub = append(defaultPub, ip) continue } @@ -118,7 +118,7 @@ func getIP(hostname string, ignoreIfaces ...int) (*net.IPNet, *net.IPNet, error) if isLocal(ip.IP) { continue } - if isPublic(ip) { + if isPublic(ip.IP) { interfacePub = append(interfacePub, ip) continue } @@ -206,9 +206,9 @@ func isLocal(ip net.IP) bool { return ip.IsLoopback() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() } -func isPublic(ip *net.IPNet) bool { +func isPublic(ip net.IP) bool { // Check RFC 1918 addresses. - if ip4 := ip.IP.To4(); ip4 != nil { + if ip4 := ip.To4(); ip4 != nil { switch true { // Check for 10.0.0.0/8. case ip4[0] == 10: @@ -224,10 +224,10 @@ func isPublic(ip *net.IPNet) bool { } } // Check RFC 4193 addresses. - if len(ip.IP) == net.IPv6len { + if len(ip) == net.IPv6len { switch true { // Check for fd00::/8. - case ip.IP[0] == 0xfd && ip.IP[1] == 0x00: + case ip[0] == 0xfd && ip[1] == 0x00: return false default: return true diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index 292a93a..230fdcc 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -71,7 +71,7 @@ const ( // Node represents a node in the network. type Node struct { - ExternalIP *net.IPNet + Endpoint *wireguard.Endpoint Key []byte InternalIP *net.IPNet // LastSeen is a Unix time for the last time @@ -90,7 +90,7 @@ type Node struct { // Ready indicates whether or not the node is ready. func (n *Node) Ready() bool { // Nodes that are not leaders will not have WireGuardIPs, so it is not required. - return n != nil && n.ExternalIP != nil && n.Key != nil && n.InternalIP != nil && n.Subnet != nil && time.Now().Unix()-n.LastSeen < int64(resyncPeriod)*2/int64(time.Second) + return n != nil && n.Endpoint != nil && !(n.Endpoint.IP == nil && n.Endpoint.DNS == "") && n.Endpoint.Port != 0 && n.Key != nil && n.InternalIP != nil && n.Subnet != nil && time.Now().Unix()-n.LastSeen < int64(resyncPeriod)*2/int64(time.Second) } // Peer represents a peer in the network. @@ -100,6 +100,10 @@ type Peer struct { } // Ready indicates whether or not the peer is ready. +// Peers can have empty endpoints because they may not have an +// IP, for example if they are behind a NAT, and thus +// will not declare their endpoint and instead allow it to be +// discovered. func (p *Peer) Ready() bool { return p != nil && p.AllowedIPs != nil && len(p.AllowedIPs) != 0 && p.PublicKey != nil } @@ -186,7 +190,6 @@ type Mesh struct { priv []byte privIface int pub []byte - pubIface int stop chan struct{} subnet *net.IPNet table *route.Table @@ -238,11 +241,6 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit return nil, fmt.Errorf("failed to find interface for private IP: %v", err) } privIface := ifaces[0].Index - ifaces, err = interfacesForIP(publicIP) - if err != nil { - return nil, fmt.Errorf("failed to find interface for public IP: %v", err) - } - pubIface := ifaces[0].Index kiloIface, _, err := wireguard.New(iface) if err != nil { return nil, fmt.Errorf("failed to create WireGuard interface: %v", err) @@ -276,7 +274,6 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit priv: private, privIface: privIface, pub: public, - pubIface: pubIface, local: local, stop: make(chan struct{}), subnet: subnet, @@ -511,8 +508,8 @@ func (m *Mesh) checkIn() { func (m *Mesh) handleLocal(n *Node) { // Allow the IPs to be overridden. - if n.ExternalIP == nil { - n.ExternalIP = m.externalIP + if n.Endpoint == nil || (n.Endpoint.DNS == "" && n.Endpoint.IP == nil) { + n.Endpoint = &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: m.externalIP.IP}, Port: m.port} } if n.InternalIP == nil { n.InternalIP = m.internalIP @@ -521,7 +518,7 @@ func (m *Mesh) handleLocal(n *Node) { // Take leader, location, and subnet from the argument, as these // are not determined by kilo. local := &Node{ - ExternalIP: n.ExternalIP, + Endpoint: n.Endpoint, Key: m.pub, InternalIP: n.InternalIP, LastSeen: time.Now().Unix(), @@ -559,6 +556,12 @@ func (m *Mesh) applyTopology() { m.reconcileCounter.Inc() m.mu.Lock() defer m.mu.Unlock() + // If we can't resolve an endpoint, then fail and retry later. + if err := m.resolveEndpoints(); err != nil { + level.Error(m.logger).Log("error", err) + m.errorCounter.WithLabelValues("apply").Inc() + return + } // Ensure only ready nodes are considered. nodes := make(map[string]*Node) var readyNodes float64 @@ -585,7 +588,7 @@ func (m *Mesh) applyTopology() { if nodes[m.hostname] == nil { return } - t, err := NewTopology(nodes, peers, m.granularity, m.hostname, m.port, m.priv, m.subnet) + t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet) if err != nil { level.Error(m.logger).Log("error", err) m.errorCounter.WithLabelValues("apply").Inc() @@ -598,6 +601,7 @@ func (m *Mesh) applyTopology() { if err != nil { level.Error(m.logger).Log("error", err) m.errorCounter.WithLabelValues("apply").Inc() + return } if err := ioutil.WriteFile(ConfPath, buf, 0600); err != nil { level.Error(m.logger).Log("error", err) @@ -733,6 +737,57 @@ func (m *Mesh) cleanUp() { } } +func (m *Mesh) resolveEndpoints() error { + for k := range m.nodes { + // Skip unready nodes, since they will not be used + // in the topology anyways. + if !m.nodes[k].Ready() { + continue + } + // If the node is ready, then the endpoint is not nil + // but it may not have a DNS name. + if m.nodes[k].Endpoint.DNS == "" { + continue + } + if err := resolveEndpoint(m.nodes[k].Endpoint); err != nil { + return err + } + } + for k := range m.peers { + // Skip unready peers, since they will not be used + // in the topology anyways. + if !m.peers[k].Ready() { + continue + } + // If the peer is ready, then the endpoint is not nil + // but it may not have a DNS name. + if m.peers[k].Endpoint.DNS == "" { + continue + } + if err := resolveEndpoint(m.peers[k].Endpoint); err != nil { + return err + } + } + return nil +} + +func resolveEndpoint(endpoint *wireguard.Endpoint) error { + ips, err := net.LookupIP(endpoint.DNS) + if err != nil { + return fmt.Errorf("failed to look up DNS name %q: %v", endpoint.DNS, err) + } + nets := make([]*net.IPNet, len(ips), len(ips)) + for i := range ips { + nets[i] = oneAddressCIDR(ips[i]) + } + sortIPs(nets) + if len(nets) == 0 { + return fmt.Errorf("did not find any addresses for DNS name %q", endpoint.DNS) + } + endpoint.IP = nets[0].IP + return nil +} + func isSelf(hostname string, node *Node) bool { return node != nil && node.Name == hostname } @@ -744,10 +799,26 @@ func nodesAreEqual(a, b *Node) bool { if a == b { return true } + if !(a.Endpoint != nil) == (b.Endpoint != nil) { + return false + } + if a.Endpoint != nil { + if a.Endpoint.Port != b.Endpoint.Port { + return false + } + // Check the DNS name first since this package + // is doing the DNS resolution. + if a.Endpoint.DNS != b.Endpoint.DNS { + return false + } + if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) { + return false + } + } // Ignore LastSeen when comparing equality we want to check if the nodes are // equivalent. However, we do want to check if LastSeen has transitioned // between valid and invalid. - return ipNetsEqual(a.ExternalIP, b.ExternalIP) && string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() + return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() } func peersAreEqual(a, b *Peer) bool { @@ -761,7 +832,15 @@ func peersAreEqual(a, b *Peer) bool { return false } if a.Endpoint != nil { - if !a.Endpoint.IP.Equal(b.Endpoint.IP) || a.Endpoint.Port != b.Endpoint.Port { + if a.Endpoint.Port != b.Endpoint.Port { + return false + } + // Check the DNS name first since this package + // is doing the DNS resolution. + if a.Endpoint.DNS != b.Endpoint.DNS { + return false + } + if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) { return false } } diff --git a/pkg/mesh/mesh_test.go b/pkg/mesh/mesh_test.go index 24c9e13..95ec0df 100644 --- a/pkg/mesh/mesh_test.go +++ b/pkg/mesh/mesh_test.go @@ -18,6 +18,8 @@ import ( "net" "testing" "time" + + "github.com/squat/kilo/pkg/wireguard" ) func TestReady(t *testing.T) { @@ -39,7 +41,7 @@ func TestReady(t *testing.T) { ready: false, }, { - name: "empty external IP", + name: "empty endpoint", node: &Node{ InternalIP: internalIP, Key: []byte{}, @@ -47,19 +49,39 @@ func TestReady(t *testing.T) { }, ready: false, }, + { + name: "empty endpoint IP", + node: &Node{ + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{}, Port: DefaultKiloPort}, + InternalIP: internalIP, + Key: []byte{}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + }, + ready: false, + }, + { + name: "empty endpoint port", + node: &Node{ + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}}, + InternalIP: internalIP, + Key: []byte{}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + }, + ready: false, + }, { name: "empty internal IP", node: &Node{ - ExternalIP: externalIP, - Key: []byte{}, - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, + Key: []byte{}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: false, }, { name: "empty key", node: &Node{ - ExternalIP: externalIP, + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, InternalIP: internalIP, Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, @@ -68,7 +90,7 @@ func TestReady(t *testing.T) { { name: "empty subnet", node: &Node{ - ExternalIP: externalIP, + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, InternalIP: internalIP, Key: []byte{}, }, @@ -77,7 +99,7 @@ func TestReady(t *testing.T) { { name: "valid", node: &Node{ - ExternalIP: externalIP, + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, InternalIP: internalIP, Key: []byte{}, LastSeen: time.Now().Unix(), diff --git a/pkg/mesh/topology.go b/pkg/mesh/topology.go index c58ebef..32ad9f9 100644 --- a/pkg/mesh/topology.go +++ b/pkg/mesh/topology.go @@ -54,7 +54,7 @@ type Topology struct { type segment struct { allowedIPs []*net.IPNet - endpoint net.IP + endpoint *wireguard.Endpoint key []byte // Location is the logical location of this segment. location string @@ -122,7 +122,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra } t.segments = append(t.segments, &segment{ allowedIPs: allowedIPs, - endpoint: topoMap[location][leader].ExternalIP.IP, + endpoint: topoMap[location][leader].Endpoint, key: topoMap[location][leader].Key, location: location, cidrs: cidrs, @@ -175,7 +175,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface var gw net.IP for _, segment := range t.segments { if segment.location == t.location { - gw = enc.Gw(segment.endpoint, segment.privateIPs[segment.leader], segment.cidrs[segment.leader]) + gw = enc.Gw(segment.endpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader]) break } } @@ -316,7 +316,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface // equals the external IP. This means that the node // is only accessible through an external IP and we // cannot encapsulate traffic to an IP through the IP. - if segment.privateIPs[i].Equal(segment.endpoint) { + if segment.privateIPs[i].Equal(segment.endpoint.IP) { continue } // Add routes to the private IPs of nodes in other segments. @@ -364,11 +364,8 @@ func (t *Topology) Conf() *wireguard.Conf { continue } peer := &wireguard.Peer{ - AllowedIPs: s.allowedIPs, - Endpoint: &wireguard.Endpoint{ - IP: s.endpoint, - Port: uint32(t.port), - }, + AllowedIPs: s.allowedIPs, + Endpoint: s.endpoint, PublicKey: s.key, PersistentKeepalive: s.persistentKeepalive, } @@ -394,11 +391,8 @@ func (t *Topology) AsPeer() *wireguard.Peer { continue } return &wireguard.Peer{ - AllowedIPs: s.allowedIPs, - Endpoint: &wireguard.Endpoint{ - IP: s.endpoint, - Port: uint32(t.port), - }, + AllowedIPs: s.allowedIPs, + Endpoint: s.endpoint, PersistentKeepalive: s.persistentKeepalive, PublicKey: s.key, } @@ -411,11 +405,8 @@ func (t *Topology) PeerConf(name string) *wireguard.Conf { c := &wireguard.Conf{} for _, s := range t.segments { peer := &wireguard.Peer{ - AllowedIPs: s.allowedIPs, - Endpoint: &wireguard.Endpoint{ - IP: s.endpoint, - Port: uint32(t.port), - }, + AllowedIPs: s.allowedIPs, + Endpoint: s.endpoint, PersistentKeepalive: s.persistentKeepalive, PublicKey: s.key, } @@ -450,12 +441,13 @@ func findLeader(nodes []*Node) int { var leaders, public []int for i := range nodes { if nodes[i].Leader { - if isPublic(nodes[i].ExternalIP) { + if isPublic(nodes[i].Endpoint.IP) { return i } leaders = append(leaders, i) + } - if isPublic(nodes[i].ExternalIP) { + if isPublic(nodes[i].Endpoint.IP) { public = append(public, i) } } diff --git a/pkg/mesh/topology_test.go b/pkg/mesh/topology_test.go index eb84c52..a3e1d5b 100644 --- a/pkg/mesh/topology_test.go +++ b/pkg/mesh/topology_test.go @@ -40,7 +40,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) { nodes := map[string]*Node{ "a": { Name: "a", - ExternalIP: e1, + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, InternalIP: i1, Location: "1", Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)}, @@ -49,7 +49,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) { }, "b": { Name: "b", - ExternalIP: e2, + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, InternalIP: i1, Location: "2", Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)}, @@ -57,7 +57,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) { }, "c": { Name: "c", - ExternalIP: e3, + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e3.IP}, Port: DefaultKiloPort}, InternalIP: i2, // Same location a node b. Location: "2", @@ -83,8 +83,8 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) { {IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)}, }, Endpoint: &wireguard.Endpoint{ - IP: net.ParseIP("192.168.0.1"), - Port: DefaultKiloPort, + DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("192.168.0.1")}, + Port: DefaultKiloPort, }, PublicKey: []byte("key5"), }, @@ -119,7 +119,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].ExternalIP.IP, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, location: nodes["a"].Location, cidrs: []*net.IPNet{nodes["a"].Subnet}, @@ -130,7 +130,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].ExternalIP.IP, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, location: nodes["b"].Location, cidrs: []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet}, @@ -156,7 +156,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].ExternalIP.IP, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, location: nodes["a"].Location, cidrs: []*net.IPNet{nodes["a"].Subnet}, @@ -167,7 +167,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].ExternalIP.IP, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, location: nodes["b"].Location, cidrs: []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet}, @@ -193,7 +193,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].ExternalIP.IP, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, location: nodes["a"].Location, cidrs: []*net.IPNet{nodes["a"].Subnet}, @@ -204,7 +204,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].ExternalIP.IP, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, location: nodes["b"].Location, cidrs: []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet}, @@ -230,7 +230,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].ExternalIP.IP, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, location: nodes["a"].Name, cidrs: []*net.IPNet{nodes["a"].Subnet}, @@ -241,7 +241,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].ExternalIP.IP, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, location: nodes["b"].Name, cidrs: []*net.IPNet{nodes["b"].Subnet}, @@ -251,7 +251,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["c"].ExternalIP.IP, + endpoint: nodes["c"].Endpoint, key: nodes["c"].Key, location: nodes["c"].Name, cidrs: []*net.IPNet{nodes["c"].Subnet}, @@ -277,7 +277,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].ExternalIP.IP, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, location: nodes["a"].Name, cidrs: []*net.IPNet{nodes["a"].Subnet}, @@ -288,7 +288,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].ExternalIP.IP, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, location: nodes["b"].Name, cidrs: []*net.IPNet{nodes["b"].Subnet}, @@ -298,7 +298,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["c"].ExternalIP.IP, + endpoint: nodes["c"].Endpoint, key: nodes["c"].Key, location: nodes["c"].Name, cidrs: []*net.IPNet{nodes["c"].Subnet}, @@ -324,7 +324,7 @@ func TestNewTopology(t *testing.T) { segments: []*segment{ { allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].ExternalIP.IP, + endpoint: nodes["a"].Endpoint, key: nodes["a"].Key, location: nodes["a"].Name, cidrs: []*net.IPNet{nodes["a"].Subnet}, @@ -335,7 +335,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].ExternalIP.IP, + endpoint: nodes["b"].Endpoint, key: nodes["b"].Key, location: nodes["b"].Name, cidrs: []*net.IPNet{nodes["b"].Subnet}, @@ -345,7 +345,7 @@ func TestNewTopology(t *testing.T) { }, { allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["c"].ExternalIP.IP, + endpoint: nodes["c"].Endpoint, key: nodes["c"].Key, location: nodes["c"].Name, cidrs: []*net.IPNet{nodes["c"].Subnet}, @@ -1400,26 +1400,26 @@ func TestFindLeader(t *testing.T) { nodes := []*Node{ { - Name: "a", - ExternalIP: e1, + Name: "a", + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, }, { - Name: "b", - ExternalIP: e2, + Name: "b", + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, }, { - Name: "c", - ExternalIP: e2, + Name: "c", + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, }, { - Name: "d", - ExternalIP: e1, - Leader: true, + Name: "d", + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, + Leader: true, }, { - Name: "2", - ExternalIP: e2, - Leader: true, + Name: "2", + Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, + Leader: true, }, } for _, tc := range []struct { diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index 3e3fb76..04e2ddd 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -22,6 +22,8 @@ import ( "sort" "strconv" "strings" + + "k8s.io/apimachinery/pkg/util/validation" ) type section string @@ -75,10 +77,34 @@ func (p *Peer) DeduplicateIPs() { // Endpoint represents an `endpoint` key of a `peer` section. type Endpoint struct { - IP net.IP + DNSOrIP Port uint32 } +// String prints the string representation of the endpoint. +func (e *Endpoint) String() string { + dnsOrIP := e.DNSOrIP.String() + if e.IP != nil && len(e.IP) == net.IPv6len { + dnsOrIP = "[" + dnsOrIP + "]" + } + return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10) +} + +// DNSOrIP represents either a DNS name or an IP address. +// IPs, as they are more specific, are preferred. +type DNSOrIP struct { + DNS string + IP net.IP +} + +// String prints the string representation of the struct. +func (d DNSOrIP) String() string { + if d.IP != nil { + return d.IP.String() + } + return d.DNS +} + // Parse parses a given WireGuard configuration file and produces a Conf struct. func Parse(buf []byte) *Conf { var ( @@ -160,25 +186,30 @@ func Parse(buf []byte) *Conf { case endpointKey: // Reuse string slice. kv = strings.Split(v, ":") - if len(kv) != 2 { + if len(kv) < 2 { continue } - ip = net.ParseIP(kv[0]) - if ip == nil { - continue - } - port, err = strconv.ParseUint(kv[1], 10, 32) + port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32) if err != nil { continue } - if ip4 = ip.To4(); ip4 != nil { - ip = ip4 + d := DNSOrIP{} + ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]")) + if ip == nil { + if len(validation.IsDNS1123Subdomain(kv[0])) != 0 { + continue + } + d.DNS = kv[0] } else { - ip = ip.To16() + if ip4 = ip.To4(); ip4 != nil { + d.IP = ip4 + } else { + d.IP = ip.To16() + } } peer.Endpoint = &Endpoint{ - IP: ip, - Port: uint32(port), + DNSOrIP: d, + Port: uint32(port), } case persistentKeepaliveKey: i, err = strconv.Atoi(v) @@ -242,7 +273,7 @@ func (c *Conf) Bytes() ([]byte, error) { return buf.Bytes(), nil } -// Equal checks if two WireGuare configurations are equivalent. +// Equal checks if two WireGuard configurations are equivalent. func (c *Conf) Equal(b *Conf) bool { if (c.Interface == nil) != (b.Interface == nil) { return false @@ -272,7 +303,15 @@ func (c *Conf) Equal(b *Conf) bool { return false } if c.Peers[i].Endpoint != nil { - if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) || c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port { + if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port { + return false + } + // IPs take priority, so check them first. + if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) { + return false + } + // Only check the DNS name if the IP is empty. + if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS { return false } } @@ -352,13 +391,7 @@ func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error { if err = writeKey(buf, endpointKey); err != nil { return err } - if _, err = buf.WriteString(e.IP.String()); err != nil { - return err - } - if err = buf.WriteByte(':'); err != nil { - return err - } - if _, err = buf.WriteString(strconv.FormatUint(uint64(e.Port), 10)); err != nil { + if _, err = buf.WriteString(e.String()); err != nil { return err } return buf.WriteByte('\n') diff --git a/pkg/wireguard/conf_test.go b/pkg/wireguard/conf_test.go index c85b16b..ab00db2 100644 --- a/pkg/wireguard/conf_test.go +++ b/pkg/wireguard/conf_test.go @@ -95,6 +95,28 @@ func TestCompareConf(t *testing.T) { `), out: false, }, + { + name: "different value", + a: []byte(`[Interface] + PrivateKey = private + ListenPort = 51820 + + [Peer] + Endpoint = 10.1.0.2:51820 + PublicKey = key + AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 + `), + b: []byte(`[Interface] + PrivateKey = private + ListenPort = 51820 + + [Peer] + Endpoint = 10.1.0.2:51820 + PublicKey = key2 + AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 + `), + out: false, + }, { name: "section order", a: []byte(`[Interface]