pkg/mesh/mesh.go: use Equal func

Implement an Equal func for Enpoint and use it instead of comparing
strings.

Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
leonnicolas 2021-09-30 12:29:31 +02:00
parent 61ad16dea5
commit 16bb0491e2
No known key found for this signature in database
GPG Key ID: 088D0743E2B65C07
3 changed files with 155 additions and 3 deletions

View File

@ -660,7 +660,7 @@ func nodesAreEqual(a, b *Node) bool {
} }
// Check the DNS name first since this package // Check the DNS name first since this package
// is doing the DNS resolution. // is doing the DNS resolution.
if a.Endpoint.StringOpt(true) != b.Endpoint.StringOpt(true) { if !a.Endpoint.Equal(b.Endpoint, true) {
return false return false
} }
// Ignore LastSeen when comparing equality we want to check if the nodes are // Ignore LastSeen when comparing equality we want to check if the nodes are
@ -689,7 +689,7 @@ func peersAreEqual(a, b *Peer) bool {
} }
// Check the DNS name first since this package // Check the DNS name first since this package
// is doing the DNS resolution. // is doing the DNS resolution.
if a.Endpoint.StringOpt(true) != b.Endpoint.StringOpt(true) { if !a.Endpoint.Equal(b.Endpoint, true) {
return false return false
} }
if len(a.AllowedIPs) != len(b.AllowedIPs) { if len(a.AllowedIPs) != len(b.AllowedIPs) {

View File

@ -216,7 +216,7 @@ func (e *Endpoint) String() string {
return e.StringOpt(true) return e.StringOpt(true)
} }
// StringOpt will return string of the Endpoint. // StringOpt will return the string of the Endpoint.
// If dnsFirst is false, the resolved Endpoint will // If dnsFirst is false, the resolved Endpoint will
// take precedence over the DN. // take precedence over the DN.
func (e *Endpoint) StringOpt(dnsFirst bool) string { func (e *Endpoint) StringOpt(dnsFirst bool) string {
@ -229,6 +229,13 @@ func (e *Endpoint) StringOpt(dnsFirst bool) string {
return e.addr return e.addr
} }
// Equal will return true, if the Enpoints are equal.
// If dnsFirst is false, the DN will only be compared if
// the IPs are nil.
func (e *Endpoint) Equal(b *Endpoint, dnsFirst bool) bool {
return e.StringOpt(dnsFirst) == b.StringOpt(dnsFirst)
}
// Peer represents a `peer` section of a WireGuard configuration. // Peer represents a `peer` section of a WireGuard configuration.
type Peer struct { type Peer struct {
wgtypes.PeerConfig wgtypes.PeerConfig

View File

@ -304,3 +304,148 @@ func TestReady(t *testing.T) {
} }
} }
} }
func TestEqual(t *testing.T) {
for i, tc := range []struct {
name string
a *Endpoint
b *Endpoint
df bool
r bool
}{
{
name: "nil dns last",
r: true,
},
{
name: "nil dns first",
df: true,
r: true,
},
{
name: "equal: only port",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
},
r: true,
},
{
name: "not equal: only port",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1001,
},
},
r: false,
},
{
name: "equal dns first",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
r: true,
},
{
name: "equal dns last",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "foo",
},
r: true,
},
{
name: "unequal dns first",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "foo",
},
df: true,
r: false,
},
{
name: "unequal dns last",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "foo",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("11.0.0.0"),
},
addr: "foo",
},
r: false,
},
{
name: "unequal dns last empty IP",
a: &Endpoint{
addr: "foo",
},
b: &Endpoint{
addr: "bar",
},
r: false,
},
{
name: "equal dns last empty IP",
a: &Endpoint{
addr: "foo",
},
b: &Endpoint{
addr: "foo",
},
r: true,
},
} {
if out := tc.a.Equal(tc.b, tc.df); out != tc.r {
t.Errorf("ParseEndpoint %s(%d): expected: %v\tgot: %v\n", tc.name, i, tc.r, out)
}
}
}