Refactor to use Endpoint.Equal
Compare IP first by default and compare DNS name first when we know the Endpoint was resolved.
This commit is contained in:
parent
babace573e
commit
4fadaeefff
@ -679,22 +679,11 @@ func nodesAreEqual(a, b *Node) bool {
|
||||
if a == b {
|
||||
return true
|
||||
}
|
||||
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
|
||||
// Check the DNS name first since this package
|
||||
// is doing the DNS resolution.
|
||||
if !a.Endpoint.Equal(b.Endpoint, true) {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint != nil {
|
||||
if a.Endpoint.Port != b.Endpoint.Port {
|
||||
return false
|
||||
}
|
||||
// Check the DNS name first since this package
|
||||
// is doing the DNS resolution.
|
||||
if a.Endpoint.DNS != b.Endpoint.DNS {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
// Ignore LastSeen when comparing equality we want to check if the nodes are
|
||||
// equivalent. However, we do want to check if LastSeen has transitioned
|
||||
// between valid and invalid.
|
||||
@ -708,22 +697,11 @@ func peersAreEqual(a, b *Peer) bool {
|
||||
if a == b {
|
||||
return true
|
||||
}
|
||||
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
|
||||
// Check the DNS name first since this package
|
||||
// is doing the DNS resolution.
|
||||
if !a.Endpoint.Equal(b.Endpoint, true) {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint != nil {
|
||||
if a.Endpoint.Port != b.Endpoint.Port {
|
||||
return false
|
||||
}
|
||||
// Check the DNS name first since this package
|
||||
// is doing the DNS resolution.
|
||||
if a.Endpoint.DNS != b.Endpoint.DNS {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(a.AllowedIPs) != len(b.AllowedIPs) {
|
||||
return false
|
||||
}
|
||||
@ -778,7 +756,7 @@ func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
|
||||
return false
|
||||
}
|
||||
for k := range a {
|
||||
if !a[k].Equal(b[k]) {
|
||||
if !a[k].Equal(b[k], false) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@ -802,17 +780,17 @@ func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *
|
||||
}
|
||||
for _, n := range nodes {
|
||||
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
|
||||
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint))
|
||||
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false))
|
||||
// Should check location leader but only available in topology ... or have topology handle that list
|
||||
// Better check wg latest-handshake
|
||||
if !n.Endpoint.Equal(peer.Endpoint) {
|
||||
if !n.Endpoint.Equal(peer.Endpoint, false) {
|
||||
natEndpoints[string(n.Key)] = peer.Endpoint
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, p := range peers {
|
||||
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
|
||||
if !p.Endpoint.Equal(peer.Endpoint) {
|
||||
if !p.Endpoint.Equal(peer.Endpoint, false) {
|
||||
natEndpoints[string(p.PublicKey)] = peer.Endpoint
|
||||
}
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ func (e *Endpoint) String() string {
|
||||
}
|
||||
|
||||
// Equal compares two endpoints.
|
||||
func (e *Endpoint) Equal(b *Endpoint) bool {
|
||||
func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool {
|
||||
if (e == nil) != (b == nil) {
|
||||
return false
|
||||
}
|
||||
@ -104,13 +104,23 @@ func (e *Endpoint) Equal(b *Endpoint) bool {
|
||||
if e.Port != b.Port {
|
||||
return false
|
||||
}
|
||||
// IPs take priority, so check them first.
|
||||
if !e.IP.Equal(b.IP) {
|
||||
return false
|
||||
}
|
||||
// Only check the DNS name if the IP is empty.
|
||||
if e.IP == nil && e.DNS != b.DNS {
|
||||
return false
|
||||
if DNSFirst {
|
||||
// Check the DNS name first if it was resolved.
|
||||
if e.DNS != b.DNS {
|
||||
return false
|
||||
}
|
||||
if e.DNS == "" && !e.IP.Equal(b.IP) {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
// IPs take priority, so check them first.
|
||||
if !e.IP.Equal(b.IP) {
|
||||
return false
|
||||
}
|
||||
// Only check the DNS name if the IP is empty.
|
||||
if e.IP == nil && e.DNS != b.DNS {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -331,7 +341,7 @@ func (c *Conf) Equal(b *Conf) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint) {
|
||||
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint, false) {
|
||||
return false
|
||||
}
|
||||
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
|
||||
|
@ -207,10 +207,11 @@ func TestCompareConf(t *testing.T) {
|
||||
|
||||
func TestCompareEndpoint(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
a *Endpoint
|
||||
b *Endpoint
|
||||
out bool
|
||||
name string
|
||||
a *Endpoint
|
||||
b *Endpoint
|
||||
dnsFirst bool
|
||||
out bool
|
||||
}{
|
||||
{
|
||||
name: "both nil",
|
||||
@ -272,8 +273,36 @@ func TestCompareEndpoint(t *testing.T) {
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "DNS first, ignore IP",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}},
|
||||
dnsFirst: true,
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "DNS first",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
|
||||
dnsFirst: true,
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "DNS first, no DNS compare IP",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}},
|
||||
dnsFirst: true,
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "DNS first, no DNS compare IP (same)",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||
dnsFirst: true,
|
||||
out: true,
|
||||
},
|
||||
} {
|
||||
equal := tc.a.Equal(tc.b)
|
||||
equal := tc.a.Equal(tc.b, tc.dnsFirst)
|
||||
if equal != tc.out {
|
||||
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user