Use LatestHandshake to validate endpoint (#149)
* wireguard: `wg show iface dump` reader and parser * mesh: use LatestHandshake to validate NAT Endpoints * add skip on error * switch to loop parsing So the stop on error pattern can be used * Add error handling to ParseDump
This commit is contained in:
parent
0733c83a0a
commit
e12b5029d7
@ -454,13 +454,18 @@ func (m *Mesh) applyTopology() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Find the old configuration.
|
// Find the old configuration.
|
||||||
oldConfRaw, err := wireguard.ShowConf(link.Attrs().Name)
|
oldConfDump, err := wireguard.ShowDump(link.Attrs().Name)
|
||||||
|
if err != nil {
|
||||||
|
level.Error(m.logger).Log("error", err)
|
||||||
|
m.errorCounter.WithLabelValues("apply").Inc()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
oldConf, err := wireguard.ParseDump(oldConfDump)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
level.Error(m.logger).Log("error", err)
|
level.Error(m.logger).Log("error", err)
|
||||||
m.errorCounter.WithLabelValues("apply").Inc()
|
m.errorCounter.WithLabelValues("apply").Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
oldConf := wireguard.Parse(oldConfRaw)
|
|
||||||
natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
|
natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
|
||||||
nodes[m.hostname].DiscoveredEndpoints = natEndpoints
|
nodes[m.hostname].DiscoveredEndpoints = natEndpoints
|
||||||
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
|
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
|
||||||
@ -782,17 +787,15 @@ func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *
|
|||||||
}
|
}
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
|
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
|
||||||
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false))
|
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false), "latest-handshake", peer.LatestHandshake)
|
||||||
// Should check location leader but only available in topology ... or have topology handle that list
|
if (peer.LatestHandshake != time.Time{}) {
|
||||||
// Better check wg latest-handshake
|
|
||||||
if !n.Endpoint.Equal(peer.Endpoint, false) {
|
|
||||||
natEndpoints[string(n.Key)] = peer.Endpoint
|
natEndpoints[string(n.Key)] = peer.Endpoint
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, p := range peers {
|
for _, p := range peers {
|
||||||
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
|
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
|
||||||
if !p.Endpoint.Equal(peer.Endpoint, false) {
|
if (peer.LatestHandshake != time.Time{}) {
|
||||||
natEndpoints[string(p.PublicKey)] = peer.Endpoint
|
natEndpoints[string(p.PublicKey)] = peer.Endpoint
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,11 +17,13 @@ package wireguard
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"k8s.io/apimachinery/pkg/util/validation"
|
"k8s.io/apimachinery/pkg/util/validation"
|
||||||
)
|
)
|
||||||
@ -31,6 +33,9 @@ type key string
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
separator = "="
|
separator = "="
|
||||||
|
dumpSeparator = "\t"
|
||||||
|
dumpNone = "(none)"
|
||||||
|
dumpOff = "off"
|
||||||
interfaceSection section = "Interface"
|
interfaceSection section = "Interface"
|
||||||
peerSection section = "Peer"
|
peerSection section = "Peer"
|
||||||
listenPortKey key = "ListenPort"
|
listenPortKey key = "ListenPort"
|
||||||
@ -42,6 +47,30 @@ const (
|
|||||||
publicKeyKey key = "PublicKey"
|
publicKeyKey key = "PublicKey"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type dumpInterfaceIndex int
|
||||||
|
|
||||||
|
const (
|
||||||
|
dumpInterfacePrivateKeyIndex = iota
|
||||||
|
dumpInterfacePublicKeyIndex
|
||||||
|
dumpInterfaceListenPortIndex
|
||||||
|
dumpInterfaceFWMarkIndex
|
||||||
|
dumpInterfaceLen
|
||||||
|
)
|
||||||
|
|
||||||
|
type dumpPeerIndex int
|
||||||
|
|
||||||
|
const (
|
||||||
|
dumpPeerPublicKeyIndex = iota
|
||||||
|
dumpPeerPresharedKeyIndex
|
||||||
|
dumpPeerEndpointIndex
|
||||||
|
dumpPeerAllowedIPsIndex
|
||||||
|
dumpPeerLatestHandshakeIndex
|
||||||
|
dumpPeerTransferRXIndex
|
||||||
|
dumpPeerTransferTXIndex
|
||||||
|
dumpPeerPersistentKeepaliveIndex
|
||||||
|
dumpPeerLen
|
||||||
|
)
|
||||||
|
|
||||||
// Conf represents a WireGuard configuration file.
|
// Conf represents a WireGuard configuration file.
|
||||||
type Conf struct {
|
type Conf struct {
|
||||||
Interface *Interface
|
Interface *Interface
|
||||||
@ -61,6 +90,8 @@ type Peer struct {
|
|||||||
PersistentKeepalive int
|
PersistentKeepalive int
|
||||||
PresharedKey []byte
|
PresharedKey []byte
|
||||||
PublicKey []byte
|
PublicKey []byte
|
||||||
|
// The following fields are part of the runtime information, not the configuration.
|
||||||
|
LatestHandshake time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeduplicateIPs eliminates duplicate allowed IPs.
|
// DeduplicateIPs eliminates duplicate allowed IPs.
|
||||||
@ -146,13 +177,11 @@ func (d DNSOrIP) String() string {
|
|||||||
func Parse(buf []byte) *Conf {
|
func Parse(buf []byte) *Conf {
|
||||||
var (
|
var (
|
||||||
active section
|
active section
|
||||||
ai *net.IPNet
|
|
||||||
kv []string
|
kv []string
|
||||||
c Conf
|
c Conf
|
||||||
err error
|
err error
|
||||||
iface *Interface
|
iface *Interface
|
||||||
i int
|
i int
|
||||||
ip, ip4 net.IP
|
|
||||||
k key
|
k key
|
||||||
line, v string
|
line, v string
|
||||||
peer *Peer
|
peer *Peer
|
||||||
@ -205,48 +234,14 @@ func Parse(buf []byte) *Conf {
|
|||||||
case peerSection:
|
case peerSection:
|
||||||
switch k {
|
switch k {
|
||||||
case allowedIPsKey:
|
case allowedIPsKey:
|
||||||
// Reuse string slice.
|
err = peer.parseAllowedIPs(v)
|
||||||
kv = strings.Split(v, ",")
|
|
||||||
for i = range kv {
|
|
||||||
ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i]))
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ip4 = ip.To4(); ip4 != nil {
|
|
||||||
ip = ip4
|
|
||||||
} else {
|
|
||||||
ip = ip.To16()
|
|
||||||
}
|
|
||||||
ai.IP = ip
|
|
||||||
peer.AllowedIPs = append(peer.AllowedIPs, ai)
|
|
||||||
}
|
|
||||||
case endpointKey:
|
|
||||||
// Reuse string slice.
|
|
||||||
kv = strings.Split(v, ":")
|
|
||||||
if len(kv) < 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
d := DNSOrIP{}
|
case endpointKey:
|
||||||
ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]"))
|
err = peer.parseEndpoint(v)
|
||||||
if ip == nil {
|
if err != nil {
|
||||||
if len(validation.IsDNS1123Subdomain(kv[0])) != 0 {
|
continue
|
||||||
continue
|
|
||||||
}
|
|
||||||
d.DNS = kv[0]
|
|
||||||
} else {
|
|
||||||
if ip4 = ip.To4(); ip4 != nil {
|
|
||||||
d.IP = ip4
|
|
||||||
} else {
|
|
||||||
d.IP = ip.To16()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
peer.Endpoint = &Endpoint{
|
|
||||||
DNSOrIP: d,
|
|
||||||
Port: uint32(port),
|
|
||||||
}
|
}
|
||||||
case persistentKeepaliveKey:
|
case persistentKeepaliveKey:
|
||||||
i, err = strconv.Atoi(v)
|
i, err = strconv.Atoi(v)
|
||||||
@ -448,3 +443,177 @@ func writeKey(buf *bytes.Buffer, k key) error {
|
|||||||
_, err = buf.WriteString(" = ")
|
_, err = buf.WriteString(" = ")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errParseEndpoint = errors.New("could not parse Endpoint")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Peer) parseEndpoint(v string) error {
|
||||||
|
var (
|
||||||
|
kv []string
|
||||||
|
err error
|
||||||
|
ip, ip4 net.IP
|
||||||
|
port uint64
|
||||||
|
)
|
||||||
|
kv = strings.Split(v, ":")
|
||||||
|
if len(kv) < 2 {
|
||||||
|
return errParseEndpoint
|
||||||
|
}
|
||||||
|
port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
d := DNSOrIP{}
|
||||||
|
ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]"))
|
||||||
|
if ip == nil {
|
||||||
|
if len(validation.IsDNS1123Subdomain(kv[0])) != 0 {
|
||||||
|
return errParseEndpoint
|
||||||
|
}
|
||||||
|
d.DNS = kv[0]
|
||||||
|
} else {
|
||||||
|
if ip4 = ip.To4(); ip4 != nil {
|
||||||
|
d.IP = ip4
|
||||||
|
} else {
|
||||||
|
d.IP = ip.To16()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Endpoint = &Endpoint{
|
||||||
|
DNSOrIP: d,
|
||||||
|
Port: uint32(port),
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) parseAllowedIPs(v string) error {
|
||||||
|
var (
|
||||||
|
ai *net.IPNet
|
||||||
|
kv []string
|
||||||
|
err error
|
||||||
|
i int
|
||||||
|
ip, ip4 net.IP
|
||||||
|
)
|
||||||
|
|
||||||
|
kv = strings.Split(v, ",")
|
||||||
|
for i = range kv {
|
||||||
|
ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i]))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ip4 = ip.To4(); ip4 != nil {
|
||||||
|
ip = ip4
|
||||||
|
} else {
|
||||||
|
ip = ip.To16()
|
||||||
|
}
|
||||||
|
ai.IP = ip
|
||||||
|
p.AllowedIPs = append(p.AllowedIPs, ai)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseDump parses a given WireGuard dump and produces a Conf struct.
|
||||||
|
func ParseDump(buf []byte) (*Conf, error) {
|
||||||
|
// from man wg, show section:
|
||||||
|
// If dump is specified, then several lines are printed;
|
||||||
|
// the first contains in order separated by tab: private-key, public-key, listen-port, fw‐mark.
|
||||||
|
// Subsequent lines are printed for each peer and contain in order separated by tab:
|
||||||
|
// public-key, preshared-key, endpoint, allowed-ips, latest-handshake, transfer-rx, transfer-tx, persistent-keepalive.
|
||||||
|
var (
|
||||||
|
active section
|
||||||
|
values []string
|
||||||
|
c Conf
|
||||||
|
err error
|
||||||
|
iface *Interface
|
||||||
|
peer *Peer
|
||||||
|
port uint64
|
||||||
|
sec int64
|
||||||
|
pka int
|
||||||
|
line int
|
||||||
|
)
|
||||||
|
// First line is Interface
|
||||||
|
active = interfaceSection
|
||||||
|
s := bufio.NewScanner(bytes.NewBuffer(buf))
|
||||||
|
for s.Scan() {
|
||||||
|
values = strings.Split(s.Text(), dumpSeparator)
|
||||||
|
|
||||||
|
switch active {
|
||||||
|
case interfaceSection:
|
||||||
|
if len(values) < dumpInterfaceLen {
|
||||||
|
return nil, fmt.Errorf("invalid interface line: missing fields (%d < %d)", len(values), dumpInterfaceLen)
|
||||||
|
}
|
||||||
|
iface = new(Interface)
|
||||||
|
for i := range values {
|
||||||
|
switch i {
|
||||||
|
case dumpInterfacePrivateKeyIndex:
|
||||||
|
iface.PrivateKey = []byte(values[i])
|
||||||
|
case dumpInterfaceListenPortIndex:
|
||||||
|
port, err = strconv.ParseUint(values[i], 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid interface line: error parsing listen-port: %w", err)
|
||||||
|
}
|
||||||
|
iface.ListenPort = uint32(port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Interface = iface
|
||||||
|
// Next lines are Peers
|
||||||
|
active = peerSection
|
||||||
|
case peerSection:
|
||||||
|
if len(values) < dumpPeerLen {
|
||||||
|
return nil, fmt.Errorf("invalid peer line %d: missing fields (%d < %d)", line, len(values), dumpPeerLen)
|
||||||
|
}
|
||||||
|
peer = new(Peer)
|
||||||
|
|
||||||
|
for i := range values {
|
||||||
|
switch i {
|
||||||
|
case dumpPeerPublicKeyIndex:
|
||||||
|
peer.PublicKey = []byte(values[i])
|
||||||
|
case dumpPeerPresharedKeyIndex:
|
||||||
|
if values[i] == dumpNone {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peer.PresharedKey = []byte(values[i])
|
||||||
|
case dumpPeerEndpointIndex:
|
||||||
|
if values[i] == dumpNone {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = peer.parseEndpoint(values[i])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid peer line %d: error parsing endpoint: %w", line, err)
|
||||||
|
}
|
||||||
|
case dumpPeerAllowedIPsIndex:
|
||||||
|
if values[i] == dumpNone {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = peer.parseAllowedIPs(values[i])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid peer line %d: error parsing allowed-ips: %w", line, err)
|
||||||
|
}
|
||||||
|
case dumpPeerLatestHandshakeIndex:
|
||||||
|
if values[i] == "0" {
|
||||||
|
// Use go zero value, not unix 0 timestamp.
|
||||||
|
peer.LatestHandshake = time.Time{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sec, err = strconv.ParseInt(values[i], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid peer line %d: error parsing latest-handshake: %w", line, err)
|
||||||
|
}
|
||||||
|
peer.LatestHandshake = time.Unix(sec, 0)
|
||||||
|
case dumpPeerPersistentKeepaliveIndex:
|
||||||
|
if values[i] == dumpOff {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pka, err = strconv.Atoi(values[i])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid peer line %d: error parsing persistent-keepalive: %w", line, err)
|
||||||
|
}
|
||||||
|
peer.PersistentKeepalive = pka
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Peers = append(c.Peers, peer)
|
||||||
|
peer = nil
|
||||||
|
}
|
||||||
|
line++
|
||||||
|
}
|
||||||
|
return &c, nil
|
||||||
|
}
|
||||||
|
@ -17,6 +17,8 @@ package wireguard
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/kylelemons/godebug/pretty"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCompareConf(t *testing.T) {
|
func TestCompareConf(t *testing.T) {
|
||||||
@ -308,3 +310,47 @@ func TestCompareEndpoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompareDumpConf(t *testing.T) {
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
d []byte
|
||||||
|
c []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
d: []byte{},
|
||||||
|
c: []byte{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "redacted copy from wg output",
|
||||||
|
d: []byte(`private B7qk8EMlob0nfado0ABM6HulUV607r4yqtBKjhap7S4= 51820 off
|
||||||
|
key1 (none) 10.254.1.1:51820 100.64.1.0/24,192.168.0.125/32,10.4.0.1/32 1619012801 67048 34952 10
|
||||||
|
key2 (none) 10.254.2.1:51820 100.64.4.0/24,10.69.76.55/32,100.64.3.0/24,10.66.25.131/32,10.4.0.2/32 1619013058 1134456 10077852 10`),
|
||||||
|
c: []byte(`[Interface]
|
||||||
|
ListenPort = 51820
|
||||||
|
PrivateKey = private
|
||||||
|
|
||||||
|
[Peer]
|
||||||
|
PublicKey = key1
|
||||||
|
AllowedIPs = 100.64.1.0/24, 192.168.0.125/32, 10.4.0.1/32
|
||||||
|
Endpoint = 10.254.1.1:51820
|
||||||
|
PersistentKeepalive = 10
|
||||||
|
|
||||||
|
[Peer]
|
||||||
|
PublicKey = key2
|
||||||
|
AllowedIPs = 100.64.4.0/24, 10.69.76.55/32, 100.64.3.0/24, 10.66.25.131/32, 10.4.0.2/32
|
||||||
|
Endpoint = 10.254.2.1:51820
|
||||||
|
PersistentKeepalive = 10`),
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
|
||||||
|
dumpConf, _ := ParseDump(tc.d)
|
||||||
|
conf := Parse(tc.c)
|
||||||
|
// Equal will ignore runtime fields and only compare configuration fields.
|
||||||
|
if !dumpConf.Equal(conf) {
|
||||||
|
diff := pretty.Compare(dumpConf, conf)
|
||||||
|
t.Errorf("test case %q: got diff: %v", tc.name, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -119,3 +119,15 @@ func ShowConf(iface string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
return stdout.Bytes(), nil
|
return stdout.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ShowDump gets the WireGuard configuration and runtime information for the given interface.
|
||||||
|
func ShowDump(iface string) ([]byte, error) {
|
||||||
|
cmd := exec.Command("wg", "show", iface, "dump")
|
||||||
|
var stderr, stdout bytes.Buffer
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read the WireGuard dump output: %s", stderr.String())
|
||||||
|
}
|
||||||
|
return stdout.Bytes(), nil
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user