diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index d09b466..9add840 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -355,7 +355,6 @@ func (m *Mesh) Run() error { if m.cni { m.updateCNIConfig() } - m.syncEndpoints() m.applyTopology() t.Reset(resyncPeriod) case <-m.stop: @@ -364,47 +363,6 @@ func (m *Mesh) Run() error { } } -// WireGuard updates the endpoints of peers to match the -// last place a valid packet was received from. -// Periodically we need to syncronize the endpoints -// of peers in the backend to match the WireGuard configuration. -func (m *Mesh) syncEndpoints() { - link, err := linkByIndex(m.kiloIface) - if err != nil { - level.Error(m.logger).Log("error", err) - m.errorCounter.WithLabelValues("endpoints").Inc() - return - } - conf, err := wireguard.ShowConf(link.Attrs().Name) - if err != nil { - level.Error(m.logger).Log("error", err) - m.errorCounter.WithLabelValues("endpoints").Inc() - return - } - m.mu.Lock() - defer m.mu.Unlock() - c := wireguard.Parse(conf) - var key string - var tmp *Peer - for i := range c.Peers { - // Peers are indexed by public key. - key = string(c.Peers[i].PublicKey) - if p, ok := m.peers[key]; ok { - tmp = &Peer{ - Name: p.Name, - Peer: *c.Peers[i], - } - if !peersAreEqual(tmp, p) { - p.Endpoint = tmp.Endpoint - if err := m.Peers().Set(p.Name, p); err != nil { - level.Error(m.logger).Log("error", err) - m.errorCounter.WithLabelValues("endpoints").Inc() - } - } - } - } -} - func (m *Mesh) syncNodes(e *NodeEvent) { logger := log.With(m.logger, "event", e.Type) level.Debug(logger).Log("msg", "syncing nodes", "event", e.Type) @@ -659,7 +617,7 @@ func (m *Mesh) applyTopology() { } // Setting the WireGuard configuration interrupts existing connections // so only set the configuration if it has changed. - equal := conf.Equal(wireguard.Parse(oldConf)) + equal := conf.EqualWithPeerCheck(wireguard.Parse(oldConf), peersAreEqualIgnoreNAT) if !equal { level.Info(m.logger).Log("msg", "WireGuard configurations are different") if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil { @@ -856,6 +814,41 @@ func peersAreEqual(a, b *Peer) bool { return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive } +// Basic nil checks and checking the lengths of the allowed IPs is +// done by the WireGuard package. +func peersAreEqualIgnoreNAT(a, b *wireguard.Peer) bool { + for j := range a.AllowedIPs { + if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() { + return false + } + } + if a.PersistentKeepalive != b.PersistentKeepalive || !bytes.Equal(a.PublicKey, b.PublicKey) { + return false + } + // If a persistent keepalive is set, then the peer is behind NAT + // and we want to ignore changes in endpoints, since it may roam. + if a.PersistentKeepalive != 0 { + return true + } + if (a.Endpoint == nil) != (b.Endpoint == nil) { + return false + } + if a.Endpoint != nil { + if a.Endpoint.Port != b.Endpoint.Port { + return false + } + // IPs take priority, so check them first. + if !a.Endpoint.IP.Equal(b.Endpoint.IP) { + return false + } + // Only check the DNS name if the IP is empty. + if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS { + return false + } + } + return true +} + func ipNetsEqual(a, b *net.IPNet) bool { if a == nil && b == nil { return true diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index 04e2ddd..c1fd0fd 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -275,6 +275,12 @@ func (c *Conf) Bytes() ([]byte, error) { // Equal checks if two WireGuard configurations are equivalent. func (c *Conf) Equal(b *Conf) bool { + return c.EqualWithPeerCheck(b, strictPeerCheck) +} + +// EqualWithPeerCheck checks if two WireGuard configurations are equivalent +// when their peers are compared using the given peer comparison func. +func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool { if (c.Interface == nil) != (b.Interface == nil) { return false } @@ -288,38 +294,47 @@ func (c *Conf) Equal(b *Conf) bool { } sortPeers(c.Peers) sortPeers(b.Peers) + var ok bool for i := range c.Peers { if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) { return false } sortCIDRs(c.Peers[i].AllowedIPs) sortCIDRs(b.Peers[i].AllowedIPs) - for j := range c.Peers[i].AllowedIPs { - if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() { - return false - } - } - if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) { - return false - } - if c.Peers[i].Endpoint != nil { - if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port { - return false - } - // IPs take priority, so check them first. - if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) { - return false - } - // Only check the DNS name if the IP is empty. - if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS { - return false - } - } - if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) { + if ok = pc(c.Peers[i], b.Peers[i]); !ok { return false } } return true + +} + +// PeerCheck is a function that compares two peers. +type PeerCheck func(a, b *Peer) bool + +func strictPeerCheck(a, b *Peer) bool { + for j := range a.AllowedIPs { + if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() { + return false + } + } + if (a.Endpoint == nil) != (b.Endpoint == nil) { + return false + } + if a.Endpoint != nil { + if a.Endpoint.Port != b.Endpoint.Port { + return false + } + // IPs take priority, so check them first. + if !a.Endpoint.IP.Equal(b.Endpoint.IP) { + return false + } + // Only check the DNS name if the IP is empty. + if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS { + return false + } + } + return a.PersistentKeepalive == b.PersistentKeepalive && bytes.Equal(a.PublicKey, b.PublicKey) } func sortPeers(peers []*Peer) {