// Copyright 2021 the Kilo authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package mesh import ( "errors" "net" "sort" "time" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/squat/kilo/pkg/wireguard" ) const ( logicalLocationPrefix = "location:" nodeLocationPrefix = "node:" ) // Topology represents the logical structure of the overlay network. type Topology struct { // key is the private key of the node creating the topology. key wgtypes.Key port int // location is the logical location of the local host. location string // nodeLocation is the location annotation of the node. This is set only in cross location topology. nodeLocation string segments []*segment peers []*Peer // hostname is the hostname of the local host. hostname string // leader represents whether or not the local host // is the segment leader. leader bool // persistentKeepalive is the interval in seconds of the emission // of keepalive packets by the local node to its peers. persistentKeepalive time.Duration // privateIP is the private IP address of the local node. privateIP *net.IPNet // subnet is the Pod subnet of the local node. subnet *net.IPNet // wireGuardCIDR is the allocated CIDR of the WireGuard // interface of the local node within the Kilo subnet. // If the local node is not the leader of a location, then // the IP is the 0th address in the subnet, i.e. the CIDR // is equal to the Kilo subnet. wireGuardCIDR *net.IPNet // discoveredEndpoints is the updated map of valid discovered Endpoints discoveredEndpoints map[string]*net.UDPAddr logger log.Logger } // segment represents one logical unit in the topology that is united by one common WireGuard IP. type segment struct { allowedIPs []net.IPNet endpoint *wireguard.Endpoint key wgtypes.Key persistentKeepalive time.Duration // location is the logical location of this segment. location string // nodeLocation is the node location annotation. This is set only for cross location topology. nodeLocation string // cidrs is a slice of subnets of all peers in the segment. cidrs []*net.IPNet // hostnames is a slice of the hostnames of the peers in the segment. hostnames []string // leader is the index of the leader of the segment. leader int // privateIPs is a slice of private IPs of all peers in the segment. privateIPs []net.IP // wireGuardIP is the allocated IP address of the WireGuard // interface on the leader of the segment. wireGuardIP net.IP // allowedLocationIPs are not part of the cluster and are not peers. // They are directly routable from nodes within the segment. // A classic example is a printer that ought to be routable from other locations. allowedLocationIPs []net.IPNet } // topoKey is used to group nodes into locations. type topoKey struct { location string nodeLocation string } // NewTopology creates a new Topology struct from a given set of nodes and peers. func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration, logger log.Logger) (*Topology, error) { if logger == nil { logger = log.NewNopLogger() } topoMap := make(map[topoKey][]*Node) var localLocation, localNodeLocation string switch granularity { case LogicalGranularity: localLocation = logicalLocationPrefix + nodes[hostname].Location if nodes[hostname].InternalIP == nil { localLocation = nodeLocationPrefix + hostname } case FullGranularity: localLocation = nodeLocationPrefix + hostname case CrossGranularity: localLocation = nodeLocationPrefix + hostname localNodeLocation = logicalLocationPrefix + nodes[hostname].Location } for _, node := range nodes { var location, nodeLocation string switch granularity { case LogicalGranularity: location = logicalLocationPrefix + node.Location // Put node in a different location, if no private // IP was found. if node.InternalIP == nil { location = nodeLocationPrefix + node.Name } case FullGranularity: location = nodeLocationPrefix + node.Name case CrossGranularity: location = nodeLocationPrefix + node.Name nodeLocation = logicalLocationPrefix + node.Location } key := topoKey{location: location, nodeLocation: nodeLocation} topoMap[key] = append(topoMap[key], node) } t := Topology{ key: key, port: port, hostname: hostname, location: localLocation, nodeLocation: localNodeLocation, persistentKeepalive: persistentKeepalive, privateIP: nodes[hostname].InternalIP, subnet: nodes[hostname].Subnet, wireGuardCIDR: subnet, discoveredEndpoints: make(map[string]*net.UDPAddr), logger: logger, } for location := range topoMap { // Sort the location so the result is stable. sort.Slice(topoMap[location], func(i, j int) bool { return topoMap[location][i].Name < topoMap[location][j].Name }) leader := findLeader(topoMap[location]) if location.nodeLocation != "" || (location.location == localLocation && topoMap[location][leader].Name == hostname) { t.leader = true } var allowedIPs []net.IPNet allowedLocationIPsMap := make(map[string]struct{}) var allowedLocationIPs []net.IPNet var cidrs []*net.IPNet var hostnames []string var privateIPs []net.IP for _, node := range topoMap[location] { // Allowed IPs should include: // - the node's allocated subnet // - the node's WireGuard IP // - the node's internal IP // - IPs that were specified by the allowed-location-ips annotation if node.Subnet != nil { allowedIPs = append(allowedIPs, *node.Subnet) } for _, ip := range node.AllowedLocationIPs { if _, ok := allowedLocationIPsMap[ip.String()]; !ok { allowedLocationIPs = append(allowedLocationIPs, ip) allowedLocationIPsMap[ip.String()] = struct{}{} } } if node.InternalIP != nil { allowedIPs = append(allowedIPs, *oneAddressCIDR(node.InternalIP.IP)) privateIPs = append(privateIPs, node.InternalIP.IP) } cidrs = append(cidrs, node.Subnet) hostnames = append(hostnames, node.Name) } // The sorting has no function, but makes testing easier. sort.Slice(allowedLocationIPs, func(i, j int) bool { return allowedLocationIPs[i].String() < allowedLocationIPs[j].String() }) t.segments = append(t.segments, &segment{ allowedIPs: allowedIPs, endpoint: topoMap[location][leader].Endpoint, key: topoMap[location][leader].Key, persistentKeepalive: topoMap[location][leader].PersistentKeepalive, location: location.location, nodeLocation: location.nodeLocation, cidrs: cidrs, hostnames: hostnames, leader: leader, privateIPs: privateIPs, allowedLocationIPs: allowedLocationIPs, }) level.Debug(t.logger).Log("msg", "generated segment", "location", location, "allowedIPs", allowedIPs, "endpoint", topoMap[location][leader].Endpoint, "cidrs", cidrs, "hostnames", hostnames, "leader", leader, "privateIPs", privateIPs, "allowedLocationIPs", allowedLocationIPs) } // Sort the Topology segments so the result is stable. sort.Slice(t.segments, func(i, j int) bool { return t.segments[i].location < t.segments[j].location }) for _, peer := range peers { t.peers = append(t.peers, peer) } // Sort the Topology peers so the result is stable. sort.Slice(t.peers, func(i, j int) bool { return t.peers[i].Name < t.peers[j].Name }) // We need to defensively deduplicate peer allowed IPs. If two peers claim the same IP, // the WireGuard configuration could flap, causing the interface to churn. t.peers = deduplicatePeerIPs(t.peers) // Copy the host node DiscoveredEndpoints in the topology as a starting point. for key := range nodes[hostname].DiscoveredEndpoints { t.discoveredEndpoints[key] = nodes[hostname].DiscoveredEndpoints[key] } // Allocate IPs to the segment leaders in a stable, coordination-free manner. a := newAllocator(*subnet) for _, segment := range t.segments { ipNet := a.next() if ipNet == nil { return nil, errors.New("failed to allocate an IP address; ran out of IP addresses") } segment.wireGuardIP = ipNet.IP segment.allowedIPs = append(segment.allowedIPs, *oneAddressCIDR(ipNet.IP)) if t.leader && segment.location == t.location { t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: subnet.Mask} } // Now that the topology is ordered, update the discoveredEndpoints map // add new ones by going through the ordered topology: segments, nodes for _, node := range topoMap[topoKey{location: segment.location, nodeLocation: segment.nodeLocation}] { for key := range node.DiscoveredEndpoints { if _, ok := t.discoveredEndpoints[key]; !ok { t.discoveredEndpoints[key] = node.DiscoveredEndpoints[key] } } } // Check for intersecting IPs in allowed location IPs segment.allowedLocationIPs = t.filterAllowedLocationIPs(segment.allowedLocationIPs, segment.location) } level.Debug(t.logger).Log("msg", "generated topology", "location", t.location, "hostname", t.hostname, "wireGuardIP", t.wireGuardCIDR, "privateIP", t.privateIP, "subnet", t.subnet, "leader", t.leader) return &t, nil } func intersect(n1, n2 net.IPNet) bool { return n1.Contains(n2.IP) || n2.Contains(n1.IP) } func (t *Topology) filterAllowedLocationIPs(ips []net.IPNet, location string) (ret []net.IPNet) { CheckIPs: for _, ip := range ips { for _, s := range t.segments { // Check if allowed location IPs are also allowed in other locations. if location != s.location { for _, i := range s.allowedLocationIPs { if intersect(ip, i) { level.Warn(t.logger).Log("msg", "overlapping allowed location IPnets", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location) continue CheckIPs } } } // Check if allowed location IPs intersect with the allowed IPs. for _, i := range s.allowedIPs { if intersect(ip, i) { level.Warn(t.logger).Log("msg", "overlapping allowed location IPnet with allowed IPnets", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location) continue CheckIPs } } // Check if allowed location IPs intersect with the private IPs of the segment. for _, i := range s.privateIPs { if ip.Contains(i) { level.Warn(t.logger).Log("msg", "overlapping allowed location IPnet with privateIP", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location) continue CheckIPs } } } // Check if allowed location IPs intersect with allowed IPs of peers. for _, p := range t.peers { for _, i := range p.AllowedIPs { if intersect(ip, i) { level.Warn(t.logger).Log("msg", "overlapping allowed location IPnet with peer IPnet", "IP", ip.String(), "IP2", i.String(), "peer", p.Name) continue CheckIPs } } } ret = append(ret, ip) } return } func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *wireguard.Endpoint { // Do not update non-nat peers if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) { return endpoint } e, ok := t.discoveredEndpoints[key.String()] if ok { return wireguard.NewEndpointFromUDPAddr(e) } return endpoint } // Conf generates a WireGuard configuration file for a given Topology. func (t *Topology) Conf() *wireguard.Conf { c := &wireguard.Conf{ Config: wgtypes.Config{ PrivateKey: &t.key, ListenPort: &t.port, ReplacePeers: true, }, } for _, s := range t.segments { if (s.location == t.location) || (t.nodeLocation != "" && t.nodeLocation == s.nodeLocation) { continue } peer := wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...), PersistentKeepaliveInterval: &t.persistentKeepalive, PublicKey: s.key, ReplaceAllowedIPs: true, }, Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive), } c.Peers = append(c.Peers, peer) } for _, p := range t.peers { peer := wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: p.AllowedIPs, PersistentKeepaliveInterval: &t.persistentKeepalive, PresharedKey: p.PresharedKey, PublicKey: p.PublicKey, ReplaceAllowedIPs: true, }, Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval), } c.Peers = append(c.Peers, peer) } return c } // AsPeer generates the WireGuard peer configuration for the local location of the given Topology. // This configuration can be used to configure this location as a peer of another WireGuard interface. func (t *Topology) AsPeer() *wireguard.Peer { for _, s := range t.segments { if s.location != t.location { continue } p := &wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: s.allowedIPs, PublicKey: s.key, }, Endpoint: s.endpoint, } return p } return nil } // PeerConf generates a WireGuard configuration file for a given peer in a Topology. func (t *Topology) PeerConf(name string) *wireguard.Conf { var pka *time.Duration var psk *wgtypes.Key for i := range t.peers { if t.peers[i].Name == name { pka = t.peers[i].PersistentKeepaliveInterval psk = t.peers[i].PresharedKey break } } c := &wireguard.Conf{} for _, s := range t.segments { peer := wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...), PersistentKeepaliveInterval: pka, PresharedKey: psk, PublicKey: s.key, }, Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive), } c.Peers = append(c.Peers, peer) } for i := range t.peers { if t.peers[i].Name == name { continue } peer := wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: t.peers[i].AllowedIPs, PersistentKeepaliveInterval: pka, PublicKey: t.peers[i].PublicKey, }, Endpoint: t.updateEndpoint(t.peers[i].Endpoint, t.peers[i].PublicKey, t.peers[i].PersistentKeepaliveInterval), } c.Peers = append(c.Peers, peer) } return c } // oneAddressCIDR takes an IP address and returns a CIDR // that contains only that address. func oneAddressCIDR(ip net.IP) *net.IPNet { return &net.IPNet{IP: ip, Mask: net.CIDRMask(len(ip)*8, len(ip)*8)} } // findLeader selects a leader for the nodes in a segment; // it will select the first node that says it should lead // or the first node in the segment if none have volunteered, // always preferring those with a public external IP address, func findLeader(nodes []*Node) int { var leaders, public []int for i := range nodes { if nodes[i].Leader { if isPublic(nodes[i].Endpoint.IP()) { return i } leaders = append(leaders, i) } if nodes[i].Endpoint.IP() != nil && isPublic(nodes[i].Endpoint.IP()) { public = append(public, i) } } if len(leaders) != 0 { return leaders[0] } if len(public) != 0 { return public[0] } return 0 } func deduplicatePeerIPs(peers []*Peer) []*Peer { ps := make([]*Peer, len(peers)) ips := make(map[string]struct{}) for i, peer := range peers { p := Peer{ Name: peer.Name, Peer: wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ PersistentKeepaliveInterval: peer.PersistentKeepaliveInterval, PresharedKey: peer.PresharedKey, PublicKey: peer.PublicKey, }, Endpoint: peer.Endpoint, }, } for _, ip := range peer.AllowedIPs { if _, ok := ips[ip.String()]; ok { continue } p.AllowedIPs = append(p.AllowedIPs, ip) ips[ip.String()] = struct{}{} } ps[i] = &p } return ps }