migrate to golang.zx2c4.com/wireguard/wgctrl (#239)

* migrate to golang.zx2c4.com/wireguard/wgctrl

This commit introduces the usage of wgctrl.
It avoids the usage of exec calls of the wg command
and parsing the output of `wg show`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* vendor wgctrl

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* apply suggestions from code review

Remove wireguard.Enpoint struct and use net.UDPAddr for the resolved
endpoint and addr string (dnsanme:port) if a DN was supplied.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/*: use wireguard.Enpoint

This commit introduces the wireguard.Enpoint struct.
It encapsulates a DN name with port and a net.UPDAddr.
The fields are private and only accessible over exported Methods
to avoid accidental modification.

Also iptables.GetProtocol is improved to avoid ipv4 rules being applied
by `ip6tables`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/wireguard/conf_test.go: add tests for Endpoint

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* cmd/kg/main.go: validate port range

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* add suggestions from review

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/mesh/mesh.go: use Equal func

Implement an Equal func for Enpoint and use it instead of comparing
strings.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* cmd/kgctl/main.go: check port range

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* vendor

Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
leonnicolas
2022-01-30 17:38:45 +01:00
committed by GitHub
parent 797133f272
commit 6a696e03e7
299 changed files with 26275 additions and 10252 deletions

View File

@@ -0,0 +1,99 @@
package wguser
import (
"fmt"
"net"
"os"
"path/filepath"
"strings"
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var _ wginternal.Client = &Client{}
// A Client provides access to userspace WireGuard device information.
type Client struct {
dial func(device string) (net.Conn, error)
find func() ([]string, error)
}
// New creates a new Client.
func New() (*Client, error) {
return &Client{
// Operating system-specific functions which can identify and connect
// to userspace WireGuard devices. These functions can also be
// overridden for tests.
dial: dial,
find: find,
}, nil
}
// Close implements wginternal.Client.
func (c *Client) Close() error { return nil }
// Devices implements wginternal.Client.
func (c *Client) Devices() ([]*wgtypes.Device, error) {
devices, err := c.find()
if err != nil {
return nil, err
}
var wgds []*wgtypes.Device
for _, d := range devices {
wgd, err := c.getDevice(d)
if err != nil {
return nil, err
}
wgds = append(wgds, wgd)
}
return wgds, nil
}
// Device implements wginternal.Client.
func (c *Client) Device(name string) (*wgtypes.Device, error) {
devices, err := c.find()
if err != nil {
return nil, err
}
for _, d := range devices {
if name != deviceName(d) {
continue
}
return c.getDevice(d)
}
return nil, os.ErrNotExist
}
// ConfigureDevice implements wginternal.Client.
func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error {
devices, err := c.find()
if err != nil {
return err
}
for _, d := range devices {
if name != deviceName(d) {
continue
}
return c.configureDevice(d, cfg)
}
return os.ErrNotExist
}
// deviceName infers a device name from an absolute file path with extension.
func deviceName(sock string) string {
return strings.TrimSuffix(filepath.Base(sock), filepath.Ext(sock))
}
func panicf(format string, a ...interface{}) {
panic(fmt.Sprintf(format, a...))
}

View File

@@ -0,0 +1,106 @@
package wguser
import (
"bytes"
"encoding/hex"
"fmt"
"io"
"os"
"strings"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// configureDevice configures a device specified by its path.
func (c *Client) configureDevice(device string, cfg wgtypes.Config) error {
conn, err := c.dial(device)
if err != nil {
return err
}
defer conn.Close()
// Start with set command.
var buf bytes.Buffer
buf.WriteString("set=1\n")
// Add any necessary configuration from cfg, then finish with an empty line.
writeConfig(&buf, cfg)
buf.WriteString("\n")
// Apply configuration for the device and then check the error number.
if _, err := io.Copy(conn, &buf); err != nil {
return err
}
res := make([]byte, 32)
n, err := conn.Read(res)
if err != nil {
return err
}
// errno=0 indicates success, anything else returns an error number that
// matches definitions from errno.h.
str := strings.TrimSpace(string(res[:n]))
if str != "errno=0" {
// TODO(mdlayher): return actual errno on Linux?
return os.NewSyscallError("read", fmt.Errorf("wguser: %s", str))
}
return nil
}
// writeConfig writes textual configuration to w as specified by cfg.
func writeConfig(w io.Writer, cfg wgtypes.Config) {
if cfg.PrivateKey != nil {
fmt.Fprintf(w, "private_key=%s\n", hexKey(*cfg.PrivateKey))
}
if cfg.ListenPort != nil {
fmt.Fprintf(w, "listen_port=%d\n", *cfg.ListenPort)
}
if cfg.FirewallMark != nil {
fmt.Fprintf(w, "fwmark=%d\n", *cfg.FirewallMark)
}
if cfg.ReplacePeers {
fmt.Fprintln(w, "replace_peers=true")
}
for _, p := range cfg.Peers {
fmt.Fprintf(w, "public_key=%s\n", hexKey(p.PublicKey))
if p.Remove {
fmt.Fprintln(w, "remove=true")
}
if p.UpdateOnly {
fmt.Fprintln(w, "update_only=true")
}
if p.PresharedKey != nil {
fmt.Fprintf(w, "preshared_key=%s\n", hexKey(*p.PresharedKey))
}
if p.Endpoint != nil {
fmt.Fprintf(w, "endpoint=%s\n", p.Endpoint.String())
}
if p.PersistentKeepaliveInterval != nil {
fmt.Fprintf(w, "persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds()))
}
if p.ReplaceAllowedIPs {
fmt.Fprintln(w, "replace_allowed_ips=true")
}
for _, ip := range p.AllowedIPs {
fmt.Fprintf(w, "allowed_ip=%s\n", ip.String())
}
}
}
// hexKey encodes a wgtypes.Key into a hexadecimal string.
func hexKey(k wgtypes.Key) string {
return hex.EncodeToString(k[:])
}

View File

@@ -0,0 +1,51 @@
//go:build !windows
// +build !windows
package wguser
import (
"errors"
"io/ioutil"
"net"
"os"
"path/filepath"
)
// dial is the default implementation of Client.dial.
func dial(device string) (net.Conn, error) {
return net.Dial("unix", device)
}
// find is the default implementation of Client.find.
func find() ([]string, error) {
return findUNIXSockets([]string{
// It seems that /var/run is a common location between Linux and the
// BSDs, even though it's a symlink on Linux.
"/var/run/wireguard",
})
}
// findUNIXSockets looks for UNIX socket files in the specified directories.
func findUNIXSockets(dirs []string) ([]string, error) {
var socks []string
for _, d := range dirs {
files, err := ioutil.ReadDir(d)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
continue
}
return nil, err
}
for _, f := range files {
if f.Mode()&os.ModeSocket == 0 {
continue
}
socks = append(socks, filepath.Join(d, f.Name()))
}
}
return socks, nil
}

View File

@@ -0,0 +1,82 @@
//go:build windows
// +build windows
package wguser
import (
"net"
"strings"
"time"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/namedpipe"
)
// Expected prefixes when dealing with named pipes.
const (
pipePrefix = `\\.\pipe\`
wgPrefix = `ProtectedPrefix\Administrators\WireGuard\`
)
// dial is the default implementation of Client.dial.
func dial(device string) (net.Conn, error) {
localSystem, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
if err != nil {
return nil, err
}
return (&namedpipe.DialConfig{
ExpectedOwner: localSystem,
}).DialTimeout(device, time.Duration(0))
}
// find is the default implementation of Client.find.
func find() ([]string, error) {
return findNamedPipes(wgPrefix)
}
// findNamedPipes looks for Windows named pipes that match the specified
// search string prefix.
func findNamedPipes(search string) ([]string, error) {
var (
pipes []string
data windows.Win32finddata
)
// Thanks @zx2c4 for the tips on the appropriate Windows APIs here:
// https://א.cc/dHGpnhxX/c.
h, err := windows.FindFirstFile(
// Append * to find all named pipes.
windows.StringToUTF16Ptr(pipePrefix+"*"),
&data,
)
if err != nil {
return nil, err
}
// FindClose is used to close file search handles instead of the typical
// CloseHandle used elsewhere, see:
// https://docs.microsoft.com/en-us/windows/desktop/api/fileapi/nf-fileapi-findclose.
defer windows.FindClose(h)
// Check the first file's name for a match, but also keep searching for
// WireGuard named pipes until no more files can be iterated.
for {
name := windows.UTF16ToString(data.FileName[:])
if strings.HasPrefix(name, search) {
// Concatenate strings directly as filepath.Join appears to break the
// named pipe prefix convention.
pipes = append(pipes, pipePrefix+name)
}
if err := windows.FindNextFile(h, &data); err != nil {
if err == windows.ERROR_NO_MORE_FILES {
break
}
return nil, err
}
}
return pipes, nil
}

View File

@@ -0,0 +1,6 @@
// Package wguser provides internal access to the userspace WireGuard
// configuration protocol interface.
//
// This package is internal-only and not meant for end users to consume.
// Please use package wgctrl (an abstraction over this package) instead.
package wguser

View File

@@ -0,0 +1,258 @@
package wguser
import (
"bufio"
"bytes"
"encoding/hex"
"fmt"
"io"
"net"
"os"
"strconv"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// The WireGuard userspace configuration protocol is described here:
// https://www.wireguard.com/xplatform/#cross-platform-userspace-implementation.
// getDevice gathers device information from a device specified by its path
// and returns a Device.
func (c *Client) getDevice(device string) (*wgtypes.Device, error) {
conn, err := c.dial(device)
if err != nil {
return nil, err
}
defer conn.Close()
// Get information about this device.
if _, err := io.WriteString(conn, "get=1\n\n"); err != nil {
return nil, err
}
// Parse the device from the incoming data stream.
d, err := parseDevice(conn)
if err != nil {
return nil, err
}
// TODO(mdlayher): populate interface index too?
d.Name = deviceName(device)
d.Type = wgtypes.Userspace
return d, nil
}
// parseDevice parses a Device and its Peers from an io.Reader.
func parseDevice(r io.Reader) (*wgtypes.Device, error) {
var dp deviceParser
s := bufio.NewScanner(r)
for s.Scan() {
b := s.Bytes()
if len(b) == 0 {
// Empty line, done parsing.
break
}
// All data is in key=value format.
kvs := bytes.Split(b, []byte("="))
if len(kvs) != 2 {
return nil, fmt.Errorf("wguser: invalid key=value pair: %q", string(b))
}
dp.Parse(string(kvs[0]), string(kvs[1]))
}
if err := s.Err(); err != nil {
return nil, err
}
return dp.Device()
}
// A deviceParser accumulates information about a Device and its Peers.
type deviceParser struct {
d wgtypes.Device
err error
parsePeers bool
peers int
hsSec, hsNano int
}
// Device returns a Device or any errors that were encountered while parsing
// a Device.
func (dp *deviceParser) Device() (*wgtypes.Device, error) {
if dp.err != nil {
return nil, dp.err
}
// Compute remaining fields of the Device now that all parsing is done.
dp.d.PublicKey = dp.d.PrivateKey.PublicKey()
return &dp.d, nil
}
// Parse parses a single key/value pair into fields of a Device.
func (dp *deviceParser) Parse(key, value string) {
switch key {
case "errno":
// 0 indicates success, anything else returns an error number that matches
// definitions from errno.h.
if errno := dp.parseInt(value); errno != 0 {
// TODO(mdlayher): return actual errno on Linux?
dp.err = os.NewSyscallError("read", fmt.Errorf("wguser: errno=%d", errno))
return
}
case "public_key":
// We've either found the first peer or the next peer. Stop parsing
// Device fields and start parsing Peer fields, including the public
// key indicated here.
dp.parsePeers = true
dp.peers++
dp.d.Peers = append(dp.d.Peers, wgtypes.Peer{
PublicKey: dp.parseKey(value),
})
return
}
// Are we parsing peer fields?
if dp.parsePeers {
dp.peerParse(key, value)
return
}
// Device field parsing.
switch key {
case "private_key":
dp.d.PrivateKey = dp.parseKey(value)
case "listen_port":
dp.d.ListenPort = dp.parseInt(value)
case "fwmark":
dp.d.FirewallMark = dp.parseInt(value)
}
}
// curPeer returns the current Peer being parsed so its fields can be populated.
func (dp *deviceParser) curPeer() *wgtypes.Peer {
return &dp.d.Peers[dp.peers-1]
}
// peerParse parses a key/value field into the current Peer.
func (dp *deviceParser) peerParse(key, value string) {
p := dp.curPeer()
switch key {
case "preshared_key":
p.PresharedKey = dp.parseKey(value)
case "endpoint":
p.Endpoint = dp.parseAddr(value)
case "last_handshake_time_sec":
dp.hsSec = dp.parseInt(value)
case "last_handshake_time_nsec":
dp.hsNano = dp.parseInt(value)
// Assume that we've seen both seconds and nanoseconds and populate this
// field now. However, if both fields were set to 0, assume we have never
// had a successful handshake with this peer, and return a zero-value
// time.Time to our callers.
if dp.hsSec > 0 && dp.hsNano > 0 {
p.LastHandshakeTime = time.Unix(int64(dp.hsSec), int64(dp.hsNano))
}
case "tx_bytes":
p.TransmitBytes = dp.parseInt64(value)
case "rx_bytes":
p.ReceiveBytes = dp.parseInt64(value)
case "persistent_keepalive_interval":
p.PersistentKeepaliveInterval = time.Duration(dp.parseInt(value)) * time.Second
case "allowed_ip":
cidr := dp.parseCIDR(value)
if cidr != nil {
p.AllowedIPs = append(p.AllowedIPs, *cidr)
}
case "protocol_version":
p.ProtocolVersion = dp.parseInt(value)
}
}
// parseKey parses a Key from a hex string.
func (dp *deviceParser) parseKey(s string) wgtypes.Key {
if dp.err != nil {
return wgtypes.Key{}
}
b, err := hex.DecodeString(s)
if err != nil {
dp.err = err
return wgtypes.Key{}
}
key, err := wgtypes.NewKey(b)
if err != nil {
dp.err = err
return wgtypes.Key{}
}
return key
}
// parseInt parses an integer from a string.
func (dp *deviceParser) parseInt(s string) int {
if dp.err != nil {
return 0
}
v, err := strconv.Atoi(s)
if err != nil {
dp.err = err
return 0
}
return v
}
// parseInt64 parses an int64 from a string.
func (dp *deviceParser) parseInt64(s string) int64 {
if dp.err != nil {
return 0
}
v, err := strconv.ParseInt(s, 10, 64)
if err != nil {
dp.err = err
return 0
}
return v
}
// parseAddr parses a UDP address from a string.
func (dp *deviceParser) parseAddr(s string) *net.UDPAddr {
if dp.err != nil {
return nil
}
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
dp.err = err
return nil
}
return addr
}
// parseInt parses an address CIDR from a string.
func (dp *deviceParser) parseCIDR(s string) *net.IPNet {
if dp.err != nil {
return nil
}
_, cidr, err := net.ParseCIDR(s)
if err != nil {
dp.err = err
return nil
}
return cidr
}