Refactor to use Endpoint.Equal
Compare IP first by default and compare DNS name first when we know the Endpoint was resolved.
This commit is contained in:
parent
babace573e
commit
4fadaeefff
@ -679,22 +679,11 @@ func nodesAreEqual(a, b *Node) bool {
|
|||||||
if a == b {
|
if a == b {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
|
// Check the DNS name first since this package
|
||||||
|
// is doing the DNS resolution.
|
||||||
|
if !a.Endpoint.Equal(b.Endpoint, true) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if a.Endpoint != nil {
|
|
||||||
if a.Endpoint.Port != b.Endpoint.Port {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// Check the DNS name first since this package
|
|
||||||
// is doing the DNS resolution.
|
|
||||||
if a.Endpoint.DNS != b.Endpoint.DNS {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Ignore LastSeen when comparing equality we want to check if the nodes are
|
// Ignore LastSeen when comparing equality we want to check if the nodes are
|
||||||
// equivalent. However, we do want to check if LastSeen has transitioned
|
// equivalent. However, we do want to check if LastSeen has transitioned
|
||||||
// between valid and invalid.
|
// between valid and invalid.
|
||||||
@ -708,22 +697,11 @@ func peersAreEqual(a, b *Peer) bool {
|
|||||||
if a == b {
|
if a == b {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
|
// Check the DNS name first since this package
|
||||||
|
// is doing the DNS resolution.
|
||||||
|
if !a.Endpoint.Equal(b.Endpoint, true) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if a.Endpoint != nil {
|
|
||||||
if a.Endpoint.Port != b.Endpoint.Port {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// Check the DNS name first since this package
|
|
||||||
// is doing the DNS resolution.
|
|
||||||
if a.Endpoint.DNS != b.Endpoint.DNS {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(a.AllowedIPs) != len(b.AllowedIPs) {
|
if len(a.AllowedIPs) != len(b.AllowedIPs) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -778,7 +756,7 @@ func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for k := range a {
|
for k := range a {
|
||||||
if !a[k].Equal(b[k]) {
|
if !a[k].Equal(b[k], false) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -802,17 +780,17 @@ func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *
|
|||||||
}
|
}
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
|
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
|
||||||
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint))
|
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false))
|
||||||
// Should check location leader but only available in topology ... or have topology handle that list
|
// Should check location leader but only available in topology ... or have topology handle that list
|
||||||
// Better check wg latest-handshake
|
// Better check wg latest-handshake
|
||||||
if !n.Endpoint.Equal(peer.Endpoint) {
|
if !n.Endpoint.Equal(peer.Endpoint, false) {
|
||||||
natEndpoints[string(n.Key)] = peer.Endpoint
|
natEndpoints[string(n.Key)] = peer.Endpoint
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, p := range peers {
|
for _, p := range peers {
|
||||||
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
|
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
|
||||||
if !p.Endpoint.Equal(peer.Endpoint) {
|
if !p.Endpoint.Equal(peer.Endpoint, false) {
|
||||||
natEndpoints[string(p.PublicKey)] = peer.Endpoint
|
natEndpoints[string(p.PublicKey)] = peer.Endpoint
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -96,7 +96,7 @@ func (e *Endpoint) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Equal compares two endpoints.
|
// Equal compares two endpoints.
|
||||||
func (e *Endpoint) Equal(b *Endpoint) bool {
|
func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool {
|
||||||
if (e == nil) != (b == nil) {
|
if (e == nil) != (b == nil) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -104,13 +104,23 @@ func (e *Endpoint) Equal(b *Endpoint) bool {
|
|||||||
if e.Port != b.Port {
|
if e.Port != b.Port {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// IPs take priority, so check them first.
|
if DNSFirst {
|
||||||
if !e.IP.Equal(b.IP) {
|
// Check the DNS name first if it was resolved.
|
||||||
return false
|
if e.DNS != b.DNS {
|
||||||
}
|
return false
|
||||||
// Only check the DNS name if the IP is empty.
|
}
|
||||||
if e.IP == nil && e.DNS != b.DNS {
|
if e.DNS == "" && !e.IP.Equal(b.IP) {
|
||||||
return false
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// IPs take priority, so check them first.
|
||||||
|
if !e.IP.Equal(b.IP) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Only check the DNS name if the IP is empty.
|
||||||
|
if e.IP == nil && e.DNS != b.DNS {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -331,7 +341,7 @@ func (c *Conf) Equal(b *Conf) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint) {
|
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint, false) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
|
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
|
||||||
|
@ -207,10 +207,11 @@ func TestCompareConf(t *testing.T) {
|
|||||||
|
|
||||||
func TestCompareEndpoint(t *testing.T) {
|
func TestCompareEndpoint(t *testing.T) {
|
||||||
for _, tc := range []struct {
|
for _, tc := range []struct {
|
||||||
name string
|
name string
|
||||||
a *Endpoint
|
a *Endpoint
|
||||||
b *Endpoint
|
b *Endpoint
|
||||||
out bool
|
dnsFirst bool
|
||||||
|
out bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "both nil",
|
name: "both nil",
|
||||||
@ -272,8 +273,36 @@ func TestCompareEndpoint(t *testing.T) {
|
|||||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||||
out: true,
|
out: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "DNS first, ignore IP",
|
||||||
|
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
|
||||||
|
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}},
|
||||||
|
dnsFirst: true,
|
||||||
|
out: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DNS first",
|
||||||
|
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||||
|
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
|
||||||
|
dnsFirst: true,
|
||||||
|
out: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DNS first, no DNS compare IP",
|
||||||
|
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||||
|
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}},
|
||||||
|
dnsFirst: true,
|
||||||
|
out: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DNS first, no DNS compare IP (same)",
|
||||||
|
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||||
|
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||||
|
dnsFirst: true,
|
||||||
|
out: true,
|
||||||
|
},
|
||||||
} {
|
} {
|
||||||
equal := tc.a.Equal(tc.b)
|
equal := tc.a.Equal(tc.b, tc.dnsFirst)
|
||||||
if equal != tc.out {
|
if equal != tc.out {
|
||||||
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
|
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user