diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index d09b466..b64bcd1 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -659,7 +659,7 @@ func (m *Mesh) applyTopology() { } // Setting the WireGuard configuration interrupts existing connections // so only set the configuration if it has changed. - equal := conf.Equal(wireguard.Parse(oldConf)) + equal := conf.EqualWithPeerCheck(wireguard.Parse(oldConf), peersAreEqualIgnoreNAT) if !equal { level.Info(m.logger).Log("msg", "WireGuard configurations are different") if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil { @@ -856,6 +856,41 @@ func peersAreEqual(a, b *Peer) bool { return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive } +// Basic nil checks and checking the lengths of the allowed IPs is +// done by the WireGuard package. +func peersAreEqualIgnoreNAT(a, b *wireguard.Peer) bool { + for j := range a.AllowedIPs { + if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() { + return false + } + } + if a.PersistentKeepalive != b.PersistentKeepalive || !bytes.Equal(a.PublicKey, b.PublicKey) { + return false + } + // If a persistent keepalive is set, then the peer is behind NAT + // and we want to ignore changes in endpoints, since it may roam. + if a.PersistentKeepalive != 0 { + return true + } + if (a.Endpoint == nil) != (b.Endpoint == nil) { + return false + } + if a.Endpoint != nil { + if a.Endpoint.Port != b.Endpoint.Port { + return false + } + // IPs take priority, so check them first. + if !a.Endpoint.IP.Equal(b.Endpoint.IP) { + return false + } + // Only check the DNS name if the IP is empty. + if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS { + return false + } + } + return true +} + func ipNetsEqual(a, b *net.IPNet) bool { if a == nil && b == nil { return true diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index 04e2ddd..c1fd0fd 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -275,6 +275,12 @@ func (c *Conf) Bytes() ([]byte, error) { // Equal checks if two WireGuard configurations are equivalent. func (c *Conf) Equal(b *Conf) bool { + return c.EqualWithPeerCheck(b, strictPeerCheck) +} + +// EqualWithPeerCheck checks if two WireGuard configurations are equivalent +// when their peers are compared using the given peer comparison func. +func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool { if (c.Interface == nil) != (b.Interface == nil) { return false } @@ -288,38 +294,47 @@ func (c *Conf) Equal(b *Conf) bool { } sortPeers(c.Peers) sortPeers(b.Peers) + var ok bool for i := range c.Peers { if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) { return false } sortCIDRs(c.Peers[i].AllowedIPs) sortCIDRs(b.Peers[i].AllowedIPs) - for j := range c.Peers[i].AllowedIPs { - if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() { - return false - } - } - if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) { - return false - } - if c.Peers[i].Endpoint != nil { - if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port { - return false - } - // IPs take priority, so check them first. - if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) { - return false - } - // Only check the DNS name if the IP is empty. - if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS { - return false - } - } - if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) { + if ok = pc(c.Peers[i], b.Peers[i]); !ok { return false } } return true + +} + +// PeerCheck is a function that compares two peers. +type PeerCheck func(a, b *Peer) bool + +func strictPeerCheck(a, b *Peer) bool { + for j := range a.AllowedIPs { + if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() { + return false + } + } + if (a.Endpoint == nil) != (b.Endpoint == nil) { + return false + } + if a.Endpoint != nil { + if a.Endpoint.Port != b.Endpoint.Port { + return false + } + // IPs take priority, so check them first. + if !a.Endpoint.IP.Equal(b.Endpoint.IP) { + return false + } + // Only check the DNS name if the IP is empty. + if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS { + return false + } + } + return a.PersistentKeepalive == b.PersistentKeepalive && bytes.Equal(a.PublicKey, b.PublicKey) } func sortPeers(peers []*Peer) {