From 4fadaeefff8e2a978da5bb57057909d8b913e73b Mon Sep 17 00:00:00 2001 From: Julien Viard de Galbert Date: Fri, 16 Apr 2021 16:53:52 +0200 Subject: [PATCH] Refactor to use Endpoint.Equal Compare IP first by default and compare DNS name first when we know the Endpoint was resolved. --- pkg/mesh/mesh.go | 42 +++++++++----------------------------- pkg/wireguard/conf.go | 28 +++++++++++++++++-------- pkg/wireguard/conf_test.go | 39 ++++++++++++++++++++++++++++++----- 3 files changed, 63 insertions(+), 46 deletions(-) diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index 42f9fa1..8015d5e 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -679,22 +679,11 @@ func nodesAreEqual(a, b *Node) bool { if a == b { return true } - if !(a.Endpoint != nil) == (b.Endpoint != nil) { + // Check the DNS name first since this package + // is doing the DNS resolution. + if !a.Endpoint.Equal(b.Endpoint, true) { return false } - if a.Endpoint != nil { - if a.Endpoint.Port != b.Endpoint.Port { - return false - } - // Check the DNS name first since this package - // is doing the DNS resolution. - if a.Endpoint.DNS != b.Endpoint.DNS { - return false - } - if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) { - return false - } - } // Ignore LastSeen when comparing equality we want to check if the nodes are // equivalent. However, we do want to check if LastSeen has transitioned // between valid and invalid. @@ -708,22 +697,11 @@ func peersAreEqual(a, b *Peer) bool { if a == b { return true } - if !(a.Endpoint != nil) == (b.Endpoint != nil) { + // Check the DNS name first since this package + // is doing the DNS resolution. + if !a.Endpoint.Equal(b.Endpoint, true) { return false } - if a.Endpoint != nil { - if a.Endpoint.Port != b.Endpoint.Port { - return false - } - // Check the DNS name first since this package - // is doing the DNS resolution. - if a.Endpoint.DNS != b.Endpoint.DNS { - return false - } - if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) { - return false - } - } if len(a.AllowedIPs) != len(b.AllowedIPs) { return false } @@ -778,7 +756,7 @@ func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool { return false } for k := range a { - if !a[k].Equal(b[k]) { + if !a[k].Equal(b[k], false) { return false } } @@ -802,17 +780,17 @@ func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf * } for _, n := range nodes { if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 { - level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint)) + level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false)) // Should check location leader but only available in topology ... or have topology handle that list // Better check wg latest-handshake - if !n.Endpoint.Equal(peer.Endpoint) { + if !n.Endpoint.Equal(peer.Endpoint, false) { natEndpoints[string(n.Key)] = peer.Endpoint } } } for _, p := range peers { if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 { - if !p.Endpoint.Equal(peer.Endpoint) { + if !p.Endpoint.Equal(peer.Endpoint, false) { natEndpoints[string(p.PublicKey)] = peer.Endpoint } } diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index 20c01bf..0ce55e3 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -96,7 +96,7 @@ func (e *Endpoint) String() string { } // Equal compares two endpoints. -func (e *Endpoint) Equal(b *Endpoint) bool { +func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool { if (e == nil) != (b == nil) { return false } @@ -104,13 +104,23 @@ func (e *Endpoint) Equal(b *Endpoint) bool { if e.Port != b.Port { return false } - // IPs take priority, so check them first. - if !e.IP.Equal(b.IP) { - return false - } - // Only check the DNS name if the IP is empty. - if e.IP == nil && e.DNS != b.DNS { - return false + if DNSFirst { + // Check the DNS name first if it was resolved. + if e.DNS != b.DNS { + return false + } + if e.DNS == "" && !e.IP.Equal(b.IP) { + return false + } + } else { + // IPs take priority, so check them first. + if !e.IP.Equal(b.IP) { + return false + } + // Only check the DNS name if the IP is empty. + if e.IP == nil && e.DNS != b.DNS { + return false + } } } @@ -331,7 +341,7 @@ func (c *Conf) Equal(b *Conf) bool { return false } } - if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint) { + if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint, false) { return false } if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) { diff --git a/pkg/wireguard/conf_test.go b/pkg/wireguard/conf_test.go index c4a87e7..81a3ed0 100644 --- a/pkg/wireguard/conf_test.go +++ b/pkg/wireguard/conf_test.go @@ -207,10 +207,11 @@ func TestCompareConf(t *testing.T) { func TestCompareEndpoint(t *testing.T) { for _, tc := range []struct { - name string - a *Endpoint - b *Endpoint - out bool + name string + a *Endpoint + b *Endpoint + dnsFirst bool + out bool }{ { name: "both nil", @@ -272,8 +273,36 @@ func TestCompareEndpoint(t *testing.T) { b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}}, out: true, }, + { + name: "DNS first, ignore IP", + a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}}, + b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}}, + dnsFirst: true, + out: true, + }, + { + name: "DNS first", + a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}}, + b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}}, + dnsFirst: true, + out: false, + }, + { + name: "DNS first, no DNS compare IP", + a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}}, + b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}}, + dnsFirst: true, + out: false, + }, + { + name: "DNS first, no DNS compare IP (same)", + a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}}, + b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}}, + dnsFirst: true, + out: true, + }, } { - equal := tc.a.Equal(tc.b) + equal := tc.a.Equal(tc.b, tc.dnsFirst) if equal != tc.out { t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal) }