wireguard: export an Endpoint comparison method

This commit is contained in:
Julien Viard de Galbert 2021-04-16 15:33:17 +02:00
parent 863628ffaa
commit f66efc7140
2 changed files with 99 additions and 14 deletions

View File

@ -95,6 +95,28 @@ func (e *Endpoint) String() string {
return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
}
// Equal compares two endpoints.
func (e *Endpoint) Equal(b *Endpoint) bool {
if (e == nil) != (b == nil) {
return false
}
if e != nil {
if e.Port != b.Port {
return false
}
// IPs take priority, so check them first.
if !e.IP.Equal(b.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if e.IP == nil && e.DNS != b.DNS {
return false
}
}
return true
}
// DNSOrIP represents either a DNS name or an IP address.
// IPs, as they are more specific, are preferred.
type DNSOrIP struct {
@ -309,22 +331,9 @@ func (c *Conf) Equal(b *Conf) bool {
return false
}
}
if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) {
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint) {
return false
}
if c.Peers[i].Endpoint != nil {
if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port {
return false
}
// IPs take priority, so check them first.
if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS {
return false
}
}
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
return false
}

View File

@ -15,6 +15,7 @@
package wireguard
import (
"net"
"testing"
)
@ -203,3 +204,78 @@ func TestCompareConf(t *testing.T) {
}
}
}
func TestCompareEndpoint(t *testing.T) {
for _, tc := range []struct {
name string
a *Endpoint
b *Endpoint
out bool
}{
{
name: "both nil",
a: nil,
b: nil,
out: true,
},
{
name: "a nil",
a: nil,
b: &Endpoint{},
out: false,
},
{
name: "b nil",
a: &Endpoint{},
b: nil,
out: false,
},
{
name: "zero",
a: &Endpoint{},
b: &Endpoint{},
out: true,
},
{
name: "diff port",
a: &Endpoint{Port: 1234},
b: &Endpoint{Port: 5678},
out: false,
},
{
name: "same IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
out: true,
},
{
name: "diff IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2")}},
out: false,
},
{
name: "same IP ignore DNS",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "b"}},
out: true,
},
{
name: "no IP check DNS",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
out: false,
},
{
name: "no IP check DNS (same)",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
out: true,
},
} {
equal := tc.a.Equal(tc.b)
if equal != tc.out {
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
}
}
}