wireguard: export an Endpoint comparison method
This commit is contained in:
parent
863628ffaa
commit
f66efc7140
@ -95,6 +95,28 @@ func (e *Endpoint) String() string {
|
|||||||
return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
|
return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Equal compares two endpoints.
|
||||||
|
func (e *Endpoint) Equal(b *Endpoint) bool {
|
||||||
|
if (e == nil) != (b == nil) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if e != nil {
|
||||||
|
if e.Port != b.Port {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// IPs take priority, so check them first.
|
||||||
|
if !e.IP.Equal(b.IP) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Only check the DNS name if the IP is empty.
|
||||||
|
if e.IP == nil && e.DNS != b.DNS {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// DNSOrIP represents either a DNS name or an IP address.
|
// DNSOrIP represents either a DNS name or an IP address.
|
||||||
// IPs, as they are more specific, are preferred.
|
// IPs, as they are more specific, are preferred.
|
||||||
type DNSOrIP struct {
|
type DNSOrIP struct {
|
||||||
@ -309,22 +331,9 @@ func (c *Conf) Equal(b *Conf) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) {
|
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if c.Peers[i].Endpoint != nil {
|
|
||||||
if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// IPs take priority, so check them first.
|
|
||||||
if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// Only check the DNS name if the IP is empty.
|
|
||||||
if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
|
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -203,3 +204,78 @@ func TestCompareConf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompareEndpoint(t *testing.T) {
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
a *Endpoint
|
||||||
|
b *Endpoint
|
||||||
|
out bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "both nil",
|
||||||
|
a: nil,
|
||||||
|
b: nil,
|
||||||
|
out: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "a nil",
|
||||||
|
a: nil,
|
||||||
|
b: &Endpoint{},
|
||||||
|
out: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "b nil",
|
||||||
|
a: &Endpoint{},
|
||||||
|
b: nil,
|
||||||
|
out: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero",
|
||||||
|
a: &Endpoint{},
|
||||||
|
b: &Endpoint{},
|
||||||
|
out: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "diff port",
|
||||||
|
a: &Endpoint{Port: 1234},
|
||||||
|
b: &Endpoint{Port: 5678},
|
||||||
|
out: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same IP",
|
||||||
|
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
|
||||||
|
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
|
||||||
|
out: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "diff IP",
|
||||||
|
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
|
||||||
|
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2")}},
|
||||||
|
out: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same IP ignore DNS",
|
||||||
|
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
|
||||||
|
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "b"}},
|
||||||
|
out: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no IP check DNS",
|
||||||
|
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||||
|
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
|
||||||
|
out: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no IP check DNS (same)",
|
||||||
|
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||||
|
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||||
|
out: true,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
equal := tc.a.Equal(tc.b)
|
||||||
|
if equal != tc.out {
|
||||||
|
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user