migrate to golang.zx2c4.com/wireguard/wgctrl (#239)

* migrate to golang.zx2c4.com/wireguard/wgctrl

This commit introduces the usage of wgctrl.
It avoids the usage of exec calls of the wg command
and parsing the output of `wg show`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* vendor wgctrl

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* apply suggestions from code review

Remove wireguard.Enpoint struct and use net.UDPAddr for the resolved
endpoint and addr string (dnsanme:port) if a DN was supplied.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/*: use wireguard.Enpoint

This commit introduces the wireguard.Enpoint struct.
It encapsulates a DN name with port and a net.UPDAddr.
The fields are private and only accessible over exported Methods
to avoid accidental modification.

Also iptables.GetProtocol is improved to avoid ipv4 rules being applied
by `ip6tables`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/wireguard/conf_test.go: add tests for Endpoint

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* cmd/kg/main.go: validate port range

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* add suggestions from review

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/mesh/mesh.go: use Equal func

Implement an Equal func for Enpoint and use it instead of comparing
strings.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* cmd/kgctl/main.go: check port range

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* vendor

Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
leonnicolas
2022-01-30 17:38:45 +01:00
committed by GitHub
parent 797133f272
commit 6a696e03e7
299 changed files with 26275 additions and 10252 deletions

View File

@@ -0,0 +1,265 @@
//go:build linux
// +build linux
package wglinux
import (
"errors"
"fmt"
"os"
"syscall"
"github.com/mdlayher/genetlink"
"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nlenc"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
"golang.zx2c4.com/wireguard/wgctrl/internal/wglinux/internal/wgh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var _ wginternal.Client = &Client{}
// A Client provides access to Linux WireGuard netlink information.
type Client struct {
c *genetlink.Conn
family genetlink.Family
interfaces func() ([]string, error)
}
// New creates a new Client and returns whether or not the generic netlink
// interface is available.
func New() (*Client, bool, error) {
c, err := genetlink.Dial(nil)
if err != nil {
return nil, false, err
}
return initClient(c)
}
// initClient is the internal Client constructor used in some tests.
func initClient(c *genetlink.Conn) (*Client, bool, error) {
f, err := c.GetFamily(wgh.GenlName)
if err != nil {
_ = c.Close()
if errors.Is(err, os.ErrNotExist) {
// The generic netlink interface is not available.
return nil, false, nil
}
return nil, false, err
}
return &Client{
c: c,
family: f,
// By default, gather only WireGuard interfaces using rtnetlink.
interfaces: rtnlInterfaces,
}, true, nil
}
// Close implements wginternal.Client.
func (c *Client) Close() error {
return c.c.Close()
}
// Devices implements wginternal.Client.
func (c *Client) Devices() ([]*wgtypes.Device, error) {
// By default, rtnetlink is used to fetch a list of all interfaces and then
// filter that list to only find WireGuard interfaces.
//
// The remainder of this function assumes that any returned device from this
// function is a valid WireGuard device.
ifis, err := c.interfaces()
if err != nil {
return nil, err
}
ds := make([]*wgtypes.Device, 0, len(ifis))
for _, ifi := range ifis {
d, err := c.Device(ifi)
if err != nil {
return nil, err
}
ds = append(ds, d)
}
return ds, nil
}
// Device implements wginternal.Client.
func (c *Client) Device(name string) (*wgtypes.Device, error) {
// Don't bother querying netlink with empty input.
if name == "" {
return nil, os.ErrNotExist
}
// Fetching a device by interface index is possible as well, but we only
// support fetching by name as it seems to be more convenient in general.
b, err := netlink.MarshalAttributes([]netlink.Attribute{{
Type: wgh.DeviceAIfname,
Data: nlenc.Bytes(name),
}})
if err != nil {
return nil, err
}
msgs, err := c.execute(wgh.CmdGetDevice, netlink.Request|netlink.Dump, b)
if err != nil {
return nil, err
}
return parseDevice(msgs)
}
// ConfigureDevice implements wginternal.Client.
func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error {
// Large configurations are split into batches for use with netlink.
for _, b := range buildBatches(cfg) {
attrs, err := configAttrs(name, b)
if err != nil {
return err
}
// Request acknowledgement of our request from netlink, even though the
// output messages are unused. The netlink package checks and trims the
// status code value.
if _, err := c.execute(wgh.CmdSetDevice, netlink.Request|netlink.Acknowledge, attrs); err != nil {
return err
}
}
return nil
}
// execute executes a single WireGuard netlink request with the specified command,
// header flags, and attribute arguments.
func (c *Client) execute(command uint8, flags netlink.HeaderFlags, attrb []byte) ([]genetlink.Message, error) {
msg := genetlink.Message{
Header: genetlink.Header{
Command: command,
Version: wgh.GenlVersion,
},
Data: attrb,
}
msgs, err := c.c.Execute(msg, c.family.ID, flags)
if err == nil {
return msgs, nil
}
// We don't want to expose netlink errors directly to callers so unpack to
// something more generic.
oerr, ok := err.(*netlink.OpError)
if !ok {
// Expect all errors to conform to netlink.OpError.
return nil, fmt.Errorf("wglinux: netlink operation returned non-netlink error (please file a bug: https://golang.zx2c4.com/wireguard/wgctrl): %v", err)
}
switch oerr.Err {
// Convert "no such device" and "not a wireguard device" to an error
// compatible with os.ErrNotExist for easy checking.
case unix.ENODEV, unix.ENOTSUP:
return nil, os.ErrNotExist
default:
// Expose the inner error directly (such as EPERM).
return nil, oerr.Err
}
}
// rtnlInterfaces uses rtnetlink to fetch a list of WireGuard interfaces.
func rtnlInterfaces() ([]string, error) {
// Use the stdlib's rtnetlink helpers to get ahold of a table of all
// interfaces, so we can begin filtering it down to just WireGuard devices.
tab, err := syscall.NetlinkRIB(unix.RTM_GETLINK, unix.AF_UNSPEC)
if err != nil {
return nil, fmt.Errorf("wglinux: failed to get list of interfaces from rtnetlink: %v", err)
}
msgs, err := syscall.ParseNetlinkMessage(tab)
if err != nil {
return nil, fmt.Errorf("wglinux: failed to parse rtnetlink messages: %v", err)
}
return parseRTNLInterfaces(msgs)
}
// parseRTNLInterfaces unpacks rtnetlink messages and returns WireGuard
// interface names.
func parseRTNLInterfaces(msgs []syscall.NetlinkMessage) ([]string, error) {
var ifis []string
for _, m := range msgs {
// Only deal with link messages, and they must have an ifinfomsg
// structure appear before the attributes.
if m.Header.Type != unix.RTM_NEWLINK {
continue
}
if len(m.Data) < unix.SizeofIfInfomsg {
return nil, fmt.Errorf("wglinux: rtnetlink message is too short for ifinfomsg: %d", len(m.Data))
}
ad, err := netlink.NewAttributeDecoder(m.Data[syscall.SizeofIfInfomsg:])
if err != nil {
return nil, err
}
// Determine the interface's name and if it's a WireGuard device.
var (
ifi string
isWG bool
)
for ad.Next() {
switch ad.Type() {
case unix.IFLA_IFNAME:
ifi = ad.String()
case unix.IFLA_LINKINFO:
ad.Do(isWGKind(&isWG))
}
}
if err := ad.Err(); err != nil {
return nil, err
}
if isWG {
// Found one; append it to the list.
ifis = append(ifis, ifi)
}
}
return ifis, nil
}
// wgKind is the IFLA_INFO_KIND value for WireGuard devices.
const wgKind = "wireguard"
// isWGKind parses netlink attributes to determine if a link is a WireGuard
// device, then populates ok with the result.
func isWGKind(ok *bool) func(b []byte) error {
return func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return err
}
for ad.Next() {
if ad.Type() != unix.IFLA_INFO_KIND {
continue
}
if ad.String() == wgKind {
*ok = true
return nil
}
}
return ad.Err()
}
}

View File

@@ -0,0 +1,294 @@
//go:build linux
// +build linux
package wglinux
import (
"encoding/binary"
"fmt"
"net"
"unsafe"
"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nlenc"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl/internal/wglinux/internal/wgh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// configAttrs creates the required encoded netlink attributes to configure
// the device specified by name using the non-nil fields in cfg.
func configAttrs(name string, cfg wgtypes.Config) ([]byte, error) {
ae := netlink.NewAttributeEncoder()
ae.String(wgh.DeviceAIfname, name)
if cfg.PrivateKey != nil {
ae.Bytes(wgh.DeviceAPrivateKey, (*cfg.PrivateKey)[:])
}
if cfg.ListenPort != nil {
ae.Uint16(wgh.DeviceAListenPort, uint16(*cfg.ListenPort))
}
if cfg.FirewallMark != nil {
ae.Uint32(wgh.DeviceAFwmark, uint32(*cfg.FirewallMark))
}
if cfg.ReplacePeers {
ae.Uint32(wgh.DeviceAFlags, wgh.DeviceFReplacePeers)
}
// Only apply peer attributes if necessary.
if len(cfg.Peers) > 0 {
ae.Nested(wgh.DeviceAPeers, func(nae *netlink.AttributeEncoder) error {
// Netlink arrays use type as an array index.
for i, p := range cfg.Peers {
nae.Nested(uint16(i), encodePeer(p))
}
return nil
})
}
return ae.Encode()
}
// ipBatchChunk is a tunable allowed IP batch limit per peer.
//
// Because we don't necessarily know how much space a given peer will occupy,
// we play it safe and use a reasonably small value. Note that this constant
// is used both in this package and tests, so be aware when making changes.
const ipBatchChunk = 256
// peerBatchChunk specifies the number of peers that can appear in a
// configuration before we start splitting it into chunks.
const peerBatchChunk = 32
// shouldBatch determines if a configuration is sufficiently complex that it
// should be split into batches.
func shouldBatch(cfg wgtypes.Config) bool {
if len(cfg.Peers) > peerBatchChunk {
return true
}
var ips int
for _, p := range cfg.Peers {
ips += len(p.AllowedIPs)
}
return ips > ipBatchChunk
}
// buildBatches produces a batch of configs from a single config, if needed.
func buildBatches(cfg wgtypes.Config) []wgtypes.Config {
// Is this a small configuration; no need to batch?
if !shouldBatch(cfg) {
return []wgtypes.Config{cfg}
}
// Use most fields of cfg for our "base" configuration, and only differ
// peers in each batch.
base := cfg
base.Peers = nil
// Track the known peers so that peer IPs are not replaced if a single
// peer has its allowed IPs split into multiple batches.
knownPeers := make(map[wgtypes.Key]struct{})
batches := make([]wgtypes.Config, 0)
for _, p := range cfg.Peers {
batch := base
// Iterate until no more allowed IPs.
var done bool
for !done {
var tmp []net.IPNet
if len(p.AllowedIPs) < ipBatchChunk {
// IPs all fit within a batch; we are done.
tmp = make([]net.IPNet, len(p.AllowedIPs))
copy(tmp, p.AllowedIPs)
done = true
} else {
// IPs are larger than a single batch, copy a batch out and
// advance the cursor.
tmp = make([]net.IPNet, ipBatchChunk)
copy(tmp, p.AllowedIPs[:ipBatchChunk])
p.AllowedIPs = p.AllowedIPs[ipBatchChunk:]
if len(p.AllowedIPs) == 0 {
// IPs ended on a batch boundary; no more IPs left so end
// iteration after this loop.
done = true
}
}
pcfg := wgtypes.PeerConfig{
// PublicKey denotes the peer and must be present.
PublicKey: p.PublicKey,
// Apply the update only flag to every chunk to ensure
// consistency between batches when the kernel module processes
// them.
UpdateOnly: p.UpdateOnly,
// It'd be a bit weird to have a remove peer message with many
// IPs, but just in case, add this to every peer's message.
Remove: p.Remove,
// The IPs for this chunk.
AllowedIPs: tmp,
}
// Only pass certain fields on the first occurrence of a peer, so
// that subsequent IPs won't be wiped out and space isn't wasted.
if _, ok := knownPeers[p.PublicKey]; !ok {
knownPeers[p.PublicKey] = struct{}{}
pcfg.PresharedKey = p.PresharedKey
pcfg.Endpoint = p.Endpoint
pcfg.PersistentKeepaliveInterval = p.PersistentKeepaliveInterval
// Important: do not move or appending peers won't work.
pcfg.ReplaceAllowedIPs = p.ReplaceAllowedIPs
}
// Add a peer configuration to this batch and keep going.
batch.Peers = []wgtypes.PeerConfig{pcfg}
batches = append(batches, batch)
}
}
// Do not allow peer replacement beyond the first message in a batch,
// so we don't overwrite our previous batch work.
for i := range batches {
if i > 0 {
batches[i].ReplacePeers = false
}
}
return batches
}
// encodePeer returns a function to encode PeerConfig nested attributes.
func encodePeer(p wgtypes.PeerConfig) func(ae *netlink.AttributeEncoder) error {
return func(ae *netlink.AttributeEncoder) error {
ae.Bytes(wgh.PeerAPublicKey, p.PublicKey[:])
// Flags are stored in a single attribute.
var flags uint32
if p.Remove {
flags |= wgh.PeerFRemoveMe
}
if p.ReplaceAllowedIPs {
flags |= wgh.PeerFReplaceAllowedips
}
if p.UpdateOnly {
flags |= wgh.PeerFUpdateOnly
}
if flags != 0 {
ae.Uint32(wgh.PeerAFlags, flags)
}
if p.PresharedKey != nil {
ae.Bytes(wgh.PeerAPresharedKey, (*p.PresharedKey)[:])
}
if p.Endpoint != nil {
ae.Do(wgh.PeerAEndpoint, encodeSockaddr(*p.Endpoint))
}
if p.PersistentKeepaliveInterval != nil {
ae.Uint16(wgh.PeerAPersistentKeepaliveInterval, uint16(p.PersistentKeepaliveInterval.Seconds()))
}
// Only apply allowed IPs if necessary.
if len(p.AllowedIPs) > 0 {
ae.Nested(wgh.PeerAAllowedips, encodeAllowedIPs(p.AllowedIPs))
}
return nil
}
}
// encodeSockaddr returns a function which encodes a net.UDPAddr as raw
// sockaddr_in or sockaddr_in6 bytes.
func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) {
return func() ([]byte, error) {
if !isValidIP(endpoint.IP) {
return nil, fmt.Errorf("wglinux: invalid endpoint IP: %s", endpoint.IP.String())
}
// Is this an IPv6 address?
if isIPv6(endpoint.IP) {
var addr [16]byte
copy(addr[:], endpoint.IP.To16())
sa := unix.RawSockaddrInet6{
Family: unix.AF_INET6,
Port: sockaddrPort(endpoint.Port),
Addr: addr,
}
return (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:], nil
}
// IPv4 address handling.
var addr [4]byte
copy(addr[:], endpoint.IP.To4())
sa := unix.RawSockaddrInet4{
Family: unix.AF_INET,
Port: sockaddrPort(endpoint.Port),
Addr: addr,
}
return (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:], nil
}
}
// encodeAllowedIPs returns a function to encode allowed IP nested attributes.
func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error {
return func(ae *netlink.AttributeEncoder) error {
for i, ipn := range ipns {
if !isValidIP(ipn.IP) {
return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.IP.String())
}
family := uint16(unix.AF_INET6)
if !isIPv6(ipn.IP) {
// Make sure address is 4 bytes if IPv4.
family = unix.AF_INET
ipn.IP = ipn.IP.To4()
}
// Netlink arrays use type as an array index.
ae.Nested(uint16(i), func(nae *netlink.AttributeEncoder) error {
nae.Uint16(wgh.AllowedipAFamily, family)
nae.Bytes(wgh.AllowedipAIpaddr, ipn.IP)
ones, _ := ipn.Mask.Size()
nae.Uint8(wgh.AllowedipACidrMask, uint8(ones))
return nil
})
}
return nil
}
}
// isValidIP determines if IP is a valid IPv4 or IPv6 address.
func isValidIP(ip net.IP) bool {
return ip.To16() != nil
}
// isIPv6 determines if IP is a valid IPv6 address.
func isIPv6(ip net.IP) bool {
return isValidIP(ip) && ip.To4() == nil
}
// sockaddrPort interprets port as a big endian uint16 for use passing sockaddr
// structures to the kernel.
func sockaddrPort(port int) uint16 {
return binary.BigEndian.Uint16(nlenc.Uint16Bytes(uint16(port)))
}

View File

@@ -0,0 +1,6 @@
// Package wglinux provides internal access to Linux's WireGuard generic
// netlink interface.
//
// This package is internal-only and not meant for end users to consume.
// Please use package wgctrl (an abstraction over this package) instead.
package wglinux

View File

@@ -0,0 +1,99 @@
// WARNING: This file has automatically been generated on Tue, 04 May 2021 18:36:46 EDT.
// Code generated by https://git.io/c-for-go. DO NOT EDIT.
package wgh
const (
// GenlName as defined in wgh/wireguard.h:134
GenlName = "wireguard"
// GenlVersion as defined in wgh/wireguard.h:135
GenlVersion = 1
// KeyLen as defined in wgh/wireguard.h:137
KeyLen = 32
// CmdMax as defined in wgh/wireguard.h:144
CmdMax = (__CmdMax - 1)
// DeviceAMax as defined in wgh/wireguard.h:162
DeviceAMax = (_DeviceALast - 1)
// PeerAMax as defined in wgh/wireguard.h:185
PeerAMax = (_PeerALast - 1)
// AllowedipAMax as defined in wgh/wireguard.h:194
AllowedipAMax = (_AllowedipALast - 1)
)
// wgCmd as declared in wgh/wireguard.h:139
type wgCmd int32
// wgCmd enumeration from wgh/wireguard.h:139
const (
CmdGetDevice = iota
CmdSetDevice = 1
__CmdMax = 2
)
// wgdeviceFlag as declared in wgh/wireguard.h:146
type wgdeviceFlag int32
// wgdeviceFlag enumeration from wgh/wireguard.h:146
const (
DeviceFReplacePeers = uint32(1) << 0
_DeviceFAll = DeviceFReplacePeers
)
// wgdeviceAttribute as declared in wgh/wireguard.h:150
type wgdeviceAttribute int32
// wgdeviceAttribute enumeration from wgh/wireguard.h:150
const (
DeviceAUnspec = iota
DeviceAIfindex = 1
DeviceAIfname = 2
DeviceAPrivateKey = 3
DeviceAPublicKey = 4
DeviceAFlags = 5
DeviceAListenPort = 6
DeviceAFwmark = 7
DeviceAPeers = 8
_DeviceALast = 9
)
// wgpeerFlag as declared in wgh/wireguard.h:164
type wgpeerFlag int32
// wgpeerFlag enumeration from wgh/wireguard.h:164
const (
PeerFRemoveMe = uint32(1) << 0
PeerFReplaceAllowedips = uint32(1) << 1
PeerFUpdateOnly = uint32(1) << 2
_PeerFAll = PeerFRemoveMe | PeerFReplaceAllowedips | PeerFUpdateOnly
)
// wgpeerAttribute as declared in wgh/wireguard.h:171
type wgpeerAttribute int32
// wgpeerAttribute enumeration from wgh/wireguard.h:171
const (
PeerAUnspec = iota
PeerAPublicKey = 1
PeerAPresharedKey = 2
PeerAFlags = 3
PeerAEndpoint = 4
PeerAPersistentKeepaliveInterval = 5
PeerALastHandshakeTime = 6
PeerARxBytes = 7
PeerATxBytes = 8
PeerAAllowedips = 9
PeerAProtocolVersion = 10
_PeerALast = 11
)
// wgallowedipAttribute as declared in wgh/wireguard.h:187
type wgallowedipAttribute int32
// wgallowedipAttribute enumeration from wgh/wireguard.h:187
const (
AllowedipAUnspec = iota
AllowedipAFamily = 1
AllowedipAIpaddr = 2
AllowedipACidrMask = 3
_AllowedipALast = 4
)

View File

@@ -0,0 +1,12 @@
// Package wgh is an auto-generated package which contains constants and
// types used to access WireGuard information using generic netlink.
package wgh
// Pull the latest wireguard.h from GitHub for code generation.
//go:generate wget https://raw.githubusercontent.com/torvalds/linux/master/include/uapi/linux/wireguard.h
// Generate Go source from C constants.
//go:generate c-for-go -out ../ -nocgo wgh.yml
// Clean up build artifacts.
//go:generate rm -rf wireguard.h _obj/

View File

@@ -0,0 +1,22 @@
---
GENERATOR:
PackageName: wgh
PARSER:
IncludePaths: [/usr/include]
SourcesPaths: [wireguard.h]
TRANSLATOR:
ConstRules:
defines: expand
enum: expand
Rules:
const:
- {transform: lower}
- {action: accept, from: "(?i)wg_"}
- {action: replace, from: "(?i)wg_", to: _}
- {action: accept, from: "(?i)wg"}
- {action: replace, from: "(?i)wg", to: }
- {transform: export}
post-global:
- {load: snakecase}

View File

@@ -0,0 +1,304 @@
//go:build linux
// +build linux
package wglinux
import (
"fmt"
"net"
"time"
"unsafe"
"github.com/mdlayher/genetlink"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl/internal/wglinux/internal/wgh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// parseDevice parses a Device from a slice of generic netlink messages,
// automatically merging peer lists from subsequent messages into the Device
// from the first message.
func parseDevice(msgs []genetlink.Message) (*wgtypes.Device, error) {
var first wgtypes.Device
knownPeers := make(map[wgtypes.Key]int)
for i, m := range msgs {
d, err := parseDeviceLoop(m)
if err != nil {
return nil, err
}
if i == 0 {
// First message contains our target device.
first = *d
// Gather the known peers so that we can merge
// them later if needed
for i := range first.Peers {
knownPeers[first.Peers[i].PublicKey] = i
}
continue
}
// Any subsequent messages have their peer contents merged into the
// first "target" message.
mergeDevices(&first, d, knownPeers)
}
return &first, nil
}
// parseDeviceLoop parses a Device from a single generic netlink message.
func parseDeviceLoop(m genetlink.Message) (*wgtypes.Device, error) {
ad, err := netlink.NewAttributeDecoder(m.Data)
if err != nil {
return nil, err
}
d := wgtypes.Device{Type: wgtypes.LinuxKernel}
for ad.Next() {
switch ad.Type() {
case wgh.DeviceAIfindex:
// Ignored; interface index isn't exposed at all in the userspace
// configuration protocol, and name is more friendly anyway.
case wgh.DeviceAIfname:
d.Name = ad.String()
case wgh.DeviceAPrivateKey:
ad.Do(parseKey(&d.PrivateKey))
case wgh.DeviceAPublicKey:
ad.Do(parseKey(&d.PublicKey))
case wgh.DeviceAListenPort:
d.ListenPort = int(ad.Uint16())
case wgh.DeviceAFwmark:
d.FirewallMark = int(ad.Uint32())
case wgh.DeviceAPeers:
// Netlink array of peers.
//
// Errors while parsing are propagated up to top-level ad.Err check.
ad.Nested(func(nad *netlink.AttributeDecoder) error {
// Initialize to the number of peers in this decoder and begin
// handling nested Peer attributes.
d.Peers = make([]wgtypes.Peer, 0, nad.Len())
for nad.Next() {
nad.Nested(func(nnad *netlink.AttributeDecoder) error {
d.Peers = append(d.Peers, parsePeer(nnad))
return nil
})
}
return nil
})
}
}
if err := ad.Err(); err != nil {
return nil, err
}
return &d, nil
}
// parseAllowedIPs parses a wgtypes.Peer from a netlink attribute payload.
func parsePeer(ad *netlink.AttributeDecoder) wgtypes.Peer {
var p wgtypes.Peer
for ad.Next() {
switch ad.Type() {
case wgh.PeerAPublicKey:
ad.Do(parseKey(&p.PublicKey))
case wgh.PeerAPresharedKey:
ad.Do(parseKey(&p.PresharedKey))
case wgh.PeerAEndpoint:
p.Endpoint = &net.UDPAddr{}
ad.Do(parseSockaddr(p.Endpoint))
case wgh.PeerAPersistentKeepaliveInterval:
p.PersistentKeepaliveInterval = time.Duration(ad.Uint16()) * time.Second
case wgh.PeerALastHandshakeTime:
ad.Do(parseTimespec(&p.LastHandshakeTime))
case wgh.PeerARxBytes:
p.ReceiveBytes = int64(ad.Uint64())
case wgh.PeerATxBytes:
p.TransmitBytes = int64(ad.Uint64())
case wgh.PeerAAllowedips:
ad.Nested(parseAllowedIPs(&p.AllowedIPs))
case wgh.PeerAProtocolVersion:
p.ProtocolVersion = int(ad.Uint32())
}
}
return p
}
// parseAllowedIPs parses a slice of net.IPNet from a netlink attribute payload.
func parseAllowedIPs(ipns *[]net.IPNet) func(ad *netlink.AttributeDecoder) error {
return func(ad *netlink.AttributeDecoder) error {
// Initialize to the number of allowed IPs and begin iterating through
// the netlink array to decode each one.
*ipns = make([]net.IPNet, 0, ad.Len())
for ad.Next() {
// Allowed IP nested attributes.
ad.Nested(func(nad *netlink.AttributeDecoder) error {
var (
ipn net.IPNet
mask int
family int
)
for nad.Next() {
switch nad.Type() {
case wgh.AllowedipAIpaddr:
nad.Do(parseAddr(&ipn.IP))
case wgh.AllowedipACidrMask:
mask = int(nad.Uint8())
case wgh.AllowedipAFamily:
family = int(nad.Uint16())
}
}
if err := nad.Err(); err != nil {
return err
}
// The address family determines the correct number of bits in
// the mask.
switch family {
case unix.AF_INET:
ipn.Mask = net.CIDRMask(mask, 32)
case unix.AF_INET6:
ipn.Mask = net.CIDRMask(mask, 128)
}
*ipns = append(*ipns, ipn)
return nil
})
}
return nil
}
}
// parseKey parses a wgtypes.Key from a byte slice.
func parseKey(key *wgtypes.Key) func(b []byte) error {
return func(b []byte) error {
k, err := wgtypes.NewKey(b)
if err != nil {
return err
}
*key = k
return nil
}
}
// parseAddr parses a net.IP from raw in_addr or in6_addr struct bytes.
func parseAddr(ip *net.IP) func(b []byte) error {
return func(b []byte) error {
switch len(b) {
case net.IPv4len, net.IPv6len:
// Okay to convert directly to net.IP; memory layout is identical.
*ip = make(net.IP, len(b))
copy(*ip, b)
return nil
default:
return fmt.Errorf("wglinux: unexpected IP address size: %d", len(b))
}
}
}
// parseSockaddr parses a *net.UDPAddr from raw sockaddr_in or sockaddr_in6 bytes.
func parseSockaddr(endpoint *net.UDPAddr) func(b []byte) error {
return func(b []byte) error {
switch len(b) {
case unix.SizeofSockaddrInet4:
// IPv4 address parsing.
sa := *(*unix.RawSockaddrInet4)(unsafe.Pointer(&b[0]))
*endpoint = net.UDPAddr{
IP: net.IP(sa.Addr[:]).To4(),
Port: int(sockaddrPort(int(sa.Port))),
}
return nil
case unix.SizeofSockaddrInet6:
// IPv6 address parsing.
sa := *(*unix.RawSockaddrInet6)(unsafe.Pointer(&b[0]))
*endpoint = net.UDPAddr{
IP: net.IP(sa.Addr[:]),
Port: int(sockaddrPort(int(sa.Port))),
}
return nil
default:
return fmt.Errorf("wglinux: unexpected sockaddr size: %d", len(b))
}
}
}
// timespec32 is a unix.Timespec with 32-bit integers.
type timespec32 struct {
Sec int32
Nsec int32
}
// timespec64 is a unix.Timespec with 64-bit integers.
type timespec64 struct {
Sec int64
Nsec int64
}
const (
sizeofTimespec32 = int(unsafe.Sizeof(timespec32{}))
sizeofTimespec64 = int(unsafe.Sizeof(timespec64{}))
)
// parseTimespec parses a time.Time from raw timespec bytes.
func parseTimespec(t *time.Time) func(b []byte) error {
return func(b []byte) error {
// It would appear that WireGuard can return a __kernel_timespec which
// uses 64-bit integers, even on 32-bit platforms. Clarification of this
// behavior is being sought in:
// https://lists.zx2c4.com/pipermail/wireguard/2019-April/004088.html.
//
// In the mean time, be liberal and accept 32-bit and 64-bit variants.
var sec, nsec int64
switch len(b) {
case sizeofTimespec32:
ts := *(*timespec32)(unsafe.Pointer(&b[0]))
sec = int64(ts.Sec)
nsec = int64(ts.Nsec)
case sizeofTimespec64:
ts := *(*timespec64)(unsafe.Pointer(&b[0]))
sec = ts.Sec
nsec = ts.Nsec
default:
return fmt.Errorf("wglinux: unexpected timespec size: %d bytes, expected 8 or 16 bytes", len(b))
}
// Only set fields if UNIX timestamp value is greater than 0, so the
// caller will see a zero-value time.Time otherwise.
if sec > 0 || nsec > 0 {
*t = time.Unix(sec, nsec)
}
return nil
}
}
// mergeDevices merges Peer information from d into target. mergeDevices is
// used to deal with multiple incoming netlink messages for the same device.
func mergeDevices(target, d *wgtypes.Device, knownPeers map[wgtypes.Key]int) {
for i := range d.Peers {
// Peer is already known, append to it's allowed IP networks
if peerIndex, ok := knownPeers[d.Peers[i].PublicKey]; ok {
target.Peers[peerIndex].AllowedIPs = append(target.Peers[peerIndex].AllowedIPs, d.Peers[i].AllowedIPs...)
} else { // New peer, add it to the target peers.
target.Peers = append(target.Peers, d.Peers[i])
knownPeers[d.Peers[i].PublicKey] = len(target.Peers) - 1
}
}
}