pkg/*: use wireguard.Enpoint
This commit introduces the wireguard.Enpoint struct. It encapsulates a DN name with port and a net.UPDAddr. The fields are private and only accessible over exported Methods to avoid accidental modification. Also iptables.GetProtocol is improved to avoid ipv4 rules being applied by `ip6tables`. Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
@@ -16,6 +16,7 @@ package wireguard
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"k8s.io/apimachinery/pkg/util/validation"
|
||||
)
|
||||
|
||||
type section string
|
||||
@@ -53,6 +55,10 @@ func (c Conf) WGConfig() wgtypes.Config {
|
||||
wgPs := make([]wgtypes.PeerConfig, len(c.Peers))
|
||||
for i, p := range c.Peers {
|
||||
wgPs[i] = p.PeerConfig
|
||||
if p.Endpoint.Resolved() {
|
||||
// We can ingore the error because we already checked if the Endpoint was resolved in the above line.
|
||||
wgPs[i].Endpoint, _ = p.Endpoint.UDPAddr(false)
|
||||
}
|
||||
wgPs[i].ReplaceAllowedIPs = true
|
||||
}
|
||||
r.Peers = wgPs
|
||||
@@ -60,10 +66,169 @@ func (c Conf) WGConfig() wgtypes.Config {
|
||||
return r
|
||||
}
|
||||
|
||||
// Endpoint represents a WireGuard endpoint.
|
||||
type Endpoint struct {
|
||||
udpAddr *net.UDPAddr
|
||||
addr string
|
||||
}
|
||||
|
||||
// ParseEndpoint returns an Endpoint from a string.
|
||||
// The input should look like "10.0.0.0:100", "[ff10::10]:100"
|
||||
// or "example.com:100".
|
||||
func ParseEndpoint(endpoint string) *Endpoint {
|
||||
if len(endpoint) == 0 {
|
||||
return nil
|
||||
}
|
||||
hostRaw, portRaw, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
port, err := strconv.ParseUint(portRaw, 10, 32)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(validation.IsValidPortNum(int(port))) != 0 {
|
||||
return nil
|
||||
}
|
||||
ip := net.ParseIP(hostRaw)
|
||||
if ip == nil {
|
||||
if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 {
|
||||
return &Endpoint{
|
||||
addr: endpoint,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// ResolveUDPAddr will not resolve the endpoint as long as a valid IP and port is given.
|
||||
// This should be the case here.
|
||||
u, err := net.ResolveUDPAddr("udp", endpoint)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
u.IP = cutIP(u.IP)
|
||||
return &Endpoint{
|
||||
udpAddr: u,
|
||||
}
|
||||
}
|
||||
|
||||
// NewEndpointFromUDPAddr returns an Endpoint from a net.UDPAddr.
|
||||
func NewEndpointFromUDPAddr(u *net.UDPAddr) *Endpoint {
|
||||
if u != nil {
|
||||
u.IP = cutIP(u.IP)
|
||||
}
|
||||
return &Endpoint{
|
||||
udpAddr: u,
|
||||
}
|
||||
}
|
||||
|
||||
// NewEndpoint returns an Endpoint from a net.IP and port.
|
||||
func NewEndpoint(ip net.IP, port int) *Endpoint {
|
||||
return &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: cutIP(ip),
|
||||
Port: port,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Ready return true, if the Enpoint is ready.
|
||||
// Ready means that an IP or DN and port exists.
|
||||
func (e *Endpoint) Ready() bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
return (e.udpAddr != nil && e.udpAddr.IP != nil && e.udpAddr.Port > 0) || len(e.addr) > 0
|
||||
}
|
||||
|
||||
// Port returns the port of the Endpoint.
|
||||
func (e *Endpoint) Port() int {
|
||||
if !e.Ready() {
|
||||
return 0
|
||||
}
|
||||
if e.udpAddr != nil {
|
||||
return e.udpAddr.Port
|
||||
}
|
||||
// We can ignore the errors here bacause the returned port will be "".
|
||||
// This will result to Port 0 after the conversion to and int.
|
||||
_, p, _ := net.SplitHostPort(e.addr)
|
||||
port, _ := strconv.ParseUint(p, 10, 32)
|
||||
return int(port)
|
||||
}
|
||||
|
||||
// HasDNS returns true if the endpoint has a DN.
|
||||
func (e *Endpoint) HasDNS() bool {
|
||||
return e != nil && e.addr != ""
|
||||
}
|
||||
|
||||
// DNS returns the DN of the Endpoint.
|
||||
func (e *Endpoint) DNS() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
_, s, _ := net.SplitHostPort(e.addr)
|
||||
return s
|
||||
}
|
||||
|
||||
// Resolved returns true, if the DN of the Endpoint was resolved
|
||||
// or if the Endpoint has a resolved endpoint.
|
||||
func (e *Endpoint) Resolved() bool {
|
||||
return e != nil && e.udpAddr != nil
|
||||
}
|
||||
|
||||
// UDPAddr returns the UDPAddr of the Endpoint. If resolve is false,
|
||||
// UDPAddr() will not try to resolve a DN name, if the Endpoint is not yet resolved.
|
||||
func (e *Endpoint) UDPAddr(resolve bool) (*net.UDPAddr, error) {
|
||||
if !e.Ready() {
|
||||
return nil, errors.New("Enpoint is not ready")
|
||||
}
|
||||
if e.udpAddr != nil {
|
||||
// Make a copy of the UDPAddr to protect it from modification outside this package.
|
||||
h := *e.udpAddr
|
||||
return &h, nil
|
||||
}
|
||||
if !resolve {
|
||||
return nil, errors.New("Endpoint is not resolved")
|
||||
}
|
||||
var err error
|
||||
if e.udpAddr, err = net.ResolveUDPAddr("udp", e.addr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Make a copy of the UDPAddr to protect it from modification outside this package.
|
||||
h := *e.udpAddr
|
||||
return &h, nil
|
||||
}
|
||||
|
||||
// IP returns the IP address of the Enpoint or nil.
|
||||
func (e *Endpoint) IP() net.IP {
|
||||
if !e.Resolved() {
|
||||
return nil
|
||||
}
|
||||
return e.udpAddr.IP
|
||||
}
|
||||
|
||||
// String will return the endpoint as a string.
|
||||
// If a DN exists, it will take prcedence over the resolved endpoint.
|
||||
func (e *Endpoint) String() string {
|
||||
return e.StringOpt(true)
|
||||
}
|
||||
|
||||
// StringOpt will return string of the Endpoint.
|
||||
// If dnsFirst is false, the resolved Endpoint will
|
||||
// take precedence over the DN.
|
||||
func (e *Endpoint) StringOpt(dnsFirst bool) string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.udpAddr != nil && (!dnsFirst || e.addr == "") {
|
||||
return e.udpAddr.String()
|
||||
}
|
||||
return e.addr
|
||||
}
|
||||
|
||||
// Peer represents a `peer` section of a WireGuard configuration.
|
||||
type Peer struct {
|
||||
wgtypes.PeerConfig
|
||||
Addr string // eg: dnsname:port
|
||||
Endpoint *Endpoint
|
||||
}
|
||||
|
||||
// DeduplicateIPs eliminates duplicate allowed IPs.
|
||||
@@ -109,7 +274,7 @@ func (c Conf) Bytes() ([]byte, error) {
|
||||
if err = writeAllowedIPs(buf, p.AllowedIPs); err != nil {
|
||||
return nil, fmt.Errorf("failed to write allowed IPs: %v", err)
|
||||
}
|
||||
if err = writeEndpoint(buf, p.Endpoint, p.Addr); err != nil {
|
||||
if err = writeEndpoint(buf, p.Endpoint); err != nil {
|
||||
return nil, fmt.Errorf("failed to write endpoint: %v", err)
|
||||
}
|
||||
if p.PersistentKeepaliveInterval == nil {
|
||||
@@ -158,8 +323,8 @@ func (c *Conf) Equal(d *wgtypes.Device) (bool, string) {
|
||||
if c.Peers[i].Endpoint == nil || d.Peers[i].Endpoint == nil {
|
||||
return c.Peers[i].Endpoint == nil && d.Peers[i].Endpoint == nil, "peer endpoints: nil value"
|
||||
}
|
||||
if !c.Peers[i].Endpoint.IP.Equal(d.Peers[i].Endpoint.IP) || c.Peers[i].Endpoint.Port != d.Peers[i].Endpoint.Port {
|
||||
return false, fmt.Sprintf("Peer %d endpoint: old=%q, new=%q", i, d.Peers[i].Endpoint.String(), c.Peers[i].Endpoint.String())
|
||||
if c.Peers[i].Endpoint.StringOpt(false) != d.Peers[i].Endpoint.String() {
|
||||
return false, fmt.Sprintf("Peer %d endpoint: old=%q, new=%q", i, d.Peers[i].Endpoint.String(), c.Peers[i].Endpoint.StringOpt(false))
|
||||
}
|
||||
|
||||
pki := time.Duration(0)
|
||||
@@ -201,6 +366,13 @@ func sortCIDRs(cidrs []net.IPNet) {
|
||||
})
|
||||
}
|
||||
|
||||
func cutIP(ip net.IP) net.IP {
|
||||
if i4 := ip.To4(); i4 != nil {
|
||||
return i4
|
||||
}
|
||||
return ip.To16()
|
||||
}
|
||||
|
||||
func writeAllowedIPs(buf *bytes.Buffer, ais []net.IPNet) error {
|
||||
if len(ais) == 0 {
|
||||
return nil
|
||||
@@ -248,13 +420,9 @@ func writeValue(buf *bytes.Buffer, k key, v string) error {
|
||||
return buf.WriteByte('\n')
|
||||
}
|
||||
|
||||
func writeEndpoint(buf *bytes.Buffer, e *net.UDPAddr, d string) error {
|
||||
str := ""
|
||||
if d != "" {
|
||||
str = d
|
||||
} else if e != nil {
|
||||
str = e.String()
|
||||
} else {
|
||||
func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error {
|
||||
str := e.String()
|
||||
if str == "" {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
|
Reference in New Issue
Block a user