migrate to golang.zx2c4.com/wireguard/wgctrl (#239)

* migrate to golang.zx2c4.com/wireguard/wgctrl

This commit introduces the usage of wgctrl.
It avoids the usage of exec calls of the wg command
and parsing the output of `wg show`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* vendor wgctrl

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* apply suggestions from code review

Remove wireguard.Enpoint struct and use net.UDPAddr for the resolved
endpoint and addr string (dnsanme:port) if a DN was supplied.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/*: use wireguard.Enpoint

This commit introduces the wireguard.Enpoint struct.
It encapsulates a DN name with port and a net.UPDAddr.
The fields are private and only accessible over exported Methods
to avoid accidental modification.

Also iptables.GetProtocol is improved to avoid ipv4 rules being applied
by `ip6tables`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/wireguard/conf_test.go: add tests for Endpoint

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* cmd/kg/main.go: validate port range

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* add suggestions from review

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/mesh/mesh.go: use Equal func

Implement an Equal func for Enpoint and use it instead of comparing
strings.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* cmd/kgctl/main.go: check port range

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* vendor

Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
leonnicolas
2022-01-30 17:38:45 +01:00
committed by GitHub
parent 797133f272
commit 6a696e03e7
299 changed files with 26275 additions and 10252 deletions

View File

@@ -15,16 +15,15 @@
package wireguard
import (
"bufio"
"bytes"
"errors"
"fmt"
"net"
"sort"
"strconv"
"strings"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"k8s.io/apimachinery/pkg/util/validation"
)
@@ -32,10 +31,6 @@ type section string
type key string
const (
separator = "="
dumpSeparator = "\t"
dumpNone = "(none)"
dumpOff = "off"
interfaceSection section = "Interface"
peerSection section = "Peer"
listenPortKey key = "ListenPort"
@@ -47,56 +42,209 @@ const (
publicKeyKey key = "PublicKey"
)
type dumpInterfaceIndex int
const (
dumpInterfacePrivateKeyIndex = iota
dumpInterfacePublicKeyIndex
dumpInterfaceListenPortIndex
dumpInterfaceFWMarkIndex
dumpInterfaceLen
)
type dumpPeerIndex int
const (
dumpPeerPublicKeyIndex = iota
dumpPeerPresharedKeyIndex
dumpPeerEndpointIndex
dumpPeerAllowedIPsIndex
dumpPeerLatestHandshakeIndex
dumpPeerTransferRXIndex
dumpPeerTransferTXIndex
dumpPeerPersistentKeepaliveIndex
dumpPeerLen
)
// Conf represents a WireGuard configuration file.
type Conf struct {
Interface *Interface
Peers []*Peer
wgtypes.Config
// The Peers field is shadowed because every Peer needs the Endpoint field that contains a DNS endpoint.
Peers []Peer
}
// Interface represents the `interface` section of a WireGuard configuration.
type Interface struct {
ListenPort uint32
PrivateKey []byte
// WGConfig returns a wgytpes.Config from a Conf.
func (c *Conf) WGConfig() wgtypes.Config {
if c == nil {
// The empty Config will do nothing, when applied.
return wgtypes.Config{}
}
r := c.Config
wgPs := make([]wgtypes.PeerConfig, len(c.Peers))
for i, p := range c.Peers {
wgPs[i] = p.PeerConfig
if p.Endpoint.Resolved() {
// We can ingore the error because we already checked if the Endpoint was resolved in the above line.
wgPs[i].Endpoint, _ = p.Endpoint.UDPAddr(false)
}
wgPs[i].ReplaceAllowedIPs = true
}
r.Peers = wgPs
r.ReplacePeers = true
return r
}
// Endpoint represents a WireGuard endpoint.
type Endpoint struct {
udpAddr *net.UDPAddr
addr string
}
// ParseEndpoint returns an Endpoint from a string.
// The input should look like "10.0.0.0:100", "[ff10::10]:100"
// or "example.com:100".
func ParseEndpoint(endpoint string) *Endpoint {
if len(endpoint) == 0 {
return nil
}
hostRaw, portRaw, err := net.SplitHostPort(endpoint)
if err != nil {
return nil
}
port, err := strconv.ParseUint(portRaw, 10, 32)
if err != nil {
return nil
}
if len(validation.IsValidPortNum(int(port))) != 0 {
return nil
}
ip := net.ParseIP(hostRaw)
if ip == nil {
if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 {
return &Endpoint{
addr: endpoint,
}
}
return nil
}
// ResolveUDPAddr will not resolve the endpoint as long as a valid IP and port is given.
// This should be the case here.
u, err := net.ResolveUDPAddr("udp", endpoint)
if err != nil {
return nil
}
u.IP = cutIP(u.IP)
return &Endpoint{
udpAddr: u,
}
}
// NewEndpointFromUDPAddr returns an Endpoint from a net.UDPAddr.
func NewEndpointFromUDPAddr(u *net.UDPAddr) *Endpoint {
if u != nil {
u.IP = cutIP(u.IP)
}
return &Endpoint{
udpAddr: u,
}
}
// NewEndpoint returns an Endpoint from a net.IP and port.
func NewEndpoint(ip net.IP, port int) *Endpoint {
return &Endpoint{
udpAddr: &net.UDPAddr{
IP: cutIP(ip),
Port: port,
},
}
}
// Ready return true, if the Enpoint is ready.
// Ready means that an IP or DN and port exists.
func (e *Endpoint) Ready() bool {
if e == nil {
return false
}
return (e.udpAddr != nil && e.udpAddr.IP != nil && e.udpAddr.Port > 0) || len(e.addr) > 0
}
// Port returns the port of the Endpoint.
func (e *Endpoint) Port() int {
if !e.Ready() {
return 0
}
if e.udpAddr != nil {
return e.udpAddr.Port
}
// We can ignore the errors here bacause the returned port will be "".
// This will result to Port 0 after the conversion to and int.
_, p, _ := net.SplitHostPort(e.addr)
port, _ := strconv.ParseUint(p, 10, 32)
return int(port)
}
// HasDNS returns true if the endpoint has a DN.
func (e *Endpoint) HasDNS() bool {
return e != nil && e.addr != ""
}
// DNS returns the DN of the Endpoint.
func (e *Endpoint) DNS() string {
if e == nil {
return ""
}
_, s, _ := net.SplitHostPort(e.addr)
return s
}
// Resolved returns true, if the DN of the Endpoint was resolved
// or if the Endpoint has a resolved endpoint.
func (e *Endpoint) Resolved() bool {
return e != nil && e.udpAddr != nil
}
// UDPAddr returns the UDPAddr of the Endpoint. If resolve is false,
// UDPAddr() will not try to resolve a DN name, if the Endpoint is not yet resolved.
func (e *Endpoint) UDPAddr(resolve bool) (*net.UDPAddr, error) {
if !e.Ready() {
return nil, errors.New("Enpoint is not ready")
}
if e.udpAddr != nil {
// Make a copy of the UDPAddr to protect it from modification outside this package.
h := *e.udpAddr
return &h, nil
}
if !resolve {
return nil, errors.New("Endpoint is not resolved")
}
var err error
if e.udpAddr, err = net.ResolveUDPAddr("udp", e.addr); err != nil {
return nil, err
}
// Make a copy of the UDPAddr to protect it from modification outside this package.
h := *e.udpAddr
return &h, nil
}
// IP returns the IP address of the Enpoint or nil.
func (e *Endpoint) IP() net.IP {
if !e.Resolved() {
return nil
}
return e.udpAddr.IP
}
// String will return the endpoint as a string.
// If a DN exists, it will take prcedence over the resolved endpoint.
func (e *Endpoint) String() string {
return e.StringOpt(true)
}
// StringOpt will return the string of the Endpoint.
// If dnsFirst is false, the resolved Endpoint will
// take precedence over the DN.
func (e *Endpoint) StringOpt(dnsFirst bool) string {
if e == nil {
return ""
}
if e.udpAddr != nil && (!dnsFirst || e.addr == "") {
return e.udpAddr.String()
}
return e.addr
}
// Equal will return true, if the Enpoints are equal.
// If dnsFirst is false, the DN will only be compared if
// the IPs are nil.
func (e *Endpoint) Equal(b *Endpoint, dnsFirst bool) bool {
return e.StringOpt(dnsFirst) == b.StringOpt(dnsFirst)
}
// Peer represents a `peer` section of a WireGuard configuration.
type Peer struct {
AllowedIPs []*net.IPNet
Endpoint *Endpoint
PersistentKeepalive int
PresharedKey []byte
PublicKey []byte
// The following fields are part of the runtime information, not the configuration.
LatestHandshake time.Time
wgtypes.PeerConfig
Endpoint *Endpoint
}
// DeduplicateIPs eliminates duplicate allowed IPs.
func (p *Peer) DeduplicateIPs() {
var ips []*net.IPNet
var ips []net.IPNet
seen := make(map[string]struct{})
for _, ip := range p.AllowedIPs {
if _, ok := seen[ip.String()]; ok {
@@ -108,181 +256,27 @@ func (p *Peer) DeduplicateIPs() {
p.AllowedIPs = ips
}
// Endpoint represents an `endpoint` key of a `peer` section.
type Endpoint struct {
DNSOrIP
Port uint32
}
// String prints the string representation of the endpoint.
func (e *Endpoint) String() string {
if e == nil {
return ""
}
dnsOrIP := e.DNSOrIP.String()
if e.IP != nil && len(e.IP) == net.IPv6len {
dnsOrIP = "[" + dnsOrIP + "]"
}
return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
}
// Equal compares two endpoints.
func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool {
if (e == nil) != (b == nil) {
return false
}
if e != nil {
if e.Port != b.Port {
return false
}
if DNSFirst {
// Check the DNS name first if it was resolved.
if e.DNS != b.DNS {
return false
}
if e.DNS == "" && !e.IP.Equal(b.IP) {
return false
}
} else {
// IPs take priority, so check them first.
if !e.IP.Equal(b.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if e.IP == nil && e.DNS != b.DNS {
return false
}
}
}
return true
}
// DNSOrIP represents either a DNS name or an IP address.
// IPs, as they are more specific, are preferred.
type DNSOrIP struct {
DNS string
IP net.IP
}
// String prints the string representation of the struct.
func (d DNSOrIP) String() string {
if d.IP != nil {
return d.IP.String()
}
return d.DNS
}
// Parse parses a given WireGuard configuration file and produces a Conf struct.
func Parse(buf []byte) *Conf {
var (
active section
kv []string
c Conf
err error
iface *Interface
i int
k key
line, v string
peer *Peer
port uint64
)
s := bufio.NewScanner(bytes.NewBuffer(buf))
for s.Scan() {
line = strings.TrimSpace(s.Text())
// Skip comments.
if strings.HasPrefix(line, "#") {
continue
}
// Line is a section title.
if strings.HasPrefix(line, "[") {
if peer != nil {
c.Peers = append(c.Peers, peer)
peer = nil
}
if iface != nil {
c.Interface = iface
iface = nil
}
active = section(strings.TrimSpace(strings.Trim(line, "[]")))
switch active {
case interfaceSection:
iface = new(Interface)
case peerSection:
peer = new(Peer)
}
continue
}
kv = strings.SplitN(line, separator, 2)
if len(kv) != 2 {
continue
}
k = key(strings.TrimSpace(kv[0]))
v = strings.TrimSpace(kv[1])
switch active {
case interfaceSection:
switch k {
case listenPortKey:
port, err = strconv.ParseUint(v, 10, 32)
if err != nil {
continue
}
iface.ListenPort = uint32(port)
case privateKeyKey:
iface.PrivateKey = []byte(v)
}
case peerSection:
switch k {
case allowedIPsKey:
err = peer.parseAllowedIPs(v)
if err != nil {
continue
}
case endpointKey:
err = peer.parseEndpoint(v)
if err != nil {
continue
}
case persistentKeepaliveKey:
i, err = strconv.Atoi(v)
if err != nil {
continue
}
peer.PersistentKeepalive = i
case presharedKeyKey:
peer.PresharedKey = []byte(v)
case publicKeyKey:
peer.PublicKey = []byte(v)
}
}
}
if peer != nil {
c.Peers = append(c.Peers, peer)
}
if iface != nil {
c.Interface = iface
}
return &c
}
// Bytes renders a WireGuard configuration to bytes.
func (c *Conf) Bytes() ([]byte, error) {
if c == nil {
return nil, nil
}
var err error
buf := bytes.NewBuffer(make([]byte, 0, 512))
if c.Interface != nil {
if c.PrivateKey != nil {
if err = writeSection(buf, interfaceSection); err != nil {
return nil, fmt.Errorf("failed to write interface: %v", err)
}
if err = writePKey(buf, privateKeyKey, c.Interface.PrivateKey); err != nil {
if err = writePKey(buf, privateKeyKey, c.PrivateKey); err != nil {
return nil, fmt.Errorf("failed to write private key: %v", err)
}
if err = writeValue(buf, listenPortKey, strconv.FormatUint(uint64(c.Interface.ListenPort), 10)); err != nil {
if err = writeValue(buf, listenPortKey, strconv.Itoa(*c.ListenPort)); err != nil {
return nil, fmt.Errorf("failed to write listen port: %v", err)
}
}
for i, p := range c.Peers {
// Add newlines to make the formatting nicer.
if i == 0 && c.Interface != nil || i != 0 {
if i == 0 && c.PrivateKey != nil || i != 0 {
if err = buf.WriteByte('\n'); err != nil {
return nil, err
}
@@ -297,71 +291,103 @@ func (c *Conf) Bytes() ([]byte, error) {
if err = writeEndpoint(buf, p.Endpoint); err != nil {
return nil, fmt.Errorf("failed to write endpoint: %v", err)
}
if err = writeValue(buf, persistentKeepaliveKey, strconv.Itoa(p.PersistentKeepalive)); err != nil {
if p.PersistentKeepaliveInterval == nil {
p.PersistentKeepaliveInterval = new(time.Duration)
}
if err = writeValue(buf, persistentKeepaliveKey, strconv.FormatUint(uint64(*p.PersistentKeepaliveInterval/time.Second), 10)); err != nil {
return nil, fmt.Errorf("failed to write persistent keepalive: %v", err)
}
if err = writePKey(buf, presharedKeyKey, p.PresharedKey); err != nil {
return nil, fmt.Errorf("failed to write preshared key: %v", err)
}
if err = writePKey(buf, publicKeyKey, p.PublicKey); err != nil {
if err = writePKey(buf, publicKeyKey, &p.PublicKey); err != nil {
return nil, fmt.Errorf("failed to write public key: %v", err)
}
}
return buf.Bytes(), nil
}
// Equal checks if two WireGuard configurations are equivalent.
func (c *Conf) Equal(b *Conf) bool {
if (c.Interface == nil) != (b.Interface == nil) {
return false
// Equal returns true if the Conf and wgtypes.Device are equal.
func (c *Conf) Equal(d *wgtypes.Device) (bool, string) {
if c == nil || d == nil {
return c == nil && d == nil, "nil values"
}
if c.Interface != nil {
if c.Interface.ListenPort != b.Interface.ListenPort || !bytes.Equal(c.Interface.PrivateKey, b.Interface.PrivateKey) {
return false
}
if c.ListenPort == nil || *c.ListenPort != d.ListenPort {
return false, fmt.Sprintf("port: old=%q, new=\"%v\"", d.ListenPort, c.ListenPort)
}
if len(c.Peers) != len(b.Peers) {
return false
if c.PrivateKey == nil || *c.PrivateKey != d.PrivateKey {
return false, fmt.Sprintf("private key: old=\"%s...\", new=\"%s\"", d.PrivateKey.String()[0:5], c.PrivateKey.String()[0:5])
}
if len(c.Peers) != len(d.Peers) {
return false, fmt.Sprintf("number of peers: old=%d, new=%d", len(d.Peers), len(c.Peers))
}
sortPeerConfigs(d.Peers)
sortPeers(c.Peers)
sortPeers(b.Peers)
for i := range c.Peers {
if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) {
return false
if len(c.Peers[i].AllowedIPs) != len(d.Peers[i].AllowedIPs) {
return false, fmt.Sprintf("Peer %d allowed IP length: old=%d, new=%d", i, len(d.Peers[i].AllowedIPs), len(c.Peers[i].AllowedIPs))
}
sortCIDRs(c.Peers[i].AllowedIPs)
sortCIDRs(b.Peers[i].AllowedIPs)
sortCIDRs(d.Peers[i].AllowedIPs)
for j := range c.Peers[i].AllowedIPs {
if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() {
return false
if c.Peers[i].AllowedIPs[j].String() != d.Peers[i].AllowedIPs[j].String() {
return false, fmt.Sprintf("Peer %d allowed IP: old=%q, new=%q", i, d.Peers[i].AllowedIPs[j].String(), c.Peers[i].AllowedIPs[j].String())
}
}
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint, false) {
return false
if c.Peers[i].Endpoint == nil || d.Peers[i].Endpoint == nil {
return c.Peers[i].Endpoint == nil && d.Peers[i].Endpoint == nil, "peer endpoints: nil value"
}
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
return false
if c.Peers[i].Endpoint.StringOpt(false) != d.Peers[i].Endpoint.String() {
return false, fmt.Sprintf("Peer %d endpoint: old=%q, new=%q", i, d.Peers[i].Endpoint.String(), c.Peers[i].Endpoint.StringOpt(false))
}
pki := time.Duration(0)
if p := c.Peers[i].PersistentKeepaliveInterval; p != nil {
pki = *p
}
psk := wgtypes.Key{}
if p := c.Peers[i].PresharedKey; p != nil {
psk = *p
}
if pki != d.Peers[i].PersistentKeepaliveInterval || psk != d.Peers[i].PresharedKey || c.Peers[i].PublicKey != d.Peers[i].PublicKey {
return false, "persistent keepalive or pershared key"
}
}
return true
return true, ""
}
func sortPeers(peers []*Peer) {
func sortPeerConfigs(peers []wgtypes.Peer) {
sort.Slice(peers, func(i, j int) bool {
if bytes.Compare(peers[i].PublicKey, peers[j].PublicKey) < 0 {
if peers[i].PublicKey.String() < peers[j].PublicKey.String() {
return true
}
return false
})
}
func sortCIDRs(cidrs []*net.IPNet) {
func sortPeers(peers []Peer) {
sort.Slice(peers, func(i, j int) bool {
if peers[i].PublicKey.String() < peers[j].PublicKey.String() {
return true
}
return false
})
}
func sortCIDRs(cidrs []net.IPNet) {
sort.Slice(cidrs, func(i, j int) bool {
return cidrs[i].String() < cidrs[j].String()
})
}
func writeAllowedIPs(buf *bytes.Buffer, ais []*net.IPNet) error {
func cutIP(ip net.IP) net.IP {
if i4 := ip.To4(); i4 != nil {
return i4
}
return ip.To16()
}
func writeAllowedIPs(buf *bytes.Buffer, ais []net.IPNet) error {
if len(ais) == 0 {
return nil
}
@@ -382,15 +408,16 @@ func writeAllowedIPs(buf *bytes.Buffer, ais []*net.IPNet) error {
return buf.WriteByte('\n')
}
func writePKey(buf *bytes.Buffer, k key, b []byte) error {
if len(b) == 0 {
func writePKey(buf *bytes.Buffer, k key, b *wgtypes.Key) error {
// Print nothing if the public key was never initialized.
if b == nil || (wgtypes.Key{}) == *b {
return nil
}
var err error
if err = writeKey(buf, k); err != nil {
return err
}
if _, err = buf.Write(b); err != nil {
if _, err = buf.Write([]byte(b.String())); err != nil {
return err
}
return buf.WriteByte('\n')
@@ -408,14 +435,15 @@ func writeValue(buf *bytes.Buffer, k key, v string) error {
}
func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error {
if e == nil {
str := e.String()
if str == "" {
return nil
}
var err error
if err = writeKey(buf, endpointKey); err != nil {
return err
}
if _, err = buf.WriteString(e.String()); err != nil {
if _, err = buf.WriteString(str); err != nil {
return err
}
return buf.WriteByte('\n')
@@ -443,177 +471,3 @@ func writeKey(buf *bytes.Buffer, k key) error {
_, err = buf.WriteString(" = ")
return err
}
var (
errParseEndpoint = errors.New("could not parse Endpoint")
)
func (p *Peer) parseEndpoint(v string) error {
var (
kv []string
err error
ip, ip4 net.IP
port uint64
)
kv = strings.Split(v, ":")
if len(kv) < 2 {
return errParseEndpoint
}
port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32)
if err != nil {
return err
}
d := DNSOrIP{}
ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]"))
if ip == nil {
if len(validation.IsDNS1123Subdomain(kv[0])) != 0 {
return errParseEndpoint
}
d.DNS = kv[0]
} else {
if ip4 = ip.To4(); ip4 != nil {
d.IP = ip4
} else {
d.IP = ip.To16()
}
}
p.Endpoint = &Endpoint{
DNSOrIP: d,
Port: uint32(port),
}
return nil
}
func (p *Peer) parseAllowedIPs(v string) error {
var (
ai *net.IPNet
kv []string
err error
i int
ip, ip4 net.IP
)
kv = strings.Split(v, ",")
for i = range kv {
ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i]))
if err != nil {
return err
}
if ip4 = ip.To4(); ip4 != nil {
ip = ip4
} else {
ip = ip.To16()
}
ai.IP = ip
p.AllowedIPs = append(p.AllowedIPs, ai)
}
return nil
}
// ParseDump parses a given WireGuard dump and produces a Conf struct.
func ParseDump(buf []byte) (*Conf, error) {
// from man wg, show section:
// If dump is specified, then several lines are printed;
// the first contains in order separated by tab: private-key, public-key, listen-port, fwmark.
// Subsequent lines are printed for each peer and contain in order separated by tab:
// public-key, preshared-key, endpoint, allowed-ips, latest-handshake, transfer-rx, transfer-tx, persistent-keepalive.
var (
active section
values []string
c Conf
err error
iface *Interface
peer *Peer
port uint64
sec int64
pka int
line int
)
// First line is Interface
active = interfaceSection
s := bufio.NewScanner(bytes.NewBuffer(buf))
for s.Scan() {
values = strings.Split(s.Text(), dumpSeparator)
switch active {
case interfaceSection:
if len(values) < dumpInterfaceLen {
return nil, fmt.Errorf("invalid interface line: missing fields (%d < %d)", len(values), dumpInterfaceLen)
}
iface = new(Interface)
for i := range values {
switch i {
case dumpInterfacePrivateKeyIndex:
iface.PrivateKey = []byte(values[i])
case dumpInterfaceListenPortIndex:
port, err = strconv.ParseUint(values[i], 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid interface line: error parsing listen-port: %w", err)
}
iface.ListenPort = uint32(port)
}
}
c.Interface = iface
// Next lines are Peers
active = peerSection
case peerSection:
if len(values) < dumpPeerLen {
return nil, fmt.Errorf("invalid peer line %d: missing fields (%d < %d)", line, len(values), dumpPeerLen)
}
peer = new(Peer)
for i := range values {
switch i {
case dumpPeerPublicKeyIndex:
peer.PublicKey = []byte(values[i])
case dumpPeerPresharedKeyIndex:
if values[i] == dumpNone {
continue
}
peer.PresharedKey = []byte(values[i])
case dumpPeerEndpointIndex:
if values[i] == dumpNone {
continue
}
err = peer.parseEndpoint(values[i])
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing endpoint: %w", line, err)
}
case dumpPeerAllowedIPsIndex:
if values[i] == dumpNone {
continue
}
err = peer.parseAllowedIPs(values[i])
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing allowed-ips: %w", line, err)
}
case dumpPeerLatestHandshakeIndex:
if values[i] == "0" {
// Use go zero value, not unix 0 timestamp.
peer.LatestHandshake = time.Time{}
continue
}
sec, err = strconv.ParseInt(values[i], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing latest-handshake: %w", line, err)
}
peer.LatestHandshake = time.Unix(sec, 0)
case dumpPeerPersistentKeepaliveIndex:
if values[i] == dumpOff {
continue
}
pka, err = strconv.Atoi(values[i])
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing persistent-keepalive: %w", line, err)
}
peer.PersistentKeepalive = pka
}
}
c.Peers = append(c.Peers, peer)
peer = nil
}
line++
}
return &c, nil
}

View File

@@ -1,4 +1,4 @@
// Copyright 2019 the Kilo authors
// Copyright 2021 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -21,336 +21,431 @@ import (
"github.com/kylelemons/godebug/pretty"
)
func TestCompareConf(t *testing.T) {
for _, tc := range []struct {
func TestNewEndpoint(t *testing.T) {
for i, tc := range []struct {
name string
a []byte
b []byte
out bool
ip net.IP
port int
out *Endpoint
}{
{
name: "empty",
a: []byte{},
b: []byte{},
out: true,
name: "no ip, no port",
out: &Endpoint{
udpAddr: &net.UDPAddr{},
},
},
{
name: "key and value order",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
ListenPort = 51820
PrivateKey = private
[Peer]
PublicKey = key
AllowedIPs = 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32, 10.2.2.0/24
PresharedKey = psk
Endpoint = 10.1.0.2:51820
`),
out: true,
name: "only port",
ip: nil,
port: 99,
out: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 99,
},
},
},
{
name: "whitespace",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
PrivateKey=private
ListenPort=51820
[Peer]
Endpoint=10.1.0.2:51820
PresharedKey = psk
PublicKey=key
AllowedIPs=10.2.2.0/24,192.168.0.1/32,10.2.3.0/24,192.168.0.2/32,10.4.0.2/32
`),
out: true,
name: "only ipv4",
ip: net.ParseIP("10.0.0.0"),
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0").To4(),
},
},
},
{
name: "missing key",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
out: false,
name: "only ipv6",
ip: net.ParseIP("ff50::10"),
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff50::10").To16(),
},
},
},
{
name: "different value",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key2
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
out: false,
name: "ipv4",
ip: net.ParseIP("10.0.0.0"),
port: 1000,
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0").To4(),
Port: 1000,
},
},
},
{
name: "section order",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
[Interface]
PrivateKey = private
ListenPort = 51820
`),
out: true,
name: "ipv6",
ip: net.ParseIP("ff50::10"),
port: 1000,
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff50::10").To16(),
Port: 1000,
},
},
},
{
name: "out of order peers",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk2
PublicKey = key2
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk1
PublicKey = key1
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk1
PublicKey = key1
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk2
PublicKey = key2
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
out: true,
},
{
name: "one empty",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(``),
out: false,
name: "ipv6",
ip: net.ParseIP("fc00:f853:ccd:e793::3"),
port: 51820,
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("fc00:f853:ccd:e793::3").To16(),
Port: 51820,
},
},
},
} {
equal := Parse(tc.a).Equal(Parse(tc.b))
if equal != tc.out {
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
out := NewEndpoint(tc.ip, tc.port)
if diff := pretty.Compare(out, tc.out); diff != "" {
t.Errorf("%d %s: got diff:\n%s\n", i, tc.name, diff)
}
}
}
func TestCompareEndpoint(t *testing.T) {
for _, tc := range []struct {
name string
a *Endpoint
b *Endpoint
dnsFirst bool
out bool
}{
{
name: "both nil",
a: nil,
b: nil,
out: true,
},
{
name: "a nil",
a: nil,
b: &Endpoint{},
out: false,
},
{
name: "b nil",
a: &Endpoint{},
b: nil,
out: false,
},
{
name: "zero",
a: &Endpoint{},
b: &Endpoint{},
out: true,
},
{
name: "diff port",
a: &Endpoint{Port: 1234},
b: &Endpoint{Port: 5678},
out: false,
},
{
name: "same IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
out: true,
},
{
name: "diff IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2")}},
out: false,
},
{
name: "same IP ignore DNS",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "b"}},
out: true,
},
{
name: "no IP check DNS",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
out: false,
},
{
name: "no IP check DNS (same)",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
out: true,
},
{
name: "DNS first, ignore IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}},
dnsFirst: true,
out: true,
},
{
name: "DNS first",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
dnsFirst: true,
out: false,
},
{
name: "DNS first, no DNS compare IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}},
dnsFirst: true,
out: false,
},
{
name: "DNS first, no DNS compare IP (same)",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
dnsFirst: true,
out: true,
},
} {
equal := tc.a.Equal(tc.b, tc.dnsFirst)
if equal != tc.out {
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
}
}
}
func TestCompareDumpConf(t *testing.T) {
for _, tc := range []struct {
func TestParseEndpoint(t *testing.T) {
for i, tc := range []struct {
name string
d []byte
c []byte
str string
out *Endpoint
}{
{
name: "empty",
d: []byte{},
c: []byte{},
name: "no ip, no port",
},
{
name: "redacted copy from wg output",
d: []byte(`private B7qk8EMlob0nfado0ABM6HulUV607r4yqtBKjhap7S4= 51820 off
key1 (none) 10.254.1.1:51820 100.64.1.0/24,192.168.0.125/32,10.4.0.1/32 1619012801 67048 34952 10
key2 (none) 10.254.2.1:51820 100.64.4.0/24,10.69.76.55/32,100.64.3.0/24,10.66.25.131/32,10.4.0.2/32 1619013058 1134456 10077852 10`),
c: []byte(`[Interface]
ListenPort = 51820
PrivateKey = private
[Peer]
PublicKey = key1
AllowedIPs = 100.64.1.0/24, 192.168.0.125/32, 10.4.0.1/32
Endpoint = 10.254.1.1:51820
PersistentKeepalive = 10
[Peer]
PublicKey = key2
AllowedIPs = 100.64.4.0/24, 10.69.76.55/32, 100.64.3.0/24, 10.66.25.131/32, 10.4.0.2/32
Endpoint = 10.254.2.1:51820
PersistentKeepalive = 10`),
name: "only port",
str: ":1000",
},
{
name: "only ipv4",
str: "10.0.0.0",
},
{
name: "only ipv6",
str: "ff50::10",
},
{
name: "ipv4",
str: "10.0.0.0:1000",
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0").To4(),
Port: 1000,
},
},
},
{
name: "ipv6",
str: "[ff50::10]:1000",
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff50::10").To16(),
Port: 1000,
},
},
},
} {
dumpConf, _ := ParseDump(tc.d)
conf := Parse(tc.c)
// Equal will ignore runtime fields and only compare configuration fields.
if !dumpConf.Equal(conf) {
diff := pretty.Compare(dumpConf, conf)
t.Errorf("test case %q: got diff: %v", tc.name, diff)
out := ParseEndpoint(tc.str)
if diff := pretty.Compare(out, tc.out); diff != "" {
t.Errorf("ParseEndpoint %s(%d): got diff:\n%s\n", tc.name, i, diff)
}
}
}
func TestNewEndpointFromUDPAddr(t *testing.T) {
for i, tc := range []struct {
name string
u *net.UDPAddr
out *Endpoint
}{
{
name: "no ip, no port",
out: &Endpoint{
addr: "",
},
},
{
name: "only port",
u: &net.UDPAddr{
Port: 1000,
},
out: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
addr: "",
},
},
{
name: "only ipv4",
u: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0"),
},
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0").To4(),
},
addr: "",
},
},
{
name: "only ipv6",
u: &net.UDPAddr{
IP: net.ParseIP("ff60::10"),
},
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff60::10").To16(),
},
},
},
{
name: "ipv4",
u: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0"),
Port: 1000,
},
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0").To4(),
Port: 1000,
},
},
},
{
name: "ipv6",
u: &net.UDPAddr{
IP: net.ParseIP("ff50::10"),
Port: 1000,
},
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff50::10").To16(),
Port: 1000,
},
},
},
} {
out := NewEndpointFromUDPAddr(tc.u)
if diff := pretty.Compare(out, tc.out); diff != "" {
t.Errorf("ParseEndpoint %s(%d): got diff:\n%s\n", tc.name, i, diff)
}
}
}
func TestReady(t *testing.T) {
for i, tc := range []struct {
name string
in *Endpoint
r bool
}{
{
name: "nil",
r: false,
},
{
name: "no ip, no port",
in: &Endpoint{
addr: "",
udpAddr: &net.UDPAddr{},
},
r: false,
},
{
name: "only port",
in: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
},
r: false,
},
{
name: "only ipv4",
in: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0"),
},
},
r: false,
},
{
name: "only ipv6",
in: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff60::10"),
},
},
r: false,
},
{
name: "ipv4",
in: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0"),
Port: 1000,
},
},
r: true,
},
{
name: "ipv6",
in: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff50::10"),
Port: 1000,
},
},
r: true,
},
} {
if tc.r != tc.in.Ready() {
t.Errorf("Endpoint.Ready() %s(%d): expected=%v\tgot=%v\n", tc.name, i, tc.r, tc.in.Ready())
}
}
}
func TestEqual(t *testing.T) {
for i, tc := range []struct {
name string
a *Endpoint
b *Endpoint
df bool
r bool
}{
{
name: "nil dns last",
r: true,
},
{
name: "nil dns first",
df: true,
r: true,
},
{
name: "equal: only port",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
},
r: true,
},
{
name: "not equal: only port",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1001,
},
},
r: false,
},
{
name: "equal dns first",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
r: true,
},
{
name: "equal dns last",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "foo",
},
r: true,
},
{
name: "unequal dns first",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "foo",
},
df: true,
r: false,
},
{
name: "unequal dns last",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "foo",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("11.0.0.0"),
},
addr: "foo",
},
r: false,
},
{
name: "unequal dns last empty IP",
a: &Endpoint{
addr: "foo",
},
b: &Endpoint{
addr: "bar",
},
r: false,
},
{
name: "equal dns last empty IP",
a: &Endpoint{
addr: "foo",
},
b: &Endpoint{
addr: "foo",
},
r: true,
},
} {
if out := tc.a.Equal(tc.b, tc.df); out != tc.r {
t.Errorf("ParseEndpoint %s(%d): expected: %v\tgot: %v\n", tc.name, i, tc.r, out)
}
}
}

View File

@@ -18,9 +18,7 @@
package wireguard
import (
"bytes"
"fmt"
"os/exec"
"github.com/vishvananda/netlink"
)
@@ -65,74 +63,3 @@ func New(name string, mtu uint) (int, bool, error) {
}
return link.Attrs().Index, true, nil
}
// Keys generates a WireGuard private and public key-pair.
func Keys() ([]byte, []byte, error) {
private, err := GenKey()
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private key: %v", err)
}
public, err := PubKey(private)
return private, public, err
}
// GenKey generates a WireGuard private key.
func GenKey() ([]byte, error) {
key, err := exec.Command("wg", "genkey").Output()
return bytes.Trim(key, "\n"), err
}
// PubKey generates a WireGuard public key for a given private key.
func PubKey(key []byte) ([]byte, error) {
cmd := exec.Command("wg", "pubkey")
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("failed to open pipe to stdin: %v", err)
}
go func() {
defer stdin.Close()
stdin.Write(key)
}()
public, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("failed to generate public key: %v", err)
}
return bytes.Trim(public, "\n"), nil
}
// SetConf applies a WireGuard configuration file to the given interface.
func SetConf(iface string, path string) error {
cmd := exec.Command("wg", "setconf", iface, path)
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to apply the WireGuard configuration: %s", stderr.String())
}
return nil
}
// ShowConf gets the WireGuard configuration for the given interface.
func ShowConf(iface string) ([]byte, error) {
cmd := exec.Command("wg", "showconf", iface)
var stderr, stdout bytes.Buffer
cmd.Stderr = &stderr
cmd.Stdout = &stdout
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("failed to read the WireGuard configuration: %s", stderr.String())
}
return stdout.Bytes(), nil
}
// ShowDump gets the WireGuard configuration and runtime information for the given interface.
func ShowDump(iface string) ([]byte, error) {
cmd := exec.Command("wg", "show", iface, "dump")
var stderr, stdout bytes.Buffer
cmd.Stderr = &stderr
cmd.Stdout = &stdout
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("failed to read the WireGuard dump output: %s", stderr.String())
}
return stdout.Bytes(), nil
}