apply suggestions from code review
Remove wireguard.Enpoint struct and use net.UDPAddr for the resolved endpoint and addr string (dnsanme:port) if a DN was supplied. Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
@@ -43,7 +43,7 @@ const (
|
||||
// Conf represents a WireGuard configuration file.
|
||||
type Conf struct {
|
||||
wgtypes.Config
|
||||
// The Peers field is shadowed because every Peer needs the KiloEndpoint field that contains a DNS endpoint.
|
||||
// The Peers field is shadowed because every Peer needs the Endpoint field that contains a DNS endpoint.
|
||||
Peers []Peer
|
||||
}
|
||||
|
||||
@@ -60,16 +60,10 @@ func (c Conf) WGConfig() wgtypes.Config {
|
||||
return r
|
||||
}
|
||||
|
||||
// Interface represents the `interface` section of a WireGuard configuration.
|
||||
type Interface struct {
|
||||
ListenPort uint32
|
||||
PrivateKey []byte
|
||||
}
|
||||
|
||||
// Peer represents a `peer` section of a WireGuard configuration.
|
||||
type Peer struct {
|
||||
wgtypes.PeerConfig
|
||||
KiloEndpoint *Endpoint
|
||||
Addr string // eg: dnsname:port
|
||||
}
|
||||
|
||||
// DeduplicateIPs eliminates duplicate allowed IPs.
|
||||
@@ -86,79 +80,6 @@ func (p *Peer) DeduplicateIPs() {
|
||||
p.AllowedIPs = ips
|
||||
}
|
||||
|
||||
// Endpoint represents an `endpoint` key of a `peer` section.
|
||||
type Endpoint struct {
|
||||
DNSOrIP
|
||||
Port int
|
||||
}
|
||||
|
||||
// String prints the string representation of the endpoint.
|
||||
func (e *Endpoint) String() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
dnsOrIP := e.DNSOrIP.String()
|
||||
if e.IP != nil && len(e.IP) == net.IPv6len {
|
||||
dnsOrIP = "[" + dnsOrIP + "]"
|
||||
}
|
||||
return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
|
||||
}
|
||||
|
||||
// UDPAddr returns the corresponding net.UDPAddr of the Endpoint or nil.
|
||||
func (e *Endpoint) UDPAddr() (u *net.UDPAddr) {
|
||||
if a, err := net.ResolveUDPAddr("udp", e.String()); err == nil {
|
||||
u = a
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Equal compares two endpoints.
|
||||
func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool {
|
||||
if (e == nil) != (b == nil) {
|
||||
return false
|
||||
}
|
||||
if e != nil {
|
||||
if e.Port != b.Port {
|
||||
return false
|
||||
}
|
||||
if DNSFirst {
|
||||
// Check the DNS name first if it was resolved.
|
||||
if e.DNS != b.DNS {
|
||||
return false
|
||||
}
|
||||
if e.DNS == "" && !e.IP.Equal(b.IP) {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
// IPs take priority, so check them first.
|
||||
if !e.IP.Equal(b.IP) {
|
||||
return false
|
||||
}
|
||||
// Only check the DNS name if the IP is empty.
|
||||
if e.IP == nil && e.DNS != b.DNS {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// DNSOrIP represents either a DNS name or an IP address.
|
||||
// IPs, as they are more specific, are preferred.
|
||||
type DNSOrIP struct {
|
||||
DNS string
|
||||
IP net.IP
|
||||
}
|
||||
|
||||
// String prints the string representation of the struct.
|
||||
func (d DNSOrIP) String() string {
|
||||
if d.IP != nil {
|
||||
return d.IP.String()
|
||||
}
|
||||
return d.DNS
|
||||
}
|
||||
|
||||
// Bytes renders a WireGuard configuration to bytes.
|
||||
func (c Conf) Bytes() ([]byte, error) {
|
||||
var err error
|
||||
@@ -188,13 +109,13 @@ func (c Conf) Bytes() ([]byte, error) {
|
||||
if err = writeAllowedIPs(buf, p.AllowedIPs); err != nil {
|
||||
return nil, fmt.Errorf("failed to write allowed IPs: %v", err)
|
||||
}
|
||||
if err = writeEndpoint(buf, p.KiloEndpoint); err != nil {
|
||||
if err = writeEndpoint(buf, p.Endpoint, p.Addr); err != nil {
|
||||
return nil, fmt.Errorf("failed to write endpoint: %v", err)
|
||||
}
|
||||
if p.PersistentKeepaliveInterval == nil {
|
||||
p.PersistentKeepaliveInterval = new(time.Duration)
|
||||
}
|
||||
if err = writeValue(buf, persistentKeepaliveKey, strconv.FormatUint(uint64(*p.PersistentKeepaliveInterval), 10)); err != nil {
|
||||
if err = writeValue(buf, persistentKeepaliveKey, strconv.FormatUint(uint64(*p.PersistentKeepaliveInterval/time.Second), 10)); err != nil {
|
||||
return nil, fmt.Errorf("failed to write persistent keepalive: %v", err)
|
||||
}
|
||||
if err = writePKey(buf, presharedKeyKey, p.PresharedKey); err != nil {
|
||||
@@ -327,15 +248,20 @@ func writeValue(buf *bytes.Buffer, k key, v string) error {
|
||||
return buf.WriteByte('\n')
|
||||
}
|
||||
|
||||
func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error {
|
||||
if e == nil {
|
||||
func writeEndpoint(buf *bytes.Buffer, e *net.UDPAddr, d string) error {
|
||||
str := ""
|
||||
if d != "" {
|
||||
str = d
|
||||
} else if e != nil {
|
||||
str = e.String()
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
if err = writeKey(buf, endpointKey); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = buf.WriteString(e.String()); err != nil {
|
||||
if _, err = buf.WriteString(str); err != nil {
|
||||
return err
|
||||
}
|
||||
return buf.WriteByte('\n')
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
// Copyright 2021 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompareEndpoint(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
a *Endpoint
|
||||
b *Endpoint
|
||||
dnsFirst bool
|
||||
out bool
|
||||
}{
|
||||
{
|
||||
name: "both nil",
|
||||
a: nil,
|
||||
b: nil,
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "a nil",
|
||||
a: nil,
|
||||
b: &Endpoint{},
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "b nil",
|
||||
a: &Endpoint{},
|
||||
b: nil,
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "zero",
|
||||
a: &Endpoint{},
|
||||
b: &Endpoint{},
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "diff port",
|
||||
a: &Endpoint{Port: 1234},
|
||||
b: &Endpoint{Port: 5678},
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "same IP",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "diff IP",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2")}},
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "same IP ignore DNS",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "b"}},
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "no IP check DNS",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "no IP check DNS (same)",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "DNS first, ignore IP",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}},
|
||||
dnsFirst: true,
|
||||
out: true,
|
||||
},
|
||||
{
|
||||
name: "DNS first",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
|
||||
dnsFirst: true,
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "DNS first, no DNS compare IP",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}},
|
||||
dnsFirst: true,
|
||||
out: false,
|
||||
},
|
||||
{
|
||||
name: "DNS first, no DNS compare IP (same)",
|
||||
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
|
||||
dnsFirst: true,
|
||||
out: true,
|
||||
},
|
||||
} {
|
||||
equal := tc.a.Equal(tc.b, tc.dnsFirst)
|
||||
if equal != tc.out {
|
||||
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user