Refactor to use Endpoint.Equal

Compare IP first by default and compare DNS name first when we know the Endpoint was resolved.
This commit is contained in:
Julien Viard de Galbert
2021-04-16 16:53:52 +02:00
parent babace573e
commit 4fadaeefff
3 changed files with 63 additions and 46 deletions

View File

@@ -96,7 +96,7 @@ func (e *Endpoint) String() string {
}
// Equal compares two endpoints.
func (e *Endpoint) Equal(b *Endpoint) bool {
func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool {
if (e == nil) != (b == nil) {
return false
}
@@ -104,13 +104,23 @@ func (e *Endpoint) Equal(b *Endpoint) bool {
if e.Port != b.Port {
return false
}
// IPs take priority, so check them first.
if !e.IP.Equal(b.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if e.IP == nil && e.DNS != b.DNS {
return false
if DNSFirst {
// Check the DNS name first if it was resolved.
if e.DNS != b.DNS {
return false
}
if e.DNS == "" && !e.IP.Equal(b.IP) {
return false
}
} else {
// IPs take priority, so check them first.
if !e.IP.Equal(b.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if e.IP == nil && e.DNS != b.DNS {
return false
}
}
}
@@ -331,7 +341,7 @@ func (c *Conf) Equal(b *Conf) bool {
return false
}
}
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint) {
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint, false) {
return false
}
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {

View File

@@ -207,10 +207,11 @@ func TestCompareConf(t *testing.T) {
func TestCompareEndpoint(t *testing.T) {
for _, tc := range []struct {
name string
a *Endpoint
b *Endpoint
out bool
name string
a *Endpoint
b *Endpoint
dnsFirst bool
out bool
}{
{
name: "both nil",
@@ -272,8 +273,36 @@ func TestCompareEndpoint(t *testing.T) {
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
out: true,
},
{
name: "DNS first, ignore IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}},
dnsFirst: true,
out: true,
},
{
name: "DNS first",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
dnsFirst: true,
out: false,
},
{
name: "DNS first, no DNS compare IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}},
dnsFirst: true,
out: false,
},
{
name: "DNS first, no DNS compare IP (same)",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
dnsFirst: true,
out: true,
},
} {
equal := tc.a.Equal(tc.b)
equal := tc.a.Equal(tc.b, tc.dnsFirst)
if equal != tc.out {
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
}