Merge pull request #41 from squat/ignore-nat-peer-changes
pkg/wireguard: ignore changes to peers behind NAT
This commit is contained in:
commit
406a397566
@ -355,7 +355,6 @@ func (m *Mesh) Run() error {
|
|||||||
if m.cni {
|
if m.cni {
|
||||||
m.updateCNIConfig()
|
m.updateCNIConfig()
|
||||||
}
|
}
|
||||||
m.syncEndpoints()
|
|
||||||
m.applyTopology()
|
m.applyTopology()
|
||||||
t.Reset(resyncPeriod)
|
t.Reset(resyncPeriod)
|
||||||
case <-m.stop:
|
case <-m.stop:
|
||||||
@ -364,47 +363,6 @@ func (m *Mesh) Run() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WireGuard updates the endpoints of peers to match the
|
|
||||||
// last place a valid packet was received from.
|
|
||||||
// Periodically we need to syncronize the endpoints
|
|
||||||
// of peers in the backend to match the WireGuard configuration.
|
|
||||||
func (m *Mesh) syncEndpoints() {
|
|
||||||
link, err := linkByIndex(m.kiloIface)
|
|
||||||
if err != nil {
|
|
||||||
level.Error(m.logger).Log("error", err)
|
|
||||||
m.errorCounter.WithLabelValues("endpoints").Inc()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
conf, err := wireguard.ShowConf(link.Attrs().Name)
|
|
||||||
if err != nil {
|
|
||||||
level.Error(m.logger).Log("error", err)
|
|
||||||
m.errorCounter.WithLabelValues("endpoints").Inc()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
c := wireguard.Parse(conf)
|
|
||||||
var key string
|
|
||||||
var tmp *Peer
|
|
||||||
for i := range c.Peers {
|
|
||||||
// Peers are indexed by public key.
|
|
||||||
key = string(c.Peers[i].PublicKey)
|
|
||||||
if p, ok := m.peers[key]; ok {
|
|
||||||
tmp = &Peer{
|
|
||||||
Name: p.Name,
|
|
||||||
Peer: *c.Peers[i],
|
|
||||||
}
|
|
||||||
if !peersAreEqual(tmp, p) {
|
|
||||||
p.Endpoint = tmp.Endpoint
|
|
||||||
if err := m.Peers().Set(p.Name, p); err != nil {
|
|
||||||
level.Error(m.logger).Log("error", err)
|
|
||||||
m.errorCounter.WithLabelValues("endpoints").Inc()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Mesh) syncNodes(e *NodeEvent) {
|
func (m *Mesh) syncNodes(e *NodeEvent) {
|
||||||
logger := log.With(m.logger, "event", e.Type)
|
logger := log.With(m.logger, "event", e.Type)
|
||||||
level.Debug(logger).Log("msg", "syncing nodes", "event", e.Type)
|
level.Debug(logger).Log("msg", "syncing nodes", "event", e.Type)
|
||||||
@ -659,7 +617,7 @@ func (m *Mesh) applyTopology() {
|
|||||||
}
|
}
|
||||||
// Setting the WireGuard configuration interrupts existing connections
|
// Setting the WireGuard configuration interrupts existing connections
|
||||||
// so only set the configuration if it has changed.
|
// so only set the configuration if it has changed.
|
||||||
equal := conf.Equal(wireguard.Parse(oldConf))
|
equal := conf.EqualWithPeerCheck(wireguard.Parse(oldConf), peersAreEqualIgnoreNAT)
|
||||||
if !equal {
|
if !equal {
|
||||||
level.Info(m.logger).Log("msg", "WireGuard configurations are different")
|
level.Info(m.logger).Log("msg", "WireGuard configurations are different")
|
||||||
if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil {
|
if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil {
|
||||||
@ -856,6 +814,41 @@ func peersAreEqual(a, b *Peer) bool {
|
|||||||
return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive
|
return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Basic nil checks and checking the lengths of the allowed IPs is
|
||||||
|
// done by the WireGuard package.
|
||||||
|
func peersAreEqualIgnoreNAT(a, b *wireguard.Peer) bool {
|
||||||
|
for j := range a.AllowedIPs {
|
||||||
|
if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if a.PersistentKeepalive != b.PersistentKeepalive || !bytes.Equal(a.PublicKey, b.PublicKey) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// If a persistent keepalive is set, then the peer is behind NAT
|
||||||
|
// and we want to ignore changes in endpoints, since it may roam.
|
||||||
|
if a.PersistentKeepalive != 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if (a.Endpoint == nil) != (b.Endpoint == nil) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a.Endpoint != nil {
|
||||||
|
if a.Endpoint.Port != b.Endpoint.Port {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// IPs take priority, so check them first.
|
||||||
|
if !a.Endpoint.IP.Equal(b.Endpoint.IP) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Only check the DNS name if the IP is empty.
|
||||||
|
if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func ipNetsEqual(a, b *net.IPNet) bool {
|
func ipNetsEqual(a, b *net.IPNet) bool {
|
||||||
if a == nil && b == nil {
|
if a == nil && b == nil {
|
||||||
return true
|
return true
|
||||||
|
@ -275,6 +275,12 @@ func (c *Conf) Bytes() ([]byte, error) {
|
|||||||
|
|
||||||
// Equal checks if two WireGuard configurations are equivalent.
|
// Equal checks if two WireGuard configurations are equivalent.
|
||||||
func (c *Conf) Equal(b *Conf) bool {
|
func (c *Conf) Equal(b *Conf) bool {
|
||||||
|
return c.EqualWithPeerCheck(b, strictPeerCheck)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EqualWithPeerCheck checks if two WireGuard configurations are equivalent
|
||||||
|
// when their peers are compared using the given peer comparison func.
|
||||||
|
func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool {
|
||||||
if (c.Interface == nil) != (b.Interface == nil) {
|
if (c.Interface == nil) != (b.Interface == nil) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -288,38 +294,47 @@ func (c *Conf) Equal(b *Conf) bool {
|
|||||||
}
|
}
|
||||||
sortPeers(c.Peers)
|
sortPeers(c.Peers)
|
||||||
sortPeers(b.Peers)
|
sortPeers(b.Peers)
|
||||||
|
var ok bool
|
||||||
for i := range c.Peers {
|
for i := range c.Peers {
|
||||||
if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) {
|
if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
sortCIDRs(c.Peers[i].AllowedIPs)
|
sortCIDRs(c.Peers[i].AllowedIPs)
|
||||||
sortCIDRs(b.Peers[i].AllowedIPs)
|
sortCIDRs(b.Peers[i].AllowedIPs)
|
||||||
for j := range c.Peers[i].AllowedIPs {
|
if ok = pc(c.Peers[i], b.Peers[i]); !ok {
|
||||||
if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if c.Peers[i].Endpoint != nil {
|
|
||||||
if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// IPs take priority, so check them first.
|
|
||||||
if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// Only check the DNS name if the IP is empty.
|
|
||||||
if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerCheck is a function that compares two peers.
|
||||||
|
type PeerCheck func(a, b *Peer) bool
|
||||||
|
|
||||||
|
func strictPeerCheck(a, b *Peer) bool {
|
||||||
|
for j := range a.AllowedIPs {
|
||||||
|
if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (a.Endpoint == nil) != (b.Endpoint == nil) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a.Endpoint != nil {
|
||||||
|
if a.Endpoint.Port != b.Endpoint.Port {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// IPs take priority, so check them first.
|
||||||
|
if !a.Endpoint.IP.Equal(b.Endpoint.IP) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Only check the DNS name if the IP is empty.
|
||||||
|
if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return a.PersistentKeepalive == b.PersistentKeepalive && bytes.Equal(a.PublicKey, b.PublicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func sortPeers(peers []*Peer) {
|
func sortPeers(peers []*Peer) {
|
||||||
|
Loading…
Reference in New Issue
Block a user