From f66efc71406aad7f4faf622793c9ea2582df96a3 Mon Sep 17 00:00:00 2001 From: Julien Viard de Galbert Date: Fri, 16 Apr 2021 15:33:17 +0200 Subject: [PATCH] wireguard: export an Endpoint comparison method --- pkg/wireguard/conf.go | 37 ++++++++++++------- pkg/wireguard/conf_test.go | 76 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 14 deletions(-) diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index ac81144..20c01bf 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -95,6 +95,28 @@ func (e *Endpoint) String() string { return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10) } +// Equal compares two endpoints. +func (e *Endpoint) Equal(b *Endpoint) bool { + if (e == nil) != (b == nil) { + return false + } + if e != nil { + if e.Port != b.Port { + return false + } + // IPs take priority, so check them first. + if !e.IP.Equal(b.IP) { + return false + } + // Only check the DNS name if the IP is empty. + if e.IP == nil && e.DNS != b.DNS { + return false + } + } + + return true +} + // DNSOrIP represents either a DNS name or an IP address. // IPs, as they are more specific, are preferred. type DNSOrIP struct { @@ -309,22 +331,9 @@ func (c *Conf) Equal(b *Conf) bool { return false } } - if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) { + if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint) { return false } - if c.Peers[i].Endpoint != nil { - if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port { - return false - } - // IPs take priority, so check them first. - if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) { - return false - } - // Only check the DNS name if the IP is empty. - if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS { - return false - } - } if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) { return false } diff --git a/pkg/wireguard/conf_test.go b/pkg/wireguard/conf_test.go index 97c19a3..c4a87e7 100644 --- a/pkg/wireguard/conf_test.go +++ b/pkg/wireguard/conf_test.go @@ -15,6 +15,7 @@ package wireguard import ( + "net" "testing" ) @@ -203,3 +204,78 @@ func TestCompareConf(t *testing.T) { } } } + +func TestCompareEndpoint(t *testing.T) { + for _, tc := range []struct { + name string + a *Endpoint + b *Endpoint + out bool + }{ + { + name: "both nil", + a: nil, + b: nil, + out: true, + }, + { + name: "a nil", + a: nil, + b: &Endpoint{}, + out: false, + }, + { + name: "b nil", + a: &Endpoint{}, + b: nil, + out: false, + }, + { + name: "zero", + a: &Endpoint{}, + b: &Endpoint{}, + out: true, + }, + { + name: "diff port", + a: &Endpoint{Port: 1234}, + b: &Endpoint{Port: 5678}, + out: false, + }, + { + name: "same IP", + a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}}, + b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}}, + out: true, + }, + { + name: "diff IP", + a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}}, + b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2")}}, + out: false, + }, + { + name: "same IP ignore DNS", + a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}}, + b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "b"}}, + out: true, + }, + { + name: "no IP check DNS", + a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}}, + b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}}, + out: false, + }, + { + name: "no IP check DNS (same)", + a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}}, + b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}}, + out: true, + }, + } { + equal := tc.a.Equal(tc.b) + if equal != tc.out { + t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal) + } + } +}