diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index 76bc049..cec9ec0 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -529,7 +529,9 @@ func (m *Mesh) applyTopology() { if !m.nodes[k].Ready() { continue } - nodes[k] = m.nodes[k] + // Make a shallow copy of the node. + node := *m.nodes[k] + nodes[k] = &node readyNodes++ } // Ensure only ready nodes are considered. @@ -539,7 +541,9 @@ func (m *Mesh) applyTopology() { if !m.peers[k].Ready() { continue } - peers[k] = m.peers[k] + // Make a shallow copy of the peer. + peer := *m.peers[k] + peers[k] = &peer readyPeers++ } m.nodesGuage.Set(readyNodes) @@ -548,6 +552,22 @@ func (m *Mesh) applyTopology() { if nodes[m.hostname] == nil { return } + // Find the Kilo interface name. + link, err := linkByIndex(m.kiloIface) + if err != nil { + level.Error(m.logger).Log("error", err) + m.errorCounter.WithLabelValues("apply").Inc() + return + } + // Find the old configuration. + oldConfRaw, err := wireguard.ShowConf(link.Attrs().Name) + if err != nil { + level.Error(m.logger).Log("error", err) + m.errorCounter.WithLabelValues("apply").Inc() + return + } + oldConf := wireguard.Parse(oldConfRaw) + updateNATEndpoints(nodes, peers, oldConf) t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive) if err != nil { level.Error(m.logger).Log("error", err) @@ -582,7 +602,6 @@ func (m *Mesh) applyTopology() { } } ipRules = append(ipRules, m.enc.Rules(cidrs)...) - // If we are handling local routes, ensure the local // tunnel has an IP address. if err := m.enc.Set(oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil { @@ -596,28 +615,15 @@ func (m *Mesh) applyTopology() { m.errorCounter.WithLabelValues("apply").Inc() return } - // Find the Kilo interface name. - link, err := linkByIndex(m.kiloIface) - if err != nil { - level.Error(m.logger).Log("error", err) - m.errorCounter.WithLabelValues("apply").Inc() - return - } if t.leader { if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil { level.Error(m.logger).Log("error", err) m.errorCounter.WithLabelValues("apply").Inc() return } - oldConf, err := wireguard.ShowConf(link.Attrs().Name) - if err != nil { - level.Error(m.logger).Log("error", err) - m.errorCounter.WithLabelValues("apply").Inc() - return - } // Setting the WireGuard configuration interrupts existing connections // so only set the configuration if it has changed. - equal := conf.EqualWithPeerCheck(wireguard.Parse(oldConf), peersAreEqualIgnoreNAT) + equal := conf.Equal(oldConf) if !equal { level.Info(m.logger).Log("msg", "WireGuard configurations are different") if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil { @@ -814,41 +820,6 @@ func peersAreEqual(a, b *Peer) bool { return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive } -// Basic nil checks and checking the lengths of the allowed IPs is -// done by the WireGuard package. -func peersAreEqualIgnoreNAT(a, b *wireguard.Peer) bool { - for j := range a.AllowedIPs { - if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() { - return false - } - } - if a.PersistentKeepalive != b.PersistentKeepalive || !bytes.Equal(a.PublicKey, b.PublicKey) { - return false - } - // If a persistent keepalive is set, then the peer is behind NAT - // and we want to ignore changes in endpoints, since it may roam. - if a.PersistentKeepalive != 0 { - return true - } - if (a.Endpoint == nil) != (b.Endpoint == nil) { - return false - } - if a.Endpoint != nil { - if a.Endpoint.Port != b.Endpoint.Port { - return false - } - // IPs take priority, so check them first. - if !a.Endpoint.IP.Equal(b.Endpoint.IP) { - return false - } - // Only check the DNS name if the IP is empty. - if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS { - return false - } - } - return true -} - func ipNetsEqual(a, b *net.IPNet) bool { if a == nil && b == nil { return true @@ -888,3 +859,22 @@ func linkByIndex(index int) (netlink.Link, error) { } return link, nil } + +// updateNATEndpoints ensures that nodes and peers behind NAT update +// their endpoints from the WireGuard configuration so they can roam. +func updateNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf) { + keys := make(map[string]*wireguard.Peer) + for i := range conf.Peers { + keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i] + } + for _, n := range nodes { + if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 { + n.Endpoint = peer.Endpoint + } + } + for _, p := range peers { + if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 { + p.Endpoint = peer.Endpoint + } + } +} diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index c1fd0fd..04e2ddd 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -275,12 +275,6 @@ func (c *Conf) Bytes() ([]byte, error) { // Equal checks if two WireGuard configurations are equivalent. func (c *Conf) Equal(b *Conf) bool { - return c.EqualWithPeerCheck(b, strictPeerCheck) -} - -// EqualWithPeerCheck checks if two WireGuard configurations are equivalent -// when their peers are compared using the given peer comparison func. -func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool { if (c.Interface == nil) != (b.Interface == nil) { return false } @@ -294,47 +288,38 @@ func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool { } sortPeers(c.Peers) sortPeers(b.Peers) - var ok bool for i := range c.Peers { if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) { return false } sortCIDRs(c.Peers[i].AllowedIPs) sortCIDRs(b.Peers[i].AllowedIPs) - if ok = pc(c.Peers[i], b.Peers[i]); !ok { + for j := range c.Peers[i].AllowedIPs { + if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() { + return false + } + } + if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) { + return false + } + if c.Peers[i].Endpoint != nil { + if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port { + return false + } + // IPs take priority, so check them first. + if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) { + return false + } + // Only check the DNS name if the IP is empty. + if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS { + return false + } + } + if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) { return false } } return true - -} - -// PeerCheck is a function that compares two peers. -type PeerCheck func(a, b *Peer) bool - -func strictPeerCheck(a, b *Peer) bool { - for j := range a.AllowedIPs { - if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() { - return false - } - } - if (a.Endpoint == nil) != (b.Endpoint == nil) { - return false - } - if a.Endpoint != nil { - if a.Endpoint.Port != b.Endpoint.Port { - return false - } - // IPs take priority, so check them first. - if !a.Endpoint.IP.Equal(b.Endpoint.IP) { - return false - } - // Only check the DNS name if the IP is empty. - if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS { - return false - } - } - return a.PersistentKeepalive == b.PersistentKeepalive && bytes.Equal(a.PublicKey, b.PublicKey) } func sortPeers(peers []*Peer) {