migrate to golang.zx2c4.com/wireguard/wgctrl (#239)

* migrate to golang.zx2c4.com/wireguard/wgctrl

This commit introduces the usage of wgctrl.
It avoids the usage of exec calls of the wg command
and parsing the output of `wg show`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* vendor wgctrl

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* apply suggestions from code review

Remove wireguard.Enpoint struct and use net.UDPAddr for the resolved
endpoint and addr string (dnsanme:port) if a DN was supplied.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/*: use wireguard.Enpoint

This commit introduces the wireguard.Enpoint struct.
It encapsulates a DN name with port and a net.UPDAddr.
The fields are private and only accessible over exported Methods
to avoid accidental modification.

Also iptables.GetProtocol is improved to avoid ipv4 rules being applied
by `ip6tables`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/wireguard/conf_test.go: add tests for Endpoint

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* cmd/kg/main.go: validate port range

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* add suggestions from review

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/mesh/mesh.go: use Equal func

Implement an Equal func for Enpoint and use it instead of comparing
strings.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* cmd/kgctl/main.go: check port range

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* vendor

Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
leonnicolas
2022-01-30 17:38:45 +01:00
committed by GitHub
parent 797133f272
commit 6a696e03e7
299 changed files with 26275 additions and 10252 deletions

View File

@@ -0,0 +1,21 @@
package wginternal
import (
"errors"
"io"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// ErrReadOnly indicates that the driver backing a device is read-only. It is
// a sentinel value used in integration tests.
// TODO(mdlayher): consider exposing in API.
var ErrReadOnly = errors.New("driver is read-only")
// A Client is a type which can control a WireGuard device.
type Client interface {
io.Closer
Devices() ([]*wgtypes.Device, error)
Device(name string) (*wgtypes.Device, error)
ConfigureDevice(name string, cfg wgtypes.Config) error
}

View File

@@ -0,0 +1,5 @@
// Package wginternal contains shared internal types for wgctrl.
//
// This package is internal-only and not meant for end users to consume.
// Please use package wgctrl (an abstraction over this package) instead.
package wginternal

View File

@@ -0,0 +1,265 @@
//go:build linux
// +build linux
package wglinux
import (
"errors"
"fmt"
"os"
"syscall"
"github.com/mdlayher/genetlink"
"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nlenc"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
"golang.zx2c4.com/wireguard/wgctrl/internal/wglinux/internal/wgh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var _ wginternal.Client = &Client{}
// A Client provides access to Linux WireGuard netlink information.
type Client struct {
c *genetlink.Conn
family genetlink.Family
interfaces func() ([]string, error)
}
// New creates a new Client and returns whether or not the generic netlink
// interface is available.
func New() (*Client, bool, error) {
c, err := genetlink.Dial(nil)
if err != nil {
return nil, false, err
}
return initClient(c)
}
// initClient is the internal Client constructor used in some tests.
func initClient(c *genetlink.Conn) (*Client, bool, error) {
f, err := c.GetFamily(wgh.GenlName)
if err != nil {
_ = c.Close()
if errors.Is(err, os.ErrNotExist) {
// The generic netlink interface is not available.
return nil, false, nil
}
return nil, false, err
}
return &Client{
c: c,
family: f,
// By default, gather only WireGuard interfaces using rtnetlink.
interfaces: rtnlInterfaces,
}, true, nil
}
// Close implements wginternal.Client.
func (c *Client) Close() error {
return c.c.Close()
}
// Devices implements wginternal.Client.
func (c *Client) Devices() ([]*wgtypes.Device, error) {
// By default, rtnetlink is used to fetch a list of all interfaces and then
// filter that list to only find WireGuard interfaces.
//
// The remainder of this function assumes that any returned device from this
// function is a valid WireGuard device.
ifis, err := c.interfaces()
if err != nil {
return nil, err
}
ds := make([]*wgtypes.Device, 0, len(ifis))
for _, ifi := range ifis {
d, err := c.Device(ifi)
if err != nil {
return nil, err
}
ds = append(ds, d)
}
return ds, nil
}
// Device implements wginternal.Client.
func (c *Client) Device(name string) (*wgtypes.Device, error) {
// Don't bother querying netlink with empty input.
if name == "" {
return nil, os.ErrNotExist
}
// Fetching a device by interface index is possible as well, but we only
// support fetching by name as it seems to be more convenient in general.
b, err := netlink.MarshalAttributes([]netlink.Attribute{{
Type: wgh.DeviceAIfname,
Data: nlenc.Bytes(name),
}})
if err != nil {
return nil, err
}
msgs, err := c.execute(wgh.CmdGetDevice, netlink.Request|netlink.Dump, b)
if err != nil {
return nil, err
}
return parseDevice(msgs)
}
// ConfigureDevice implements wginternal.Client.
func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error {
// Large configurations are split into batches for use with netlink.
for _, b := range buildBatches(cfg) {
attrs, err := configAttrs(name, b)
if err != nil {
return err
}
// Request acknowledgement of our request from netlink, even though the
// output messages are unused. The netlink package checks and trims the
// status code value.
if _, err := c.execute(wgh.CmdSetDevice, netlink.Request|netlink.Acknowledge, attrs); err != nil {
return err
}
}
return nil
}
// execute executes a single WireGuard netlink request with the specified command,
// header flags, and attribute arguments.
func (c *Client) execute(command uint8, flags netlink.HeaderFlags, attrb []byte) ([]genetlink.Message, error) {
msg := genetlink.Message{
Header: genetlink.Header{
Command: command,
Version: wgh.GenlVersion,
},
Data: attrb,
}
msgs, err := c.c.Execute(msg, c.family.ID, flags)
if err == nil {
return msgs, nil
}
// We don't want to expose netlink errors directly to callers so unpack to
// something more generic.
oerr, ok := err.(*netlink.OpError)
if !ok {
// Expect all errors to conform to netlink.OpError.
return nil, fmt.Errorf("wglinux: netlink operation returned non-netlink error (please file a bug: https://golang.zx2c4.com/wireguard/wgctrl): %v", err)
}
switch oerr.Err {
// Convert "no such device" and "not a wireguard device" to an error
// compatible with os.ErrNotExist for easy checking.
case unix.ENODEV, unix.ENOTSUP:
return nil, os.ErrNotExist
default:
// Expose the inner error directly (such as EPERM).
return nil, oerr.Err
}
}
// rtnlInterfaces uses rtnetlink to fetch a list of WireGuard interfaces.
func rtnlInterfaces() ([]string, error) {
// Use the stdlib's rtnetlink helpers to get ahold of a table of all
// interfaces, so we can begin filtering it down to just WireGuard devices.
tab, err := syscall.NetlinkRIB(unix.RTM_GETLINK, unix.AF_UNSPEC)
if err != nil {
return nil, fmt.Errorf("wglinux: failed to get list of interfaces from rtnetlink: %v", err)
}
msgs, err := syscall.ParseNetlinkMessage(tab)
if err != nil {
return nil, fmt.Errorf("wglinux: failed to parse rtnetlink messages: %v", err)
}
return parseRTNLInterfaces(msgs)
}
// parseRTNLInterfaces unpacks rtnetlink messages and returns WireGuard
// interface names.
func parseRTNLInterfaces(msgs []syscall.NetlinkMessage) ([]string, error) {
var ifis []string
for _, m := range msgs {
// Only deal with link messages, and they must have an ifinfomsg
// structure appear before the attributes.
if m.Header.Type != unix.RTM_NEWLINK {
continue
}
if len(m.Data) < unix.SizeofIfInfomsg {
return nil, fmt.Errorf("wglinux: rtnetlink message is too short for ifinfomsg: %d", len(m.Data))
}
ad, err := netlink.NewAttributeDecoder(m.Data[syscall.SizeofIfInfomsg:])
if err != nil {
return nil, err
}
// Determine the interface's name and if it's a WireGuard device.
var (
ifi string
isWG bool
)
for ad.Next() {
switch ad.Type() {
case unix.IFLA_IFNAME:
ifi = ad.String()
case unix.IFLA_LINKINFO:
ad.Do(isWGKind(&isWG))
}
}
if err := ad.Err(); err != nil {
return nil, err
}
if isWG {
// Found one; append it to the list.
ifis = append(ifis, ifi)
}
}
return ifis, nil
}
// wgKind is the IFLA_INFO_KIND value for WireGuard devices.
const wgKind = "wireguard"
// isWGKind parses netlink attributes to determine if a link is a WireGuard
// device, then populates ok with the result.
func isWGKind(ok *bool) func(b []byte) error {
return func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return err
}
for ad.Next() {
if ad.Type() != unix.IFLA_INFO_KIND {
continue
}
if ad.String() == wgKind {
*ok = true
return nil
}
}
return ad.Err()
}
}

View File

@@ -0,0 +1,294 @@
//go:build linux
// +build linux
package wglinux
import (
"encoding/binary"
"fmt"
"net"
"unsafe"
"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nlenc"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl/internal/wglinux/internal/wgh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// configAttrs creates the required encoded netlink attributes to configure
// the device specified by name using the non-nil fields in cfg.
func configAttrs(name string, cfg wgtypes.Config) ([]byte, error) {
ae := netlink.NewAttributeEncoder()
ae.String(wgh.DeviceAIfname, name)
if cfg.PrivateKey != nil {
ae.Bytes(wgh.DeviceAPrivateKey, (*cfg.PrivateKey)[:])
}
if cfg.ListenPort != nil {
ae.Uint16(wgh.DeviceAListenPort, uint16(*cfg.ListenPort))
}
if cfg.FirewallMark != nil {
ae.Uint32(wgh.DeviceAFwmark, uint32(*cfg.FirewallMark))
}
if cfg.ReplacePeers {
ae.Uint32(wgh.DeviceAFlags, wgh.DeviceFReplacePeers)
}
// Only apply peer attributes if necessary.
if len(cfg.Peers) > 0 {
ae.Nested(wgh.DeviceAPeers, func(nae *netlink.AttributeEncoder) error {
// Netlink arrays use type as an array index.
for i, p := range cfg.Peers {
nae.Nested(uint16(i), encodePeer(p))
}
return nil
})
}
return ae.Encode()
}
// ipBatchChunk is a tunable allowed IP batch limit per peer.
//
// Because we don't necessarily know how much space a given peer will occupy,
// we play it safe and use a reasonably small value. Note that this constant
// is used both in this package and tests, so be aware when making changes.
const ipBatchChunk = 256
// peerBatchChunk specifies the number of peers that can appear in a
// configuration before we start splitting it into chunks.
const peerBatchChunk = 32
// shouldBatch determines if a configuration is sufficiently complex that it
// should be split into batches.
func shouldBatch(cfg wgtypes.Config) bool {
if len(cfg.Peers) > peerBatchChunk {
return true
}
var ips int
for _, p := range cfg.Peers {
ips += len(p.AllowedIPs)
}
return ips > ipBatchChunk
}
// buildBatches produces a batch of configs from a single config, if needed.
func buildBatches(cfg wgtypes.Config) []wgtypes.Config {
// Is this a small configuration; no need to batch?
if !shouldBatch(cfg) {
return []wgtypes.Config{cfg}
}
// Use most fields of cfg for our "base" configuration, and only differ
// peers in each batch.
base := cfg
base.Peers = nil
// Track the known peers so that peer IPs are not replaced if a single
// peer has its allowed IPs split into multiple batches.
knownPeers := make(map[wgtypes.Key]struct{})
batches := make([]wgtypes.Config, 0)
for _, p := range cfg.Peers {
batch := base
// Iterate until no more allowed IPs.
var done bool
for !done {
var tmp []net.IPNet
if len(p.AllowedIPs) < ipBatchChunk {
// IPs all fit within a batch; we are done.
tmp = make([]net.IPNet, len(p.AllowedIPs))
copy(tmp, p.AllowedIPs)
done = true
} else {
// IPs are larger than a single batch, copy a batch out and
// advance the cursor.
tmp = make([]net.IPNet, ipBatchChunk)
copy(tmp, p.AllowedIPs[:ipBatchChunk])
p.AllowedIPs = p.AllowedIPs[ipBatchChunk:]
if len(p.AllowedIPs) == 0 {
// IPs ended on a batch boundary; no more IPs left so end
// iteration after this loop.
done = true
}
}
pcfg := wgtypes.PeerConfig{
// PublicKey denotes the peer and must be present.
PublicKey: p.PublicKey,
// Apply the update only flag to every chunk to ensure
// consistency between batches when the kernel module processes
// them.
UpdateOnly: p.UpdateOnly,
// It'd be a bit weird to have a remove peer message with many
// IPs, but just in case, add this to every peer's message.
Remove: p.Remove,
// The IPs for this chunk.
AllowedIPs: tmp,
}
// Only pass certain fields on the first occurrence of a peer, so
// that subsequent IPs won't be wiped out and space isn't wasted.
if _, ok := knownPeers[p.PublicKey]; !ok {
knownPeers[p.PublicKey] = struct{}{}
pcfg.PresharedKey = p.PresharedKey
pcfg.Endpoint = p.Endpoint
pcfg.PersistentKeepaliveInterval = p.PersistentKeepaliveInterval
// Important: do not move or appending peers won't work.
pcfg.ReplaceAllowedIPs = p.ReplaceAllowedIPs
}
// Add a peer configuration to this batch and keep going.
batch.Peers = []wgtypes.PeerConfig{pcfg}
batches = append(batches, batch)
}
}
// Do not allow peer replacement beyond the first message in a batch,
// so we don't overwrite our previous batch work.
for i := range batches {
if i > 0 {
batches[i].ReplacePeers = false
}
}
return batches
}
// encodePeer returns a function to encode PeerConfig nested attributes.
func encodePeer(p wgtypes.PeerConfig) func(ae *netlink.AttributeEncoder) error {
return func(ae *netlink.AttributeEncoder) error {
ae.Bytes(wgh.PeerAPublicKey, p.PublicKey[:])
// Flags are stored in a single attribute.
var flags uint32
if p.Remove {
flags |= wgh.PeerFRemoveMe
}
if p.ReplaceAllowedIPs {
flags |= wgh.PeerFReplaceAllowedips
}
if p.UpdateOnly {
flags |= wgh.PeerFUpdateOnly
}
if flags != 0 {
ae.Uint32(wgh.PeerAFlags, flags)
}
if p.PresharedKey != nil {
ae.Bytes(wgh.PeerAPresharedKey, (*p.PresharedKey)[:])
}
if p.Endpoint != nil {
ae.Do(wgh.PeerAEndpoint, encodeSockaddr(*p.Endpoint))
}
if p.PersistentKeepaliveInterval != nil {
ae.Uint16(wgh.PeerAPersistentKeepaliveInterval, uint16(p.PersistentKeepaliveInterval.Seconds()))
}
// Only apply allowed IPs if necessary.
if len(p.AllowedIPs) > 0 {
ae.Nested(wgh.PeerAAllowedips, encodeAllowedIPs(p.AllowedIPs))
}
return nil
}
}
// encodeSockaddr returns a function which encodes a net.UDPAddr as raw
// sockaddr_in or sockaddr_in6 bytes.
func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) {
return func() ([]byte, error) {
if !isValidIP(endpoint.IP) {
return nil, fmt.Errorf("wglinux: invalid endpoint IP: %s", endpoint.IP.String())
}
// Is this an IPv6 address?
if isIPv6(endpoint.IP) {
var addr [16]byte
copy(addr[:], endpoint.IP.To16())
sa := unix.RawSockaddrInet6{
Family: unix.AF_INET6,
Port: sockaddrPort(endpoint.Port),
Addr: addr,
}
return (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:], nil
}
// IPv4 address handling.
var addr [4]byte
copy(addr[:], endpoint.IP.To4())
sa := unix.RawSockaddrInet4{
Family: unix.AF_INET,
Port: sockaddrPort(endpoint.Port),
Addr: addr,
}
return (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:], nil
}
}
// encodeAllowedIPs returns a function to encode allowed IP nested attributes.
func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error {
return func(ae *netlink.AttributeEncoder) error {
for i, ipn := range ipns {
if !isValidIP(ipn.IP) {
return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.IP.String())
}
family := uint16(unix.AF_INET6)
if !isIPv6(ipn.IP) {
// Make sure address is 4 bytes if IPv4.
family = unix.AF_INET
ipn.IP = ipn.IP.To4()
}
// Netlink arrays use type as an array index.
ae.Nested(uint16(i), func(nae *netlink.AttributeEncoder) error {
nae.Uint16(wgh.AllowedipAFamily, family)
nae.Bytes(wgh.AllowedipAIpaddr, ipn.IP)
ones, _ := ipn.Mask.Size()
nae.Uint8(wgh.AllowedipACidrMask, uint8(ones))
return nil
})
}
return nil
}
}
// isValidIP determines if IP is a valid IPv4 or IPv6 address.
func isValidIP(ip net.IP) bool {
return ip.To16() != nil
}
// isIPv6 determines if IP is a valid IPv6 address.
func isIPv6(ip net.IP) bool {
return isValidIP(ip) && ip.To4() == nil
}
// sockaddrPort interprets port as a big endian uint16 for use passing sockaddr
// structures to the kernel.
func sockaddrPort(port int) uint16 {
return binary.BigEndian.Uint16(nlenc.Uint16Bytes(uint16(port)))
}

View File

@@ -0,0 +1,6 @@
// Package wglinux provides internal access to Linux's WireGuard generic
// netlink interface.
//
// This package is internal-only and not meant for end users to consume.
// Please use package wgctrl (an abstraction over this package) instead.
package wglinux

View File

@@ -0,0 +1,99 @@
// WARNING: This file has automatically been generated on Tue, 04 May 2021 18:36:46 EDT.
// Code generated by https://git.io/c-for-go. DO NOT EDIT.
package wgh
const (
// GenlName as defined in wgh/wireguard.h:134
GenlName = "wireguard"
// GenlVersion as defined in wgh/wireguard.h:135
GenlVersion = 1
// KeyLen as defined in wgh/wireguard.h:137
KeyLen = 32
// CmdMax as defined in wgh/wireguard.h:144
CmdMax = (__CmdMax - 1)
// DeviceAMax as defined in wgh/wireguard.h:162
DeviceAMax = (_DeviceALast - 1)
// PeerAMax as defined in wgh/wireguard.h:185
PeerAMax = (_PeerALast - 1)
// AllowedipAMax as defined in wgh/wireguard.h:194
AllowedipAMax = (_AllowedipALast - 1)
)
// wgCmd as declared in wgh/wireguard.h:139
type wgCmd int32
// wgCmd enumeration from wgh/wireguard.h:139
const (
CmdGetDevice = iota
CmdSetDevice = 1
__CmdMax = 2
)
// wgdeviceFlag as declared in wgh/wireguard.h:146
type wgdeviceFlag int32
// wgdeviceFlag enumeration from wgh/wireguard.h:146
const (
DeviceFReplacePeers = uint32(1) << 0
_DeviceFAll = DeviceFReplacePeers
)
// wgdeviceAttribute as declared in wgh/wireguard.h:150
type wgdeviceAttribute int32
// wgdeviceAttribute enumeration from wgh/wireguard.h:150
const (
DeviceAUnspec = iota
DeviceAIfindex = 1
DeviceAIfname = 2
DeviceAPrivateKey = 3
DeviceAPublicKey = 4
DeviceAFlags = 5
DeviceAListenPort = 6
DeviceAFwmark = 7
DeviceAPeers = 8
_DeviceALast = 9
)
// wgpeerFlag as declared in wgh/wireguard.h:164
type wgpeerFlag int32
// wgpeerFlag enumeration from wgh/wireguard.h:164
const (
PeerFRemoveMe = uint32(1) << 0
PeerFReplaceAllowedips = uint32(1) << 1
PeerFUpdateOnly = uint32(1) << 2
_PeerFAll = PeerFRemoveMe | PeerFReplaceAllowedips | PeerFUpdateOnly
)
// wgpeerAttribute as declared in wgh/wireguard.h:171
type wgpeerAttribute int32
// wgpeerAttribute enumeration from wgh/wireguard.h:171
const (
PeerAUnspec = iota
PeerAPublicKey = 1
PeerAPresharedKey = 2
PeerAFlags = 3
PeerAEndpoint = 4
PeerAPersistentKeepaliveInterval = 5
PeerALastHandshakeTime = 6
PeerARxBytes = 7
PeerATxBytes = 8
PeerAAllowedips = 9
PeerAProtocolVersion = 10
_PeerALast = 11
)
// wgallowedipAttribute as declared in wgh/wireguard.h:187
type wgallowedipAttribute int32
// wgallowedipAttribute enumeration from wgh/wireguard.h:187
const (
AllowedipAUnspec = iota
AllowedipAFamily = 1
AllowedipAIpaddr = 2
AllowedipACidrMask = 3
_AllowedipALast = 4
)

View File

@@ -0,0 +1,12 @@
// Package wgh is an auto-generated package which contains constants and
// types used to access WireGuard information using generic netlink.
package wgh
// Pull the latest wireguard.h from GitHub for code generation.
//go:generate wget https://raw.githubusercontent.com/torvalds/linux/master/include/uapi/linux/wireguard.h
// Generate Go source from C constants.
//go:generate c-for-go -out ../ -nocgo wgh.yml
// Clean up build artifacts.
//go:generate rm -rf wireguard.h _obj/

View File

@@ -0,0 +1,22 @@
---
GENERATOR:
PackageName: wgh
PARSER:
IncludePaths: [/usr/include]
SourcesPaths: [wireguard.h]
TRANSLATOR:
ConstRules:
defines: expand
enum: expand
Rules:
const:
- {transform: lower}
- {action: accept, from: "(?i)wg_"}
- {action: replace, from: "(?i)wg_", to: _}
- {action: accept, from: "(?i)wg"}
- {action: replace, from: "(?i)wg", to: }
- {transform: export}
post-global:
- {load: snakecase}

View File

@@ -0,0 +1,304 @@
//go:build linux
// +build linux
package wglinux
import (
"fmt"
"net"
"time"
"unsafe"
"github.com/mdlayher/genetlink"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl/internal/wglinux/internal/wgh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// parseDevice parses a Device from a slice of generic netlink messages,
// automatically merging peer lists from subsequent messages into the Device
// from the first message.
func parseDevice(msgs []genetlink.Message) (*wgtypes.Device, error) {
var first wgtypes.Device
knownPeers := make(map[wgtypes.Key]int)
for i, m := range msgs {
d, err := parseDeviceLoop(m)
if err != nil {
return nil, err
}
if i == 0 {
// First message contains our target device.
first = *d
// Gather the known peers so that we can merge
// them later if needed
for i := range first.Peers {
knownPeers[first.Peers[i].PublicKey] = i
}
continue
}
// Any subsequent messages have their peer contents merged into the
// first "target" message.
mergeDevices(&first, d, knownPeers)
}
return &first, nil
}
// parseDeviceLoop parses a Device from a single generic netlink message.
func parseDeviceLoop(m genetlink.Message) (*wgtypes.Device, error) {
ad, err := netlink.NewAttributeDecoder(m.Data)
if err != nil {
return nil, err
}
d := wgtypes.Device{Type: wgtypes.LinuxKernel}
for ad.Next() {
switch ad.Type() {
case wgh.DeviceAIfindex:
// Ignored; interface index isn't exposed at all in the userspace
// configuration protocol, and name is more friendly anyway.
case wgh.DeviceAIfname:
d.Name = ad.String()
case wgh.DeviceAPrivateKey:
ad.Do(parseKey(&d.PrivateKey))
case wgh.DeviceAPublicKey:
ad.Do(parseKey(&d.PublicKey))
case wgh.DeviceAListenPort:
d.ListenPort = int(ad.Uint16())
case wgh.DeviceAFwmark:
d.FirewallMark = int(ad.Uint32())
case wgh.DeviceAPeers:
// Netlink array of peers.
//
// Errors while parsing are propagated up to top-level ad.Err check.
ad.Nested(func(nad *netlink.AttributeDecoder) error {
// Initialize to the number of peers in this decoder and begin
// handling nested Peer attributes.
d.Peers = make([]wgtypes.Peer, 0, nad.Len())
for nad.Next() {
nad.Nested(func(nnad *netlink.AttributeDecoder) error {
d.Peers = append(d.Peers, parsePeer(nnad))
return nil
})
}
return nil
})
}
}
if err := ad.Err(); err != nil {
return nil, err
}
return &d, nil
}
// parseAllowedIPs parses a wgtypes.Peer from a netlink attribute payload.
func parsePeer(ad *netlink.AttributeDecoder) wgtypes.Peer {
var p wgtypes.Peer
for ad.Next() {
switch ad.Type() {
case wgh.PeerAPublicKey:
ad.Do(parseKey(&p.PublicKey))
case wgh.PeerAPresharedKey:
ad.Do(parseKey(&p.PresharedKey))
case wgh.PeerAEndpoint:
p.Endpoint = &net.UDPAddr{}
ad.Do(parseSockaddr(p.Endpoint))
case wgh.PeerAPersistentKeepaliveInterval:
p.PersistentKeepaliveInterval = time.Duration(ad.Uint16()) * time.Second
case wgh.PeerALastHandshakeTime:
ad.Do(parseTimespec(&p.LastHandshakeTime))
case wgh.PeerARxBytes:
p.ReceiveBytes = int64(ad.Uint64())
case wgh.PeerATxBytes:
p.TransmitBytes = int64(ad.Uint64())
case wgh.PeerAAllowedips:
ad.Nested(parseAllowedIPs(&p.AllowedIPs))
case wgh.PeerAProtocolVersion:
p.ProtocolVersion = int(ad.Uint32())
}
}
return p
}
// parseAllowedIPs parses a slice of net.IPNet from a netlink attribute payload.
func parseAllowedIPs(ipns *[]net.IPNet) func(ad *netlink.AttributeDecoder) error {
return func(ad *netlink.AttributeDecoder) error {
// Initialize to the number of allowed IPs and begin iterating through
// the netlink array to decode each one.
*ipns = make([]net.IPNet, 0, ad.Len())
for ad.Next() {
// Allowed IP nested attributes.
ad.Nested(func(nad *netlink.AttributeDecoder) error {
var (
ipn net.IPNet
mask int
family int
)
for nad.Next() {
switch nad.Type() {
case wgh.AllowedipAIpaddr:
nad.Do(parseAddr(&ipn.IP))
case wgh.AllowedipACidrMask:
mask = int(nad.Uint8())
case wgh.AllowedipAFamily:
family = int(nad.Uint16())
}
}
if err := nad.Err(); err != nil {
return err
}
// The address family determines the correct number of bits in
// the mask.
switch family {
case unix.AF_INET:
ipn.Mask = net.CIDRMask(mask, 32)
case unix.AF_INET6:
ipn.Mask = net.CIDRMask(mask, 128)
}
*ipns = append(*ipns, ipn)
return nil
})
}
return nil
}
}
// parseKey parses a wgtypes.Key from a byte slice.
func parseKey(key *wgtypes.Key) func(b []byte) error {
return func(b []byte) error {
k, err := wgtypes.NewKey(b)
if err != nil {
return err
}
*key = k
return nil
}
}
// parseAddr parses a net.IP from raw in_addr or in6_addr struct bytes.
func parseAddr(ip *net.IP) func(b []byte) error {
return func(b []byte) error {
switch len(b) {
case net.IPv4len, net.IPv6len:
// Okay to convert directly to net.IP; memory layout is identical.
*ip = make(net.IP, len(b))
copy(*ip, b)
return nil
default:
return fmt.Errorf("wglinux: unexpected IP address size: %d", len(b))
}
}
}
// parseSockaddr parses a *net.UDPAddr from raw sockaddr_in or sockaddr_in6 bytes.
func parseSockaddr(endpoint *net.UDPAddr) func(b []byte) error {
return func(b []byte) error {
switch len(b) {
case unix.SizeofSockaddrInet4:
// IPv4 address parsing.
sa := *(*unix.RawSockaddrInet4)(unsafe.Pointer(&b[0]))
*endpoint = net.UDPAddr{
IP: net.IP(sa.Addr[:]).To4(),
Port: int(sockaddrPort(int(sa.Port))),
}
return nil
case unix.SizeofSockaddrInet6:
// IPv6 address parsing.
sa := *(*unix.RawSockaddrInet6)(unsafe.Pointer(&b[0]))
*endpoint = net.UDPAddr{
IP: net.IP(sa.Addr[:]),
Port: int(sockaddrPort(int(sa.Port))),
}
return nil
default:
return fmt.Errorf("wglinux: unexpected sockaddr size: %d", len(b))
}
}
}
// timespec32 is a unix.Timespec with 32-bit integers.
type timespec32 struct {
Sec int32
Nsec int32
}
// timespec64 is a unix.Timespec with 64-bit integers.
type timespec64 struct {
Sec int64
Nsec int64
}
const (
sizeofTimespec32 = int(unsafe.Sizeof(timespec32{}))
sizeofTimespec64 = int(unsafe.Sizeof(timespec64{}))
)
// parseTimespec parses a time.Time from raw timespec bytes.
func parseTimespec(t *time.Time) func(b []byte) error {
return func(b []byte) error {
// It would appear that WireGuard can return a __kernel_timespec which
// uses 64-bit integers, even on 32-bit platforms. Clarification of this
// behavior is being sought in:
// https://lists.zx2c4.com/pipermail/wireguard/2019-April/004088.html.
//
// In the mean time, be liberal and accept 32-bit and 64-bit variants.
var sec, nsec int64
switch len(b) {
case sizeofTimespec32:
ts := *(*timespec32)(unsafe.Pointer(&b[0]))
sec = int64(ts.Sec)
nsec = int64(ts.Nsec)
case sizeofTimespec64:
ts := *(*timespec64)(unsafe.Pointer(&b[0]))
sec = ts.Sec
nsec = ts.Nsec
default:
return fmt.Errorf("wglinux: unexpected timespec size: %d bytes, expected 8 or 16 bytes", len(b))
}
// Only set fields if UNIX timestamp value is greater than 0, so the
// caller will see a zero-value time.Time otherwise.
if sec > 0 || nsec > 0 {
*t = time.Unix(sec, nsec)
}
return nil
}
}
// mergeDevices merges Peer information from d into target. mergeDevices is
// used to deal with multiple incoming netlink messages for the same device.
func mergeDevices(target, d *wgtypes.Device, knownPeers map[wgtypes.Key]int) {
for i := range d.Peers {
// Peer is already known, append to it's allowed IP networks
if peerIndex, ok := knownPeers[d.Peers[i].PublicKey]; ok {
target.Peers[peerIndex].AllowedIPs = append(target.Peers[peerIndex].AllowedIPs, d.Peers[i].AllowedIPs...)
} else { // New peer, add it to the target peers.
target.Peers = append(target.Peers, d.Peers[i])
knownPeers[d.Peers[i].PublicKey] = len(target.Peers) - 1
}
}
}

View File

@@ -0,0 +1,373 @@
//go:build openbsd
// +build openbsd
package wgopenbsd
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"os"
"runtime"
"time"
"unsafe"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
"golang.zx2c4.com/wireguard/wgctrl/internal/wgopenbsd/internal/wgh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var (
// ifGroupWG is the WireGuard interface group name passed to the kernel.
ifGroupWG = [16]byte{0: 'w', 1: 'g'}
)
var _ wginternal.Client = &Client{}
// A Client provides access to OpenBSD WireGuard ioctl information.
type Client struct {
// Hooks which use system calls by default, but can also be swapped out
// during tests.
close func() error
ioctlIfgroupreq func(ifg *wgh.Ifgroupreq) error
ioctlWGDataIO func(data *wgh.WGDataIO) error
}
// New creates a new Client and returns whether or not the ioctl interface
// is available.
func New() (*Client, bool, error) {
// The OpenBSD ioctl interface operates on a generic AF_INET socket.
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
if err != nil {
return nil, false, err
}
// TODO(mdlayher): find a call to invoke here to probe for availability.
// c.Devices won't work because it returns a "not found" error when the
// kernel WireGuard implementation is available but the interface group
// has no members.
// By default, use system call implementations for all hook functions.
return &Client{
close: func() error { return unix.Close(fd) },
ioctlIfgroupreq: ioctlIfgroupreq(fd),
ioctlWGDataIO: ioctlWGDataIO(fd),
}, true, nil
}
// Close implements wginternal.Client.
func (c *Client) Close() error {
return c.close()
}
// Devices implements wginternal.Client.
func (c *Client) Devices() ([]*wgtypes.Device, error) {
ifg := wgh.Ifgroupreq{
// Query for devices in the "wg" group.
Name: ifGroupWG,
}
// Determine how many device names we must allocate memory for.
if err := c.ioctlIfgroupreq(&ifg); err != nil {
return nil, err
}
// ifg.Len is size in bytes; allocate enough memory for the correct number
// of wgh.Ifgreq and then store a pointer to the memory where the data
// should be written (ifgrs) in ifg.Groups.
//
// From a thread in golang-nuts, this pattern is valid:
// "It would be OK to pass a pointer to a struct to ioctl if the struct
// contains a pointer to other Go memory, but the struct field must have
// pointer type."
// See: https://groups.google.com/forum/#!topic/golang-nuts/FfasFTZvU_o.
ifgrs := make([]wgh.Ifgreq, ifg.Len/wgh.SizeofIfgreq)
ifg.Groups = &ifgrs[0]
// Now actually fetch the device names.
if err := c.ioctlIfgroupreq(&ifg); err != nil {
return nil, err
}
// Keep this alive until we're done doing the ioctl dance.
runtime.KeepAlive(&ifg)
devices := make([]*wgtypes.Device, 0, len(ifgrs))
for _, ifgr := range ifgrs {
// Remove any trailing NULL bytes from the interface names.
d, err := c.Device(string(bytes.TrimRight(ifgr.Ifgrqu[:], "\x00")))
if err != nil {
return nil, err
}
devices = append(devices, d)
}
return devices, nil
}
// Device implements wginternal.Client.
func (c *Client) Device(name string) (*wgtypes.Device, error) {
dname, err := deviceName(name)
if err != nil {
return nil, err
}
// First, specify the name of the device and determine how much memory
// must be allocated in order to store the WGInterfaceIO structure and
// any trailing WGPeerIO/WGAIPIOs.
data := wgh.WGDataIO{Name: dname}
// TODO: consider preallocating some memory to avoid a second system call
// if it proves to be a concern.
var mem []byte
for {
if err := c.ioctlWGDataIO(&data); err != nil {
// ioctl functions always return a wrapped unix.Errno value.
// Conform to the wgctrl contract by unwrapping some values:
// ENXIO: "no such device": (no such WireGuard device)
// ENOTTY: "inappropriate ioctl for device" (device is not a
// WireGuard device)
switch err.(*os.SyscallError).Err {
case unix.ENXIO, unix.ENOTTY:
return nil, os.ErrNotExist
default:
return nil, err
}
}
if len(mem) >= int(data.Size) {
// Allocated enough memory!
break
}
// Ensure we don't unsafe cast into uninitialized memory. We need at very
// least a single WGInterfaceIO with no peers.
if data.Size < wgh.SizeofWGInterfaceIO {
return nil, fmt.Errorf("wgopenbsd: kernel returned unexpected number of bytes for WGInterfaceIO: %d", data.Size)
}
// Allocate the appropriate amount of memory and point the kernel at
// the first byte of our slice's backing array. When the loop continues,
// we will check if we've allocated enough memory.
mem = make([]byte, data.Size)
data.Interface = (*wgh.WGInterfaceIO)(unsafe.Pointer(&mem[0]))
}
return parseDevice(name, data.Interface)
}
// parseDevice unpacks a Device from ifio, along with its associated peers
// and their allowed IPs.
func parseDevice(name string, ifio *wgh.WGInterfaceIO) (*wgtypes.Device, error) {
d := &wgtypes.Device{
Name: name,
Type: wgtypes.OpenBSDKernel,
}
// The kernel populates ifio.Flags to indicate which fields are present.
if ifio.Flags&wgh.WG_INTERFACE_HAS_PRIVATE != 0 {
d.PrivateKey = wgtypes.Key(ifio.Private)
}
if ifio.Flags&wgh.WG_INTERFACE_HAS_PUBLIC != 0 {
d.PublicKey = wgtypes.Key(ifio.Public)
}
if ifio.Flags&wgh.WG_INTERFACE_HAS_PORT != 0 {
d.ListenPort = int(ifio.Port)
}
if ifio.Flags&wgh.WG_INTERFACE_HAS_RTABLE != 0 {
d.FirewallMark = int(ifio.Rtable)
}
d.Peers = make([]wgtypes.Peer, 0, ifio.Peers_count)
// If there were no peers, exit early so we do not advance the pointer
// beyond the end of the WGInterfaceIO structure.
if ifio.Peers_count == 0 {
return d, nil
}
// Set our pointer to the beginning of the first peer's location in memory.
peer := (*wgh.WGPeerIO)(unsafe.Pointer(
uintptr(unsafe.Pointer(ifio)) + wgh.SizeofWGInterfaceIO,
))
for i := 0; i < int(ifio.Peers_count); i++ {
p := parsePeer(peer)
// Same idea, we know how many allowed IPs we need to account for, so
// reserve the space and advance the pointer through each WGAIP structure.
p.AllowedIPs = make([]net.IPNet, 0, peer.Aips_count)
for j := uintptr(0); j < uintptr(peer.Aips_count); j++ {
aip := (*wgh.WGAIPIO)(unsafe.Pointer(
uintptr(unsafe.Pointer(peer)) + wgh.SizeofWGPeerIO + j*wgh.SizeofWGAIPIO,
))
p.AllowedIPs = append(p.AllowedIPs, parseAllowedIP(aip))
}
// Prepare for the next iteration.
d.Peers = append(d.Peers, p)
peer = (*wgh.WGPeerIO)(unsafe.Pointer(
uintptr(unsafe.Pointer(peer)) + wgh.SizeofWGPeerIO +
uintptr(peer.Aips_count)*wgh.SizeofWGAIPIO,
))
}
return d, nil
}
// ConfigureDevice implements wginternal.Client.
func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error {
// Currently read-only: we must determine if a device belongs to this driver,
// and if it does, return a sentinel so integration tests that configure a
// device can be skipped.
if _, err := c.Device(name); err != nil {
return err
}
return wginternal.ErrReadOnly
}
// deviceName converts an interface name string to the format required to pass
// with wgh.WGGetServ.
func deviceName(name string) ([16]byte, error) {
var out [unix.IFNAMSIZ]byte
if len(name) > unix.IFNAMSIZ {
return out, fmt.Errorf("wgopenbsd: interface name %q too long", name)
}
copy(out[:], name)
return out, nil
}
// parsePeer unpacks a wgtypes.Peer from a WGPeerIO structure.
func parsePeer(pio *wgh.WGPeerIO) wgtypes.Peer {
p := wgtypes.Peer{
ReceiveBytes: int64(pio.Rxbytes),
TransmitBytes: int64(pio.Txbytes),
ProtocolVersion: int(pio.Protocol_version),
}
// Only set last handshake if a non-zero timespec was provided, matching
// the time.Time.IsZero() behavior of internal/wglinux.
if pio.Last_handshake.Sec > 0 && pio.Last_handshake.Nsec > 0 {
p.LastHandshakeTime = time.Unix(
pio.Last_handshake.Sec,
// Conversion required for GOARCH=386.
int64(pio.Last_handshake.Nsec),
)
}
if pio.Flags&wgh.WG_PEER_HAS_PUBLIC != 0 {
p.PublicKey = wgtypes.Key(pio.Public)
}
if pio.Flags&wgh.WG_PEER_HAS_PSK != 0 {
p.PresharedKey = wgtypes.Key(pio.Psk)
}
if pio.Flags&wgh.WG_PEER_HAS_PKA != 0 {
p.PersistentKeepaliveInterval = time.Duration(pio.Pka) * time.Second
}
if pio.Flags&wgh.WG_PEER_HAS_ENDPOINT != 0 {
p.Endpoint = parseEndpoint(pio.Endpoint)
}
return p
}
// parseAllowedIP unpacks a net.IPNet from a WGAIP structure.
func parseAllowedIP(aip *wgh.WGAIPIO) net.IPNet {
switch aip.Af {
case unix.AF_INET:
return net.IPNet{
IP: net.IP(aip.Addr[:net.IPv4len]),
Mask: net.CIDRMask(int(aip.Cidr), 32),
}
case unix.AF_INET6:
return net.IPNet{
IP: net.IP(aip.Addr[:]),
Mask: net.CIDRMask(int(aip.Cidr), 128),
}
default:
panicf("wgopenbsd: invalid address family for allowed IP: %+v", aip)
return net.IPNet{}
}
}
// parseEndpoint parses a peer endpoint from a wgh.WGIP structure.
func parseEndpoint(ep [28]byte) *net.UDPAddr {
// sockaddr* structures have family at index 1.
switch ep[1] {
case unix.AF_INET:
sa := *(*unix.RawSockaddrInet4)(unsafe.Pointer(&ep[0]))
ep := &net.UDPAddr{
IP: make(net.IP, net.IPv4len),
Port: bePort(sa.Port),
}
copy(ep.IP, sa.Addr[:])
return ep
case unix.AF_INET6:
sa := *(*unix.RawSockaddrInet6)(unsafe.Pointer(&ep[0]))
// TODO(mdlayher): IPv6 zone?
ep := &net.UDPAddr{
IP: make(net.IP, net.IPv6len),
Port: bePort(sa.Port),
}
copy(ep.IP, sa.Addr[:])
return ep
default:
// No endpoint configured.
return nil
}
}
// bePort interprets a port integer stored in native endianness as a big
// endian value. This is necessary for proper endpoint port handling on
// little endian machines.
func bePort(port uint16) int {
b := *(*[2]byte)(unsafe.Pointer(&port))
return int(binary.BigEndian.Uint16(b[:]))
}
// ioctlIfgroupreq returns a function which performs the appropriate ioctl on
// fd to retrieve members of an interface group.
func ioctlIfgroupreq(fd int) func(*wgh.Ifgroupreq) error {
return func(ifg *wgh.Ifgroupreq) error {
return ioctl(fd, wgh.SIOCGIFGMEMB, unsafe.Pointer(ifg))
}
}
// ioctlWGDataIO returns a function which performs the appropriate ioctl on
// fd to issue a WireGuard data I/O.
func ioctlWGDataIO(fd int) func(*wgh.WGDataIO) error {
return func(data *wgh.WGDataIO) error {
return ioctl(fd, wgh.SIOCGWG, unsafe.Pointer(data))
}
}
// ioctl is a raw wrapper for the ioctl system call.
func ioctl(fd int, req uint, arg unsafe.Pointer) error {
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg))
if errno != 0 {
return os.NewSyscallError("ioctl", errno)
}
return nil
}
func panicf(format string, a ...interface{}) {
panic(fmt.Sprintf(format, a...))
}

View File

@@ -0,0 +1,6 @@
// Package wgopenbsd provides internal access to OpenBSD's WireGuard
// ioctl interface.
//
// This package is internal-only and not meant for end users to consume.
// Please use package wgctrl (an abstraction over this package) instead.
package wgopenbsd

View File

@@ -0,0 +1,84 @@
//go:build openbsd && 386
// +build openbsd,386
// Code generated by cmd/cgo -godefs; DO NOT EDIT.
// cgo -godefs defs.go
package wgh
const (
SIOCGIFGMEMB = 0xc024698a
SizeofIfgreq = 0x10
)
type Ifgroupreq struct {
Name [16]byte
Len uint32
Pad1 [0]byte
Groups *Ifgreq
Pad2 [12]byte
}
type Ifgreq struct {
Ifgrqu [16]byte
}
type Timespec struct {
Sec int64
Nsec int32
}
type WGAIPIO struct {
Af uint8
Cidr int32
Addr [16]byte
}
type WGDataIO struct {
Name [16]byte
Size uint32
Interface *WGInterfaceIO
}
type WGInterfaceIO struct {
Flags uint8
Port uint16
Rtable int32
Public [32]byte
Private [32]byte
Peers_count uint32
}
type WGPeerIO struct {
Flags int32
Protocol_version int32
Public [32]byte
Psk [32]byte
Pka uint16
Pad_cgo_0 [2]byte
Endpoint [28]byte
Txbytes uint64
Rxbytes uint64
Last_handshake Timespec
Aips_count uint32
}
const (
SIOCGWG = 0xc01869d3
WG_INTERFACE_HAS_PUBLIC = 0x1
WG_INTERFACE_HAS_PRIVATE = 0x2
WG_INTERFACE_HAS_PORT = 0x4
WG_INTERFACE_HAS_RTABLE = 0x8
WG_INTERFACE_REPLACE_PEERS = 0x10
WG_PEER_HAS_PUBLIC = 0x1
WG_PEER_HAS_PSK = 0x2
WG_PEER_HAS_PKA = 0x4
WG_PEER_HAS_ENDPOINT = 0x8
SizeofWGAIPIO = 0x18
SizeofWGInterfaceIO = 0x4c
SizeofWGPeerIO = 0x88
)

View File

@@ -0,0 +1,84 @@
//go:build openbsd && amd64
// +build openbsd,amd64
// Code generated by cmd/cgo -godefs; DO NOT EDIT.
// cgo -godefs defs.go
package wgh
const (
SIOCGIFGMEMB = 0xc028698a
SizeofIfgreq = 0x10
)
type Ifgroupreq struct {
Name [16]byte
Len uint32
Pad1 [4]byte
Groups *Ifgreq
Pad2 [8]byte
}
type Ifgreq struct {
Ifgrqu [16]byte
}
type Timespec struct {
Sec int64
Nsec int64
}
type WGAIPIO struct {
Af uint8
Cidr int32
Addr [16]byte
}
type WGDataIO struct {
Name [16]byte
Size uint64
Interface *WGInterfaceIO
}
type WGInterfaceIO struct {
Flags uint8
Port uint16
Rtable int32
Public [32]byte
Private [32]byte
Peers_count uint64
}
type WGPeerIO struct {
Flags int32
Protocol_version int32
Public [32]byte
Psk [32]byte
Pka uint16
Pad_cgo_0 [2]byte
Endpoint [28]byte
Txbytes uint64
Rxbytes uint64
Last_handshake Timespec
Aips_count uint64
}
const (
SIOCGWG = 0xc02069d3
WG_INTERFACE_HAS_PUBLIC = 0x1
WG_INTERFACE_HAS_PRIVATE = 0x2
WG_INTERFACE_HAS_PORT = 0x4
WG_INTERFACE_HAS_RTABLE = 0x8
WG_INTERFACE_REPLACE_PEERS = 0x10
WG_PEER_HAS_PUBLIC = 0x1
WG_PEER_HAS_PSK = 0x2
WG_PEER_HAS_PKA = 0x4
WG_PEER_HAS_ENDPOINT = 0x8
SizeofWGAIPIO = 0x18
SizeofWGInterfaceIO = 0x50
SizeofWGPeerIO = 0x90
)

View File

@@ -0,0 +1,3 @@
// Package wgh is an auto-generated package which contains constants and
// types used to access WireGuard information using ioctl calls.
package wgh

View File

@@ -0,0 +1,25 @@
#/bin/sh
set -x
# Fix up generated code.
gofix()
{
IN=$1
OUT=$2
# Change types that are a nuisance to deal with in Go, use byte for
# consistency, and produce gofmt'd output.
sed 's/]u*int8/]byte/g' $1 | gofmt -s > $2
}
echo -e "//+build openbsd,amd64\n" > /tmp/wgamd64.go
GOARCH=amd64 go tool cgo -godefs defs.go >> /tmp/wgamd64.go
echo -e "//+build openbsd,386\n" > /tmp/wg386.go
GOARCH=386 go tool cgo -godefs defs.go >> /tmp/wg386.go
gofix /tmp/wgamd64.go defs_openbsd_amd64.go
gofix /tmp/wg386.go defs_openbsd_386.go
rm -rf _obj/ /tmp/wg*.go

View File

@@ -0,0 +1,99 @@
package wguser
import (
"fmt"
"net"
"os"
"path/filepath"
"strings"
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var _ wginternal.Client = &Client{}
// A Client provides access to userspace WireGuard device information.
type Client struct {
dial func(device string) (net.Conn, error)
find func() ([]string, error)
}
// New creates a new Client.
func New() (*Client, error) {
return &Client{
// Operating system-specific functions which can identify and connect
// to userspace WireGuard devices. These functions can also be
// overridden for tests.
dial: dial,
find: find,
}, nil
}
// Close implements wginternal.Client.
func (c *Client) Close() error { return nil }
// Devices implements wginternal.Client.
func (c *Client) Devices() ([]*wgtypes.Device, error) {
devices, err := c.find()
if err != nil {
return nil, err
}
var wgds []*wgtypes.Device
for _, d := range devices {
wgd, err := c.getDevice(d)
if err != nil {
return nil, err
}
wgds = append(wgds, wgd)
}
return wgds, nil
}
// Device implements wginternal.Client.
func (c *Client) Device(name string) (*wgtypes.Device, error) {
devices, err := c.find()
if err != nil {
return nil, err
}
for _, d := range devices {
if name != deviceName(d) {
continue
}
return c.getDevice(d)
}
return nil, os.ErrNotExist
}
// ConfigureDevice implements wginternal.Client.
func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error {
devices, err := c.find()
if err != nil {
return err
}
for _, d := range devices {
if name != deviceName(d) {
continue
}
return c.configureDevice(d, cfg)
}
return os.ErrNotExist
}
// deviceName infers a device name from an absolute file path with extension.
func deviceName(sock string) string {
return strings.TrimSuffix(filepath.Base(sock), filepath.Ext(sock))
}
func panicf(format string, a ...interface{}) {
panic(fmt.Sprintf(format, a...))
}

View File

@@ -0,0 +1,106 @@
package wguser
import (
"bytes"
"encoding/hex"
"fmt"
"io"
"os"
"strings"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// configureDevice configures a device specified by its path.
func (c *Client) configureDevice(device string, cfg wgtypes.Config) error {
conn, err := c.dial(device)
if err != nil {
return err
}
defer conn.Close()
// Start with set command.
var buf bytes.Buffer
buf.WriteString("set=1\n")
// Add any necessary configuration from cfg, then finish with an empty line.
writeConfig(&buf, cfg)
buf.WriteString("\n")
// Apply configuration for the device and then check the error number.
if _, err := io.Copy(conn, &buf); err != nil {
return err
}
res := make([]byte, 32)
n, err := conn.Read(res)
if err != nil {
return err
}
// errno=0 indicates success, anything else returns an error number that
// matches definitions from errno.h.
str := strings.TrimSpace(string(res[:n]))
if str != "errno=0" {
// TODO(mdlayher): return actual errno on Linux?
return os.NewSyscallError("read", fmt.Errorf("wguser: %s", str))
}
return nil
}
// writeConfig writes textual configuration to w as specified by cfg.
func writeConfig(w io.Writer, cfg wgtypes.Config) {
if cfg.PrivateKey != nil {
fmt.Fprintf(w, "private_key=%s\n", hexKey(*cfg.PrivateKey))
}
if cfg.ListenPort != nil {
fmt.Fprintf(w, "listen_port=%d\n", *cfg.ListenPort)
}
if cfg.FirewallMark != nil {
fmt.Fprintf(w, "fwmark=%d\n", *cfg.FirewallMark)
}
if cfg.ReplacePeers {
fmt.Fprintln(w, "replace_peers=true")
}
for _, p := range cfg.Peers {
fmt.Fprintf(w, "public_key=%s\n", hexKey(p.PublicKey))
if p.Remove {
fmt.Fprintln(w, "remove=true")
}
if p.UpdateOnly {
fmt.Fprintln(w, "update_only=true")
}
if p.PresharedKey != nil {
fmt.Fprintf(w, "preshared_key=%s\n", hexKey(*p.PresharedKey))
}
if p.Endpoint != nil {
fmt.Fprintf(w, "endpoint=%s\n", p.Endpoint.String())
}
if p.PersistentKeepaliveInterval != nil {
fmt.Fprintf(w, "persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds()))
}
if p.ReplaceAllowedIPs {
fmt.Fprintln(w, "replace_allowed_ips=true")
}
for _, ip := range p.AllowedIPs {
fmt.Fprintf(w, "allowed_ip=%s\n", ip.String())
}
}
}
// hexKey encodes a wgtypes.Key into a hexadecimal string.
func hexKey(k wgtypes.Key) string {
return hex.EncodeToString(k[:])
}

View File

@@ -0,0 +1,51 @@
//go:build !windows
// +build !windows
package wguser
import (
"errors"
"io/ioutil"
"net"
"os"
"path/filepath"
)
// dial is the default implementation of Client.dial.
func dial(device string) (net.Conn, error) {
return net.Dial("unix", device)
}
// find is the default implementation of Client.find.
func find() ([]string, error) {
return findUNIXSockets([]string{
// It seems that /var/run is a common location between Linux and the
// BSDs, even though it's a symlink on Linux.
"/var/run/wireguard",
})
}
// findUNIXSockets looks for UNIX socket files in the specified directories.
func findUNIXSockets(dirs []string) ([]string, error) {
var socks []string
for _, d := range dirs {
files, err := ioutil.ReadDir(d)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
continue
}
return nil, err
}
for _, f := range files {
if f.Mode()&os.ModeSocket == 0 {
continue
}
socks = append(socks, filepath.Join(d, f.Name()))
}
}
return socks, nil
}

View File

@@ -0,0 +1,82 @@
//go:build windows
// +build windows
package wguser
import (
"net"
"strings"
"time"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/namedpipe"
)
// Expected prefixes when dealing with named pipes.
const (
pipePrefix = `\\.\pipe\`
wgPrefix = `ProtectedPrefix\Administrators\WireGuard\`
)
// dial is the default implementation of Client.dial.
func dial(device string) (net.Conn, error) {
localSystem, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
if err != nil {
return nil, err
}
return (&namedpipe.DialConfig{
ExpectedOwner: localSystem,
}).DialTimeout(device, time.Duration(0))
}
// find is the default implementation of Client.find.
func find() ([]string, error) {
return findNamedPipes(wgPrefix)
}
// findNamedPipes looks for Windows named pipes that match the specified
// search string prefix.
func findNamedPipes(search string) ([]string, error) {
var (
pipes []string
data windows.Win32finddata
)
// Thanks @zx2c4 for the tips on the appropriate Windows APIs here:
// https://א.cc/dHGpnhxX/c.
h, err := windows.FindFirstFile(
// Append * to find all named pipes.
windows.StringToUTF16Ptr(pipePrefix+"*"),
&data,
)
if err != nil {
return nil, err
}
// FindClose is used to close file search handles instead of the typical
// CloseHandle used elsewhere, see:
// https://docs.microsoft.com/en-us/windows/desktop/api/fileapi/nf-fileapi-findclose.
defer windows.FindClose(h)
// Check the first file's name for a match, but also keep searching for
// WireGuard named pipes until no more files can be iterated.
for {
name := windows.UTF16ToString(data.FileName[:])
if strings.HasPrefix(name, search) {
// Concatenate strings directly as filepath.Join appears to break the
// named pipe prefix convention.
pipes = append(pipes, pipePrefix+name)
}
if err := windows.FindNextFile(h, &data); err != nil {
if err == windows.ERROR_NO_MORE_FILES {
break
}
return nil, err
}
}
return pipes, nil
}

View File

@@ -0,0 +1,6 @@
// Package wguser provides internal access to the userspace WireGuard
// configuration protocol interface.
//
// This package is internal-only and not meant for end users to consume.
// Please use package wgctrl (an abstraction over this package) instead.
package wguser

View File

@@ -0,0 +1,258 @@
package wguser
import (
"bufio"
"bytes"
"encoding/hex"
"fmt"
"io"
"net"
"os"
"strconv"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// The WireGuard userspace configuration protocol is described here:
// https://www.wireguard.com/xplatform/#cross-platform-userspace-implementation.
// getDevice gathers device information from a device specified by its path
// and returns a Device.
func (c *Client) getDevice(device string) (*wgtypes.Device, error) {
conn, err := c.dial(device)
if err != nil {
return nil, err
}
defer conn.Close()
// Get information about this device.
if _, err := io.WriteString(conn, "get=1\n\n"); err != nil {
return nil, err
}
// Parse the device from the incoming data stream.
d, err := parseDevice(conn)
if err != nil {
return nil, err
}
// TODO(mdlayher): populate interface index too?
d.Name = deviceName(device)
d.Type = wgtypes.Userspace
return d, nil
}
// parseDevice parses a Device and its Peers from an io.Reader.
func parseDevice(r io.Reader) (*wgtypes.Device, error) {
var dp deviceParser
s := bufio.NewScanner(r)
for s.Scan() {
b := s.Bytes()
if len(b) == 0 {
// Empty line, done parsing.
break
}
// All data is in key=value format.
kvs := bytes.Split(b, []byte("="))
if len(kvs) != 2 {
return nil, fmt.Errorf("wguser: invalid key=value pair: %q", string(b))
}
dp.Parse(string(kvs[0]), string(kvs[1]))
}
if err := s.Err(); err != nil {
return nil, err
}
return dp.Device()
}
// A deviceParser accumulates information about a Device and its Peers.
type deviceParser struct {
d wgtypes.Device
err error
parsePeers bool
peers int
hsSec, hsNano int
}
// Device returns a Device or any errors that were encountered while parsing
// a Device.
func (dp *deviceParser) Device() (*wgtypes.Device, error) {
if dp.err != nil {
return nil, dp.err
}
// Compute remaining fields of the Device now that all parsing is done.
dp.d.PublicKey = dp.d.PrivateKey.PublicKey()
return &dp.d, nil
}
// Parse parses a single key/value pair into fields of a Device.
func (dp *deviceParser) Parse(key, value string) {
switch key {
case "errno":
// 0 indicates success, anything else returns an error number that matches
// definitions from errno.h.
if errno := dp.parseInt(value); errno != 0 {
// TODO(mdlayher): return actual errno on Linux?
dp.err = os.NewSyscallError("read", fmt.Errorf("wguser: errno=%d", errno))
return
}
case "public_key":
// We've either found the first peer or the next peer. Stop parsing
// Device fields and start parsing Peer fields, including the public
// key indicated here.
dp.parsePeers = true
dp.peers++
dp.d.Peers = append(dp.d.Peers, wgtypes.Peer{
PublicKey: dp.parseKey(value),
})
return
}
// Are we parsing peer fields?
if dp.parsePeers {
dp.peerParse(key, value)
return
}
// Device field parsing.
switch key {
case "private_key":
dp.d.PrivateKey = dp.parseKey(value)
case "listen_port":
dp.d.ListenPort = dp.parseInt(value)
case "fwmark":
dp.d.FirewallMark = dp.parseInt(value)
}
}
// curPeer returns the current Peer being parsed so its fields can be populated.
func (dp *deviceParser) curPeer() *wgtypes.Peer {
return &dp.d.Peers[dp.peers-1]
}
// peerParse parses a key/value field into the current Peer.
func (dp *deviceParser) peerParse(key, value string) {
p := dp.curPeer()
switch key {
case "preshared_key":
p.PresharedKey = dp.parseKey(value)
case "endpoint":
p.Endpoint = dp.parseAddr(value)
case "last_handshake_time_sec":
dp.hsSec = dp.parseInt(value)
case "last_handshake_time_nsec":
dp.hsNano = dp.parseInt(value)
// Assume that we've seen both seconds and nanoseconds and populate this
// field now. However, if both fields were set to 0, assume we have never
// had a successful handshake with this peer, and return a zero-value
// time.Time to our callers.
if dp.hsSec > 0 && dp.hsNano > 0 {
p.LastHandshakeTime = time.Unix(int64(dp.hsSec), int64(dp.hsNano))
}
case "tx_bytes":
p.TransmitBytes = dp.parseInt64(value)
case "rx_bytes":
p.ReceiveBytes = dp.parseInt64(value)
case "persistent_keepalive_interval":
p.PersistentKeepaliveInterval = time.Duration(dp.parseInt(value)) * time.Second
case "allowed_ip":
cidr := dp.parseCIDR(value)
if cidr != nil {
p.AllowedIPs = append(p.AllowedIPs, *cidr)
}
case "protocol_version":
p.ProtocolVersion = dp.parseInt(value)
}
}
// parseKey parses a Key from a hex string.
func (dp *deviceParser) parseKey(s string) wgtypes.Key {
if dp.err != nil {
return wgtypes.Key{}
}
b, err := hex.DecodeString(s)
if err != nil {
dp.err = err
return wgtypes.Key{}
}
key, err := wgtypes.NewKey(b)
if err != nil {
dp.err = err
return wgtypes.Key{}
}
return key
}
// parseInt parses an integer from a string.
func (dp *deviceParser) parseInt(s string) int {
if dp.err != nil {
return 0
}
v, err := strconv.Atoi(s)
if err != nil {
dp.err = err
return 0
}
return v
}
// parseInt64 parses an int64 from a string.
func (dp *deviceParser) parseInt64(s string) int64 {
if dp.err != nil {
return 0
}
v, err := strconv.ParseInt(s, 10, 64)
if err != nil {
dp.err = err
return 0
}
return v
}
// parseAddr parses a UDP address from a string.
func (dp *deviceParser) parseAddr(s string) *net.UDPAddr {
if dp.err != nil {
return nil
}
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
dp.err = err
return nil
}
return addr
}
// parseInt parses an address CIDR from a string.
func (dp *deviceParser) parseCIDR(s string) *net.IPNet {
if dp.err != nil {
return nil
}
_, cidr, err := net.ParseCIDR(s)
if err != nil {
dp.err = err
return nil
}
return cidr
}

View File

@@ -0,0 +1,295 @@
package wgwindows
import (
"net"
"os"
"time"
"unsafe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
"golang.zx2c4.com/wireguard/wgctrl/internal/wgwindows/internal/ioctl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var _ wginternal.Client = &Client{}
// A Client provides access to WireGuardNT ioctl information.
type Client struct {
cachedAdapters map[string]string
lastLenGuess uint32
}
var (
deviceClassNetGUID = windows.GUID{0x4d36e972, 0xe325, 0x11ce, [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
deviceInterfaceNetGUID = windows.GUID{0xcac88484, 0x7515, 0x4c03, [8]byte{0x82, 0xe6, 0x71, 0xa8, 0x7a, 0xba, 0xc3, 0x61}}
devpkeyWgName = windows.DEVPROPKEY{
FmtID: windows.DEVPROPGUID{0x65726957, 0x7547, 0x7261, [8]byte{0x64, 0x4e, 0x61, 0x6d, 0x65, 0x4b, 0x65, 0x79}},
PID: windows.DEVPROPID_FIRST_USABLE + 1,
}
)
var enumerator = `SWD\WireGuard`
func init() {
if maj, min, _ := windows.RtlGetNtVersionNumbers(); (maj == 6 && min <= 1) || maj < 6 {
enumerator = `ROOT\WIREGUARD`
}
}
func (c *Client) refreshInstanceIdCache() error {
cachedAdapters := make(map[string]string, 5)
devInfo, err := windows.SetupDiGetClassDevsEx(&deviceClassNetGUID, enumerator, 0, windows.DIGCF_PRESENT, 0, "")
if err != nil {
return err
}
defer windows.SetupDiDestroyDeviceInfoList(devInfo)
for i := 0; ; i++ {
devInfoData, err := windows.SetupDiEnumDeviceInfo(devInfo, i)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
}
continue
}
prop, err := windows.SetupDiGetDeviceProperty(devInfo, devInfoData, &devpkeyWgName)
if err != nil {
continue
}
adapterName, ok := prop.(string)
if !ok {
continue
}
var status, problemCode uint32
ret := windows.CM_Get_DevNode_Status(&status, &problemCode, devInfoData.DevInst, 0)
if ret != windows.CR_SUCCESS || (status&windows.DN_DRIVER_LOADED|windows.DN_STARTED) != windows.DN_DRIVER_LOADED|windows.DN_STARTED {
continue
}
instanceId, err := windows.SetupDiGetDeviceInstanceId(devInfo, devInfoData)
if err != nil {
continue
}
cachedAdapters[adapterName] = instanceId
}
c.cachedAdapters = cachedAdapters
return nil
}
func (c *Client) interfaceHandle(name string) (windows.Handle, error) {
instanceId, ok := c.cachedAdapters[name]
if !ok {
err := c.refreshInstanceIdCache()
if err != nil {
return 0, err
}
instanceId, ok = c.cachedAdapters[name]
if !ok {
return 0, os.ErrNotExist
}
}
interfaces, err := windows.CM_Get_Device_Interface_List(instanceId, &deviceInterfaceNetGUID, windows.CM_GET_DEVICE_INTERFACE_LIST_PRESENT)
if err != nil {
return 0, err
}
interface16, err := windows.UTF16PtrFromString(interfaces[0])
if err != nil {
return 0, err
}
return windows.CreateFile(interface16, windows.GENERIC_READ|windows.GENERIC_WRITE, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, 0, 0)
}
// Devices implements wginternal.Client.
func (c *Client) Devices() ([]*wgtypes.Device, error) {
err := c.refreshInstanceIdCache()
if err != nil {
return nil, err
}
ds := make([]*wgtypes.Device, 0, len(c.cachedAdapters))
for name := range c.cachedAdapters {
d, err := c.Device(name)
if err != nil {
return nil, err
}
ds = append(ds, d)
}
return ds, nil
}
// New creates a new Client
func New() *Client {
return &Client{}
}
// Close implements wginternal.Client.
func (c *Client) Close() error {
return nil
}
// Device implements wginternal.Client.
func (c *Client) Device(name string) (*wgtypes.Device, error) {
handle, err := c.interfaceHandle(name)
if err != nil {
return nil, err
}
defer windows.CloseHandle(handle)
size := c.lastLenGuess
if size == 0 {
size = 512
}
var buf []byte
for {
buf = make([]byte, size)
err = windows.DeviceIoControl(handle, ioctl.IoctlGet, nil, 0, &buf[0], size, &size, nil)
if err == windows.ERROR_MORE_DATA {
continue
}
if err != nil {
return nil, err
}
break
}
c.lastLenGuess = size
interfaze := (*ioctl.Interface)(unsafe.Pointer(&buf[0]))
device := wgtypes.Device{Type: wgtypes.WindowsKernel, Name: name}
if interfaze.Flags&ioctl.InterfaceHasPrivateKey != 0 {
device.PrivateKey = interfaze.PrivateKey
}
if interfaze.Flags&ioctl.InterfaceHasPublicKey != 0 {
device.PublicKey = interfaze.PublicKey
}
if interfaze.Flags&ioctl.InterfaceHasListenPort != 0 {
device.ListenPort = int(interfaze.ListenPort)
}
var p *ioctl.Peer
for i := uint32(0); i < interfaze.PeerCount; i++ {
if p == nil {
p = interfaze.FirstPeer()
} else {
p = p.NextPeer()
}
peer := wgtypes.Peer{}
if p.Flags&ioctl.PeerHasPublicKey != 0 {
peer.PublicKey = p.PublicKey
}
if p.Flags&ioctl.PeerHasPresharedKey != 0 {
peer.PresharedKey = p.PresharedKey
}
if p.Flags&ioctl.PeerHasEndpoint != 0 {
peer.Endpoint = &net.UDPAddr{IP: p.Endpoint.IP(), Port: int(p.Endpoint.Port())}
}
if p.Flags&ioctl.PeerHasPersistentKeepalive != 0 {
peer.PersistentKeepaliveInterval = time.Duration(p.PersistentKeepalive) * time.Second
}
if p.Flags&ioctl.PeerHasProtocolVersion != 0 {
peer.ProtocolVersion = int(p.ProtocolVersion)
}
peer.TransmitBytes = int64(p.TxBytes)
peer.ReceiveBytes = int64(p.RxBytes)
if p.LastHandshake != 0 {
peer.LastHandshakeTime = time.Unix(0, int64((p.LastHandshake-116444736000000000)*100))
}
var a *ioctl.AllowedIP
for j := uint32(0); j < p.AllowedIPsCount; j++ {
if a == nil {
a = p.FirstAllowedIP()
} else {
a = a.NextAllowedIP()
}
var ip net.IP
var bits int
if a.AddressFamily == windows.AF_INET {
ip = a.Address[:4]
bits = 32
} else if a.AddressFamily == windows.AF_INET6 {
ip = a.Address[:16]
bits = 128
}
peer.AllowedIPs = append(peer.AllowedIPs, net.IPNet{
IP: ip,
Mask: net.CIDRMask(int(a.Cidr), bits),
})
}
device.Peers = append(device.Peers, peer)
}
return &device, nil
}
// ConfigureDevice implements wginternal.Client.
func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error {
handle, err := c.interfaceHandle(name)
if err != nil {
return err
}
defer windows.CloseHandle(handle)
preallocation := unsafe.Sizeof(ioctl.Interface{}) + uintptr(len(cfg.Peers))*unsafe.Sizeof(ioctl.Peer{})
for i := range cfg.Peers {
preallocation += uintptr(len(cfg.Peers[i].AllowedIPs)) * unsafe.Sizeof(ioctl.AllowedIP{})
}
var b ioctl.ConfigBuilder
b.Preallocate(uint32(preallocation))
interfaze := &ioctl.Interface{PeerCount: uint32(len(cfg.Peers))}
if cfg.ReplacePeers {
interfaze.Flags |= ioctl.InterfaceReplacePeers
}
if cfg.PrivateKey != nil {
interfaze.PrivateKey = *cfg.PrivateKey
interfaze.Flags |= ioctl.InterfaceHasPrivateKey
}
if cfg.ListenPort != nil {
interfaze.ListenPort = uint16(*cfg.ListenPort)
interfaze.Flags |= ioctl.InterfaceHasListenPort
}
b.AppendInterface(interfaze)
for i := range cfg.Peers {
peer := &ioctl.Peer{
Flags: ioctl.PeerHasPublicKey,
PublicKey: cfg.Peers[i].PublicKey,
AllowedIPsCount: uint32(len(cfg.Peers[i].AllowedIPs)),
}
if cfg.Peers[i].ReplaceAllowedIPs {
peer.Flags |= ioctl.PeerReplaceAllowedIPs
}
if cfg.Peers[i].UpdateOnly {
peer.Flags |= ioctl.PeerUpdateOnly
}
if cfg.Peers[i].Remove {
peer.Flags |= ioctl.PeerRemove
}
if cfg.Peers[i].PresharedKey != nil {
peer.Flags |= ioctl.PeerHasPresharedKey
peer.PresharedKey = *cfg.Peers[i].PresharedKey
}
if cfg.Peers[i].Endpoint != nil {
peer.Flags |= ioctl.PeerHasEndpoint
peer.Endpoint.SetIP(cfg.Peers[i].Endpoint.IP, uint16(cfg.Peers[i].Endpoint.Port))
}
if cfg.Peers[i].PersistentKeepaliveInterval != nil {
peer.Flags |= ioctl.PeerHasPersistentKeepalive
peer.PersistentKeepalive = uint16(*cfg.Peers[i].PersistentKeepaliveInterval / time.Second)
}
b.AppendPeer(peer)
for j := range cfg.Peers[i].AllowedIPs {
var family ioctl.AddressFamily
var ip net.IP
if ip = cfg.Peers[i].AllowedIPs[j].IP.To4(); ip != nil {
family = windows.AF_INET
} else if ip = cfg.Peers[i].AllowedIPs[j].IP.To16(); ip != nil {
family = windows.AF_INET6
} else {
ip = cfg.Peers[i].AllowedIPs[j].IP
}
cidr, _ := cfg.Peers[i].AllowedIPs[j].Mask.Size()
a := &ioctl.AllowedIP{
AddressFamily: family,
Cidr: uint8(cidr),
}
copy(a.Address[:], ip)
b.AppendAllowedIP(a)
}
}
interfaze, size := b.Interface()
return windows.DeviceIoControl(handle, ioctl.IoctlSet, nil, 0, (*byte)(unsafe.Pointer(interfaze)), size, &size, nil)
}

View File

@@ -0,0 +1,135 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package ioctl
import "unsafe"
const (
IoctlGet = 0xb098c506
IoctlSet = 0xb098c509
)
type AllowedIP struct {
Address [16]byte
AddressFamily AddressFamily
Cidr uint8
_ [4]byte
}
type PeerFlag uint32
const (
PeerHasPublicKey PeerFlag = 1 << 0
PeerHasPresharedKey PeerFlag = 1 << 1
PeerHasPersistentKeepalive PeerFlag = 1 << 2
PeerHasEndpoint PeerFlag = 1 << 3
PeerHasProtocolVersion PeerFlag = 1 << 4
PeerReplaceAllowedIPs PeerFlag = 1 << 5
PeerRemove PeerFlag = 1 << 6
PeerUpdateOnly PeerFlag = 1 << 7
)
type Peer struct {
Flags PeerFlag
ProtocolVersion uint32
PublicKey [32]byte
PresharedKey [32]byte
PersistentKeepalive uint16
_ uint16
Endpoint RawSockaddrInet
TxBytes uint64
RxBytes uint64
LastHandshake uint64
AllowedIPsCount uint32
_ [4]byte
}
type InterfaceFlag uint32
const (
InterfaceHasPublicKey InterfaceFlag = 1 << 0
InterfaceHasPrivateKey InterfaceFlag = 1 << 1
InterfaceHasListenPort InterfaceFlag = 1 << 2
InterfaceReplacePeers InterfaceFlag = 1 << 3
)
type Interface struct {
Flags InterfaceFlag
ListenPort uint16
PrivateKey [32]byte
PublicKey [32]byte
PeerCount uint32
_ [4]byte
}
func (interfaze *Interface) FirstPeer() *Peer {
return (*Peer)(unsafe.Pointer(uintptr(unsafe.Pointer(interfaze)) + unsafe.Sizeof(*interfaze)))
}
func (peer *Peer) NextPeer() *Peer {
return (*Peer)(unsafe.Pointer(uintptr(unsafe.Pointer(peer)) + unsafe.Sizeof(*peer) + uintptr(peer.AllowedIPsCount)*unsafe.Sizeof(AllowedIP{})))
}
func (peer *Peer) FirstAllowedIP() *AllowedIP {
return (*AllowedIP)(unsafe.Pointer(uintptr(unsafe.Pointer(peer)) + unsafe.Sizeof(*peer)))
}
func (allowedIP *AllowedIP) NextAllowedIP() *AllowedIP {
return (*AllowedIP)(unsafe.Pointer(uintptr(unsafe.Pointer(allowedIP)) + unsafe.Sizeof(*allowedIP)))
}
type ConfigBuilder struct {
buffer []byte
}
func (builder *ConfigBuilder) Preallocate(size uint32) {
if builder.buffer == nil {
builder.buffer = make([]byte, 0, size)
}
}
func (builder *ConfigBuilder) AppendInterface(interfaze *Interface) {
var newBytes []byte
unsafeSlice(unsafe.Pointer(&newBytes), unsafe.Pointer(interfaze), int(unsafe.Sizeof(*interfaze)))
builder.buffer = append(builder.buffer, newBytes...)
}
func (builder *ConfigBuilder) AppendPeer(peer *Peer) {
var newBytes []byte
unsafeSlice(unsafe.Pointer(&newBytes), unsafe.Pointer(peer), int(unsafe.Sizeof(*peer)))
builder.buffer = append(builder.buffer, newBytes...)
}
func (builder *ConfigBuilder) AppendAllowedIP(allowedIP *AllowedIP) {
var newBytes []byte
unsafeSlice(unsafe.Pointer(&newBytes), unsafe.Pointer(allowedIP), int(unsafe.Sizeof(*allowedIP)))
builder.buffer = append(builder.buffer, newBytes...)
}
func (builder *ConfigBuilder) Interface() (*Interface, uint32) {
if builder.buffer == nil {
return nil, 0
}
return (*Interface)(unsafe.Pointer(&builder.buffer[0])), uint32(len(builder.buffer))
}
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: whenGo 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
}

View File

@@ -0,0 +1,87 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package ioctl
import (
"encoding/binary"
"net"
"unsafe"
"golang.org/x/sys/windows"
)
// AddressFamily enumeration specifies protocol family and is one of the windows.AF_* constants.
type AddressFamily uint16
// RawSockaddrInet union contains an IPv4, an IPv6 address, or an address family.
// https://docs.microsoft.com/en-us/windows/desktop/api/ws2ipdef/ns-ws2ipdef-_sockaddr_inet
type RawSockaddrInet struct {
Family AddressFamily
data [26]byte
}
func ntohs(i uint16) uint16 {
return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&i))[:])
}
func htons(i uint16) uint16 {
b := make([]byte, 2)
binary.BigEndian.PutUint16(b, i)
return *(*uint16)(unsafe.Pointer(&b[0]))
}
// SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port.
// All other members of the structure are set to zero.
func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
if v4 := ip.To4(); v4 != nil {
addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr))
addr4.Family = windows.AF_INET
copy(addr4.Addr[:], v4)
addr4.Port = htons(port)
for i := 0; i < 8; i++ {
addr4.Zero[i] = 0
}
return nil
}
if v6 := ip.To16(); v6 != nil {
addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr))
addr6.Family = windows.AF_INET6
addr6.Port = htons(port)
addr6.Flowinfo = 0
copy(addr6.Addr[:], v6)
addr6.Scope_id = 0
return nil
}
return windows.ERROR_INVALID_PARAMETER
}
// IP returns IPv4 or IPv6 address, or nil if the address is neither.
func (addr *RawSockaddrInet) IP() net.IP {
switch addr.Family {
case windows.AF_INET:
return (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr[:]
case windows.AF_INET6:
return (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Addr[:]
}
return nil
}
// Port returns the port if the address if IPv4 or IPv6, or 0 if neither.
func (addr *RawSockaddrInet) Port() uint16 {
switch addr.Family {
case windows.AF_INET:
return ntohs((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Port)
case windows.AF_INET6:
return ntohs((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Port)
}
return 0
}