diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index 146a6e5..61f4a05 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -660,7 +660,7 @@ func nodesAreEqual(a, b *Node) bool { } // Check the DNS name first since this package // is doing the DNS resolution. - if a.Endpoint.StringOpt(true) != b.Endpoint.StringOpt(true) { + if !a.Endpoint.Equal(b.Endpoint, true) { return false } // Ignore LastSeen when comparing equality we want to check if the nodes are @@ -689,7 +689,7 @@ func peersAreEqual(a, b *Peer) bool { } // Check the DNS name first since this package // is doing the DNS resolution. - if a.Endpoint.StringOpt(true) != b.Endpoint.StringOpt(true) { + if !a.Endpoint.Equal(b.Endpoint, true) { return false } if len(a.AllowedIPs) != len(b.AllowedIPs) { diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index 67ba589..8dbcde4 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -216,7 +216,7 @@ func (e *Endpoint) String() string { return e.StringOpt(true) } -// StringOpt will return string of the Endpoint. +// StringOpt will return the string of the Endpoint. // If dnsFirst is false, the resolved Endpoint will // take precedence over the DN. func (e *Endpoint) StringOpt(dnsFirst bool) string { @@ -229,6 +229,13 @@ func (e *Endpoint) StringOpt(dnsFirst bool) string { return e.addr } +// Equal will return true, if the Enpoints are equal. +// If dnsFirst is false, the DN will only be compared if +// the IPs are nil. +func (e *Endpoint) Equal(b *Endpoint, dnsFirst bool) bool { + return e.StringOpt(dnsFirst) == b.StringOpt(dnsFirst) +} + // Peer represents a `peer` section of a WireGuard configuration. type Peer struct { wgtypes.PeerConfig diff --git a/pkg/wireguard/conf_test.go b/pkg/wireguard/conf_test.go index 42948c0..f2eb11d 100644 --- a/pkg/wireguard/conf_test.go +++ b/pkg/wireguard/conf_test.go @@ -304,3 +304,148 @@ func TestReady(t *testing.T) { } } } + +func TestEqual(t *testing.T) { + for i, tc := range []struct { + name string + a *Endpoint + b *Endpoint + df bool + r bool + }{ + { + name: "nil dns last", + r: true, + }, + { + name: "nil dns first", + df: true, + r: true, + }, + { + name: "equal: only port", + a: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1000, + }, + }, + b: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1000, + }, + }, + r: true, + }, + { + name: "not equal: only port", + a: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1000, + }, + }, + b: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1001, + }, + }, + r: false, + }, + { + name: "equal dns first", + a: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1000, + IP: net.ParseIP("10.0.0.0"), + }, + addr: "example.com:1000", + }, + b: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1000, + IP: net.ParseIP("10.0.0.0"), + }, + addr: "example.com:1000", + }, + r: true, + }, + { + name: "equal dns last", + a: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1000, + IP: net.ParseIP("10.0.0.0"), + }, + addr: "example.com:1000", + }, + b: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1000, + IP: net.ParseIP("10.0.0.0"), + }, + addr: "foo", + }, + r: true, + }, + { + name: "unequal dns first", + a: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1000, + IP: net.ParseIP("10.0.0.0"), + }, + addr: "example.com:1000", + }, + b: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1000, + IP: net.ParseIP("10.0.0.0"), + }, + addr: "foo", + }, + df: true, + r: false, + }, + { + name: "unequal dns last", + a: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1000, + IP: net.ParseIP("10.0.0.0"), + }, + addr: "foo", + }, + b: &Endpoint{ + udpAddr: &net.UDPAddr{ + Port: 1000, + IP: net.ParseIP("11.0.0.0"), + }, + addr: "foo", + }, + r: false, + }, + { + name: "unequal dns last empty IP", + a: &Endpoint{ + addr: "foo", + }, + b: &Endpoint{ + addr: "bar", + }, + r: false, + }, + { + name: "equal dns last empty IP", + a: &Endpoint{ + addr: "foo", + }, + b: &Endpoint{ + addr: "foo", + }, + r: true, + }, + } { + if out := tc.a.Equal(tc.b, tc.df); out != tc.r { + t.Errorf("ParseEndpoint %s(%d): expected: %v\tgot: %v\n", tc.name, i, tc.r, out) + } + } +}