migrate to golang.zx2c4.com/wireguard/wgctrl (#239)

* migrate to golang.zx2c4.com/wireguard/wgctrl

This commit introduces the usage of wgctrl.
It avoids the usage of exec calls of the wg command
and parsing the output of `wg show`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* vendor wgctrl

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* apply suggestions from code review

Remove wireguard.Enpoint struct and use net.UDPAddr for the resolved
endpoint and addr string (dnsanme:port) if a DN was supplied.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/*: use wireguard.Enpoint

This commit introduces the wireguard.Enpoint struct.
It encapsulates a DN name with port and a net.UPDAddr.
The fields are private and only accessible over exported Methods
to avoid accidental modification.

Also iptables.GetProtocol is improved to avoid ipv4 rules being applied
by `ip6tables`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/wireguard/conf_test.go: add tests for Endpoint

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* cmd/kg/main.go: validate port range

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* add suggestions from review

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* pkg/mesh/mesh.go: use Equal func

Implement an Equal func for Enpoint and use it instead of comparing
strings.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* cmd/kgctl/main.go: check port range

Signed-off-by: leonnicolas <leonloechner@gmx.de>

* vendor

Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
leonnicolas 2022-01-30 17:38:45 +01:00 committed by GitHub
parent 797133f272
commit 6a696e03e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
299 changed files with 26275 additions and 10252 deletions

View File

@ -11,7 +11,7 @@ ARG GOARCH
ARG ALPINE_VERSION=v3.12
LABEL maintainer="squat <lserven@gmail.com>"
RUN echo -e "https://alpine.global.ssl.fastly.net/alpine/$ALPINE_VERSION/main\nhttps://alpine.global.ssl.fastly.net/alpine/$ALPINE_VERSION/community" > /etc/apk/repositories && \
apk add --no-cache ipset iptables ip6tables wireguard-tools graphviz font-noto
apk add --no-cache ipset iptables ip6tables graphviz font-noto
COPY --from=cni bridge host-local loopback portmap /opt/cni/bin/
COPY bin/linux/$GOARCH/kg /opt/bin/
ENTRYPOINT ["/opt/bin/kg"]

View File

@ -24,6 +24,8 @@ import (
"os"
"os/exec"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/mesh"
)
@ -62,7 +64,7 @@ func (h *graphHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
peers[p.Name] = p
}
}
topo, err := mesh.NewTopology(nodes, peers, h.granularity, *h.hostname, 0, []byte{}, h.subnet, nodes[*h.hostname].PersistentKeepalive, nil)
topo, err := mesh.NewTopology(nodes, peers, h.granularity, *h.hostname, 0, wgtypes.Key{}, h.subnet, nodes[*h.hostname].PersistentKeepalive, nil)
if err != nil {
http.Error(w, fmt.Sprintf("failed to create topology: %v", err), http.StatusInternalServerError)
return

View File

@ -109,7 +109,7 @@ var (
master string
mtu uint
topologyLabel string
port uint
port int
subnet string
resyncPeriod time.Duration
iptablesForwardRule bool
@ -139,7 +139,7 @@ func init() {
cmd.Flags().StringVar(&master, "master", "", "The address of the Kubernetes API server (overrides any value in kubeconfig).")
cmd.Flags().UintVar(&mtu, "mtu", wireguard.DefaultMTU, "The MTU of the WireGuard interface created by Kilo.")
cmd.Flags().StringVar(&topologyLabel, "topology-label", k8s.RegionLabelKey, "Kubernetes node label used to group nodes into logical locations.")
cmd.Flags().UintVar(&port, "port", mesh.DefaultKiloPort, "The port over which WireGuard peers should communicate.")
cmd.Flags().IntVar(&port, "port", mesh.DefaultKiloPort, "The port over which WireGuard peers should communicate.")
cmd.Flags().StringVar(&subnet, "subnet", mesh.DefaultKiloSubnet.String(), "CIDR from which to allocate addresses for WireGuard interfaces.")
cmd.Flags().DurationVar(&resyncPeriod, "resync-period", 30*time.Second, "How often should the Kilo controllers reconcile?")
cmd.Flags().BoolVar(&iptablesForwardRule, "iptables-forward-rules", false, "Add default accept rules to the FORWARD chain in iptables. Warning: this may break firewalls with a deny all policy and is potentially insecure!")
@ -234,12 +234,15 @@ func runRoot(_ *cobra.Command, _ []string) error {
c := kubernetes.NewForConfigOrDie(config)
kc := kiloclient.NewForConfigOrDie(config)
ec := apiextensions.NewForConfigOrDie(config)
b = k8s.New(c, kc, ec, topologyLabel)
b = k8s.New(c, kc, ec, topologyLabel, log.With(logger, "component", "k8s backend"))
default:
return fmt.Errorf("backend %v unknown; possible values are: %s", backend, availableBackends)
}
m, err := mesh.New(b, enc, gr, hostname, uint32(port), s, local, cni, cniPath, iface, cleanUpIface, createIface, mtu, resyncPeriod, prioritisePrivateAddr, iptablesForwardRule, log.With(logger, "component", "kilo"))
if port < 1 || port > 1<<16-1 {
return fmt.Errorf("invalid port: port mus be in range [%d:%d], but got %d", 1, 1<<16-1, port)
}
m, err := mesh.New(b, enc, gr, hostname, port, s, local, cni, cniPath, iface, cleanUpIface, createIface, mtu, resyncPeriod, prioritisePrivateAddr, iptablesForwardRule, log.With(logger, "component", "kilo"))
if err != nil {
return fmt.Errorf("failed to create Kilo mesh: %v", err)
}

View File

@ -18,6 +18,8 @@ import (
"fmt"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/mesh"
)
@ -65,7 +67,7 @@ func runGraph(_ *cobra.Command, _ []string) error {
peers[p.Name] = p
}
}
t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, 0, []byte{}, subnet, nodes[hostname].PersistentKeepalive, nil)
t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, 0, wgtypes.Key{}, subnet, nodes[hostname].PersistentKeepalive, nil)
if err != nil {
return fmt.Errorf("failed to create topology: %v", err)
}

View File

@ -21,6 +21,7 @@ import (
"path/filepath"
"strings"
"github.com/go-kit/kit/log"
"github.com/spf13/cobra"
apiextensions "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset"
"k8s.io/client-go/kubernetes"
@ -61,7 +62,7 @@ var (
opts struct {
backend mesh.Backend
granularity mesh.Granularity
port uint32
port int
}
backend string
granularity string
@ -70,6 +71,10 @@ var (
)
func runRoot(_ *cobra.Command, _ []string) error {
if opts.port < 1 || opts.port > 1<<16-1 {
return fmt.Errorf("invalid port: port mus be in range [%d:%d], but got %d", 1, 1<<16-1, opts.port)
}
opts.granularity = mesh.Granularity(granularity)
switch opts.granularity {
case mesh.LogicalGranularity:
@ -88,7 +93,7 @@ func runRoot(_ *cobra.Command, _ []string) error {
c := kubernetes.NewForConfigOrDie(config)
kc := kiloclient.NewForConfigOrDie(config)
ec := apiextensions.NewForConfigOrDie(config)
opts.backend = k8s.New(c, kc, ec, topologyLabel)
opts.backend = k8s.New(c, kc, ec, topologyLabel, log.NewNopLogger())
default:
return fmt.Errorf("backend %v unknown; posible values are: %s", backend, availableBackends)
}
@ -119,7 +124,7 @@ func main() {
defaultKubeconfig = filepath.Join(os.Getenv("HOME"), ".kube/config")
}
cmd.PersistentFlags().StringVar(&kubeconfig, "kubeconfig", defaultKubeconfig, "Path to kubeconfig.")
cmd.PersistentFlags().Uint32Var(&opts.port, "port", mesh.DefaultKiloPort, "The WireGuard port over which the nodes communicate.")
cmd.PersistentFlags().IntVar(&opts.port, "port", mesh.DefaultKiloPort, "The WireGuard port over which the nodes communicate.")
cmd.PersistentFlags().StringVar(&topologyLabel, "topology-label", k8s.RegionLabelKey, "Kubernetes node label used to group nodes into logical locations.")
for _, subCmd := range []*cobra.Command{

View File

@ -1,4 +1,4 @@
// Copyright 2019 the Kilo authors
// Copyright 2021 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -15,14 +15,15 @@
package main
import (
"bytes"
"errors"
"fmt"
"net"
"os"
"strings"
"time"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
@ -47,7 +48,7 @@ var (
}, ", ")
allowedIPs []string
showConfOpts struct {
allowedIPs []*net.IPNet
allowedIPs []net.IPNet
serializer *json.Serializer
output string
asPeer bool
@ -89,7 +90,7 @@ func runShowConf(c *cobra.Command, args []string) error {
if err != nil {
return fmt.Errorf("allowed-ips must contain only valid CIDRs; got %q", allowedIPs[i])
}
showConfOpts.allowedIPs = append(showConfOpts.allowedIPs, aip)
showConfOpts.allowedIPs = append(showConfOpts.allowedIPs, *aip)
}
return runRoot(c, args)
}
@ -151,14 +152,14 @@ func runShowConfNode(_ *cobra.Command, args []string) error {
}
}
t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, opts.port, []byte{}, subnet, nodes[hostname].PersistentKeepalive, nil)
t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, int(opts.port), wgtypes.Key{}, subnet, nodes[hostname].PersistentKeepalive, nil)
if err != nil {
return fmt.Errorf("failed to create topology: %v", err)
}
var found bool
for _, p := range t.PeerConf("").Peers {
if bytes.Equal(p.PublicKey, nodes[hostname].Key) {
if p.PublicKey == nodes[hostname].Key {
found = true
break
}
@ -182,6 +183,9 @@ func runShowConfNode(_ *cobra.Command, args []string) error {
fallthrough
case outputFormatYAML:
p := t.AsPeer()
if p == nil {
return errors.New("cannot generate config from nil peer")
}
p.AllowedIPs = append(p.AllowedIPs, showConfOpts.allowedIPs...)
p.DeduplicateIPs()
k8sp := translatePeer(p)
@ -189,10 +193,13 @@ func runShowConfNode(_ *cobra.Command, args []string) error {
return showConfOpts.serializer.Encode(k8sp, os.Stdout)
case outputFormatWireGuard:
p := t.AsPeer()
if p == nil {
return errors.New("cannot generate config from nil peer")
}
p.AllowedIPs = append(p.AllowedIPs, showConfOpts.allowedIPs...)
p.DeduplicateIPs()
c, err := (&wireguard.Conf{
Peers: []*wireguard.Peer{p},
Peers: []wireguard.Peer{*p},
}).Bytes()
if err != nil {
return fmt.Errorf("failed to generate configuration: %v", err)
@ -244,7 +251,11 @@ func runShowConfPeer(_ *cobra.Command, args []string) error {
return fmt.Errorf("did not find any peer named %q in the cluster", peer)
}
t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, mesh.DefaultKiloPort, []byte{}, subnet, peers[peer].PersistentKeepalive, nil)
pka := time.Duration(0)
if p := peers[peer].PersistentKeepaliveInterval; p != nil {
pka = *p
}
t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, mesh.DefaultKiloPort, wgtypes.Key{}, subnet, pka, nil)
if err != nil {
return fmt.Errorf("failed to create topology: %v", err)
}
@ -272,7 +283,7 @@ func runShowConfPeer(_ *cobra.Command, args []string) error {
p.AllowedIPs = append(p.AllowedIPs, showConfOpts.allowedIPs...)
p.DeduplicateIPs()
c, err := (&wireguard.Conf{
Peers: []*wireguard.Peer{p},
Peers: []wireguard.Peer{*p},
}).Bytes()
if err != nil {
return fmt.Errorf("failed to generate configuration: %v", err)
@ -284,6 +295,7 @@ func runShowConfPeer(_ *cobra.Command, args []string) error {
}
// translatePeer translates a wireguard.Peer to a Peer CRD.
// TODO this function has many similarities to peerBackend.Set(name, peer)
func translatePeer(peer *wireguard.Peer) *v1alpha1.Peer {
if peer == nil {
return &v1alpha1.Peer{}
@ -291,36 +303,33 @@ func translatePeer(peer *wireguard.Peer) *v1alpha1.Peer {
var aips []string
for _, aip := range peer.AllowedIPs {
// Skip any invalid IPs.
if aip == nil {
// TODO all IPs should be valid, so no need to skip here?
if aip.String() == (&net.IPNet{}).String() {
continue
}
aips = append(aips, aip.String())
}
var endpoint *v1alpha1.PeerEndpoint
if peer.Endpoint != nil && peer.Endpoint.Port > 0 && (peer.Endpoint.IP != nil || peer.Endpoint.DNS != "") {
var ip string
if peer.Endpoint.IP != nil {
ip = peer.Endpoint.IP.String()
}
if peer.Endpoint.Port() > 0 || !peer.Endpoint.HasDNS() {
endpoint = &v1alpha1.PeerEndpoint{
DNSOrIP: v1alpha1.DNSOrIP{
DNS: peer.Endpoint.DNS,
IP: ip,
IP: peer.Endpoint.IP().String(),
DNS: peer.Endpoint.DNS(),
},
Port: peer.Endpoint.Port,
Port: uint32(peer.Endpoint.Port()),
}
}
var key string
if len(peer.PublicKey) > 0 {
key = string(peer.PublicKey)
if peer.PublicKey != (wgtypes.Key{}) {
key = peer.PublicKey.String()
}
var psk string
if len(peer.PresharedKey) > 0 {
psk = string(peer.PresharedKey)
if peer.PresharedKey != nil {
psk = peer.PresharedKey.String()
}
var pka int
if peer.PersistentKeepalive > 0 {
pka = peer.PersistentKeepalive
if peer.PersistentKeepaliveInterval != nil && *peer.PersistentKeepaliveInterval > time.Duration(0) {
pka = int(*peer.PersistentKeepaliveInterval)
}
return &v1alpha1.Peer{
TypeMeta: metav1.TypeMeta{

View File

@ -18,7 +18,7 @@ test_full_mesh_connectivity() {
}
test_full_mesh_peer() {
check_peer wg1 e2e 10.5.0.1/32 full
check_peer wg99 e2e 10.5.0.1/32 full
}
test_full_mesh_allowed_location_ips() {

View File

@ -184,14 +184,14 @@ check_peer() {
local ALLOWED_IP=$3
local GRANULARITY=$4
create_interface "$INTERFACE"
docker run --rm --entrypoint=/usr/bin/wg "$KILO_IMAGE" genkey > "$INTERFACE"
assert "create_peer $PEER $ALLOWED_IP 10 $(docker run --rm --entrypoint=/bin/sh -v "$PWD/$INTERFACE":/key "$KILO_IMAGE" -c 'cat /key | wg pubkey')" "should be able to create Peer"
docker run --rm leonnicolas/wg-tools wg genkey > "$INTERFACE"
assert "create_peer $PEER $ALLOWED_IP 10 $(docker run --rm --entrypoint=/bin/sh -v "$PWD/$INTERFACE":/key leonnicolas/wg-tools -c 'cat /key | wg pubkey')" "should be able to create Peer"
assert "_kgctl showconf peer $PEER --mesh-granularity=$GRANULARITY > $PEER.ini" "should be able to get Peer configuration"
assert "docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/usr/bin/wg -v /var/run/wireguard:/var/run/wireguard -v $PWD/$PEER.ini:/peer.ini $KILO_IMAGE setconf $INTERFACE /peer.ini" "should be able to apply configuration from kgctl"
docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/usr/bin/wg -v /var/run/wireguard:/var/run/wireguard -v "$PWD/$INTERFACE":/key "$KILO_IMAGE" set "$INTERFACE" private-key /key
docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip "$KILO_IMAGE" address add "$ALLOWED_IP" dev "$INTERFACE"
docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip "$KILO_IMAGE" link set "$INTERFACE" up
docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip "$KILO_IMAGE" route add 10.42/16 dev "$INTERFACE"
assert "docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/usr/bin/wg -v /var/run/wireguard:/var/run/wireguard -v $PWD/$PEER.ini:/peer.ini leonnicolas/wg-tools setconf $INTERFACE /peer.ini" "should be able to apply configuration from kgctl"
docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/usr/bin/wg -v /var/run/wireguard:/var/run/wireguard -v "$PWD/$INTERFACE":/key leonnicolas/wg-tools set "$INTERFACE" private-key /key
docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip leonnicolas/wg-tools address add "$ALLOWED_IP" dev "$INTERFACE"
docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip leonnicolas/wg-tools link set "$INTERFACE" up
docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip leonnicolas/wg-tools route add 10.42/16 dev "$INTERFACE"
assert "retry 10 5 '' check_ping --local" "should be able to ping Pods from host"
assert_equals "$(_kgctl showconf peer "$PEER")" "$(_kgctl showconf peer "$PEER" --mesh-granularity="$GRANULARITY")" "kgctl should be able to auto detect the mesh granularity"
rm "$INTERFACE" "$PEER".ini

View File

@ -18,7 +18,7 @@ test_location_mesh_connectivity() {
}
test_location_mesh_peer() {
check_peer wg1 e2e 10.5.0.1/32 location
check_peer wg99 e2e 10.5.0.1/32 location
}
test_mesh_granularity_auto_detect() {

11
go.mod
View File

@ -17,7 +17,8 @@ require (
github.com/vishvananda/netlink v1.0.0
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc // indirect
golang.org/x/lint v0.0.0-20200302205851-738671d3881b
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211124212657-dd7407c86d22
k8s.io/api v0.21.1
k8s.io/apiextensions-apiserver v0.21.1
k8s.io/apimachinery v0.21.1
@ -42,10 +43,14 @@ require (
github.com/googleapis/gnostic v0.4.1 // indirect
github.com/hashicorp/golang-lru v0.5.1 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
github.com/json-iterator/go v1.1.11 // indirect
github.com/mattn/go-colorable v0.1.8 // indirect
github.com/mattn/go-isatty v0.0.12 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
github.com/mdlayher/genetlink v1.0.0 // indirect
github.com/mdlayher/netlink v1.4.1 // indirect
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
@ -54,14 +59,16 @@ require (
github.com/prometheus/common v0.26.0 // indirect
github.com/prometheus/procfs v0.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 // indirect
golang.org/x/mod v0.4.2 // indirect
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 // indirect
golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9 // indirect
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d // indirect
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d // indirect
golang.org/x/text v0.3.6 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
golang.org/x/tools v0.1.2 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
golang.zx2c4.com/wireguard v0.0.0-20211123210315-387f7c461a16 // indirect
google.golang.org/appengine v1.6.5 // indirect
google.golang.org/protobuf v1.26.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect

74
go.sum
View File

@ -65,6 +65,7 @@ github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/cilium/ebpf v0.5.0/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8=
github.com/containernetworking/cni v0.6.0 h1:FXICGBZNMtdHlW65trpoHviHctQD3seWhRRcqp2hMOU=
@ -105,6 +106,7 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
github.com/fatih/color v1.12.0 h1:mRhaKNwANqRgUBGKmnI5ZxEk7QXmjQeCcuYFMX2bfcc=
github.com/fatih/color v1.12.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
@ -237,7 +239,18 @@ github.com/imdario/mergo v0.3.6/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJ
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 h1:uhL5Gw7BINiiPAo24A2sxkcDI0Jt/sqp1v5xQCniEFA=
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
github.com/jsimonetti/rtnetlink v0.0.0-20201009170750-9c6f07d100c1/go.mod h1:hqoO/u39cqLeBLebZ8fWdE96O7FxrAsRYhnVOdgHxok=
github.com/jsimonetti/rtnetlink v0.0.0-20201216134343-bde56ed16391/go.mod h1:cR77jAZG3Y3bsb8hF6fHJbFoyFukLFOkQ98S0pQz3xw=
github.com/jsimonetti/rtnetlink v0.0.0-20201220180245-69540ac93943/go.mod h1:z4c53zj6Eex712ROyh8WI0ihysb5j2ROyV42iNogmAs=
github.com/jsimonetti/rtnetlink v0.0.0-20210122163228-8d122574c736/go.mod h1:ZXpIyOK59ZnN7J0BV99cZUPmsqDRZ3eq5X+st7u/oSA=
github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b/go.mod h1:8w9Rh8m+aHZIG69YPGGem1i5VzoyRC8nw2kA8B+ik5U=
github.com/jsimonetti/rtnetlink v0.0.0-20210525051524-4cc836578190 h1:iycCSDo8EKVueI9sfVBBJmtNn9DnXV/K1YWwEJO+uOs=
github.com/jsimonetti/rtnetlink v0.0.0-20210525051524-4cc836578190/go.mod h1:NmKSdU4VGSiv1bMsdqNALI4RSvvjtz65tTMCnD05qLo=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
@ -256,6 +269,7 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
@ -278,7 +292,27 @@ github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzp
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI=
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43 h1:WgyLFv10Ov49JAQI/ZLUkCZ7VJS3r74hwFIGXJsgZlY=
github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43/go.mod h1:+t7E0lkKfbBsebllff1xdTmyJt8lH37niI6kwFk9OTo=
github.com/mdlayher/genetlink v1.0.0 h1:OoHN1OdyEIkScEmRgxLEe2M9U8ClMytqA5niynLtfj0=
github.com/mdlayher/genetlink v1.0.0/go.mod h1:0rJ0h4itni50A86M2kHcgS85ttZazNt7a8H2a2cw0Gc=
github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA=
github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M=
github.com/mdlayher/netlink v1.1.0/go.mod h1:H4WCitaheIsdF9yOYu8CFmCgQthAPIWZmcKp9uZHgmY=
github.com/mdlayher/netlink v1.1.1/go.mod h1:WTYpFb/WTvlRJAyKhZL5/uy69TDDpHHu2VZmb2XgV7o=
github.com/mdlayher/netlink v1.2.0/go.mod h1:kwVW1io0AZy9A1E2YYgaD4Cj+C+GPkU6klXCMzIJ9p8=
github.com/mdlayher/netlink v1.2.1/go.mod h1:bacnNlfhqHqqLo4WsYeXSqfyXkInQ9JneWI68v1KwSU=
github.com/mdlayher/netlink v1.2.2-0.20210123213345-5cc92139ae3e/go.mod h1:bacnNlfhqHqqLo4WsYeXSqfyXkInQ9JneWI68v1KwSU=
github.com/mdlayher/netlink v1.3.0/go.mod h1:xK/BssKuwcRXHrtN04UBkwQ6dY9VviGGuriDdoPSWys=
github.com/mdlayher/netlink v1.4.0/go.mod h1:dRJi5IABcZpBD2A3D0Mv/AiX8I9uDEu5oGkAVrekmf8=
github.com/mdlayher/netlink v1.4.1 h1:I154BCU+mKlIf7BgcAJB2r7QjveNPty6uNY1g9ChVfI=
github.com/mdlayher/netlink v1.4.1/go.mod h1:e4/KuJ+s8UhfUpO9z00/fDZZmhSrs+oxyqAS9cNgn6Q=
github.com/mdlayher/socket v0.0.0-20210307095302-262dc9984e00/go.mod h1:GAFlyu4/XV68LkQKYzKhIo/WW7j3Zi0YRAz/BOoanUc=
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb h1:2dC7L10LmTqlyMVzFJ00qM25lqESg9Z4u3GuEXN5iHY=
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
@ -429,6 +463,9 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 h1:/pEO3GD/ABYAjuakUS6xSEmmlyVS4kxBNkeA9tLJiTI=
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -482,6 +519,7 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191007182048-72f939374954/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@ -491,11 +529,22 @@ golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201216054612-986b41b23924/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210224082022-3d97a244fca7/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211111083644-e5c967477495/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9 h1:0qxwC5n+ttVOINCBeRHO0nq9X7uy8SDsPoi5OaCdIEI=
golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -520,6 +569,7 @@ golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -532,6 +582,7 @@ golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -549,16 +600,29 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200831180312-196b9ba8737a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201118182958-a01c418693c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201218084310-7d0127a74742/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210110051926-789bb1bd4061/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210123111255-9b0068b26619/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210216163648-f7da38b97c65/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210225134936-a50acf3fe073/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40 h1:JWgyZ1qgdTaF3N3oxC+MdTV7qvEEgHo3otj+HB5CM7Q=
golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211110154304-99a53858aa08/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 h1:TyHqChC80pFkXWraUUf6RuB5IqFdQieMLwwCJokV2pc=
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d h1:SZxvLBoTP5yHO3Frd4z4vrF+DBX9vMVanchswa69toE=
@ -624,6 +688,12 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20211123210315-387f7c461a16 h1:SCBV/ayxt56AuC0R8oN5p8Mmc9Llv34tr9VbNlh1kL0=
golang.zx2c4.com/wireguard v0.0.0-20211123210315-387f7c461a16/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211124212657-dd7407c86d22 h1:TnoJ6AWs/q4xGE9smgTi+bJmEDet3nrBqdHSV0d5/zA=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211124212657-dd7407c86d22/go.mod h1:ELrVb7MHNIb8OV6iZ6TP/ScunqUha+vWsUP+SVBH7lQ=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=

View File

@ -74,7 +74,7 @@ func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule {
rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
for _, n := range nodes {
// Accept encapsulated traffic from peers.
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(n.IP)), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
}
// Drop all other IPIP traffic.
rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))

View File

@ -53,11 +53,11 @@ const (
)
// GetProtocol will return a protocol from the length of an IP address.
func GetProtocol(length int) Protocol {
if length == net.IPv6len {
return ProtocolIPv6
func GetProtocol(ip net.IP) Protocol {
if len(ip) == net.IPv4len || ip.To4() != nil {
return ProtocolIPv4
}
return ProtocolIPv4
return ProtocolIPv6
}
// Client represents any type that can administer iptables rules.

View File

@ -25,13 +25,15 @@ import (
"strings"
"time"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
v1 "k8s.io/api/core/v1"
apiextensions "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/strategicpatch"
"k8s.io/apimachinery/pkg/util/validation"
v1informers "k8s.io/client-go/informers/core/v1"
"k8s.io/client-go/kubernetes"
v1listers "k8s.io/client-go/listers/core/v1"
@ -67,6 +69,8 @@ const (
jsonRemovePatch = `{"op": "remove", "path": "%s"}`
)
var logger = log.NewNopLogger()
type backend struct {
nodes *nodeBackend
peers *peerBackend
@ -99,10 +103,12 @@ type peerBackend struct {
}
// New creates a new instance of a mesh.Backend.
func New(c kubernetes.Interface, kc kiloclient.Interface, ec apiextensions.Interface, topologyLabel string) mesh.Backend {
func New(c kubernetes.Interface, kc kiloclient.Interface, ec apiextensions.Interface, topologyLabel string, l log.Logger) mesh.Backend {
ni := v1informers.NewNodeInformer(c, 5*time.Minute, nil)
pi := v1alpha1informers.NewPeerInformer(kc, 5*time.Minute, nil)
logger = l
return &backend{
&nodeBackend{
client: c,
@ -218,7 +224,7 @@ func (nb *nodeBackend) Set(name string, node *mesh.Node) error {
} else {
n.ObjectMeta.Annotations[internalIPAnnotationKey] = node.InternalIP.String()
}
n.ObjectMeta.Annotations[keyAnnotationKey] = string(node.Key)
n.ObjectMeta.Annotations[keyAnnotationKey] = node.Key.String()
n.ObjectMeta.Annotations[lastSeenAnnotationKey] = strconv.FormatInt(node.LastSeen, 10)
if node.WireGuardIP == nil {
n.ObjectMeta.Annotations[wireGuardIPAnnotationKey] = ""
@ -276,9 +282,9 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
location = node.ObjectMeta.Labels[topologyLabel]
}
// Allow the endpoint to be overridden.
endpoint := parseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey])
endpoint := wireguard.ParseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey])
if endpoint == nil {
endpoint = parseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey])
endpoint = wireguard.ParseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey])
}
// Allow the internal IP to be overridden.
internalIP := normalizeIP(node.ObjectMeta.Annotations[forceInternalIPAnnotationKey])
@ -292,13 +298,11 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
internalIP = nil
}
// Set Wireguard PersistentKeepalive setting for the node.
var persistentKeepalive int64
if keepAlive, ok := node.ObjectMeta.Annotations[persistentKeepaliveKey]; !ok {
persistentKeepalive = 0
} else {
if persistentKeepalive, err = strconv.ParseInt(keepAlive, 10, 64); err != nil {
persistentKeepalive = 0
}
var persistentKeepalive = time.Duration(0)
if keepAlive, ok := node.ObjectMeta.Annotations[persistentKeepaliveKey]; ok {
// We can ignore the error, because p will be set to 0 if an error occures.
p, _ := strconv.ParseInt(keepAlive, 10, 64)
persistentKeepalive = time.Duration(p) * time.Second
}
var lastSeen int64
if ls, ok := node.ObjectMeta.Annotations[lastSeenAnnotationKey]; !ok {
@ -308,7 +312,7 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
lastSeen = 0
}
}
var discoveredEndpoints map[string]*wireguard.Endpoint
var discoveredEndpoints map[string]*net.UDPAddr
if de, ok := node.ObjectMeta.Annotations[discoveredEndpointsKey]; ok {
err := json.Unmarshal([]byte(de), &discoveredEndpoints)
if err != nil {
@ -316,11 +320,11 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
}
}
// Set allowed IPs for a location.
var allowedLocationIPs []*net.IPNet
var allowedLocationIPs []net.IPNet
if str, ok := node.ObjectMeta.Annotations[allowedLocationIPsKey]; ok {
for _, ip := range strings.Split(str, ",") {
if ipnet := normalizeIP(ip); ipnet != nil {
allowedLocationIPs = append(allowedLocationIPs, ipnet)
allowedLocationIPs = append(allowedLocationIPs, *ipnet)
}
}
}
@ -335,6 +339,9 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
}
}
// TODO log some error or warning.
key, _ := wgtypes.ParseKey(node.ObjectMeta.Annotations[keyAnnotationKey])
return &mesh.Node{
// Endpoint and InternalIP should only ever fail to parse if the
// remote node's agent has not yet set its IP address;
@ -345,12 +352,12 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
Endpoint: endpoint,
NoInternalIP: noInternalIP,
InternalIP: internalIP,
Key: []byte(node.ObjectMeta.Annotations[keyAnnotationKey]),
Key: key,
LastSeen: lastSeen,
Leader: leader,
Location: location,
Name: node.Name,
PersistentKeepalive: int(persistentKeepalive),
PersistentKeepalive: persistentKeepalive,
Subnet: subnet,
// WireGuardIP can fail to parse if the node is not a leader or if
// the node's agent has not yet reconciled. In either case, the IP
@ -367,14 +374,14 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer {
if peer == nil {
return nil
}
var aips []*net.IPNet
var aips []net.IPNet
for _, aip := range peer.Spec.AllowedIPs {
aip := normalizeIP(aip)
// Skip any invalid IPs.
if aip == nil {
continue
}
aips = append(aips, aip)
aips = append(aips, *aip)
}
var endpoint *wireguard.Endpoint
if peer.Spec.Endpoint != nil {
@ -384,36 +391,41 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer {
} else {
ip = ip.To16()
}
if peer.Spec.Endpoint.Port > 0 && (ip != nil || peer.Spec.Endpoint.DNS != "") {
endpoint = &wireguard.Endpoint{
DNSOrIP: wireguard.DNSOrIP{
DNS: peer.Spec.Endpoint.DNS,
IP: ip,
},
Port: peer.Spec.Endpoint.Port,
if peer.Spec.Endpoint.Port > 0 {
if ip != nil {
endpoint = wireguard.NewEndpoint(ip, int(peer.Spec.Endpoint.Port))
}
if peer.Spec.Endpoint.DNS != "" {
endpoint = wireguard.ParseEndpoint(fmt.Sprintf("%s:%d", peer.Spec.Endpoint.DNS, peer.Spec.Endpoint.Port))
}
}
}
var key []byte
if len(peer.Spec.PublicKey) > 0 {
key = []byte(peer.Spec.PublicKey)
key, err := wgtypes.ParseKey(peer.Spec.PublicKey)
if err != nil {
level.Error(logger).Log("msg", "failed to parse public key", "peer", peer.Name, "err", err.Error())
}
var psk []byte
if len(peer.Spec.PresharedKey) > 0 {
psk = []byte(peer.Spec.PresharedKey)
var psk *wgtypes.Key
if k, err := wgtypes.ParseKey(peer.Spec.PresharedKey); err != nil {
// Set key to nil to avoid setting a key to the zero value wgtypes.Key{}
psk = nil
} else {
psk = &k
}
var pka int
var pka time.Duration
if peer.Spec.PersistentKeepalive > 0 {
pka = peer.Spec.PersistentKeepalive
pka = time.Duration(peer.Spec.PersistentKeepalive)
}
return &mesh.Peer{
Name: peer.Name,
Peer: wireguard.Peer{
AllowedIPs: aips,
Endpoint: endpoint,
PersistentKeepalive: pka,
PresharedKey: psk,
PublicKey: key,
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: aips,
PersistentKeepaliveInterval: &pka,
PresharedKey: psk,
PublicKey: key,
},
Endpoint: endpoint,
},
}
}
@ -511,21 +523,25 @@ func (pb *peerBackend) Set(name string, peer *mesh.Peer) error {
p.Spec.AllowedIPs[i] = peer.AllowedIPs[i].String()
}
if peer.Endpoint != nil {
var ip string
if peer.Endpoint.IP != nil {
ip = peer.Endpoint.IP.String()
}
p.Spec.Endpoint = &v1alpha1.PeerEndpoint{
DNSOrIP: v1alpha1.DNSOrIP{
IP: ip,
DNS: peer.Endpoint.DNS,
IP: peer.Endpoint.IP().String(),
DNS: peer.Endpoint.DNS(),
},
Port: peer.Endpoint.Port,
Port: uint32(peer.Endpoint.Port()),
}
}
p.Spec.PersistentKeepalive = peer.PersistentKeepalive
p.Spec.PresharedKey = string(peer.PresharedKey)
p.Spec.PublicKey = string(peer.PublicKey)
if peer.PersistentKeepaliveInterval == nil {
p.Spec.PersistentKeepalive = 0
} else {
p.Spec.PersistentKeepalive = int(*peer.PersistentKeepaliveInterval)
}
if peer.PresharedKey == nil {
p.Spec.PresharedKey = ""
} else {
p.Spec.PresharedKey = peer.PresharedKey.String()
}
p.Spec.PublicKey = peer.PublicKey.String()
if _, err = pb.client.KiloV1alpha1().Peers().Update(context.TODO(), p, metav1.UpdateOptions{}); err != nil {
return fmt.Errorf("failed to update peer: %v", err)
}
@ -549,35 +565,3 @@ func normalizeIP(ip string) *net.IPNet {
ipNet.IP = i.To16()
return ipNet
}
func parseEndpoint(endpoint string) *wireguard.Endpoint {
if len(endpoint) == 0 {
return nil
}
parts := strings.Split(endpoint, ":")
if len(parts) < 2 {
return nil
}
portRaw := parts[len(parts)-1]
hostRaw := strings.Trim(strings.Join(parts[:len(parts)-1], ":"), "[]")
port, err := strconv.ParseUint(portRaw, 10, 32)
if err != nil {
return nil
}
if len(validation.IsValidPortNum(int(port))) != 0 {
return nil
}
ip := net.ParseIP(hostRaw)
if ip == nil {
if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 {
return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: hostRaw}, Port: uint32(port)}
}
return nil
}
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
} else {
ip = ip.To16()
}
return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: ip}, Port: uint32(port)}
}

View File

@ -1,4 +1,4 @@
// Copyright 2019 the Kilo authors
// Copyright 2021 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -17,8 +17,10 @@ package k8s
import (
"net"
"testing"
"time"
"github.com/kylelemons/godebug/pretty"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
v1 "k8s.io/api/core/v1"
"github.com/squat/kilo/pkg/k8s/apis/kilo/v1alpha1"
@ -26,6 +28,30 @@ import (
"github.com/squat/kilo/pkg/wireguard"
)
func mustKey() (k wgtypes.Key) {
var err error
if k, err = wgtypes.GeneratePrivateKey(); err != nil {
panic(err.Error())
}
return
}
func mustPSKKey() (key *wgtypes.Key) {
if k, err := wgtypes.GenerateKey(); err != nil {
panic(err.Error())
} else {
key = &k
}
return
}
var (
fooKey = mustKey()
pskKey = mustPSKKey()
second = time.Second
zero = time.Duration(0)
)
func TestTranslateNode(t *testing.T) {
for _, tc := range []struct {
name string
@ -54,8 +80,19 @@ func TestTranslateNode(t *testing.T) {
internalIPAnnotationKey: "10.0.0.2/32",
},
out: &mesh.Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort},
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort),
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2").To4(), Mask: net.CIDRMask(32, 32)},
},
},
{
name: "valid ips with ipv6",
annotations: map[string]string{
endpointAnnotationKey: "[ff10::10]:51820",
internalIPAnnotationKey: "ff60::10/64",
},
out: &mesh.Node{
Endpoint: wireguard.NewEndpoint(net.ParseIP("ff10::10").To16(), mesh.DefaultKiloPort),
InternalIP: &net.IPNet{IP: net.ParseIP("ff60::10").To16(), Mask: net.CIDRMask(64, 128)},
},
},
{
@ -68,7 +105,7 @@ func TestTranslateNode(t *testing.T) {
name: "normalize subnet",
annotations: map[string]string{},
out: &mesh.Node{
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(24, 32)},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0").To4(), Mask: net.CIDRMask(24, 32)},
},
subnet: "10.2.0.1/24",
},
@ -76,7 +113,7 @@ func TestTranslateNode(t *testing.T) {
name: "valid subnet",
annotations: map[string]string{},
out: &mesh.Node{
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0").To4(), Mask: net.CIDRMask(24, 32)},
},
subnet: "10.2.1.0/24",
},
@ -108,7 +145,7 @@ func TestTranslateNode(t *testing.T) {
forceEndpointAnnotationKey: "-10.0.0.2:51821",
},
out: &mesh.Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort),
},
},
{
@ -118,7 +155,7 @@ func TestTranslateNode(t *testing.T) {
forceEndpointAnnotationKey: "10.0.0.2:51821",
},
out: &mesh.Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821},
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.2").To4(), 51821),
},
},
{
@ -127,7 +164,7 @@ func TestTranslateNode(t *testing.T) {
persistentKeepaliveKey: "25",
},
out: &mesh.Node{
PersistentKeepalive: 25,
PersistentKeepalive: 25 * time.Second,
},
},
{
@ -137,7 +174,7 @@ func TestTranslateNode(t *testing.T) {
forceInternalIPAnnotationKey: "-10.1.0.2/24",
},
out: &mesh.Node{
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.1"), Mask: net.CIDRMask(24, 32)},
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.1").To4(), Mask: net.CIDRMask(24, 32)},
NoInternalIP: false,
},
},
@ -148,7 +185,7 @@ func TestTranslateNode(t *testing.T) {
forceInternalIPAnnotationKey: "10.1.0.2/24",
},
out: &mesh.Node{
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(24, 32)},
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(24, 32)},
NoInternalIP: false,
},
},
@ -166,7 +203,7 @@ func TestTranslateNode(t *testing.T) {
forceEndpointAnnotationKey: "10.0.0.2:51821",
forceInternalIPAnnotationKey: "10.1.0.2/32",
internalIPAnnotationKey: "10.1.0.1/32",
keyAnnotationKey: "foo",
keyAnnotationKey: fooKey.String(),
lastSeenAnnotationKey: "1000000000",
leaderAnnotationKey: "",
locationAnnotationKey: "b",
@ -177,14 +214,45 @@ func TestTranslateNode(t *testing.T) {
RegionLabelKey: "a",
},
out: &mesh.Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821},
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.2").To4(), 51821),
NoInternalIP: false,
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)},
Key: []byte("foo"),
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(32, 32)},
Key: fooKey,
LastSeen: 1000000000,
Leader: true,
Location: "b",
PersistentKeepalive: 25,
PersistentKeepalive: 25 * time.Second,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0").To4(), Mask: net.CIDRMask(24, 32)},
WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1").To4(), Mask: net.CIDRMask(16, 32)},
},
subnet: "10.2.1.0/24",
},
{
name: "complete with ipv6",
annotations: map[string]string{
endpointAnnotationKey: "10.0.0.1:51820",
forceEndpointAnnotationKey: "[1100::10]:51821",
forceInternalIPAnnotationKey: "10.1.0.2/32",
internalIPAnnotationKey: "10.1.0.1/32",
keyAnnotationKey: fooKey.String(),
lastSeenAnnotationKey: "1000000000",
leaderAnnotationKey: "",
locationAnnotationKey: "b",
persistentKeepaliveKey: "25",
wireGuardIPAnnotationKey: "10.4.0.1/16",
},
labels: map[string]string{
RegionLabelKey: "a",
},
out: &mesh.Node{
Endpoint: wireguard.NewEndpoint(net.ParseIP("1100::10"), 51821),
NoInternalIP: false,
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)},
Key: fooKey,
LastSeen: 1000000000,
Leader: true,
Location: "b",
PersistentKeepalive: 25 * time.Second,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)},
},
@ -195,7 +263,7 @@ func TestTranslateNode(t *testing.T) {
annotations: map[string]string{
endpointAnnotationKey: "10.0.0.1:51820",
internalIPAnnotationKey: "",
keyAnnotationKey: "foo",
keyAnnotationKey: fooKey.String(),
lastSeenAnnotationKey: "1000000000",
locationAnnotationKey: "b",
persistentKeepaliveKey: "25",
@ -205,13 +273,13 @@ func TestTranslateNode(t *testing.T) {
RegionLabelKey: "a",
},
out: &mesh.Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820},
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1"), 51820),
InternalIP: nil,
Key: []byte("foo"),
Key: fooKey,
LastSeen: 1000000000,
Leader: false,
Location: "b",
PersistentKeepalive: 25,
PersistentKeepalive: 25 * time.Second,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)},
},
@ -223,7 +291,7 @@ func TestTranslateNode(t *testing.T) {
endpointAnnotationKey: "10.0.0.1:51820",
internalIPAnnotationKey: "10.1.0.1/32",
forceInternalIPAnnotationKey: "",
keyAnnotationKey: "foo",
keyAnnotationKey: fooKey.String(),
lastSeenAnnotationKey: "1000000000",
locationAnnotationKey: "b",
persistentKeepaliveKey: "25",
@ -233,14 +301,14 @@ func TestTranslateNode(t *testing.T) {
RegionLabelKey: "a",
},
out: &mesh.Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820},
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1"), 51820),
NoInternalIP: true,
InternalIP: nil,
Key: []byte("foo"),
Key: fooKey,
LastSeen: 1000000000,
Leader: false,
Location: "b",
PersistentKeepalive: 25,
PersistentKeepalive: 25 * time.Second,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)},
},
@ -266,7 +334,13 @@ func TestTranslatePeer(t *testing.T) {
}{
{
name: "empty",
out: &mesh.Peer{},
out: &mesh.Peer{
Peer: wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &zero,
},
},
},
},
{
name: "invalid ips",
@ -276,7 +350,13 @@ func TestTranslatePeer(t *testing.T) {
"foo",
},
},
out: &mesh.Peer{},
out: &mesh.Peer{
Peer: wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &zero,
},
},
},
},
{
name: "valid ips",
@ -288,9 +368,12 @@ func TestTranslatePeer(t *testing.T) {
},
out: &mesh.Peer{
Peer: wireguard.Peer{
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
},
PersistentKeepaliveInterval: &zero,
},
},
},
@ -305,7 +388,13 @@ func TestTranslatePeer(t *testing.T) {
Port: mesh.DefaultKiloPort,
},
},
out: &mesh.Peer{},
out: &mesh.Peer{
Peer: wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &zero,
},
},
},
},
{
name: "only endpoint port",
@ -314,7 +403,13 @@ func TestTranslatePeer(t *testing.T) {
Port: mesh.DefaultKiloPort,
},
},
out: &mesh.Peer{},
out: &mesh.Peer{
Peer: wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &zero,
},
},
},
},
{
name: "valid endpoint ip",
@ -328,10 +423,29 @@ func TestTranslatePeer(t *testing.T) {
},
out: &mesh.Peer{
Peer: wireguard.Peer{
Endpoint: &wireguard.Endpoint{
DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")},
Port: mesh.DefaultKiloPort,
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &zero,
},
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort),
},
},
},
{
name: "valid endpoint ipv6",
spec: v1alpha1.PeerSpec{
Endpoint: &v1alpha1.PeerEndpoint{
DNSOrIP: v1alpha1.DNSOrIP{
IP: "ff60::2",
},
Port: mesh.DefaultKiloPort,
},
},
out: &mesh.Peer{
Peer: wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &zero,
},
Endpoint: wireguard.NewEndpoint(net.ParseIP("ff60::2").To16(), mesh.DefaultKiloPort),
},
},
},
@ -347,9 +461,9 @@ func TestTranslatePeer(t *testing.T) {
},
out: &mesh.Peer{
Peer: wireguard.Peer{
Endpoint: &wireguard.Endpoint{
DNSOrIP: wireguard.DNSOrIP{DNS: "example.com"},
Port: mesh.DefaultKiloPort,
Endpoint: wireguard.ParseEndpoint("example.com:51820"),
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &zero,
},
},
},
@ -359,16 +473,25 @@ func TestTranslatePeer(t *testing.T) {
spec: v1alpha1.PeerSpec{
PublicKey: "",
},
out: &mesh.Peer{},
out: &mesh.Peer{
Peer: wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &zero,
},
},
},
},
{
name: "valid key",
spec: v1alpha1.PeerSpec{
PublicKey: "foo",
PublicKey: fooKey.String(),
},
out: &mesh.Peer{
Peer: wireguard.Peer{
PublicKey: []byte("foo"),
PeerConfig: wgtypes.PeerConfig{
PublicKey: fooKey,
PersistentKeepaliveInterval: &zero,
},
},
},
},
@ -377,27 +500,38 @@ func TestTranslatePeer(t *testing.T) {
spec: v1alpha1.PeerSpec{
PersistentKeepalive: -1,
},
out: &mesh.Peer{},
out: &mesh.Peer{
Peer: wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &zero,
},
},
},
},
{
name: "valid keepalive",
spec: v1alpha1.PeerSpec{
PersistentKeepalive: 1,
PersistentKeepalive: 1 * int(time.Second),
},
out: &mesh.Peer{
Peer: wireguard.Peer{
PersistentKeepalive: 1,
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &second,
},
},
},
},
{
name: "valid preshared key",
spec: v1alpha1.PeerSpec{
PresharedKey: "psk",
PresharedKey: pskKey.String(),
},
out: &mesh.Peer{
Peer: wireguard.Peer{
PresharedKey: []byte("psk"),
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: &zero,
PresharedKey: pskKey,
},
},
},
},
@ -410,52 +544,3 @@ func TestTranslatePeer(t *testing.T) {
}
}
}
func TestParseEndpoint(t *testing.T) {
for _, tc := range []struct {
name string
endpoint string
out *wireguard.Endpoint
}{
{
name: "empty",
endpoint: "",
out: nil,
},
{
name: "invalid IP",
endpoint: "10.0.0.:51820",
out: nil,
},
{
name: "invalid hostname",
endpoint: "foo-:51820",
out: nil,
},
{
name: "invalid port",
endpoint: "10.0.0.1:100000000",
out: nil,
},
{
name: "valid IP",
endpoint: "10.0.0.1:51820",
out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort},
},
{
name: "valid IPv6",
endpoint: "[ff02::114]:51820",
out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("ff02::114")}, Port: mesh.DefaultKiloPort},
},
{
name: "valid hostname",
endpoint: "foo:51821",
out: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: "foo"}, Port: 51821},
},
} {
endpoint := parseEndpoint(tc.endpoint)
if diff := pretty.Compare(endpoint, tc.out); diff != "" {
t.Errorf("test case %q: got diff: %v", tc.name, diff)
}
}
}

View File

@ -18,6 +18,8 @@ import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/wireguard"
)
@ -55,7 +57,7 @@ const (
// Node represents a node in the network.
type Node struct {
Endpoint *wireguard.Endpoint
Key []byte
Key wgtypes.Key
NoInternalIP bool
InternalIP *net.IPNet
// LastSeen is a Unix time for the last time
@ -66,18 +68,23 @@ type Node struct {
Leader bool
Location string
Name string
PersistentKeepalive int
PersistentKeepalive time.Duration
Subnet *net.IPNet
WireGuardIP *net.IPNet
DiscoveredEndpoints map[string]*wireguard.Endpoint
AllowedLocationIPs []*net.IPNet
// DiscoveredEndpoints cannot be DNS endpoints, only net.UDPAddr.
DiscoveredEndpoints map[string]*net.UDPAddr
AllowedLocationIPs []net.IPNet
Granularity Granularity
}
// Ready indicates whether or not the node is ready.
func (n *Node) Ready() bool {
// Nodes that are not leaders will not have WireGuardIPs, so it is not required.
return n != nil && n.Endpoint != nil && !(n.Endpoint.IP == nil && n.Endpoint.DNS == "") && n.Endpoint.Port != 0 && n.Key != nil && n.Subnet != nil && time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second)
return n != nil &&
n.Endpoint.Ready() &&
n.Key != wgtypes.Key{} &&
n.Subnet != nil &&
time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second)
}
// Peer represents a peer in the network.
@ -92,7 +99,10 @@ type Peer struct {
// will not declare their endpoint and instead allow it to be
// discovered.
func (p *Peer) Ready() bool {
return p != nil && p.AllowedIPs != nil && len(p.AllowedIPs) != 0 && p.PublicKey != nil
return p != nil &&
p.AllowedIPs != nil &&
len(p.AllowedIPs) != 0 &&
p.PublicKey != wgtypes.Key{} // If Key was not set, it will be wgtypes.Key{}.
}
// EventType describes what kind of an action an event represents.

View File

@ -1,4 +1,4 @@
// Copyright 2019 the Kilo authors
// Copyright 2021 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import (
"strings"
"github.com/awalterschulze/gographviz"
"github.com/squat/kilo/pkg/wireguard"
)
@ -166,8 +167,9 @@ func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoi
if wgIP != nil {
label = append(label, wgIP.String())
}
if endpoint != nil {
label = append(label, endpoint.String())
str := endpoint.String()
if str != "" {
label = append(label, str)
}
return graphEscape(strings.Join(label, "\\n"))
}

View File

@ -1,4 +1,4 @@
// Copyright 2019 the Kilo authors
// Copyright 2021 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -30,6 +30,8 @@ import (
"github.com/go-kit/kit/log/level"
"github.com/prometheus/client_golang/prometheus"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/encapsulation"
"github.com/squat/kilo/pkg/iproute"
@ -43,8 +45,6 @@ const (
kiloPath = "/var/lib/kilo"
// privateKeyPath is the filepath where the WireGuard private key is stored.
privateKeyPath = kiloPath + "/key"
// confPath is the filepath where the WireGuard configuration is stored.
confPath = kiloPath + "/conf"
)
// Mesh is able to create Kilo network meshes.
@ -60,12 +60,13 @@ type Mesh struct {
internalIP *net.IPNet
ipTables *iptables.Controller
kiloIface int
kiloIfaceName string
key []byte
local bool
port uint32
priv []byte
port int
priv wgtypes.Key
privIface int
pub []byte
pub wgtypes.Key
resyncPeriod time.Duration
iptablesForwardRule bool
stop chan struct{}
@ -88,23 +89,24 @@ type Mesh struct {
}
// New returns a new Mesh instance.
func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port uint32, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanUpIface bool, createIface bool, mtu uint, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, logger log.Logger) (*Mesh, error) {
func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port int, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanUpIface bool, createIface bool, mtu uint, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, logger log.Logger) (*Mesh, error) {
if err := os.MkdirAll(kiloPath, 0700); err != nil {
return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
}
private, err := ioutil.ReadFile(privateKeyPath)
private = bytes.Trim(private, "\n")
privateB, err := ioutil.ReadFile(privateKeyPath)
privateB = bytes.Trim(privateB, "\n")
private, err := wgtypes.ParseKey(string(privateB))
if err != nil {
level.Warn(logger).Log("msg", "no private key found on disk; generating one now")
if private, err = wireguard.GenKey(); err != nil {
if private, err = wgtypes.GenerateKey(); err != nil {
return nil, err
}
}
public, err := wireguard.PubKey(private)
public := private.PublicKey()
if err != nil {
return nil, err
}
if err := ioutil.WriteFile(privateKeyPath, private, 0600); err != nil {
if err := ioutil.WriteFile(privateKeyPath, []byte(private.String()), 0600); err != nil {
return nil, fmt.Errorf("failed to write private key to disk: %v", err)
}
cniIndex, err := cniDeviceIndex()
@ -168,6 +170,7 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit
internalIP: privateIP,
ipTables: ipTables,
kiloIface: kiloIface,
kiloIfaceName: iface,
nodes: make(map[string]*Node),
peers: make(map[string]*Peer),
port: port,
@ -314,7 +317,7 @@ func (m *Mesh) syncPeers(e *PeerEvent) {
var diff bool
m.mu.Lock()
// Peers are indexed by public key.
key := string(e.Peer.PublicKey)
key := e.Peer.PublicKey.String()
if !e.Peer.Ready() {
// Trace non ready peer with their presence in the mesh.
_, ok := m.peers[key]
@ -324,8 +327,8 @@ func (m *Mesh) syncPeers(e *PeerEvent) {
case AddEvent:
fallthrough
case UpdateEvent:
if e.Old != nil && key != string(e.Old.PublicKey) {
delete(m.peers, string(e.Old.PublicKey))
if e.Old != nil && key != e.Old.PublicKey.String() {
delete(m.peers, e.Old.PublicKey.String())
diff = true
}
if !peersAreEqual(m.peers[key], e.Peer) {
@ -367,8 +370,10 @@ func (m *Mesh) checkIn() {
func (m *Mesh) handleLocal(n *Node) {
// Allow the IPs to be overridden.
if n.Endpoint == nil || (n.Endpoint.DNS == "" && n.Endpoint.IP == nil) {
n.Endpoint = &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: m.externalIP.IP}, Port: m.port}
if !n.Endpoint.Ready() {
e := wireguard.NewEndpoint(m.externalIP.IP, m.port)
level.Info(m.logger).Log("msg", "overriding endpoint", "node", m.hostname, "old endpoint", n.Endpoint.String(), "new endpoint", e.String())
n.Endpoint = e
}
if n.InternalIP == nil && !n.NoInternalIP {
n.InternalIP = m.internalIP
@ -462,22 +467,26 @@ func (m *Mesh) applyTopology() {
m.errorCounter.WithLabelValues("apply").Inc()
return
}
// Find the old configuration.
oldConfDump, err := wireguard.ShowDump(link.Attrs().Name)
wgClient, err := wgctrl.New()
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
oldConf, err := wireguard.ParseDump(oldConfDump)
defer wgClient.Close()
// wgDevice is the current configuration of the wg interface.
wgDevice, err := wgClient.Device(m.kiloIfaceName)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger)
nodes[m.hostname].DiscoveredEndpoints = natEndpoints
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port(), m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
@ -489,19 +498,8 @@ func (m *Mesh) applyTopology() {
} else {
m.wireGuardIP = nil
}
conf := t.Conf()
buf, err := conf.Bytes()
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
if err := ioutil.WriteFile(confPath, buf, 0600); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
ipRules := t.Rules(m.cni, m.iptablesForwardRule)
// If we are handling local routes, ensure the local
// tunnel has an IP address and IPIP traffic is allowed.
if m.enc.Strategy() != encapsulation.Never && m.local {
@ -540,10 +538,12 @@ func (m *Mesh) applyTopology() {
}
// Setting the WireGuard configuration interrupts existing connections
// so only set the configuration if it has changed.
equal := conf.Equal(oldConf)
conf := t.Conf()
equal, diff := conf.Equal(wgDevice)
if !equal {
level.Info(m.logger).Log("msg", "WireGuard configurations are different")
if err := wireguard.SetConf(link.Attrs().Name, confPath); err != nil {
level.Info(m.logger).Log("msg", "WireGuard configurations are different", "diff", diff)
level.Debug(m.logger).Log("msg", "changing wg config", "config", conf.WGConfig())
if err := wgClient.ConfigureDevice(m.kiloIfaceName, conf.WGConfig()); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
@ -598,10 +598,6 @@ func (m *Mesh) cleanUp() {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up routes: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := os.Remove(confPath); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to delete configuration file: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if m.cleanUpIface {
if err := iproute.RemoveInterface(m.kiloIface); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err))
@ -629,12 +625,8 @@ func (m *Mesh) resolveEndpoints() error {
if !m.nodes[k].Ready() {
continue
}
// If the node is ready, then the endpoint is not nil
// but it may not have a DNS name.
if m.nodes[k].Endpoint.DNS == "" {
continue
}
if err := resolveEndpoint(m.nodes[k].Endpoint); err != nil {
// Resolve the Endpoint
if _, err := m.nodes[k].Endpoint.UDPAddr(true); err != nil {
return err
}
}
@ -645,33 +637,16 @@ func (m *Mesh) resolveEndpoints() error {
continue
}
// Peers may have nil endpoints.
if m.peers[k].Endpoint == nil || m.peers[k].Endpoint.DNS == "" {
if !m.peers[k].Endpoint.Ready() {
continue
}
if err := resolveEndpoint(m.peers[k].Endpoint); err != nil {
if _, err := m.peers[k].Endpoint.UDPAddr(true); err != nil {
return err
}
}
return nil
}
func resolveEndpoint(endpoint *wireguard.Endpoint) error {
ips, err := net.LookupIP(endpoint.DNS)
if err != nil {
return fmt.Errorf("failed to look up DNS name %q: %v", endpoint.DNS, err)
}
nets := make([]*net.IPNet, len(ips), len(ips))
for i := range ips {
nets[i] = oneAddressCIDR(ips[i])
}
sortIPs(nets)
if len(nets) == 0 {
return fmt.Errorf("did not find any addresses for DNS name %q", endpoint.DNS)
}
endpoint.IP = nets[0].IP
return nil
}
func isSelf(hostname string, node *Node) bool {
return node != nil && node.Name == hostname
}
@ -691,7 +666,18 @@ func nodesAreEqual(a, b *Node) bool {
// Ignore LastSeen when comparing equality we want to check if the nodes are
// equivalent. However, we do want to check if LastSeen has transitioned
// between valid and invalid.
return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive && discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints) && ipNetSlicesEqual(a.AllowedLocationIPs, b.AllowedLocationIPs) && a.Granularity == b.Granularity
return a.Key.String() == b.Key.String() &&
ipNetsEqual(a.WireGuardIP, b.WireGuardIP) &&
ipNetsEqual(a.InternalIP, b.InternalIP) &&
a.Leader == b.Leader &&
a.Location == b.Location &&
a.Name == b.Name &&
subnetsEqual(a.Subnet, b.Subnet) &&
a.Ready() == b.Ready() &&
a.PersistentKeepalive == b.PersistentKeepalive &&
discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints) &&
ipNetSlicesEqual(a.AllowedLocationIPs, b.AllowedLocationIPs) &&
a.Granularity == b.Granularity
}
func peersAreEqual(a, b *Peer) bool {
@ -710,11 +696,15 @@ func peersAreEqual(a, b *Peer) bool {
return false
}
for i := range a.AllowedIPs {
if !ipNetsEqual(a.AllowedIPs[i], b.AllowedIPs[i]) {
if !ipNetsEqual(&a.AllowedIPs[i], &b.AllowedIPs[i]) {
return false
}
}
return string(a.PublicKey) == string(b.PublicKey) && string(a.PresharedKey) == string(b.PresharedKey) && a.PersistentKeepalive == b.PersistentKeepalive
return a.PublicKey.String() == b.PublicKey.String() &&
(a.PresharedKey == nil) == (b.PresharedKey == nil) &&
(a.PresharedKey == nil || a.PresharedKey.String() == b.PresharedKey.String()) &&
(a.PersistentKeepaliveInterval == nil) == (b.PersistentKeepaliveInterval == nil) &&
(a.PersistentKeepaliveInterval == nil || a.PersistentKeepaliveInterval == b.PersistentKeepaliveInterval)
}
func ipNetsEqual(a, b *net.IPNet) bool {
@ -730,12 +720,12 @@ func ipNetsEqual(a, b *net.IPNet) bool {
return a.IP.Equal(b.IP)
}
func ipNetSlicesEqual(a, b []*net.IPNet) bool {
func ipNetSlicesEqual(a, b []net.IPNet) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !ipNetsEqual(a[i], b[i]) {
if !ipNetsEqual(&a[i], &b[i]) {
return false
}
}
@ -761,7 +751,7 @@ func subnetsEqual(a, b *net.IPNet) bool {
return true
}
func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
func discoveredEndpointsAreEqual(a, b map[string]*net.UDPAddr) bool {
if a == nil && b == nil {
return true
}
@ -772,7 +762,7 @@ func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
return false
}
for k := range a {
if !a[k].Equal(b[k], false) {
if a[k] != b[k] {
return false
}
}
@ -788,24 +778,26 @@ func linkByIndex(index int) (netlink.Link, error) {
}
// discoverNATEndpoints uses the node's WireGuard configuration to returns a list of the most recently discovered endpoints for all nodes and peers behind NAT so that they can roam.
func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf, logger log.Logger) map[string]*wireguard.Endpoint {
natEndpoints := make(map[string]*wireguard.Endpoint)
keys := make(map[string]*wireguard.Peer)
// Discovered endpionts will never be DNS names, because WireGuard will always resolve them to net.UDPAddr.
func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wgtypes.Device, logger log.Logger) map[string]*net.UDPAddr {
natEndpoints := make(map[string]*net.UDPAddr)
keys := make(map[string]wgtypes.Peer)
for i := range conf.Peers {
keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
keys[conf.Peers[i].PublicKey.String()] = conf.Peers[i]
}
for _, n := range nodes {
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false), "latest-handshake", peer.LatestHandshake)
if (peer.LatestHandshake != time.Time{}) {
natEndpoints[string(n.Key)] = peer.Endpoint
if peer, ok := keys[n.Key.String()]; ok && n.PersistentKeepalive != time.Duration(0) {
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", peer.Endpoint.String() == n.Endpoint.String(), "latest-handshake", peer.LastHandshakeTime)
// Don't update the endpoint, if there was never any handshake.
if !peer.LastHandshakeTime.Equal(time.Time{}) {
natEndpoints[n.Key.String()] = peer.Endpoint
}
}
}
for _, p := range peers {
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
if (peer.LatestHandshake != time.Time{}) {
natEndpoints[string(p.PublicKey)] = peer.Endpoint
if peer, ok := keys[p.PublicKey.String()]; ok && p.PersistentKeepaliveInterval != nil {
if !peer.LastHandshakeTime.Equal(time.Time{}) {
natEndpoints[p.PublicKey.String()] = peer.Endpoint
}
}
}

View File

@ -19,9 +19,21 @@ import (
"testing"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/wireguard"
)
func mustKey() wgtypes.Key {
if k, err := wgtypes.GeneratePrivateKey(); err != nil {
panic(err.Error())
} else {
return k
}
}
var key = mustKey()
func TestReady(t *testing.T) {
internalIP := oneAddressCIDR(net.ParseIP("1.1.1.1"))
externalIP := oneAddressCIDR(net.ParseIP("2.2.2.2"))
@ -44,7 +56,7 @@ func TestReady(t *testing.T) {
name: "empty endpoint",
node: &Node{
InternalIP: internalIP,
Key: []byte{},
Key: key,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
@ -52,9 +64,9 @@ func TestReady(t *testing.T) {
{
name: "empty endpoint IP",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(nil, DefaultKiloPort),
InternalIP: internalIP,
Key: []byte{},
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
@ -62,9 +74,9 @@ func TestReady(t *testing.T) {
{
name: "empty endpoint port",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}},
Endpoint: wireguard.NewEndpoint(externalIP.IP, 0),
InternalIP: internalIP,
Key: []byte{},
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
@ -72,8 +84,8 @@ func TestReady(t *testing.T) {
{
name: "empty internal IP",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
Key: []byte{},
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
ready: false,
@ -81,7 +93,7 @@ func TestReady(t *testing.T) {
{
name: "empty key",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
InternalIP: internalIP,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
@ -90,18 +102,18 @@ func TestReady(t *testing.T) {
{
name: "empty subnet",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
InternalIP: internalIP,
Key: []byte{},
Key: wgtypes.Key{},
},
ready: false,
},
{
name: "valid",
node: &Node{
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
InternalIP: internalIP,
Key: []byte{},
Key: key,
LastSeen: time.Now().Unix(),
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},

View File

@ -40,7 +40,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
var gw net.IP
for _, segment := range t.segments {
if segment.location == t.location {
gw = enc.Gw(segment.endpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
gw = enc.Gw(segment.endpoint.IP(), segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
break
}
}
@ -113,7 +113,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
// we need to set routes for allowed location IPs over the leader in the current location.
for i := range segment.allowedLocationIPs {
routes = append(routes, encapsulateRoute(&netlink.Route{
Dst: segment.allowedLocationIPs[i],
Dst: &segment.allowedLocationIPs[i],
Flags: int(netlink.FLAG_ONLINK),
Gw: gw,
LinkIndex: privIface,
@ -125,7 +125,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
for _, peer := range t.peers {
for i := range peer.AllowedIPs {
routes = append(routes, encapsulateRoute(&netlink.Route{
Dst: peer.AllowedIPs[i],
Dst: &peer.AllowedIPs[i],
Flags: int(netlink.FLAG_ONLINK),
Gw: gw,
LinkIndex: privIface,
@ -196,7 +196,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
// equals the external IP. This means that the node
// is only accessible through an external IP and we
// cannot encapsulate traffic to an IP through the IP.
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP) {
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP()) {
continue
}
// Add routes to the private IPs of nodes in other segments.
@ -214,7 +214,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
// we need to set routes for allowed location IPs over the wg interface.
for i := range segment.allowedLocationIPs {
routes = append(routes, &netlink.Route{
Dst: segment.allowedLocationIPs[i],
Dst: &segment.allowedLocationIPs[i],
Flags: int(netlink.FLAG_ONLINK),
Gw: segment.wireGuardIP,
LinkIndex: kiloIface,
@ -226,7 +226,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
for _, peer := range t.peers {
for i := range peer.AllowedIPs {
routes = append(routes, &netlink.Route{
Dst: peer.AllowedIPs[i],
Dst: &peer.AllowedIPs[i],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
})
@ -248,7 +248,7 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
if cni {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
// Some linux distros or docker will set forward DROP in the filter table.
// To still be able to have pod to pod communication we need to ALLOW packets from and to pod CIDRs within a location.
// Leader nodes will forward packets from all nodes within a location because they act as a gateway for them.
@ -258,30 +258,30 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
if s.location == t.location {
// Make sure packets to and from pod cidrs are not dropped in the forward chain.
for _, c := range s.cidrs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
}
// Make sure packets to and from allowed location IPs are not dropped in the forward chain.
for _, c := range s.allowedLocationIPs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
}
// Make sure packets to and from private IPs are not dropped in the forward chain.
for _, c := range s.privateIPs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
}
}
}
} else if iptablesForwardRule {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
}
}
for _, s := range t.segments {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(s.wireGuardIP)), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
for _, aip := range s.allowedIPs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
}
// Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd,
// Otherwise packets to these destinations will reach the destination, but never find their way back.
@ -289,7 +289,7 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
if t.location == s.location {
for _, alip := range s.allowedLocationIPs {
rules = append(rules,
iptables.NewRule(iptables.GetProtocol(len(alip.IP)), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
)
}
}
@ -297,8 +297,8 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
for _, p := range t.peers {
for _, aip := range p.AllowedIPs {
rules = append(rules,
iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
)
}
}

View File

@ -75,7 +75,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@ -89,17 +89,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -132,17 +132,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -196,21 +196,21 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
@ -266,24 +266,24 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["d"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -309,7 +309,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@ -337,17 +337,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -394,17 +394,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -444,7 +444,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@ -458,17 +458,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -509,7 +509,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@ -523,17 +523,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -574,7 +574,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@ -588,17 +588,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -639,17 +639,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -698,17 +698,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -782,21 +782,21 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: privIface,
@ -868,21 +868,21 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: tunlIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: tunlIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: nodes["b"].InternalIP.IP,
LinkIndex: tunlIface,
@ -918,7 +918,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@ -946,17 +946,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -1004,17 +1004,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
@ -1055,7 +1055,7 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: nodes["b"].AllowedLocationIPs[0],
Dst: &nodes["b"].AllowedLocationIPs[0],
Flags: int(netlink.FLAG_ONLINK),
Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[1].wireGuardIP,
LinkIndex: kiloIface,
@ -1069,17 +1069,17 @@ func TestRoutes(t *testing.T) {
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[0],
Dst: &peers["a"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["a"].AllowedIPs[1],
Dst: &peers["a"].AllowedIPs[1],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},
{
Dst: peers["b"].AllowedIPs[0],
Dst: &peers["b"].AllowedIPs[0],
LinkIndex: kiloIface,
Protocol: unix.RTPROT_STATIC,
},

View File

@ -1,4 +1,4 @@
// Copyright 2019 the Kilo authors
// Copyright 2021 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -18,9 +18,11 @@ import (
"errors"
"net"
"sort"
"time"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/wireguard"
)
@ -33,8 +35,8 @@ const (
// Topology represents the logical structure of the overlay network.
type Topology struct {
// key is the private key of the node creating the topology.
key []byte
port uint32
key wgtypes.Key
port int
// Location is the logical location of the local host.
location string
segments []*segment
@ -47,7 +49,7 @@ type Topology struct {
leader bool
// persistentKeepalive is the interval in seconds of the emission
// of keepalive packets by the local node to its peers.
persistentKeepalive int
persistentKeepalive time.Duration
// privateIP is the private IP address of the local node.
privateIP *net.IPNet
// subnet is the Pod subnet of the local node.
@ -59,15 +61,15 @@ type Topology struct {
// is equal to the Kilo subnet.
wireGuardCIDR *net.IPNet
// discoveredEndpoints is the updated map of valid discovered Endpoints
discoveredEndpoints map[string]*wireguard.Endpoint
discoveredEndpoints map[string]*net.UDPAddr
logger log.Logger
}
type segment struct {
allowedIPs []*net.IPNet
allowedIPs []net.IPNet
endpoint *wireguard.Endpoint
key []byte
persistentKeepalive int
key wgtypes.Key
persistentKeepalive time.Duration
// Location is the logical location of this segment.
location string
@ -85,11 +87,11 @@ type segment struct {
// allowedLocationIPs are not part of the cluster and are not peers.
// They are directly routable from nodes within the segment.
// A classic example is a printer that ought to be routable from other locations.
allowedLocationIPs []*net.IPNet
allowedLocationIPs []net.IPNet
}
// NewTopology creates a new Topology struct from a given set of nodes and peers.
func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int, logger log.Logger) (*Topology, error) {
func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration, logger log.Logger) (*Topology, error) {
if logger == nil {
logger = log.NewNopLogger()
}
@ -120,7 +122,18 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
localLocation = nodeLocationPrefix + hostname
}
t := Topology{key: key, port: port, hostname: hostname, location: localLocation, persistentKeepalive: persistentKeepalive, privateIP: nodes[hostname].InternalIP, subnet: nodes[hostname].Subnet, wireGuardCIDR: subnet, discoveredEndpoints: make(map[string]*wireguard.Endpoint), logger: logger}
t := Topology{
key: key,
port: port,
hostname: hostname,
location: localLocation,
persistentKeepalive: persistentKeepalive,
privateIP: nodes[hostname].InternalIP,
subnet: nodes[hostname].Subnet,
wireGuardCIDR: subnet,
discoveredEndpoints: make(map[string]*net.UDPAddr),
logger: logger,
}
for location := range topoMap {
// Sort the location so the result is stable.
sort.Slice(topoMap[location], func(i, j int) bool {
@ -130,9 +143,9 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
if location == localLocation && topoMap[location][leader].Name == hostname {
t.leader = true
}
var allowedIPs []*net.IPNet
var allowedIPs []net.IPNet
allowedLocationIPsMap := make(map[string]struct{})
var allowedLocationIPs []*net.IPNet
var allowedLocationIPs []net.IPNet
var cidrs []*net.IPNet
var hostnames []string
var privateIPs []net.IP
@ -142,7 +155,9 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
// - the node's WireGuard IP
// - the node's internal IP
// - IPs that were specified by the allowed-location-ips annotation
allowedIPs = append(allowedIPs, node.Subnet)
if node.Subnet != nil {
allowedIPs = append(allowedIPs, *node.Subnet)
}
for _, ip := range node.AllowedLocationIPs {
if _, ok := allowedLocationIPsMap[ip.String()]; !ok {
allowedLocationIPs = append(allowedLocationIPs, ip)
@ -150,7 +165,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
}
}
if node.InternalIP != nil {
allowedIPs = append(allowedIPs, oneAddressCIDR(node.InternalIP.IP))
allowedIPs = append(allowedIPs, *oneAddressCIDR(node.InternalIP.IP))
privateIPs = append(privateIPs, node.InternalIP.IP)
}
cidrs = append(cidrs, node.Subnet)
@ -202,7 +217,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
return nil, errors.New("failed to allocate an IP address; ran out of IP addresses")
}
segment.wireGuardIP = ipNet.IP
segment.allowedIPs = append(segment.allowedIPs, oneAddressCIDR(ipNet.IP))
segment.allowedIPs = append(segment.allowedIPs, *oneAddressCIDR(ipNet.IP))
if t.leader && segment.location == t.location {
t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: subnet.Mask}
}
@ -224,11 +239,11 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
return &t, nil
}
func intersect(n1, n2 *net.IPNet) bool {
func intersect(n1, n2 net.IPNet) bool {
return n1.Contains(n2.IP) || n2.Contains(n1.IP)
}
func (t *Topology) filterAllowedLocationIPs(ips []*net.IPNet, location string) (ret []*net.IPNet) {
func (t *Topology) filterAllowedLocationIPs(ips []net.IPNet, location string) (ret []net.IPNet) {
CheckIPs:
for _, ip := range ips {
for _, s := range t.segments {
@ -270,14 +285,14 @@ CheckIPs:
return
}
func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key []byte, persistentKeepalive int) *wireguard.Endpoint {
func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *wireguard.Endpoint {
// Do not update non-nat peers
if persistentKeepalive == 0 {
if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) {
return endpoint
}
e, ok := t.discoveredEndpoints[string(key)]
e, ok := t.discoveredEndpoints[key.String()]
if ok {
return e
return wireguard.NewEndpointFromUDPAddr(e)
}
return endpoint
}
@ -285,30 +300,37 @@ func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key []byte, pers
// Conf generates a WireGuard configuration file for a given Topology.
func (t *Topology) Conf() *wireguard.Conf {
c := &wireguard.Conf{
Interface: &wireguard.Interface{
PrivateKey: t.key,
ListenPort: t.port,
Config: wgtypes.Config{
PrivateKey: &t.key,
ListenPort: &t.port,
ReplacePeers: true,
},
}
for _, s := range t.segments {
if s.location == t.location {
continue
}
peer := &wireguard.Peer{
AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
Endpoint: t.updateEndpoint(s.endpoint, s.key, s.persistentKeepalive),
PersistentKeepalive: t.persistentKeepalive,
PublicKey: s.key,
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
PersistentKeepaliveInterval: &t.persistentKeepalive,
PublicKey: s.key,
ReplaceAllowedIPs: true,
},
Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive),
}
c.Peers = append(c.Peers, peer)
}
for _, p := range t.peers {
peer := &wireguard.Peer{
AllowedIPs: p.AllowedIPs,
Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepalive),
PersistentKeepalive: t.persistentKeepalive,
PresharedKey: p.PresharedKey,
PublicKey: p.PublicKey,
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: p.AllowedIPs,
PersistentKeepaliveInterval: &t.persistentKeepalive,
PresharedKey: p.PresharedKey,
PublicKey: p.PublicKey,
ReplaceAllowedIPs: true,
},
Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval),
}
c.Peers = append(c.Peers, peer)
}
@ -322,34 +344,39 @@ func (t *Topology) AsPeer() *wireguard.Peer {
if s.location != t.location {
continue
}
return &wireguard.Peer{
AllowedIPs: s.allowedIPs,
Endpoint: s.endpoint,
PublicKey: s.key,
p := &wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: s.allowedIPs,
PublicKey: s.key,
},
Endpoint: s.endpoint,
}
return p
}
return nil
}
// PeerConf generates a WireGuard configuration file for a given peer in a Topology.
func (t *Topology) PeerConf(name string) *wireguard.Conf {
var pka int
var psk []byte
var pka *time.Duration
var psk *wgtypes.Key
for i := range t.peers {
if t.peers[i].Name == name {
pka = t.peers[i].PersistentKeepalive
pka = t.peers[i].PersistentKeepaliveInterval
psk = t.peers[i].PresharedKey
break
}
}
c := &wireguard.Conf{}
for _, s := range t.segments {
peer := &wireguard.Peer{
AllowedIPs: s.allowedIPs,
Endpoint: s.endpoint,
PersistentKeepalive: pka,
PresharedKey: psk,
PublicKey: s.key,
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: s.allowedIPs,
PersistentKeepaliveInterval: pka,
PresharedKey: psk,
PublicKey: s.key,
},
Endpoint: s.endpoint,
}
c.Peers = append(c.Peers, peer)
}
@ -357,11 +384,13 @@ func (t *Topology) PeerConf(name string) *wireguard.Conf {
if t.peers[i].Name == name {
continue
}
peer := &wireguard.Peer{
AllowedIPs: t.peers[i].AllowedIPs,
PersistentKeepalive: pka,
PublicKey: t.peers[i].PublicKey,
Endpoint: t.peers[i].Endpoint,
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: t.peers[i].AllowedIPs,
PersistentKeepaliveInterval: pka,
PublicKey: t.peers[i].PublicKey,
},
Endpoint: t.peers[i].Endpoint,
}
c.Peers = append(c.Peers, peer)
}
@ -382,13 +411,13 @@ func findLeader(nodes []*Node) int {
var leaders, public []int
for i := range nodes {
if nodes[i].Leader {
if isPublic(nodes[i].Endpoint.IP) {
if isPublic(nodes[i].Endpoint.IP()) {
return i
}
leaders = append(leaders, i)
}
if isPublic(nodes[i].Endpoint.IP) {
if nodes[i].Endpoint.IP() != nil && isPublic(nodes[i].Endpoint.IP()) {
public = append(public, i)
}
}
@ -408,10 +437,12 @@ func deduplicatePeerIPs(peers []*Peer) []*Peer {
p := Peer{
Name: peer.Name,
Peer: wireguard.Peer{
Endpoint: peer.Endpoint,
PersistentKeepalive: peer.PersistentKeepalive,
PresharedKey: peer.PresharedKey,
PublicKey: peer.PublicKey,
PeerConfig: wgtypes.PeerConfig{
PersistentKeepaliveInterval: peer.PersistentKeepaliveInterval,
PresharedKey: peer.PresharedKey,
PublicKey: peer.PublicKey,
},
Endpoint: peer.Endpoint,
},
}
for _, ip := range peer.AllowedIPs {

View File

@ -18,9 +18,11 @@ import (
"net"
"strings"
"testing"
"time"
"github.com/go-kit/kit/log"
"github.com/kylelemons/godebug/pretty"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/wireguard"
)
@ -29,17 +31,25 @@ func allowedIPs(ips ...string) string {
return strings.Join(ips, ", ")
}
func mustParseCIDR(s string) (r *net.IPNet) {
func mustParseCIDR(s string) (r net.IPNet) {
if _, ip, err := net.ParseCIDR(s); err != nil {
panic("failed to parse CIDR")
} else {
r = ip
r = *ip
}
return
}
func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) {
key := []byte("private")
var (
key1 = wgtypes.Key{'k', 'e', 'y', '1'}
key2 = wgtypes.Key{'k', 'e', 'y', '2'}
key3 = wgtypes.Key{'k', 'e', 'y', '3'}
key4 = wgtypes.Key{'k', 'e', 'y', '4'}
key5 = wgtypes.Key{'k', 'e', 'y', '5'}
)
func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) {
key := wgtypes.Key{'p', 'r', 'i', 'v'}
e1 := &net.IPNet{IP: net.ParseIP("10.1.0.1").To4(), Mask: net.CIDRMask(16, 32)}
e2 := &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(16, 32)}
e3 := &net.IPNet{IP: net.ParseIP("10.1.0.3").To4(), Mask: net.CIDRMask(16, 32)}
@ -50,62 +60,63 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) {
nodes := map[string]*Node{
"a": {
Name: "a",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
InternalIP: i1,
Location: "1",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
Key: []byte("key1"),
Key: key1,
PersistentKeepalive: 25,
},
"b": {
Name: "b",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
InternalIP: i1,
Location: "2",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)},
Key: []byte("key2"),
AllowedLocationIPs: []*net.IPNet{i3},
Key: key2,
AllowedLocationIPs: []net.IPNet{*i3},
},
"c": {
Name: "c",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e3.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e3.IP, DefaultKiloPort),
InternalIP: i2,
// Same location as node b.
Location: "2",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.3.0"), Mask: net.CIDRMask(24, 32)},
Key: []byte("key3"),
Key: key3,
},
"d": {
Name: "d",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e4.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e4.IP, DefaultKiloPort),
// Same location as node a, but without private IP
Location: "1",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.4.0"), Mask: net.CIDRMask(24, 32)},
Key: []byte("key4"),
Key: key4,
},
}
peers := map[string]*Peer{
"a": {
Name: "a",
Peer: wireguard.Peer{
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.5.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.5.0.2"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.5.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.5.0.2"), Mask: net.CIDRMask(24, 32)},
},
PublicKey: key4,
},
PublicKey: []byte("key4"),
},
},
"b": {
Name: "b",
Peer: wireguard.Peer{
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)},
},
PublicKey: key5,
},
Endpoint: &wireguard.Endpoint{
DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("192.168.0.1")},
Port: DefaultKiloPort,
},
PublicKey: []byte("key5"),
Endpoint: wireguard.NewEndpoint(net.ParseIP("192.168.0.1"), DefaultKiloPort),
},
},
}
@ -138,7 +149,7 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
@ -149,7 +160,7 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
@ -161,7 +172,7 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
@ -189,7 +200,7 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
@ -200,7 +211,7 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
@ -212,7 +223,7 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
@ -240,7 +251,7 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: DefaultKiloSubnet,
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
@ -251,7 +262,7 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
@ -263,7 +274,7 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
@ -291,7 +302,7 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
@ -302,7 +313,7 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
@ -314,7 +325,7 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["c"].Endpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
@ -325,7 +336,7 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w3,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
@ -353,7 +364,7 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
@ -364,7 +375,7 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
@ -376,7 +387,7 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["c"].Endpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
@ -387,7 +398,7 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w3,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
@ -415,7 +426,7 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w3, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
@ -426,7 +437,7 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
@ -438,7 +449,7 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["c"].Endpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
@ -449,7 +460,7 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w3,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
@ -477,7 +488,7 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w4, Mask: net.CIDRMask(16, 32)},
segments: []*segment{
{
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key,
persistentKeepalive: nodes["a"].PersistentKeepalive,
@ -488,7 +499,7 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w1,
},
{
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["b"].Endpoint,
key: nodes["b"].Key,
persistentKeepalive: nodes["b"].PersistentKeepalive,
@ -500,7 +511,7 @@ func TestNewTopology(t *testing.T) {
allowedLocationIPs: nodes["b"].AllowedLocationIPs,
},
{
allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["c"].Endpoint,
key: nodes["c"].Key,
persistentKeepalive: nodes["c"].PersistentKeepalive,
@ -511,7 +522,7 @@ func TestNewTopology(t *testing.T) {
wireGuardIP: w3,
},
{
allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["d"].Endpoint,
key: nodes["d"].Key,
persistentKeepalive: nodes["d"].PersistentKeepalive,
@ -539,7 +550,7 @@ func TestNewTopology(t *testing.T) {
}
}
func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int) *Topology {
func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration) *Topology {
topo, err := NewTopology(nodes, peers, granularity, hostname, port, key, subnet, persistentKeepalive, nil)
if err != nil {
t.Errorf("failed to generate Topology: %v", err)
@ -547,211 +558,6 @@ func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, gran
return topo
}
func TestConf(t *testing.T) {
nodes, peers, key, port := setup(t)
for _, tc := range []struct {
name string
topology *Topology
result string
}{
{
name: "logical from a",
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key2
Endpoint = 10.1.0.2:51820
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32, 192.168.178.3/32
PersistentKeepalive = 25
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.3/32
PersistentKeepalive = 25
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
PersistentKeepalive = 25
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
PersistentKeepalive = 25
`,
},
{
name: "logical from b",
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet, nodes["b"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.3/32
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
`,
},
{
name: "logical from c",
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet, nodes["c"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.3/32
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
`,
},
{
name: "full from a",
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key2
Endpoint = 10.1.0.2:51820
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32, 192.168.178.3/32
PersistentKeepalive = 25
[Peer]
PublicKey = key3
Endpoint = 10.1.0.3:51820
AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32
PersistentKeepalive = 25
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.4/32
PersistentKeepalive = 25
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
PersistentKeepalive = 25
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
PersistentKeepalive = 25
`,
},
{
name: "full from b",
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet, nodes["b"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer]
PublicKey = key3
Endpoint = 10.1.0.3:51820
AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.4/32
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
`,
},
{
name: "full from c",
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet, nodes["c"].PersistentKeepalive),
result: `[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key1
Endpoint = 10.1.0.1:51820
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer]
PublicKey = key2
Endpoint = 10.1.0.2:51820
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32, 192.168.178.3/32
[Peer]
PublicKey = key4
Endpoint = 10.1.0.4:51820
AllowedIPs = 10.2.4.0/24, 10.4.0.4/32
[Peer]
PublicKey = key4
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer]
PublicKey = key5
Endpoint = 192.168.0.1:51820
AllowedIPs = 10.5.0.3/24
`,
},
} {
conf := tc.topology.Conf()
if !conf.Equal(wireguard.Parse([]byte(tc.result))) {
buf, err := conf.Bytes()
if err != nil {
t.Errorf("test case %q: failed to render conf: %v", tc.name, err)
}
t.Errorf("test case %q: expected %s got %s", tc.name, tc.result, string(buf))
}
}
}
func TestFindLeader(t *testing.T) {
ip, e1, err := net.ParseCIDR("10.0.0.1/32")
if err != nil {
@ -767,24 +573,24 @@ func TestFindLeader(t *testing.T) {
nodes := []*Node{
{
Name: "a",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
},
{
Name: "b",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
},
{
Name: "c",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
},
{
Name: "d",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
Leader: true,
},
{
Name: "2",
Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
Leader: true,
},
}
@ -840,31 +646,38 @@ func TestDeduplicatePeerIPs(t *testing.T) {
p1 := &Peer{
Name: "1",
Peer: wireguard.Peer{
PublicKey: []byte("key1"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key1,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
},
},
},
}
p2 := &Peer{
Name: "2",
Peer: wireguard.Peer{
PublicKey: []byte("key2"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key2,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
},
},
},
}
p3 := &Peer{
Name: "3",
Peer: wireguard.Peer{
PublicKey: []byte("key3"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key3,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
},
},
},
}
@ -872,10 +685,12 @@ func TestDeduplicatePeerIPs(t *testing.T) {
p4 := &Peer{
Name: "4",
Peer: wireguard.Peer{
PublicKey: []byte("key4"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key4,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
},
},
},
}
@ -898,9 +713,11 @@ func TestDeduplicatePeerIPs(t *testing.T) {
{
Name: "2",
Peer: wireguard.Peer{
PublicKey: []byte("key2"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key2,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
},
},
},
},
@ -914,9 +731,11 @@ func TestDeduplicatePeerIPs(t *testing.T) {
{
Name: "1",
Peer: wireguard.Peer{
PublicKey: []byte("key1"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key1,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
},
},
},
},
@ -930,19 +749,25 @@ func TestDeduplicatePeerIPs(t *testing.T) {
{
Name: "2",
Peer: wireguard.Peer{
PublicKey: []byte("key2"),
PeerConfig: wgtypes.PeerConfig{
PublicKey: key2,
},
},
},
{
Name: "1",
Peer: wireguard.Peer{
PublicKey: []byte("key1"),
PeerConfig: wgtypes.PeerConfig{
PublicKey: key1,
},
},
},
{
Name: "4",
Peer: wireguard.Peer{
PublicKey: []byte("key4"),
PeerConfig: wgtypes.PeerConfig{
PublicKey: key4,
},
},
},
},
@ -954,19 +779,23 @@ func TestDeduplicatePeerIPs(t *testing.T) {
{
Name: "4",
Peer: wireguard.Peer{
PublicKey: []byte("key4"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key4,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)},
},
},
},
},
{
Name: "1",
Peer: wireguard.Peer{
PublicKey: []byte("key1"),
AllowedIPs: []*net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
PeerConfig: wgtypes.PeerConfig{
PublicKey: key1,
AllowedIPs: []net.IPNet{
{IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)},
},
},
},
},
@ -985,12 +814,12 @@ func TestFilterAllowedIPs(t *testing.T) {
topo := mustTopo(t, nodes, peers, LogicalGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive)
for _, tc := range []struct {
name string
allowedLocationIPs map[int][]*net.IPNet
result map[int][]*net.IPNet
allowedLocationIPs map[int][]net.IPNet
result map[int][]net.IPNet
}{
{
name: "nothing to filter",
allowedLocationIPs: map[int][]*net.IPNet{
allowedLocationIPs: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.4/32"),
},
@ -1002,7 +831,7 @@ func TestFilterAllowedIPs(t *testing.T) {
mustParseCIDR("192.168.178.7/32"),
},
},
result: map[int][]*net.IPNet{
result: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.4/32"),
},
@ -1017,7 +846,7 @@ func TestFilterAllowedIPs(t *testing.T) {
},
{
name: "intersections between segments",
allowedLocationIPs: map[int][]*net.IPNet{
allowedLocationIPs: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.4/32"),
mustParseCIDR("192.168.178.8/32"),
@ -1031,7 +860,7 @@ func TestFilterAllowedIPs(t *testing.T) {
mustParseCIDR("192.168.178.4/32"),
},
},
result: map[int][]*net.IPNet{
result: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.8/32"),
},
@ -1047,7 +876,7 @@ func TestFilterAllowedIPs(t *testing.T) {
},
{
name: "intersections with wireGuardCIDR",
allowedLocationIPs: map[int][]*net.IPNet{
allowedLocationIPs: map[int][]net.IPNet{
0: {
mustParseCIDR("10.4.0.1/32"),
mustParseCIDR("192.168.178.8/32"),
@ -1060,7 +889,7 @@ func TestFilterAllowedIPs(t *testing.T) {
mustParseCIDR("192.168.178.7/32"),
},
},
result: map[int][]*net.IPNet{
result: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.8/32"),
},
@ -1075,7 +904,7 @@ func TestFilterAllowedIPs(t *testing.T) {
},
{
name: "intersections with more than one allowedLocationIPs",
allowedLocationIPs: map[int][]*net.IPNet{
allowedLocationIPs: map[int][]net.IPNet{
0: {
mustParseCIDR("192.168.178.8/32"),
},
@ -1086,7 +915,7 @@ func TestFilterAllowedIPs(t *testing.T) {
mustParseCIDR("192.168.178.7/24"),
},
},
result: map[int][]*net.IPNet{
result: map[int][]net.IPNet{
0: {},
1: {},
2: {

View File

@ -15,16 +15,15 @@
package wireguard
import (
"bufio"
"bytes"
"errors"
"fmt"
"net"
"sort"
"strconv"
"strings"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"k8s.io/apimachinery/pkg/util/validation"
)
@ -32,10 +31,6 @@ type section string
type key string
const (
separator = "="
dumpSeparator = "\t"
dumpNone = "(none)"
dumpOff = "off"
interfaceSection section = "Interface"
peerSection section = "Peer"
listenPortKey key = "ListenPort"
@ -47,56 +42,209 @@ const (
publicKeyKey key = "PublicKey"
)
type dumpInterfaceIndex int
const (
dumpInterfacePrivateKeyIndex = iota
dumpInterfacePublicKeyIndex
dumpInterfaceListenPortIndex
dumpInterfaceFWMarkIndex
dumpInterfaceLen
)
type dumpPeerIndex int
const (
dumpPeerPublicKeyIndex = iota
dumpPeerPresharedKeyIndex
dumpPeerEndpointIndex
dumpPeerAllowedIPsIndex
dumpPeerLatestHandshakeIndex
dumpPeerTransferRXIndex
dumpPeerTransferTXIndex
dumpPeerPersistentKeepaliveIndex
dumpPeerLen
)
// Conf represents a WireGuard configuration file.
type Conf struct {
Interface *Interface
Peers []*Peer
wgtypes.Config
// The Peers field is shadowed because every Peer needs the Endpoint field that contains a DNS endpoint.
Peers []Peer
}
// Interface represents the `interface` section of a WireGuard configuration.
type Interface struct {
ListenPort uint32
PrivateKey []byte
// WGConfig returns a wgytpes.Config from a Conf.
func (c *Conf) WGConfig() wgtypes.Config {
if c == nil {
// The empty Config will do nothing, when applied.
return wgtypes.Config{}
}
r := c.Config
wgPs := make([]wgtypes.PeerConfig, len(c.Peers))
for i, p := range c.Peers {
wgPs[i] = p.PeerConfig
if p.Endpoint.Resolved() {
// We can ingore the error because we already checked if the Endpoint was resolved in the above line.
wgPs[i].Endpoint, _ = p.Endpoint.UDPAddr(false)
}
wgPs[i].ReplaceAllowedIPs = true
}
r.Peers = wgPs
r.ReplacePeers = true
return r
}
// Endpoint represents a WireGuard endpoint.
type Endpoint struct {
udpAddr *net.UDPAddr
addr string
}
// ParseEndpoint returns an Endpoint from a string.
// The input should look like "10.0.0.0:100", "[ff10::10]:100"
// or "example.com:100".
func ParseEndpoint(endpoint string) *Endpoint {
if len(endpoint) == 0 {
return nil
}
hostRaw, portRaw, err := net.SplitHostPort(endpoint)
if err != nil {
return nil
}
port, err := strconv.ParseUint(portRaw, 10, 32)
if err != nil {
return nil
}
if len(validation.IsValidPortNum(int(port))) != 0 {
return nil
}
ip := net.ParseIP(hostRaw)
if ip == nil {
if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 {
return &Endpoint{
addr: endpoint,
}
}
return nil
}
// ResolveUDPAddr will not resolve the endpoint as long as a valid IP and port is given.
// This should be the case here.
u, err := net.ResolveUDPAddr("udp", endpoint)
if err != nil {
return nil
}
u.IP = cutIP(u.IP)
return &Endpoint{
udpAddr: u,
}
}
// NewEndpointFromUDPAddr returns an Endpoint from a net.UDPAddr.
func NewEndpointFromUDPAddr(u *net.UDPAddr) *Endpoint {
if u != nil {
u.IP = cutIP(u.IP)
}
return &Endpoint{
udpAddr: u,
}
}
// NewEndpoint returns an Endpoint from a net.IP and port.
func NewEndpoint(ip net.IP, port int) *Endpoint {
return &Endpoint{
udpAddr: &net.UDPAddr{
IP: cutIP(ip),
Port: port,
},
}
}
// Ready return true, if the Enpoint is ready.
// Ready means that an IP or DN and port exists.
func (e *Endpoint) Ready() bool {
if e == nil {
return false
}
return (e.udpAddr != nil && e.udpAddr.IP != nil && e.udpAddr.Port > 0) || len(e.addr) > 0
}
// Port returns the port of the Endpoint.
func (e *Endpoint) Port() int {
if !e.Ready() {
return 0
}
if e.udpAddr != nil {
return e.udpAddr.Port
}
// We can ignore the errors here bacause the returned port will be "".
// This will result to Port 0 after the conversion to and int.
_, p, _ := net.SplitHostPort(e.addr)
port, _ := strconv.ParseUint(p, 10, 32)
return int(port)
}
// HasDNS returns true if the endpoint has a DN.
func (e *Endpoint) HasDNS() bool {
return e != nil && e.addr != ""
}
// DNS returns the DN of the Endpoint.
func (e *Endpoint) DNS() string {
if e == nil {
return ""
}
_, s, _ := net.SplitHostPort(e.addr)
return s
}
// Resolved returns true, if the DN of the Endpoint was resolved
// or if the Endpoint has a resolved endpoint.
func (e *Endpoint) Resolved() bool {
return e != nil && e.udpAddr != nil
}
// UDPAddr returns the UDPAddr of the Endpoint. If resolve is false,
// UDPAddr() will not try to resolve a DN name, if the Endpoint is not yet resolved.
func (e *Endpoint) UDPAddr(resolve bool) (*net.UDPAddr, error) {
if !e.Ready() {
return nil, errors.New("Enpoint is not ready")
}
if e.udpAddr != nil {
// Make a copy of the UDPAddr to protect it from modification outside this package.
h := *e.udpAddr
return &h, nil
}
if !resolve {
return nil, errors.New("Endpoint is not resolved")
}
var err error
if e.udpAddr, err = net.ResolveUDPAddr("udp", e.addr); err != nil {
return nil, err
}
// Make a copy of the UDPAddr to protect it from modification outside this package.
h := *e.udpAddr
return &h, nil
}
// IP returns the IP address of the Enpoint or nil.
func (e *Endpoint) IP() net.IP {
if !e.Resolved() {
return nil
}
return e.udpAddr.IP
}
// String will return the endpoint as a string.
// If a DN exists, it will take prcedence over the resolved endpoint.
func (e *Endpoint) String() string {
return e.StringOpt(true)
}
// StringOpt will return the string of the Endpoint.
// If dnsFirst is false, the resolved Endpoint will
// take precedence over the DN.
func (e *Endpoint) StringOpt(dnsFirst bool) string {
if e == nil {
return ""
}
if e.udpAddr != nil && (!dnsFirst || e.addr == "") {
return e.udpAddr.String()
}
return e.addr
}
// Equal will return true, if the Enpoints are equal.
// If dnsFirst is false, the DN will only be compared if
// the IPs are nil.
func (e *Endpoint) Equal(b *Endpoint, dnsFirst bool) bool {
return e.StringOpt(dnsFirst) == b.StringOpt(dnsFirst)
}
// Peer represents a `peer` section of a WireGuard configuration.
type Peer struct {
AllowedIPs []*net.IPNet
Endpoint *Endpoint
PersistentKeepalive int
PresharedKey []byte
PublicKey []byte
// The following fields are part of the runtime information, not the configuration.
LatestHandshake time.Time
wgtypes.PeerConfig
Endpoint *Endpoint
}
// DeduplicateIPs eliminates duplicate allowed IPs.
func (p *Peer) DeduplicateIPs() {
var ips []*net.IPNet
var ips []net.IPNet
seen := make(map[string]struct{})
for _, ip := range p.AllowedIPs {
if _, ok := seen[ip.String()]; ok {
@ -108,181 +256,27 @@ func (p *Peer) DeduplicateIPs() {
p.AllowedIPs = ips
}
// Endpoint represents an `endpoint` key of a `peer` section.
type Endpoint struct {
DNSOrIP
Port uint32
}
// String prints the string representation of the endpoint.
func (e *Endpoint) String() string {
if e == nil {
return ""
}
dnsOrIP := e.DNSOrIP.String()
if e.IP != nil && len(e.IP) == net.IPv6len {
dnsOrIP = "[" + dnsOrIP + "]"
}
return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
}
// Equal compares two endpoints.
func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool {
if (e == nil) != (b == nil) {
return false
}
if e != nil {
if e.Port != b.Port {
return false
}
if DNSFirst {
// Check the DNS name first if it was resolved.
if e.DNS != b.DNS {
return false
}
if e.DNS == "" && !e.IP.Equal(b.IP) {
return false
}
} else {
// IPs take priority, so check them first.
if !e.IP.Equal(b.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if e.IP == nil && e.DNS != b.DNS {
return false
}
}
}
return true
}
// DNSOrIP represents either a DNS name or an IP address.
// IPs, as they are more specific, are preferred.
type DNSOrIP struct {
DNS string
IP net.IP
}
// String prints the string representation of the struct.
func (d DNSOrIP) String() string {
if d.IP != nil {
return d.IP.String()
}
return d.DNS
}
// Parse parses a given WireGuard configuration file and produces a Conf struct.
func Parse(buf []byte) *Conf {
var (
active section
kv []string
c Conf
err error
iface *Interface
i int
k key
line, v string
peer *Peer
port uint64
)
s := bufio.NewScanner(bytes.NewBuffer(buf))
for s.Scan() {
line = strings.TrimSpace(s.Text())
// Skip comments.
if strings.HasPrefix(line, "#") {
continue
}
// Line is a section title.
if strings.HasPrefix(line, "[") {
if peer != nil {
c.Peers = append(c.Peers, peer)
peer = nil
}
if iface != nil {
c.Interface = iface
iface = nil
}
active = section(strings.TrimSpace(strings.Trim(line, "[]")))
switch active {
case interfaceSection:
iface = new(Interface)
case peerSection:
peer = new(Peer)
}
continue
}
kv = strings.SplitN(line, separator, 2)
if len(kv) != 2 {
continue
}
k = key(strings.TrimSpace(kv[0]))
v = strings.TrimSpace(kv[1])
switch active {
case interfaceSection:
switch k {
case listenPortKey:
port, err = strconv.ParseUint(v, 10, 32)
if err != nil {
continue
}
iface.ListenPort = uint32(port)
case privateKeyKey:
iface.PrivateKey = []byte(v)
}
case peerSection:
switch k {
case allowedIPsKey:
err = peer.parseAllowedIPs(v)
if err != nil {
continue
}
case endpointKey:
err = peer.parseEndpoint(v)
if err != nil {
continue
}
case persistentKeepaliveKey:
i, err = strconv.Atoi(v)
if err != nil {
continue
}
peer.PersistentKeepalive = i
case presharedKeyKey:
peer.PresharedKey = []byte(v)
case publicKeyKey:
peer.PublicKey = []byte(v)
}
}
}
if peer != nil {
c.Peers = append(c.Peers, peer)
}
if iface != nil {
c.Interface = iface
}
return &c
}
// Bytes renders a WireGuard configuration to bytes.
func (c *Conf) Bytes() ([]byte, error) {
if c == nil {
return nil, nil
}
var err error
buf := bytes.NewBuffer(make([]byte, 0, 512))
if c.Interface != nil {
if c.PrivateKey != nil {
if err = writeSection(buf, interfaceSection); err != nil {
return nil, fmt.Errorf("failed to write interface: %v", err)
}
if err = writePKey(buf, privateKeyKey, c.Interface.PrivateKey); err != nil {
if err = writePKey(buf, privateKeyKey, c.PrivateKey); err != nil {
return nil, fmt.Errorf("failed to write private key: %v", err)
}
if err = writeValue(buf, listenPortKey, strconv.FormatUint(uint64(c.Interface.ListenPort), 10)); err != nil {
if err = writeValue(buf, listenPortKey, strconv.Itoa(*c.ListenPort)); err != nil {
return nil, fmt.Errorf("failed to write listen port: %v", err)
}
}
for i, p := range c.Peers {
// Add newlines to make the formatting nicer.
if i == 0 && c.Interface != nil || i != 0 {
if i == 0 && c.PrivateKey != nil || i != 0 {
if err = buf.WriteByte('\n'); err != nil {
return nil, err
}
@ -297,71 +291,103 @@ func (c *Conf) Bytes() ([]byte, error) {
if err = writeEndpoint(buf, p.Endpoint); err != nil {
return nil, fmt.Errorf("failed to write endpoint: %v", err)
}
if err = writeValue(buf, persistentKeepaliveKey, strconv.Itoa(p.PersistentKeepalive)); err != nil {
if p.PersistentKeepaliveInterval == nil {
p.PersistentKeepaliveInterval = new(time.Duration)
}
if err = writeValue(buf, persistentKeepaliveKey, strconv.FormatUint(uint64(*p.PersistentKeepaliveInterval/time.Second), 10)); err != nil {
return nil, fmt.Errorf("failed to write persistent keepalive: %v", err)
}
if err = writePKey(buf, presharedKeyKey, p.PresharedKey); err != nil {
return nil, fmt.Errorf("failed to write preshared key: %v", err)
}
if err = writePKey(buf, publicKeyKey, p.PublicKey); err != nil {
if err = writePKey(buf, publicKeyKey, &p.PublicKey); err != nil {
return nil, fmt.Errorf("failed to write public key: %v", err)
}
}
return buf.Bytes(), nil
}
// Equal checks if two WireGuard configurations are equivalent.
func (c *Conf) Equal(b *Conf) bool {
if (c.Interface == nil) != (b.Interface == nil) {
return false
// Equal returns true if the Conf and wgtypes.Device are equal.
func (c *Conf) Equal(d *wgtypes.Device) (bool, string) {
if c == nil || d == nil {
return c == nil && d == nil, "nil values"
}
if c.Interface != nil {
if c.Interface.ListenPort != b.Interface.ListenPort || !bytes.Equal(c.Interface.PrivateKey, b.Interface.PrivateKey) {
return false
}
if c.ListenPort == nil || *c.ListenPort != d.ListenPort {
return false, fmt.Sprintf("port: old=%q, new=\"%v\"", d.ListenPort, c.ListenPort)
}
if len(c.Peers) != len(b.Peers) {
return false
if c.PrivateKey == nil || *c.PrivateKey != d.PrivateKey {
return false, fmt.Sprintf("private key: old=\"%s...\", new=\"%s\"", d.PrivateKey.String()[0:5], c.PrivateKey.String()[0:5])
}
if len(c.Peers) != len(d.Peers) {
return false, fmt.Sprintf("number of peers: old=%d, new=%d", len(d.Peers), len(c.Peers))
}
sortPeerConfigs(d.Peers)
sortPeers(c.Peers)
sortPeers(b.Peers)
for i := range c.Peers {
if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) {
return false
if len(c.Peers[i].AllowedIPs) != len(d.Peers[i].AllowedIPs) {
return false, fmt.Sprintf("Peer %d allowed IP length: old=%d, new=%d", i, len(d.Peers[i].AllowedIPs), len(c.Peers[i].AllowedIPs))
}
sortCIDRs(c.Peers[i].AllowedIPs)
sortCIDRs(b.Peers[i].AllowedIPs)
sortCIDRs(d.Peers[i].AllowedIPs)
for j := range c.Peers[i].AllowedIPs {
if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() {
return false
if c.Peers[i].AllowedIPs[j].String() != d.Peers[i].AllowedIPs[j].String() {
return false, fmt.Sprintf("Peer %d allowed IP: old=%q, new=%q", i, d.Peers[i].AllowedIPs[j].String(), c.Peers[i].AllowedIPs[j].String())
}
}
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint, false) {
return false
if c.Peers[i].Endpoint == nil || d.Peers[i].Endpoint == nil {
return c.Peers[i].Endpoint == nil && d.Peers[i].Endpoint == nil, "peer endpoints: nil value"
}
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
return false
if c.Peers[i].Endpoint.StringOpt(false) != d.Peers[i].Endpoint.String() {
return false, fmt.Sprintf("Peer %d endpoint: old=%q, new=%q", i, d.Peers[i].Endpoint.String(), c.Peers[i].Endpoint.StringOpt(false))
}
pki := time.Duration(0)
if p := c.Peers[i].PersistentKeepaliveInterval; p != nil {
pki = *p
}
psk := wgtypes.Key{}
if p := c.Peers[i].PresharedKey; p != nil {
psk = *p
}
if pki != d.Peers[i].PersistentKeepaliveInterval || psk != d.Peers[i].PresharedKey || c.Peers[i].PublicKey != d.Peers[i].PublicKey {
return false, "persistent keepalive or pershared key"
}
}
return true
return true, ""
}
func sortPeers(peers []*Peer) {
func sortPeerConfigs(peers []wgtypes.Peer) {
sort.Slice(peers, func(i, j int) bool {
if bytes.Compare(peers[i].PublicKey, peers[j].PublicKey) < 0 {
if peers[i].PublicKey.String() < peers[j].PublicKey.String() {
return true
}
return false
})
}
func sortCIDRs(cidrs []*net.IPNet) {
func sortPeers(peers []Peer) {
sort.Slice(peers, func(i, j int) bool {
if peers[i].PublicKey.String() < peers[j].PublicKey.String() {
return true
}
return false
})
}
func sortCIDRs(cidrs []net.IPNet) {
sort.Slice(cidrs, func(i, j int) bool {
return cidrs[i].String() < cidrs[j].String()
})
}
func writeAllowedIPs(buf *bytes.Buffer, ais []*net.IPNet) error {
func cutIP(ip net.IP) net.IP {
if i4 := ip.To4(); i4 != nil {
return i4
}
return ip.To16()
}
func writeAllowedIPs(buf *bytes.Buffer, ais []net.IPNet) error {
if len(ais) == 0 {
return nil
}
@ -382,15 +408,16 @@ func writeAllowedIPs(buf *bytes.Buffer, ais []*net.IPNet) error {
return buf.WriteByte('\n')
}
func writePKey(buf *bytes.Buffer, k key, b []byte) error {
if len(b) == 0 {
func writePKey(buf *bytes.Buffer, k key, b *wgtypes.Key) error {
// Print nothing if the public key was never initialized.
if b == nil || (wgtypes.Key{}) == *b {
return nil
}
var err error
if err = writeKey(buf, k); err != nil {
return err
}
if _, err = buf.Write(b); err != nil {
if _, err = buf.Write([]byte(b.String())); err != nil {
return err
}
return buf.WriteByte('\n')
@ -408,14 +435,15 @@ func writeValue(buf *bytes.Buffer, k key, v string) error {
}
func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error {
if e == nil {
str := e.String()
if str == "" {
return nil
}
var err error
if err = writeKey(buf, endpointKey); err != nil {
return err
}
if _, err = buf.WriteString(e.String()); err != nil {
if _, err = buf.WriteString(str); err != nil {
return err
}
return buf.WriteByte('\n')
@ -443,177 +471,3 @@ func writeKey(buf *bytes.Buffer, k key) error {
_, err = buf.WriteString(" = ")
return err
}
var (
errParseEndpoint = errors.New("could not parse Endpoint")
)
func (p *Peer) parseEndpoint(v string) error {
var (
kv []string
err error
ip, ip4 net.IP
port uint64
)
kv = strings.Split(v, ":")
if len(kv) < 2 {
return errParseEndpoint
}
port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32)
if err != nil {
return err
}
d := DNSOrIP{}
ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]"))
if ip == nil {
if len(validation.IsDNS1123Subdomain(kv[0])) != 0 {
return errParseEndpoint
}
d.DNS = kv[0]
} else {
if ip4 = ip.To4(); ip4 != nil {
d.IP = ip4
} else {
d.IP = ip.To16()
}
}
p.Endpoint = &Endpoint{
DNSOrIP: d,
Port: uint32(port),
}
return nil
}
func (p *Peer) parseAllowedIPs(v string) error {
var (
ai *net.IPNet
kv []string
err error
i int
ip, ip4 net.IP
)
kv = strings.Split(v, ",")
for i = range kv {
ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i]))
if err != nil {
return err
}
if ip4 = ip.To4(); ip4 != nil {
ip = ip4
} else {
ip = ip.To16()
}
ai.IP = ip
p.AllowedIPs = append(p.AllowedIPs, ai)
}
return nil
}
// ParseDump parses a given WireGuard dump and produces a Conf struct.
func ParseDump(buf []byte) (*Conf, error) {
// from man wg, show section:
// If dump is specified, then several lines are printed;
// the first contains in order separated by tab: private-key, public-key, listen-port, fwmark.
// Subsequent lines are printed for each peer and contain in order separated by tab:
// public-key, preshared-key, endpoint, allowed-ips, latest-handshake, transfer-rx, transfer-tx, persistent-keepalive.
var (
active section
values []string
c Conf
err error
iface *Interface
peer *Peer
port uint64
sec int64
pka int
line int
)
// First line is Interface
active = interfaceSection
s := bufio.NewScanner(bytes.NewBuffer(buf))
for s.Scan() {
values = strings.Split(s.Text(), dumpSeparator)
switch active {
case interfaceSection:
if len(values) < dumpInterfaceLen {
return nil, fmt.Errorf("invalid interface line: missing fields (%d < %d)", len(values), dumpInterfaceLen)
}
iface = new(Interface)
for i := range values {
switch i {
case dumpInterfacePrivateKeyIndex:
iface.PrivateKey = []byte(values[i])
case dumpInterfaceListenPortIndex:
port, err = strconv.ParseUint(values[i], 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid interface line: error parsing listen-port: %w", err)
}
iface.ListenPort = uint32(port)
}
}
c.Interface = iface
// Next lines are Peers
active = peerSection
case peerSection:
if len(values) < dumpPeerLen {
return nil, fmt.Errorf("invalid peer line %d: missing fields (%d < %d)", line, len(values), dumpPeerLen)
}
peer = new(Peer)
for i := range values {
switch i {
case dumpPeerPublicKeyIndex:
peer.PublicKey = []byte(values[i])
case dumpPeerPresharedKeyIndex:
if values[i] == dumpNone {
continue
}
peer.PresharedKey = []byte(values[i])
case dumpPeerEndpointIndex:
if values[i] == dumpNone {
continue
}
err = peer.parseEndpoint(values[i])
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing endpoint: %w", line, err)
}
case dumpPeerAllowedIPsIndex:
if values[i] == dumpNone {
continue
}
err = peer.parseAllowedIPs(values[i])
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing allowed-ips: %w", line, err)
}
case dumpPeerLatestHandshakeIndex:
if values[i] == "0" {
// Use go zero value, not unix 0 timestamp.
peer.LatestHandshake = time.Time{}
continue
}
sec, err = strconv.ParseInt(values[i], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing latest-handshake: %w", line, err)
}
peer.LatestHandshake = time.Unix(sec, 0)
case dumpPeerPersistentKeepaliveIndex:
if values[i] == dumpOff {
continue
}
pka, err = strconv.Atoi(values[i])
if err != nil {
return nil, fmt.Errorf("invalid peer line %d: error parsing persistent-keepalive: %w", line, err)
}
peer.PersistentKeepalive = pka
}
}
c.Peers = append(c.Peers, peer)
peer = nil
}
line++
}
return &c, nil
}

View File

@ -1,4 +1,4 @@
// Copyright 2019 the Kilo authors
// Copyright 2021 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -21,336 +21,431 @@ import (
"github.com/kylelemons/godebug/pretty"
)
func TestCompareConf(t *testing.T) {
for _, tc := range []struct {
func TestNewEndpoint(t *testing.T) {
for i, tc := range []struct {
name string
a []byte
b []byte
out bool
ip net.IP
port int
out *Endpoint
}{
{
name: "empty",
a: []byte{},
b: []byte{},
out: true,
name: "no ip, no port",
out: &Endpoint{
udpAddr: &net.UDPAddr{},
},
},
{
name: "key and value order",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
ListenPort = 51820
PrivateKey = private
[Peer]
PublicKey = key
AllowedIPs = 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32, 10.2.2.0/24
PresharedKey = psk
Endpoint = 10.1.0.2:51820
`),
out: true,
name: "only port",
ip: nil,
port: 99,
out: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 99,
},
},
},
{
name: "whitespace",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
PrivateKey=private
ListenPort=51820
[Peer]
Endpoint=10.1.0.2:51820
PresharedKey = psk
PublicKey=key
AllowedIPs=10.2.2.0/24,192.168.0.1/32,10.2.3.0/24,192.168.0.2/32,10.4.0.2/32
`),
out: true,
name: "only ipv4",
ip: net.ParseIP("10.0.0.0"),
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0").To4(),
},
},
},
{
name: "missing key",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
out: false,
name: "only ipv6",
ip: net.ParseIP("ff50::10"),
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff50::10").To16(),
},
},
},
{
name: "different value",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PublicKey = key2
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
out: false,
name: "ipv4",
ip: net.ParseIP("10.0.0.0"),
port: 1000,
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0").To4(),
Port: 1000,
},
},
},
{
name: "section order",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
[Interface]
PrivateKey = private
ListenPort = 51820
`),
out: true,
name: "ipv6",
ip: net.ParseIP("ff50::10"),
port: 1000,
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff50::10").To16(),
Port: 1000,
},
},
},
{
name: "out of order peers",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk2
PublicKey = key2
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk1
PublicKey = key1
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk1
PublicKey = key1
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk2
PublicKey = key2
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
out: true,
},
{
name: "one empty",
a: []byte(`[Interface]
PrivateKey = private
ListenPort = 51820
[Peer]
Endpoint = 10.1.0.2:51820
PresharedKey = psk
PublicKey = key
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
`),
b: []byte(``),
out: false,
name: "ipv6",
ip: net.ParseIP("fc00:f853:ccd:e793::3"),
port: 51820,
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("fc00:f853:ccd:e793::3").To16(),
Port: 51820,
},
},
},
} {
equal := Parse(tc.a).Equal(Parse(tc.b))
if equal != tc.out {
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
out := NewEndpoint(tc.ip, tc.port)
if diff := pretty.Compare(out, tc.out); diff != "" {
t.Errorf("%d %s: got diff:\n%s\n", i, tc.name, diff)
}
}
}
func TestCompareEndpoint(t *testing.T) {
for _, tc := range []struct {
name string
a *Endpoint
b *Endpoint
dnsFirst bool
out bool
}{
{
name: "both nil",
a: nil,
b: nil,
out: true,
},
{
name: "a nil",
a: nil,
b: &Endpoint{},
out: false,
},
{
name: "b nil",
a: &Endpoint{},
b: nil,
out: false,
},
{
name: "zero",
a: &Endpoint{},
b: &Endpoint{},
out: true,
},
{
name: "diff port",
a: &Endpoint{Port: 1234},
b: &Endpoint{Port: 5678},
out: false,
},
{
name: "same IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
out: true,
},
{
name: "diff IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2")}},
out: false,
},
{
name: "same IP ignore DNS",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "b"}},
out: true,
},
{
name: "no IP check DNS",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
out: false,
},
{
name: "no IP check DNS (same)",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
out: true,
},
{
name: "DNS first, ignore IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}},
dnsFirst: true,
out: true,
},
{
name: "DNS first",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
dnsFirst: true,
out: false,
},
{
name: "DNS first, no DNS compare IP",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}},
dnsFirst: true,
out: false,
},
{
name: "DNS first, no DNS compare IP (same)",
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
dnsFirst: true,
out: true,
},
} {
equal := tc.a.Equal(tc.b, tc.dnsFirst)
if equal != tc.out {
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
}
}
}
func TestCompareDumpConf(t *testing.T) {
for _, tc := range []struct {
func TestParseEndpoint(t *testing.T) {
for i, tc := range []struct {
name string
d []byte
c []byte
str string
out *Endpoint
}{
{
name: "empty",
d: []byte{},
c: []byte{},
name: "no ip, no port",
},
{
name: "redacted copy from wg output",
d: []byte(`private B7qk8EMlob0nfado0ABM6HulUV607r4yqtBKjhap7S4= 51820 off
key1 (none) 10.254.1.1:51820 100.64.1.0/24,192.168.0.125/32,10.4.0.1/32 1619012801 67048 34952 10
key2 (none) 10.254.2.1:51820 100.64.4.0/24,10.69.76.55/32,100.64.3.0/24,10.66.25.131/32,10.4.0.2/32 1619013058 1134456 10077852 10`),
c: []byte(`[Interface]
ListenPort = 51820
PrivateKey = private
[Peer]
PublicKey = key1
AllowedIPs = 100.64.1.0/24, 192.168.0.125/32, 10.4.0.1/32
Endpoint = 10.254.1.1:51820
PersistentKeepalive = 10
[Peer]
PublicKey = key2
AllowedIPs = 100.64.4.0/24, 10.69.76.55/32, 100.64.3.0/24, 10.66.25.131/32, 10.4.0.2/32
Endpoint = 10.254.2.1:51820
PersistentKeepalive = 10`),
name: "only port",
str: ":1000",
},
{
name: "only ipv4",
str: "10.0.0.0",
},
{
name: "only ipv6",
str: "ff50::10",
},
{
name: "ipv4",
str: "10.0.0.0:1000",
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0").To4(),
Port: 1000,
},
},
},
{
name: "ipv6",
str: "[ff50::10]:1000",
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff50::10").To16(),
Port: 1000,
},
},
},
} {
dumpConf, _ := ParseDump(tc.d)
conf := Parse(tc.c)
// Equal will ignore runtime fields and only compare configuration fields.
if !dumpConf.Equal(conf) {
diff := pretty.Compare(dumpConf, conf)
t.Errorf("test case %q: got diff: %v", tc.name, diff)
out := ParseEndpoint(tc.str)
if diff := pretty.Compare(out, tc.out); diff != "" {
t.Errorf("ParseEndpoint %s(%d): got diff:\n%s\n", tc.name, i, diff)
}
}
}
func TestNewEndpointFromUDPAddr(t *testing.T) {
for i, tc := range []struct {
name string
u *net.UDPAddr
out *Endpoint
}{
{
name: "no ip, no port",
out: &Endpoint{
addr: "",
},
},
{
name: "only port",
u: &net.UDPAddr{
Port: 1000,
},
out: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
addr: "",
},
},
{
name: "only ipv4",
u: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0"),
},
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0").To4(),
},
addr: "",
},
},
{
name: "only ipv6",
u: &net.UDPAddr{
IP: net.ParseIP("ff60::10"),
},
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff60::10").To16(),
},
},
},
{
name: "ipv4",
u: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0"),
Port: 1000,
},
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0").To4(),
Port: 1000,
},
},
},
{
name: "ipv6",
u: &net.UDPAddr{
IP: net.ParseIP("ff50::10"),
Port: 1000,
},
out: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff50::10").To16(),
Port: 1000,
},
},
},
} {
out := NewEndpointFromUDPAddr(tc.u)
if diff := pretty.Compare(out, tc.out); diff != "" {
t.Errorf("ParseEndpoint %s(%d): got diff:\n%s\n", tc.name, i, diff)
}
}
}
func TestReady(t *testing.T) {
for i, tc := range []struct {
name string
in *Endpoint
r bool
}{
{
name: "nil",
r: false,
},
{
name: "no ip, no port",
in: &Endpoint{
addr: "",
udpAddr: &net.UDPAddr{},
},
r: false,
},
{
name: "only port",
in: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
},
r: false,
},
{
name: "only ipv4",
in: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0"),
},
},
r: false,
},
{
name: "only ipv6",
in: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff60::10"),
},
},
r: false,
},
{
name: "ipv4",
in: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("10.0.0.0"),
Port: 1000,
},
},
r: true,
},
{
name: "ipv6",
in: &Endpoint{
udpAddr: &net.UDPAddr{
IP: net.ParseIP("ff50::10"),
Port: 1000,
},
},
r: true,
},
} {
if tc.r != tc.in.Ready() {
t.Errorf("Endpoint.Ready() %s(%d): expected=%v\tgot=%v\n", tc.name, i, tc.r, tc.in.Ready())
}
}
}
func TestEqual(t *testing.T) {
for i, tc := range []struct {
name string
a *Endpoint
b *Endpoint
df bool
r bool
}{
{
name: "nil dns last",
r: true,
},
{
name: "nil dns first",
df: true,
r: true,
},
{
name: "equal: only port",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
},
r: true,
},
{
name: "not equal: only port",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
},
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1001,
},
},
r: false,
},
{
name: "equal dns first",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
r: true,
},
{
name: "equal dns last",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "foo",
},
r: true,
},
{
name: "unequal dns first",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "example.com:1000",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "foo",
},
df: true,
r: false,
},
{
name: "unequal dns last",
a: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("10.0.0.0"),
},
addr: "foo",
},
b: &Endpoint{
udpAddr: &net.UDPAddr{
Port: 1000,
IP: net.ParseIP("11.0.0.0"),
},
addr: "foo",
},
r: false,
},
{
name: "unequal dns last empty IP",
a: &Endpoint{
addr: "foo",
},
b: &Endpoint{
addr: "bar",
},
r: false,
},
{
name: "equal dns last empty IP",
a: &Endpoint{
addr: "foo",
},
b: &Endpoint{
addr: "foo",
},
r: true,
},
} {
if out := tc.a.Equal(tc.b, tc.df); out != tc.r {
t.Errorf("ParseEndpoint %s(%d): expected: %v\tgot: %v\n", tc.name, i, tc.r, out)
}
}
}

View File

@ -18,9 +18,7 @@
package wireguard
import (
"bytes"
"fmt"
"os/exec"
"github.com/vishvananda/netlink"
)
@ -65,74 +63,3 @@ func New(name string, mtu uint) (int, bool, error) {
}
return link.Attrs().Index, true, nil
}
// Keys generates a WireGuard private and public key-pair.
func Keys() ([]byte, []byte, error) {
private, err := GenKey()
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private key: %v", err)
}
public, err := PubKey(private)
return private, public, err
}
// GenKey generates a WireGuard private key.
func GenKey() ([]byte, error) {
key, err := exec.Command("wg", "genkey").Output()
return bytes.Trim(key, "\n"), err
}
// PubKey generates a WireGuard public key for a given private key.
func PubKey(key []byte) ([]byte, error) {
cmd := exec.Command("wg", "pubkey")
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("failed to open pipe to stdin: %v", err)
}
go func() {
defer stdin.Close()
stdin.Write(key)
}()
public, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("failed to generate public key: %v", err)
}
return bytes.Trim(public, "\n"), nil
}
// SetConf applies a WireGuard configuration file to the given interface.
func SetConf(iface string, path string) error {
cmd := exec.Command("wg", "setconf", iface, path)
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to apply the WireGuard configuration: %s", stderr.String())
}
return nil
}
// ShowConf gets the WireGuard configuration for the given interface.
func ShowConf(iface string) ([]byte, error) {
cmd := exec.Command("wg", "showconf", iface)
var stderr, stdout bytes.Buffer
cmd.Stderr = &stderr
cmd.Stdout = &stdout
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("failed to read the WireGuard configuration: %s", stderr.String())
}
return stdout.Bytes(), nil
}
// ShowDump gets the WireGuard configuration and runtime information for the given interface.
func ShowDump(iface string) ([]byte, error) {
cmd := exec.Command("wg", "show", iface, "dump")
var stderr, stdout bytes.Buffer
cmd.Stderr = &stderr
cmd.Stdout = &stdout
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("failed to read the WireGuard dump output: %s", stderr.String())
}
return stdout.Bytes(), nil
}

8
vendor/github.com/josharian/native/doc.go generated vendored Normal file
View File

@ -0,0 +1,8 @@
// Package native provides easy access to native byte order.
//
// Usage: use native.Endian where you need the native binary.ByteOrder.
//
// Please think twice before using this package.
// It can break program portability.
// Native byte order is usually not the right answer.
package native

7
vendor/github.com/josharian/native/endian_big.go generated vendored Normal file
View File

@ -0,0 +1,7 @@
// +build mips mips64 ppc64 s390x
package native
import "encoding/binary"
var Endian = binary.BigEndian

26
vendor/github.com/josharian/native/endian_generic.go generated vendored Normal file
View File

@ -0,0 +1,26 @@
// +build !mips,!mips64,!ppc64,!s390x,!amd64,!386,!arm,!arm64,!mipsle,!mips64le,!ppc64le,!riscv64,!wasm
// This file is a fallback, so that package native doesn't break
// the instant the Go project adds support for a new architecture.
//
package native
import (
"encoding/binary"
"log"
"runtime"
"unsafe"
)
var Endian binary.ByteOrder
func init() {
b := uint16(0xff) // one byte
if *(*byte)(unsafe.Pointer(&b)) == 0 {
Endian = binary.BigEndian
} else {
Endian = binary.LittleEndian
}
log.Printf("github.com/josharian/native: unrecognized arch %v (%v), please file an issue", runtime.GOARCH, Endian)
}

7
vendor/github.com/josharian/native/endian_little.go generated vendored Normal file
View File

@ -0,0 +1,7 @@
// +build amd64 386 arm arm64 mipsle mips64le ppc64le riscv64 wasm
package native
import "encoding/binary"
var Endian = binary.LittleEndian

7
vendor/github.com/josharian/native/license generated vendored Normal file
View File

@ -0,0 +1,7 @@
Copyright 2020 Josh Bleecher Snyder
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

10
vendor/github.com/josharian/native/readme.md generated vendored Normal file
View File

@ -0,0 +1,10 @@
Package native provides easy access to native byte order.
`go get github.com/josharian/native`
Usage: Use `native.Endian` where you need the native binary.ByteOrder.
Please think twice before using this package.
It can break program portability.
Native byte order is usually not the right answer.

9
vendor/github.com/mdlayher/genetlink/CHANGELOG.md generated vendored Normal file
View File

@ -0,0 +1,9 @@
# CHANGELOG
## Unreleased
n/a
## v1.0.0
- Initial stable commit.

10
vendor/github.com/mdlayher/genetlink/LICENSE.md generated vendored Normal file
View File

@ -0,0 +1,10 @@
MIT License
===========
Copyright (C) 2016-2017 Matt Layher
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

27
vendor/github.com/mdlayher/genetlink/README.md generated vendored Normal file
View File

@ -0,0 +1,27 @@
genetlink [![builds.sr.ht status](https://builds.sr.ht/~mdlayher/genetlink.svg)](https://builds.sr.ht/~mdlayher/genetlink?) [![GoDoc](https://godoc.org/github.com/mdlayher/genetlink?status.svg)](https://godoc.org/github.com/mdlayher/genetlink) [![Go Report Card](https://goreportcard.com/badge/github.com/mdlayher/genetlink)](https://goreportcard.com/report/github.com/mdlayher/genetlink)
=========
Package `genetlink` implements generic netlink interactions and data types.
MIT Licensed.
For more information about how netlink and generic netlink work,
check out my blog series on [Linux, Netlink, and Go](https://mdlayher.com/blog/linux-netlink-and-go-part-1-netlink/).
If you have any questions or you'd like some guidance, please join us on
[Gophers Slack](https://invite.slack.golangbridge.org) in the `#networking`
channel!
## Stability
See the [CHANGELOG](./CHANGELOG.md) file for a description of changes between
releases.
This package has reached v1.0.0 and any future breaking API changes will prompt
the release of a new major version. Features and bug fixes will continue to
occur in the v1.x.x series.
The general policy of this package is to only support the latest, stable version
of Go. Compatibility shims may be added for prior versions of Go on an as-needed
basis. If you would like to raise a concern, please [file an issue](https://github.com/mdlayher/genetlink/issues/new).
**If you depend on this package in your applications, please use Go modules.**

235
vendor/github.com/mdlayher/genetlink/conn.go generated vendored Normal file
View File

@ -0,0 +1,235 @@
package genetlink
import (
"syscall"
"time"
"github.com/mdlayher/netlink"
"golang.org/x/net/bpf"
)
// Protocol is the netlink protocol constant used to specify generic netlink.
const Protocol = 0x10 // unix.NETLINK_GENERIC
// A Conn is a generic netlink connection. A Conn can be used to send and
// receive generic netlink messages to and from netlink.
//
// A Conn is safe for concurrent use, but to avoid contention in
// high-throughput applications, the caller should almost certainly create a
// pool of Conns and distribute them among workers.
type Conn struct {
// Operating system-specific netlink connection.
c *netlink.Conn
}
// Dial dials a generic netlink connection. Config specifies optional
// configuration for the underlying netlink connection. If config is
// nil, a default configuration will be used.
func Dial(config *netlink.Config) (*Conn, error) {
c, err := netlink.Dial(Protocol, config)
if err != nil {
return nil, err
}
return NewConn(c), nil
}
// NewConn creates a Conn that wraps an existing *netlink.Conn for
// generic netlink communications.
//
// NewConn is primarily useful for tests. Most applications should use
// Dial instead.
func NewConn(c *netlink.Conn) *Conn {
return &Conn{c: c}
}
// Close closes the connection. Close will unblock any concurrent calls to
// Receive which are waiting on a response from the kernel.
func (c *Conn) Close() error {
return c.c.Close()
}
// GetFamily retrieves a generic netlink family with the specified name.
//
// If the family does not exist, the error value can be checked using
// netlink.IsNotExist.
func (c *Conn) GetFamily(name string) (Family, error) {
return c.getFamily(name)
}
// ListFamilies retrieves all registered generic netlink families.
func (c *Conn) ListFamilies() ([]Family, error) {
return c.listFamilies()
}
// JoinGroup joins a netlink multicast group by its ID.
func (c *Conn) JoinGroup(group uint32) error {
return c.c.JoinGroup(group)
}
// LeaveGroup leaves a netlink multicast group by its ID.
func (c *Conn) LeaveGroup(group uint32) error {
return c.c.LeaveGroup(group)
}
// SetBPF attaches an assembled BPF program to a Conn.
func (c *Conn) SetBPF(filter []bpf.RawInstruction) error {
return c.c.SetBPF(filter)
}
// RemoveBPF removes a BPF filter from a Conn.
func (c *Conn) RemoveBPF() error {
return c.c.RemoveBPF()
}
// SetOption enables or disables a netlink socket option for the Conn.
func (c *Conn) SetOption(option netlink.ConnOption, enable bool) error {
return c.c.SetOption(option, enable)
}
// SetReadBuffer sets the size of the operating system's receive buffer
// associated with the Conn.
func (c *Conn) SetReadBuffer(bytes int) error {
return c.c.SetReadBuffer(bytes)
}
// SetWriteBuffer sets the size of the operating system's transmit buffer
// associated with the Conn.
func (c *Conn) SetWriteBuffer(bytes int) error {
return c.c.SetWriteBuffer(bytes)
}
// SyscallConn returns a raw network connection. This implements the
// syscall.Conn interface.
//
// On Go 1.12+, all methods of the returned syscall.RawConn are supported and
// the Conn is integrated with the runtime network poller. On versions of Go
// prior to Go 1.12, only the Control method of the returned syscall.RawConn
// is implemented.
//
// SyscallConn is intended for advanced use cases, such as getting and setting
// arbitrary socket options using the netlink socket's file descriptor.
//
// Once invoked, it is the caller's responsibility to ensure that operations
// performed using Conn and the syscall.RawConn do not conflict with
// each other.
func (c *Conn) SyscallConn() (syscall.RawConn, error) {
return c.c.SyscallConn()
}
// SetDeadline sets the read and write deadlines associated with the connection.
//
// Deadline functionality is only supported on Go 1.12+. Calling this function
// on older versions of Go will result in an error.
func (c *Conn) SetDeadline(t time.Time) error {
return c.c.SetDeadline(t)
}
// SetReadDeadline sets the read deadline associated with the connection.
//
// Deadline functionality is only supported on Go 1.12+. Calling this function
// on older versions of Go will result in an error.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline associated with the connection.
//
// Deadline functionality is only supported on Go 1.12+. Calling this function
// on older versions of Go will result in an error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.c.SetWriteDeadline(t)
}
// Send sends a single Message to netlink, wrapping it in a netlink.Message
// using the specified generic netlink family and flags. On success, Send
// returns a copy of the netlink.Message with all parameters populated, for
// later validation.
func (c *Conn) Send(m Message, family uint16, flags netlink.HeaderFlags) (netlink.Message, error) {
nm, err := packMessage(m, family, flags)
if err != nil {
return netlink.Message{}, err
}
reqnm, err := c.c.Send(nm)
if err != nil {
return netlink.Message{}, err
}
return reqnm, nil
}
// Receive receives one or more Messages from netlink. The netlink.Messages
// used to wrap each Message are available for later validation.
func (c *Conn) Receive() ([]Message, []netlink.Message, error) {
msgs, err := c.c.Receive()
if err != nil {
return nil, nil, err
}
gmsgs, err := unpackMessages(msgs)
if err != nil {
return nil, nil, err
}
return gmsgs, msgs, nil
}
// Execute sends a single Message to netlink using Send, receives one or more
// replies using Receive, and then checks the validity of the replies against
// the request using netlink.Validate.
//
// Execute acquires a lock for the duration of the function call which blocks
// concurrent calls to Send and Receive, in order to ensure consistency between
// generic netlink request/reply messages.
//
// See the documentation of Send, Receive, and netlink.Validate for details
// about each function.
func (c *Conn) Execute(m Message, family uint16, flags netlink.HeaderFlags) ([]Message, error) {
nm, err := packMessage(m, family, flags)
if err != nil {
return nil, err
}
// Locking behavior handled by netlink.Conn.Execute.
msgs, err := c.c.Execute(nm)
if err != nil {
return nil, err
}
return unpackMessages(msgs)
}
// packMessage packs a generic netlink Message into a netlink.Message with the
// appropriate generic netlink family and netlink flags.
func packMessage(m Message, family uint16, flags netlink.HeaderFlags) (netlink.Message, error) {
nm := netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType(family),
Flags: flags,
},
}
mb, err := m.MarshalBinary()
if err != nil {
return netlink.Message{}, err
}
nm.Data = mb
return nm, nil
}
// unpackMessages unpacks generic netlink Messages from a slice of netlink.Messages.
func unpackMessages(msgs []netlink.Message) ([]Message, error) {
gmsgs := make([]Message, 0, len(msgs))
for _, nm := range msgs {
var gm Message
if err := (&gm).UnmarshalBinary(nm.Data); err != nil {
return nil, err
}
gmsgs = append(gmsgs, gm)
}
return gmsgs, nil
}

6
vendor/github.com/mdlayher/genetlink/doc.go generated vendored Normal file
View File

@ -0,0 +1,6 @@
// Package genetlink implements generic netlink interactions and data types.
//
// If you have any questions or you'd like some guidance, please join us on
// Gophers Slack (https://invite.slack.golangbridge.org) in the #networking
// channel!
package genetlink

17
vendor/github.com/mdlayher/genetlink/family.go generated vendored Normal file
View File

@ -0,0 +1,17 @@
package genetlink
// A Family is a generic netlink family.
type Family struct {
ID uint16
Version uint8
Name string
Groups []MulticastGroup
}
// A MulticastGroup is a generic netlink multicast group, which can be joined
// for notifications from generic netlink families when specific events take
// place.
type MulticastGroup struct {
ID uint32
Name string
}

146
vendor/github.com/mdlayher/genetlink/family_linux.go generated vendored Normal file
View File

@ -0,0 +1,146 @@
//+build linux
package genetlink
import (
"errors"
"fmt"
"math"
"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nlenc"
"golang.org/x/sys/unix"
)
// errInvalidFamilyVersion is returned when a family's version is greater
// than an 8-bit integer.
var errInvalidFamilyVersion = errors.New("invalid family version attribute")
// getFamily retrieves a generic netlink family with the specified name.
func (c *Conn) getFamily(name string) (Family, error) {
b, err := netlink.MarshalAttributes([]netlink.Attribute{{
Type: unix.CTRL_ATTR_FAMILY_NAME,
Data: nlenc.Bytes(name),
}})
if err != nil {
return Family{}, err
}
req := Message{
Header: Header{
Command: unix.CTRL_CMD_GETFAMILY,
// TODO(mdlayher): grab nlctrl version?
Version: 1,
},
Data: b,
}
msgs, err := c.Execute(req, unix.GENL_ID_CTRL, netlink.Request)
if err != nil {
return Family{}, err
}
// TODO(mdlayher): consider interpreting generic netlink header values
families, err := buildFamilies(msgs)
if err != nil {
return Family{}, err
}
if len(families) != 1 {
// If this were to ever happen, netlink must be in a state where
// its answers cannot be trusted
panic(fmt.Sprintf("netlink returned multiple families for name: %q", name))
}
return families[0], nil
}
// listFamilies retrieves all registered generic netlink families.
func (c *Conn) listFamilies() ([]Family, error) {
req := Message{
Header: Header{
Command: unix.CTRL_CMD_GETFAMILY,
// TODO(mdlayher): grab nlctrl version?
Version: 1,
},
}
flags := netlink.Request | netlink.Dump
msgs, err := c.Execute(req, unix.GENL_ID_CTRL, flags)
if err != nil {
return nil, err
}
return buildFamilies(msgs)
}
// buildFamilies builds a slice of Families by parsing attributes from the
// input Messages.
func buildFamilies(msgs []Message) ([]Family, error) {
families := make([]Family, 0, len(msgs))
for _, m := range msgs {
var f Family
if err := (&f).parseAttributes(m.Data); err != nil {
return nil, err
}
families = append(families, f)
}
return families, nil
}
// parseAttributes decodes netlink attributes into a Family's fields.
func (f *Family) parseAttributes(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return err
}
for ad.Next() {
switch ad.Type() {
case unix.CTRL_ATTR_FAMILY_ID:
f.ID = ad.Uint16()
case unix.CTRL_ATTR_FAMILY_NAME:
f.Name = ad.String()
case unix.CTRL_ATTR_VERSION:
v := ad.Uint32()
if v > math.MaxUint8 {
return errInvalidFamilyVersion
}
f.Version = uint8(v)
case unix.CTRL_ATTR_MCAST_GROUPS:
ad.Nested(func(nad *netlink.AttributeDecoder) error {
f.Groups = parseMulticastGroups(nad)
return nil
})
}
}
return ad.Err()
}
// parseMulticastGroups parses an array of multicast group nested attributes
// into a slice of MulticastGroups.
func parseMulticastGroups(ad *netlink.AttributeDecoder) []MulticastGroup {
groups := make([]MulticastGroup, 0, ad.Len())
for ad.Next() {
ad.Nested(func(nad *netlink.AttributeDecoder) error {
var g MulticastGroup
for nad.Next() {
switch nad.Type() {
case unix.CTRL_ATTR_MCAST_GRP_NAME:
g.Name = nad.String()
case unix.CTRL_ATTR_MCAST_GRP_ID:
g.ID = nad.Uint32()
}
}
groups = append(groups, g)
return nil
})
}
return groups
}

23
vendor/github.com/mdlayher/genetlink/family_others.go generated vendored Normal file
View File

@ -0,0 +1,23 @@
//+build !linux
package genetlink
import (
"fmt"
"runtime"
)
// errUnimplemented is returned by all functions on platforms that
// cannot make use of generic netlink.
var errUnimplemented = fmt.Errorf("generic netlink not implemented on %s/%s",
runtime.GOOS, runtime.GOARCH)
// getFamily always returns an error.
func (c *Conn) getFamily(name string) (Family, error) {
return Family{}, errUnimplemented
}
// listFamilies always returns an error.
func (c *Conn) listFamilies() ([]Family, error) {
return nil, errUnimplemented
}

20
vendor/github.com/mdlayher/genetlink/fuzz.go generated vendored Normal file
View File

@ -0,0 +1,20 @@
//+build gofuzz
package genetlink
func Fuzz(data []byte) int {
return fuzzMessage(data)
}
func fuzzMessage(data []byte) int {
var m Message
if err := (&m).UnmarshalBinary(data); err != nil {
return 0
}
if _, err := m.MarshalBinary(); err != nil {
panic(err)
}
return 1
}

61
vendor/github.com/mdlayher/genetlink/message.go generated vendored Normal file
View File

@ -0,0 +1,61 @@
package genetlink
import "errors"
// errInvalidMessage is returned when a Message is malformed.
var errInvalidMessage = errors.New("generic netlink message is invalid or too short")
// A Header is a generic netlink header. A Header is sent and received with
// each generic netlink message to indicate metadata regarding a Message.
type Header struct {
// Command specifies a command to issue to netlink.
Command uint8
// Version specifies the version of a command to use.
Version uint8
}
// headerLen is the length of a Header.
const headerLen = 4 // unix.GENL_HDRLEN
// A Message is a generic netlink message. It contains a Header and an
// arbitrary byte payload, which may be decoded using information from the
// Header.
//
// Data is encoded using the native endianness of the host system. Use
// the netlink.AttributeDecoder and netlink.AttributeEncoder types to decode
// and encode data.
type Message struct {
Header Header
Data []byte
}
// MarshalBinary marshals a Message into a byte slice.
func (m Message) MarshalBinary() ([]byte, error) {
b := make([]byte, headerLen)
b[0] = m.Header.Command
b[1] = m.Header.Version
// b[2] and b[3] are padding bytes and set to zero
return append(b, m.Data...), nil
}
// UnmarshalBinary unmarshals the contents of a byte slice into a Message.
func (m *Message) UnmarshalBinary(b []byte) error {
if len(b) < headerLen {
return errInvalidMessage
}
// Don't allow reserved pad bytes to be set
if b[2] != 0 || b[3] != 0 {
return errInvalidMessage
}
m.Header.Command = b[0]
m.Header.Version = b[1]
m.Data = b[4:]
return nil
}

3
vendor/github.com/mdlayher/netlink/.gitignore generated vendored Normal file
View File

@ -0,0 +1,3 @@
netlink.test
netlink-fuzz.zip
testdata/

100
vendor/github.com/mdlayher/netlink/CHANGELOG.md generated vendored Normal file
View File

@ -0,0 +1,100 @@
# CHANGELOG
## v1.4.1
- [Improvement]: significant runtime network poller integration cleanup through
the use of `github.com/mdlayher/socket`.
## v1.4.0
- [New API] [#185](https://github.com/mdlayher/netlink/pull/185): the
`netlink.AttributeDecoder` and `netlink.AttributeEncoder` types now have
methods for dealing with signed integers: `Int8`, `Int16`, `Int32`, and
`Int64`. These are necessary for working with rtnetlink's XDP APIs. Thanks
@fbegyn.
## v1.3.2
- [Improvement]
[commit](https://github.com/mdlayher/netlink/commit/ebc6e2e28bcf1a0671411288423d8116ff924d6d):
`github.com/google/go-cmp` is no longer a (non-test) dependency of this module.
## v1.3.1
- [Improvement]: many internal cleanups and simplifications. The library is now
slimmer and features less internal indirection. There are no user-facing
changes in this release.
## v1.3.0
- [New API] [#176](https://github.com/mdlayher/netlink/pull/176):
`netlink.OpError` now has `Message` and `Offset` fields which are populated
when the kernel returns netlink extended acknowledgement data along with an
error code. The caller can turn on this option by using
`netlink.Conn.SetOption(netlink.ExtendedAcknowledge, true)`.
- [New API]
[commit](https://github.com/mdlayher/netlink/commit/beba85e0372133b6d57221191d2c557727cd1499):
the `netlink.GetStrictCheck` option can be used to tell the kernel to be more
strict when parsing requests. This enables more safety checks and can allow
the kernel to perform more advanced request filtering in subsystems such as
route netlink.
## v1.2.1
- [Bug Fix]
[commit](https://github.com/mdlayher/netlink/commit/d81418f81b0bfa2465f33790a85624c63d6afe3d):
`netlink.SetBPF` will no longer panic if an empty BPF filter is set.
- [Improvement]
[commit](https://github.com/mdlayher/netlink/commit/8014f9a7dbf4fd7b84a1783dd7b470db9113ff36):
the library now uses https://github.com/josharian/native to provide the
system's native endianness at compile time, rather than re-computing it many
times at runtime.
## v1.2.0
**This is the first release of package netlink that only supports Go 1.12+. Users on older versions must use v1.1.1.**
- [Improvement] [#173](https://github.com/mdlayher/netlink/pull/173): support
for Go 1.11 and below has been dropped. All users are highly recommended to
use a stable and supported release of Go for their applications.
- [Performance] [#171](https://github.com/mdlayher/netlink/pull/171):
`netlink.Conn` no longer requires a locked OS thread for the vast majority of
operations, which should result in a significant speedup for highly concurrent
callers. Thanks @ti-mo.
- [Bug Fix] [#169](https://github.com/mdlayher/netlink/pull/169): calls to
`netlink.Conn.Close` are now able to unblock concurrent calls to
`netlink.Conn.Receive` and other blocking operations.
## v1.1.1
**This is the last release of package netlink that supports Go 1.11.**
- [Improvement] [#165](https://github.com/mdlayher/netlink/pull/165):
`netlink.Conn` `SetReadBuffer` and `SetWriteBuffer` methods now attempt the
`SO_*BUFFORCE` socket options to possibly ignore system limits given elevated
caller permissions. Thanks @MarkusBauer.
- [Note]
[commit](https://github.com/mdlayher/netlink/commit/c5f8ab79aa345dcfcf7f14d746659ca1b80a0ecc):
`netlink.Conn.Close` has had a long-standing bug
[#162](https://github.com/mdlayher/netlink/pull/162) related to internal
concurrency handling where a call to `Close` is not sufficient to unblock
pending reads. To effectively fix this issue, it is necessary to drop support
for Go 1.11 and below. This will be fixed in a future release, but a
workaround is noted in the method documentation as of now.
## v1.1.0
- [New API] [#157](https://github.com/mdlayher/netlink/pull/157): the
`netlink.AttributeDecoder.TypeFlags` method enables retrieval of the type bits
stored in a netlink attribute's type field, because the existing `Type` method
masks away these bits. Thanks @ti-mo!
- [Performance] [#157](https://github.com/mdlayher/netlink/pull/157): `netlink.AttributeDecoder`
now decodes netlink attributes on demand, enabling callers who only need a
limited number of attributes to exit early from decoding loops. Thanks @ti-mo!
- [Improvement] [#161](https://github.com/mdlayher/netlink/pull/161): `netlink.Conn`
system calls are now ready for Go 1.14+'s changes to goroutine preemption.
See the PR for details.
## v1.0.0
- Initial stable commit.

9
vendor/github.com/mdlayher/netlink/LICENSE.md generated vendored Normal file
View File

@ -0,0 +1,9 @@
# MIT License
Copyright (C) 2016-2021 Matt Layher
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

51
vendor/github.com/mdlayher/netlink/README.md generated vendored Normal file
View File

@ -0,0 +1,51 @@
# netlink [![Test Status](https://github.com/mdlayher/netlink/workflows/Linux%20Test/badge.svg)](https://github.com/mdlayher/netlink/actions) [![Go Reference](https://pkg.go.dev/badge/github.com/mdlayher/netlink.svg)](https://pkg.go.dev/github.com/mdlayher/netlink) [![Go Report Card](https://goreportcard.com/badge/github.com/mdlayher/netlink)](https://goreportcard.com/report/github.com/mdlayher/netlink)
Package `netlink` provides low-level access to Linux netlink sockets.
MIT Licensed.
For more information about how netlink works, check out my blog series
on [Linux, Netlink, and Go](https://mdlayher.com/blog/linux-netlink-and-go-part-1-netlink/).
If you have any questions or you'd like some guidance, please join us on
[Gophers Slack](https://invite.slack.golangbridge.org) in the `#networking`
channel!
## Stability
See the [CHANGELOG](./CHANGELOG.md) file for a description of changes between
releases.
This package has a stable v1 API and any future breaking changes will prompt
the release of a new major version. Features and bug fixes will continue to
occur in the v1.x.x series.
In order to reduce the maintenance burden, this package is only supported on
Go 1.12+. Older versions of Go lack critical features and APIs which are
necessary for this package to function correctly.
**If you depend on this package in your applications, please use Go modules.**
## Design
A [number of netlink packages](https://godoc.org/?q=netlink) are already
available for Go, but I wasn't able to find one that aligned with what
I wanted in a netlink package:
- Straightforward, idiomatic API
- Well tested
- Well documented
- Doesn't use package/global variables or state
- Doesn't necessarily need root to work
My goal for this package is to use it as a building block for the creation
of other netlink family packages.
## Ecosystem
Over time, an ecosystem of Go packages has developed around package `netlink`.
Many of these packages provide building blocks for further interactions with
various netlink families, such as `NETLINK_GENERIC` or `NETLINK_ROUTE`.
To have your package included in this diagram, please send a pull request!
![netlink ecosystem](./netlink.svg)

37
vendor/github.com/mdlayher/netlink/align.go generated vendored Normal file
View File

@ -0,0 +1,37 @@
package netlink
import "unsafe"
// Functions and values used to properly align netlink messages, headers,
// and attributes. Definitions taken from Linux kernel source.
// #define NLMSG_ALIGNTO 4U
const nlmsgAlignTo = 4
// #define NLMSG_ALIGN(len) ( ((len)+NLMSG_ALIGNTO-1) & ~(NLMSG_ALIGNTO-1) )
func nlmsgAlign(len int) int {
return ((len) + nlmsgAlignTo - 1) & ^(nlmsgAlignTo - 1)
}
// #define NLMSG_LENGTH(len) ((len) + NLMSG_HDRLEN)
func nlmsgLength(len int) int {
return len + nlmsgHeaderLen
}
// #define NLMSG_HDRLEN ((int) NLMSG_ALIGN(sizeof(struct nlmsghdr)))
var nlmsgHeaderLen = nlmsgAlign(int(unsafe.Sizeof(Header{})))
// #define NLA_ALIGNTO 4
const nlaAlignTo = 4
// #define NLA_ALIGN(len) (((len) + NLA_ALIGNTO - 1) & ~(NLA_ALIGNTO - 1))
func nlaAlign(len int) int {
return ((len) + nlaAlignTo - 1) & ^(nlaAlignTo - 1)
}
// Because this package's Attribute type contains a byte slice, unsafe.Sizeof
// can't be used to determine the correct length.
const sizeofAttribute = 4
// #define NLA_HDRLEN ((int) NLA_ALIGN(sizeof(struct nlattr)))
var nlaHeaderLen = nlaAlign(sizeofAttribute)

707
vendor/github.com/mdlayher/netlink/attribute.go generated vendored Normal file
View File

@ -0,0 +1,707 @@
package netlink
import (
"encoding/binary"
"errors"
"fmt"
"github.com/josharian/native"
"github.com/mdlayher/netlink/nlenc"
)
// errInvalidAttribute specifies if an Attribute's length is incorrect.
var errInvalidAttribute = errors.New("invalid attribute; length too short or too large")
// An Attribute is a netlink attribute. Attributes are packed and unpacked
// to and from the Data field of Message for some netlink families.
type Attribute struct {
// Length of an Attribute, including this field and Type.
Length uint16
// The type of this Attribute, typically matched to a constant. Note that
// flags such as Nested and NetByteOrder must be handled manually when
// working with Attribute structures directly.
Type uint16
// An arbitrary payload which is specified by Type.
Data []byte
}
// marshal marshals the contents of a into b and returns the number of bytes
// written to b, including attribute alignment padding.
func (a *Attribute) marshal(b []byte) (int, error) {
if int(a.Length) < nlaHeaderLen {
return 0, errInvalidAttribute
}
nlenc.PutUint16(b[0:2], a.Length)
nlenc.PutUint16(b[2:4], a.Type)
n := copy(b[nlaHeaderLen:], a.Data)
return nlaHeaderLen + nlaAlign(n), nil
}
// unmarshal unmarshals the contents of a byte slice into an Attribute.
func (a *Attribute) unmarshal(b []byte) error {
if len(b) < nlaHeaderLen {
return errInvalidAttribute
}
a.Length = nlenc.Uint16(b[0:2])
a.Type = nlenc.Uint16(b[2:4])
if int(a.Length) > len(b) {
return errInvalidAttribute
}
switch {
// No length, no data
case a.Length == 0:
a.Data = make([]byte, 0)
// Not enough length for any data
case int(a.Length) < nlaHeaderLen:
return errInvalidAttribute
// Data present
case int(a.Length) >= nlaHeaderLen:
a.Data = make([]byte, len(b[nlaHeaderLen:a.Length]))
copy(a.Data, b[nlaHeaderLen:a.Length])
}
return nil
}
// MarshalAttributes packs a slice of Attributes into a single byte slice.
// In most cases, the Length field of each Attribute should be set to 0, so it
// can be calculated and populated automatically for each Attribute.
//
// It is recommend to use the AttributeEncoder type where possible instead of
// calling MarshalAttributes and using package nlenc functions directly.
func MarshalAttributes(attrs []Attribute) ([]byte, error) {
// Count how many bytes we should allocate to store each attribute's contents.
var c int
for _, a := range attrs {
c += nlaHeaderLen + nlaAlign(len(a.Data))
}
// Advance through b with idx to place attribute data at the correct offset.
var idx int
b := make([]byte, c)
for _, a := range attrs {
// Infer the length of attribute if zero.
if a.Length == 0 {
a.Length = uint16(nlaHeaderLen + len(a.Data))
}
// Marshal a into b and advance idx to show many bytes are occupied.
n, err := a.marshal(b[idx:])
if err != nil {
return nil, err
}
idx += n
}
return b, nil
}
// UnmarshalAttributes unpacks a slice of Attributes from a single byte slice.
//
// It is recommend to use the AttributeDecoder type where possible instead of calling
// UnmarshalAttributes and using package nlenc functions directly.
func UnmarshalAttributes(b []byte) ([]Attribute, error) {
ad, err := NewAttributeDecoder(b)
if err != nil {
return nil, err
}
// Return a nil slice when there are no attributes to decode.
if ad.Len() == 0 {
return nil, nil
}
attrs := make([]Attribute, 0, ad.Len())
for ad.Next() {
if ad.attr().Length != 0 {
attrs = append(attrs, ad.attr())
}
}
if err := ad.Err(); err != nil {
return nil, err
}
return attrs, nil
}
// An AttributeDecoder provides a safe, iterator-like, API around attribute
// decoding.
//
// It is recommend to use an AttributeDecoder where possible instead of calling
// UnmarshalAttributes and using package nlenc functions directly.
//
// The Err method must be called after the Next method returns false to determine
// if any errors occurred during iteration.
type AttributeDecoder struct {
// ByteOrder defines a specific byte order to use when processing integer
// attributes. ByteOrder should be set immediately after creating the
// AttributeDecoder: before any attributes are parsed.
//
// If not set, the native byte order will be used.
ByteOrder binary.ByteOrder
// The current attribute being worked on.
a Attribute
// The slice of input bytes and its iterator index.
b []byte
i int
length int
// Any error encountered while decoding attributes.
err error
}
// NewAttributeDecoder creates an AttributeDecoder that unpacks Attributes
// from b and prepares the decoder for iteration.
func NewAttributeDecoder(b []byte) (*AttributeDecoder, error) {
ad := &AttributeDecoder{
// By default, use native byte order.
ByteOrder: native.Endian,
b: b,
}
var err error
ad.length, err = ad.available()
if err != nil {
return nil, err
}
return ad, nil
}
// Next advances the decoder to the next netlink attribute. It returns false
// when no more attributes are present, or an error was encountered.
func (ad *AttributeDecoder) Next() bool {
if ad.err != nil {
// Hit an error, stop iteration.
return false
}
// Exit if array pointer is at or beyond the end of the slice.
if ad.i >= len(ad.b) {
return false
}
if err := ad.a.unmarshal(ad.b[ad.i:]); err != nil {
ad.err = err
return false
}
// Advance the pointer by at least one header's length.
if int(ad.a.Length) < nlaHeaderLen {
ad.i += nlaHeaderLen
} else {
ad.i += nlaAlign(int(ad.a.Length))
}
return true
}
// Type returns the Attribute.Type field of the current netlink attribute
// pointed to by the decoder.
//
// Type masks off the high bits of the netlink attribute type which may contain
// the Nested and NetByteOrder flags. These can be obtained by calling TypeFlags.
func (ad *AttributeDecoder) Type() uint16 {
// Mask off any flags stored in the high bits.
return ad.a.Type & attrTypeMask
}
// TypeFlags returns the two high bits of the Attribute.Type field of the current
// netlink attribute pointed to by the decoder.
//
// These bits of the netlink attribute type are used for the Nested and NetByteOrder
// flags, available as the Nested and NetByteOrder constants in this package.
func (ad *AttributeDecoder) TypeFlags() uint16 {
return ad.a.Type & ^attrTypeMask
}
// Len returns the number of netlink attributes pointed to by the decoder.
func (ad *AttributeDecoder) Len() int { return ad.length }
// count scans the input slice to count the number of netlink attributes
// that could be decoded by Next().
func (ad *AttributeDecoder) available() (int, error) {
var i, count int
for {
// No more data to read.
if i >= len(ad.b) {
break
}
// Make sure there's at least a header's worth
// of data to read on each iteration.
if len(ad.b[i:]) < nlaHeaderLen {
return 0, errInvalidAttribute
}
// Extract the length of the attribute.
l := int(nlenc.Uint16(ad.b[i : i+2]))
// Ignore zero-length attributes.
if l != 0 {
count++
}
// Advance by at least a header's worth of bytes.
if l < nlaHeaderLen {
l = nlaHeaderLen
}
i += nlaAlign(l)
}
return count, nil
}
// attr returns the current Attribute pointed to by the decoder.
func (ad *AttributeDecoder) attr() Attribute {
return ad.a
}
// data returns the Data field of the current Attribute pointed to by the decoder.
func (ad *AttributeDecoder) data() []byte {
return ad.a.Data
}
// Err returns the first error encountered by the decoder.
func (ad *AttributeDecoder) Err() error {
return ad.err
}
// Bytes returns the raw bytes of the current Attribute's data.
func (ad *AttributeDecoder) Bytes() []byte {
src := ad.data()
dest := make([]byte, len(src))
copy(dest, src)
return dest
}
// String returns the string representation of the current Attribute's data.
func (ad *AttributeDecoder) String() string {
if ad.err != nil {
return ""
}
return nlenc.String(ad.data())
}
// Uint8 returns the uint8 representation of the current Attribute's data.
func (ad *AttributeDecoder) Uint8() uint8 {
if ad.err != nil {
return 0
}
b := ad.data()
if len(b) != 1 {
ad.err = fmt.Errorf("netlink: attribute %d is not a uint8; length: %d", ad.Type(), len(b))
return 0
}
return uint8(b[0])
}
// Uint16 returns the uint16 representation of the current Attribute's data.
func (ad *AttributeDecoder) Uint16() uint16 {
if ad.err != nil {
return 0
}
b := ad.data()
if len(b) != 2 {
ad.err = fmt.Errorf("netlink: attribute %d is not a uint16; length: %d", ad.Type(), len(b))
return 0
}
return ad.ByteOrder.Uint16(b)
}
// Uint32 returns the uint32 representation of the current Attribute's data.
func (ad *AttributeDecoder) Uint32() uint32 {
if ad.err != nil {
return 0
}
b := ad.data()
if len(b) != 4 {
ad.err = fmt.Errorf("netlink: attribute %d is not a uint32; length: %d", ad.Type(), len(b))
return 0
}
return ad.ByteOrder.Uint32(b)
}
// Uint64 returns the uint64 representation of the current Attribute's data.
func (ad *AttributeDecoder) Uint64() uint64 {
if ad.err != nil {
return 0
}
b := ad.data()
if len(b) != 8 {
ad.err = fmt.Errorf("netlink: attribute %d is not a uint64; length: %d", ad.Type(), len(b))
return 0
}
return ad.ByteOrder.Uint64(b)
}
// Int8 returns the Int8 representation of the current Attribute's data.
func (ad *AttributeDecoder) Int8() int8 {
if ad.err != nil {
return 0
}
b := ad.data()
if len(b) != 1 {
ad.err = fmt.Errorf("netlink: attribute %d is not a int8; length: %d", ad.Type(), len(b))
return 0
}
return int8(b[0])
}
// Int16 returns the Int16 representation of the current Attribute's data.
func (ad *AttributeDecoder) Int16() int16 {
if ad.err != nil {
return 0
}
b := ad.data()
if len(b) != 2 {
ad.err = fmt.Errorf("netlink: attribute %d is not a int16; length: %d", ad.Type(), len(b))
return 0
}
return int16(ad.ByteOrder.Uint16(b))
}
// Int32 returns the Int32 representation of the current Attribute's data.
func (ad *AttributeDecoder) Int32() int32 {
if ad.err != nil {
return 0
}
b := ad.data()
if len(b) != 4 {
ad.err = fmt.Errorf("netlink: attribute %d is not a int32; length: %d", ad.Type(), len(b))
return 0
}
return int32(ad.ByteOrder.Uint32(b))
}
// Int64 returns the Int64 representation of the current Attribute's data.
func (ad *AttributeDecoder) Int64() int64 {
if ad.err != nil {
return 0
}
b := ad.data()
if len(b) != 8 {
ad.err = fmt.Errorf("netlink: attribute %d is not a int64; length: %d", ad.Type(), len(b))
return 0
}
return int64(ad.ByteOrder.Uint64(b))
}
// Flag returns a boolean representing the Attribute.
func (ad *AttributeDecoder) Flag() bool {
if ad.err != nil {
return false
}
b := ad.data()
if len(b) != 0 {
ad.err = fmt.Errorf("netlink: attribute %d is not a flag; length: %d", ad.Type(), len(b))
return false
}
return true
}
// Do is a general purpose function which allows access to the current data
// pointed to by the AttributeDecoder.
//
// Do can be used to allow parsing arbitrary data within the context of the
// decoder. Do is most useful when dealing with nested attributes, attribute
// arrays, or decoding arbitrary types (such as C structures) which don't fit
// cleanly into a typical unsigned integer value.
//
// The function fn should not retain any reference to the data b outside of the
// scope of the function.
func (ad *AttributeDecoder) Do(fn func(b []byte) error) {
if ad.err != nil {
return
}
b := ad.data()
if err := fn(b); err != nil {
ad.err = err
}
}
// Nested decodes data into a nested AttributeDecoder to handle nested netlink
// attributes. When calling Nested, the Err method does not need to be called on
// the nested AttributeDecoder.
//
// The nested AttributeDecoder nad inherits the same ByteOrder setting as the
// top-level AttributeDecoder ad.
func (ad *AttributeDecoder) Nested(fn func(nad *AttributeDecoder) error) {
// Because we are wrapping Do, there is no need to check ad.err immediately.
ad.Do(func(b []byte) error {
nad, err := NewAttributeDecoder(b)
if err != nil {
return err
}
nad.ByteOrder = ad.ByteOrder
if err := fn(nad); err != nil {
return err
}
return nad.Err()
})
}
// An AttributeEncoder provides a safe way to encode attributes.
//
// It is recommended to use an AttributeEncoder where possible instead of
// calling MarshalAttributes or using package nlenc directly.
//
// Errors from intermediate encoding steps are returned in the call to
// Encode.
type AttributeEncoder struct {
// ByteOrder defines a specific byte order to use when processing integer
// attributes. ByteOrder should be set immediately after creating the
// AttributeEncoder: before any attributes are encoded.
//
// If not set, the native byte order will be used.
ByteOrder binary.ByteOrder
attrs []Attribute
err error
}
// NewAttributeEncoder creates an AttributeEncoder that encodes Attributes.
func NewAttributeEncoder() *AttributeEncoder {
return &AttributeEncoder{
ByteOrder: native.Endian,
}
}
// Uint8 encodes uint8 data into an Attribute specified by typ.
func (ae *AttributeEncoder) Uint8(typ uint16, v uint8) {
if ae.err != nil {
return
}
ae.attrs = append(ae.attrs, Attribute{
Type: typ,
Data: []byte{v},
})
}
// Uint16 encodes uint16 data into an Attribute specified by typ.
func (ae *AttributeEncoder) Uint16(typ uint16, v uint16) {
if ae.err != nil {
return
}
b := make([]byte, 2)
ae.ByteOrder.PutUint16(b, v)
ae.attrs = append(ae.attrs, Attribute{
Type: typ,
Data: b,
})
}
// Uint32 encodes uint32 data into an Attribute specified by typ.
func (ae *AttributeEncoder) Uint32(typ uint16, v uint32) {
if ae.err != nil {
return
}
b := make([]byte, 4)
ae.ByteOrder.PutUint32(b, v)
ae.attrs = append(ae.attrs, Attribute{
Type: typ,
Data: b,
})
}
// Uint64 encodes uint64 data into an Attribute specified by typ.
func (ae *AttributeEncoder) Uint64(typ uint16, v uint64) {
if ae.err != nil {
return
}
b := make([]byte, 8)
ae.ByteOrder.PutUint64(b, v)
ae.attrs = append(ae.attrs, Attribute{
Type: typ,
Data: b,
})
}
// Int8 encodes int8 data into an Attribute specified by typ.
func (ae *AttributeEncoder) Int8(typ uint16, v int8) {
if ae.err != nil {
return
}
ae.attrs = append(ae.attrs, Attribute{
Type: typ,
Data: []byte{uint8(v)},
})
}
// Int16 encodes int16 data into an Attribute specified by typ.
func (ae *AttributeEncoder) Int16(typ uint16, v int16) {
if ae.err != nil {
return
}
b := make([]byte, 2)
ae.ByteOrder.PutUint16(b, uint16(v))
ae.attrs = append(ae.attrs, Attribute{
Type: typ,
Data: b,
})
}
// Int32 encodes int32 data into an Attribute specified by typ.
func (ae *AttributeEncoder) Int32(typ uint16, v int32) {
if ae.err != nil {
return
}
b := make([]byte, 4)
ae.ByteOrder.PutUint32(b, uint32(v))
ae.attrs = append(ae.attrs, Attribute{
Type: typ,
Data: b,
})
}
// Int64 encodes int64 data into an Attribute specified by typ.
func (ae *AttributeEncoder) Int64(typ uint16, v int64) {
if ae.err != nil {
return
}
b := make([]byte, 8)
ae.ByteOrder.PutUint64(b, uint64(v))
ae.attrs = append(ae.attrs, Attribute{
Type: typ,
Data: b,
})
}
// Flag encodes a flag into an Attribute specidied by typ.
func (ae *AttributeEncoder) Flag(typ uint16, v bool) {
// Only set flag on no previous error or v == true.
if ae.err != nil || !v {
return
}
// Flags have no length or data fields.
ae.attrs = append(ae.attrs, Attribute{Type: typ})
}
// String encodes string s as a null-terminated string into an Attribute
// specified by typ.
func (ae *AttributeEncoder) String(typ uint16, s string) {
if ae.err != nil {
return
}
ae.attrs = append(ae.attrs, Attribute{
Type: typ,
Data: nlenc.Bytes(s),
})
}
// Bytes embeds raw byte data into an Attribute specified by typ.
func (ae *AttributeEncoder) Bytes(typ uint16, b []byte) {
if ae.err != nil {
return
}
ae.attrs = append(ae.attrs, Attribute{
Type: typ,
Data: b,
})
}
// Do is a general purpose function to encode arbitrary data into an attribute
// specified by typ.
//
// Do is especially helpful in encoding nested attributes, attribute arrays,
// or encoding arbitrary types (such as C structures) which don't fit cleanly
// into an unsigned integer value.
func (ae *AttributeEncoder) Do(typ uint16, fn func() ([]byte, error)) {
if ae.err != nil {
return
}
b, err := fn()
if err != nil {
ae.err = err
return
}
ae.attrs = append(ae.attrs, Attribute{
Type: typ,
Data: b,
})
}
// Nested embeds data produced by a nested AttributeEncoder and flags that data
// with the Nested flag. When calling Nested, the Encode method should not be
// called on the nested AttributeEncoder.
//
// The nested AttributeEncoder nae inherits the same ByteOrder setting as the
// top-level AttributeEncoder ae.
func (ae *AttributeEncoder) Nested(typ uint16, fn func(nae *AttributeEncoder) error) {
// Because we are wrapping Do, there is no need to check ae.err immediately.
ae.Do(Nested|typ, func() ([]byte, error) {
nae := NewAttributeEncoder()
nae.ByteOrder = ae.ByteOrder
if err := fn(nae); err != nil {
return nil, err
}
return nae.Encode()
})
}
// Encode returns the encoded bytes representing the attributes.
func (ae *AttributeEncoder) Encode() ([]byte, error) {
if ae.err != nil {
return nil, ae.err
}
return MarshalAttributes(ae.attrs)
}

561
vendor/github.com/mdlayher/netlink/conn.go generated vendored Normal file
View File

@ -0,0 +1,561 @@
package netlink
import (
"math/rand"
"sync"
"sync/atomic"
"syscall"
"time"
"golang.org/x/net/bpf"
)
// A Conn is a connection to netlink. A Conn can be used to send and
// receives messages to and from netlink.
//
// A Conn is safe for concurrent use, but to avoid contention in
// high-throughput applications, the caller should almost certainly create a
// pool of Conns and distribute them among workers.
//
// A Conn is capable of manipulating netlink subsystems from within a specific
// Linux network namespace, but special care must be taken when doing so. See
// the documentation of Config for details.
type Conn struct {
// Atomics must come first.
//
// seq is an atomically incremented integer used to provide sequence
// numbers when Conn.Send is called.
seq uint32
// mu serializes access to the netlink socket for the request/response
// transaction within Execute.
mu sync.RWMutex
// sock is the operating system-specific implementation of
// a netlink sockets connection.
sock Socket
// pid is the PID assigned by netlink.
pid uint32
// d provides debugging capabilities for a Conn if not nil.
d *debugger
}
// A Socket is an operating-system specific implementation of netlink
// sockets used by Conn.
type Socket interface {
Close() error
Send(m Message) error
SendMessages(m []Message) error
Receive() ([]Message, error)
}
// Dial dials a connection to netlink, using the specified netlink family.
// Config specifies optional configuration for Conn. If config is nil, a default
// configuration will be used.
func Dial(family int, config *Config) (*Conn, error) {
// Use OS-specific dial() to create Socket
c, pid, err := dial(family, config)
if err != nil {
return nil, err
}
return NewConn(c, pid), nil
}
// NewConn creates a Conn using the specified Socket and PID for netlink
// communications.
//
// NewConn is primarily useful for tests. Most applications should use
// Dial instead.
func NewConn(sock Socket, pid uint32) *Conn {
// Seed the sequence number using a random number generator.
r := rand.New(rand.NewSource(time.Now().UnixNano()))
seq := r.Uint32()
// Configure a debugger if arguments are set.
var d *debugger
if len(debugArgs) > 0 {
d = newDebugger(debugArgs)
}
return &Conn{
seq: seq,
sock: sock,
pid: pid,
d: d,
}
}
// debug executes fn with the debugger if the debugger is not nil.
func (c *Conn) debug(fn func(d *debugger)) {
if c.d == nil {
return
}
fn(c.d)
}
// Close closes the connection and unblocks any pending read operations.
func (c *Conn) Close() error {
// Close does not acquire a lock because it must be able to interrupt any
// blocked system calls, such as when Receive is waiting on a multicast
// group message.
//
// We rely on the kernel to deal with concurrent operations to the netlink
// socket itself.
return newOpError("close", c.sock.Close())
}
// Execute sends a single Message to netlink using Send, receives one or more
// replies using Receive, and then checks the validity of the replies against
// the request using Validate.
//
// Execute acquires a lock for the duration of the function call which blocks
// concurrent calls to Send, SendMessages, and Receive, in order to ensure
// consistency between netlink request/reply messages.
//
// See the documentation of Send, Receive, and Validate for details about
// each function.
func (c *Conn) Execute(m Message) ([]Message, error) {
// Acquire the write lock and invoke the internal implementations of Send
// and Receive which require the lock already be held.
c.mu.Lock()
defer c.mu.Unlock()
req, err := c.lockedSend(m)
if err != nil {
return nil, err
}
res, err := c.lockedReceive()
if err != nil {
return nil, err
}
if err := Validate(req, res); err != nil {
return nil, err
}
return res, nil
}
// SendMessages sends multiple Messages to netlink. The handling of
// a Header's Length, Sequence and PID fields is the same as when
// calling Send.
func (c *Conn) SendMessages(msgs []Message) ([]Message, error) {
// Wait for any concurrent calls to Execute to finish before proceeding.
c.mu.RLock()
defer c.mu.RUnlock()
for i := range msgs {
c.fixMsg(&msgs[i], nlmsgLength(len(msgs[i].Data)))
}
c.debug(func(d *debugger) {
for _, m := range msgs {
d.debugf(1, "send msgs: %+v", m)
}
})
if err := c.sock.SendMessages(msgs); err != nil {
c.debug(func(d *debugger) {
d.debugf(1, "send msgs: err: %v", err)
})
return nil, newOpError("send-messages", err)
}
return msgs, nil
}
// Send sends a single Message to netlink. In most cases, a Header's Length,
// Sequence, and PID fields should be set to 0, so they can be populated
// automatically before the Message is sent. On success, Send returns a copy
// of the Message with all parameters populated, for later validation.
//
// If Header.Length is 0, it will be automatically populated using the
// correct length for the Message, including its payload.
//
// If Header.Sequence is 0, it will be automatically populated using the
// next sequence number for this connection.
//
// If Header.PID is 0, it will be automatically populated using a PID
// assigned by netlink.
func (c *Conn) Send(m Message) (Message, error) {
// Wait for any concurrent calls to Execute to finish before proceeding.
c.mu.RLock()
defer c.mu.RUnlock()
return c.lockedSend(m)
}
// lockedSend implements Send, but must be called with c.mu acquired for reading.
// We rely on the kernel to deal with concurrent reads and writes to the netlink
// socket itself.
func (c *Conn) lockedSend(m Message) (Message, error) {
c.fixMsg(&m, nlmsgLength(len(m.Data)))
c.debug(func(d *debugger) {
d.debugf(1, "send: %+v", m)
})
if err := c.sock.Send(m); err != nil {
c.debug(func(d *debugger) {
d.debugf(1, "send: err: %v", err)
})
return Message{}, newOpError("send", err)
}
return m, nil
}
// Receive receives one or more messages from netlink. Multi-part messages are
// handled transparently and returned as a single slice of Messages, with the
// final empty "multi-part done" message removed.
//
// If any of the messages indicate a netlink error, that error will be returned.
func (c *Conn) Receive() ([]Message, error) {
// Wait for any concurrent calls to Execute to finish before proceeding.
c.mu.RLock()
defer c.mu.RUnlock()
return c.lockedReceive()
}
// lockedReceive implements Receive, but must be called with c.mu acquired for reading.
// We rely on the kernel to deal with concurrent reads and writes to the netlink
// socket itself.
func (c *Conn) lockedReceive() ([]Message, error) {
msgs, err := c.receive()
if err != nil {
c.debug(func(d *debugger) {
d.debugf(1, "recv: err: %v", err)
})
return nil, err
}
c.debug(func(d *debugger) {
for _, m := range msgs {
d.debugf(1, "recv: %+v", m)
}
})
// When using nltest, it's possible for zero messages to be returned by receive.
if len(msgs) == 0 {
return msgs, nil
}
// Trim the final message with multi-part done indicator if
// present.
if m := msgs[len(msgs)-1]; m.Header.Flags&Multi != 0 && m.Header.Type == Done {
return msgs[:len(msgs)-1], nil
}
return msgs, nil
}
// receive is the internal implementation of Conn.Receive, which can be called
// recursively to handle multi-part messages.
func (c *Conn) receive() ([]Message, error) {
// NB: All non-nil errors returned from this function *must* be of type
// OpError in order to maintain the appropriate contract with callers of
// this package.
//
// This contract also applies to functions called within this function,
// such as checkMessage.
var res []Message
for {
msgs, err := c.sock.Receive()
if err != nil {
return nil, newOpError("receive", err)
}
// If this message is multi-part, we will need to continue looping to
// drain all the messages from the socket.
var multi bool
for _, m := range msgs {
if err := checkMessage(m); err != nil {
return nil, err
}
// Does this message indicate a multi-part message?
if m.Header.Flags&Multi == 0 {
// No, check the next messages.
continue
}
// Does this message indicate the last message in a series of
// multi-part messages from a single read?
multi = m.Header.Type != Done
}
res = append(res, msgs...)
if !multi {
// No more messages coming.
return res, nil
}
}
}
// A groupJoinLeaver is a Socket that supports joining and leaving
// netlink multicast groups.
type groupJoinLeaver interface {
Socket
JoinGroup(group uint32) error
LeaveGroup(group uint32) error
}
// JoinGroup joins a netlink multicast group by its ID.
func (c *Conn) JoinGroup(group uint32) error {
conn, ok := c.sock.(groupJoinLeaver)
if !ok {
return notSupported("join-group")
}
return newOpError("join-group", conn.JoinGroup(group))
}
// LeaveGroup leaves a netlink multicast group by its ID.
func (c *Conn) LeaveGroup(group uint32) error {
conn, ok := c.sock.(groupJoinLeaver)
if !ok {
return notSupported("leave-group")
}
return newOpError("leave-group", conn.LeaveGroup(group))
}
// A bpfSetter is a Socket that supports setting and removing BPF filters.
type bpfSetter interface {
Socket
bpf.Setter
RemoveBPF() error
}
// SetBPF attaches an assembled BPF program to a Conn.
func (c *Conn) SetBPF(filter []bpf.RawInstruction) error {
conn, ok := c.sock.(bpfSetter)
if !ok {
return notSupported("set-bpf")
}
return newOpError("set-bpf", conn.SetBPF(filter))
}
// RemoveBPF removes a BPF filter from a Conn.
func (c *Conn) RemoveBPF() error {
conn, ok := c.sock.(bpfSetter)
if !ok {
return notSupported("remove-bpf")
}
return newOpError("remove-bpf", conn.RemoveBPF())
}
// A deadlineSetter is a Socket that supports setting deadlines.
type deadlineSetter interface {
Socket
SetDeadline(time.Time) error
SetReadDeadline(time.Time) error
SetWriteDeadline(time.Time) error
}
// SetDeadline sets the read and write deadlines associated with the connection.
func (c *Conn) SetDeadline(t time.Time) error {
conn, ok := c.sock.(deadlineSetter)
if !ok {
return notSupported("set-deadline")
}
return newOpError("set-deadline", conn.SetDeadline(t))
}
// SetReadDeadline sets the read deadline associated with the connection.
func (c *Conn) SetReadDeadline(t time.Time) error {
conn, ok := c.sock.(deadlineSetter)
if !ok {
return notSupported("set-read-deadline")
}
return newOpError("set-read-deadline", conn.SetReadDeadline(t))
}
// SetWriteDeadline sets the write deadline associated with the connection.
func (c *Conn) SetWriteDeadline(t time.Time) error {
conn, ok := c.sock.(deadlineSetter)
if !ok {
return notSupported("set-write-deadline")
}
return newOpError("set-write-deadline", conn.SetWriteDeadline(t))
}
// A ConnOption is a boolean option that may be set for a Conn.
type ConnOption int
// Possible ConnOption values. These constants are equivalent to the Linux
// setsockopt boolean options for netlink sockets.
const (
PacketInfo ConnOption = iota
BroadcastError
NoENOBUFS
ListenAllNSID
CapAcknowledge
ExtendedAcknowledge
GetStrictCheck
)
// An optionSetter is a Socket that supports setting netlink options.
type optionSetter interface {
Socket
SetOption(option ConnOption, enable bool) error
}
// SetOption enables or disables a netlink socket option for the Conn.
func (c *Conn) SetOption(option ConnOption, enable bool) error {
conn, ok := c.sock.(optionSetter)
if !ok {
return notSupported("set-option")
}
return newOpError("set-option", conn.SetOption(option, enable))
}
// A bufferSetter is a Socket that supports setting connection buffer sizes.
type bufferSetter interface {
Socket
SetReadBuffer(bytes int) error
SetWriteBuffer(bytes int) error
}
// SetReadBuffer sets the size of the operating system's receive buffer
// associated with the Conn.
func (c *Conn) SetReadBuffer(bytes int) error {
conn, ok := c.sock.(bufferSetter)
if !ok {
return notSupported("set-read-buffer")
}
return newOpError("set-read-buffer", conn.SetReadBuffer(bytes))
}
// SetWriteBuffer sets the size of the operating system's transmit buffer
// associated with the Conn.
func (c *Conn) SetWriteBuffer(bytes int) error {
conn, ok := c.sock.(bufferSetter)
if !ok {
return notSupported("set-write-buffer")
}
return newOpError("set-write-buffer", conn.SetWriteBuffer(bytes))
}
// A syscallConner is a Socket that supports syscall.Conn.
type syscallConner interface {
Socket
syscall.Conn
}
var _ syscall.Conn = &Conn{}
// SyscallConn returns a raw network connection. This implements the
// syscall.Conn interface.
//
// SyscallConn is intended for advanced use cases, such as getting and setting
// arbitrary socket options using the netlink socket's file descriptor.
//
// Once invoked, it is the caller's responsibility to ensure that operations
// performed using Conn and the syscall.RawConn do not conflict with
// each other.
func (c *Conn) SyscallConn() (syscall.RawConn, error) {
sc, ok := c.sock.(syscallConner)
if !ok {
return nil, notSupported("syscall-conn")
}
// TODO(mdlayher): mutex or similar to enforce syscall.RawConn contract of
// FD remaining valid for duration of calls?
return sc.SyscallConn()
}
// fixMsg updates the fields of m using the logic specified in Send.
func (c *Conn) fixMsg(m *Message, ml int) {
if m.Header.Length == 0 {
m.Header.Length = uint32(nlmsgAlign(ml))
}
if m.Header.Sequence == 0 {
m.Header.Sequence = c.nextSequence()
}
if m.Header.PID == 0 {
m.Header.PID = c.pid
}
}
// nextSequence atomically increments Conn's sequence number and returns
// the incremented value.
func (c *Conn) nextSequence() uint32 {
return atomic.AddUint32(&c.seq, 1)
}
// Validate validates one or more reply Messages against a request Message,
// ensuring that they contain matching sequence numbers and PIDs.
func Validate(request Message, replies []Message) error {
for _, m := range replies {
// Check for mismatched sequence, unless:
// - request had no sequence, meaning we are probably validating
// a multicast reply
if m.Header.Sequence != request.Header.Sequence && request.Header.Sequence != 0 {
return newOpError("validate", errMismatchedSequence)
}
// Check for mismatched PID, unless:
// - request had no PID, meaning we are either:
// - validating a multicast reply
// - netlink has not yet assigned us a PID
// - response had no PID, meaning it's from the kernel as a multicast reply
if m.Header.PID != request.Header.PID && request.Header.PID != 0 && m.Header.PID != 0 {
return newOpError("validate", errMismatchedPID)
}
}
return nil
}
// Config contains options for a Conn.
type Config struct {
// Groups is a bitmask which specifies multicast groups. If set to 0,
// no multicast group subscriptions will be made.
Groups uint32
// NetNS specifies the network namespace the Conn will operate in.
//
// If set (non-zero), Conn will enter the specified network namespace and
// an error will occur in Dial if the operation fails.
//
// If not set (zero), a best-effort attempt will be made to enter the
// network namespace of the calling thread: this means that any changes made
// to the calling thread's network namespace will also be reflected in Conn.
// If this operation fails (due to lack of permissions or because network
// namespaces are disabled by kernel configuration), Dial will not return
// an error, and the Conn will operate in the default network namespace of
// the process. This enables non-privileged use of Conn in applications
// which do not require elevated privileges.
//
// Entering a network namespace is a privileged operation (root or
// CAP_SYS_ADMIN are required), and most applications should leave this set
// to 0.
NetNS int
// DisableNSLockThread is deprecated and has no effect.
DisableNSLockThread bool
}

254
vendor/github.com/mdlayher/netlink/conn_linux.go generated vendored Normal file
View File

@ -0,0 +1,254 @@
//+build linux
package netlink
import (
"os"
"runtime"
"syscall"
"time"
"unsafe"
"github.com/mdlayher/socket"
"golang.org/x/net/bpf"
"golang.org/x/sys/unix"
)
var _ Socket = &conn{}
// A conn is the Linux implementation of a netlink sockets connection.
type conn struct {
s *socket.Conn
}
// dial is the entry point for Dial. dial opens a netlink socket using
// system calls, and returns its PID.
func dial(family int, config *Config) (*conn, uint32, error) {
if config == nil {
config = &Config{}
}
// The caller has indicated it wants the netlink socket to be created
// inside another network namespace.
if config.NetNS != 0 {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
// Retrieve and store the calling OS thread's network namespace so
// the thread can be reassigned to it after creating a socket in another
// network namespace.
threadNS, err := threadNetNS()
if err != nil {
return nil, 0, err
}
// Always close the netns handle created above.
defer threadNS.Close()
// Assign the current OS thread the goroutine is locked to to the given
// network namespace.
if err := threadNS.Set(config.NetNS); err != nil {
return nil, 0, err
}
// Thread's namespace has been successfully set. Return the thread
// back to its original namespace after attempting to create the
// netlink socket.
defer threadNS.Restore()
}
// Prepare the netlink socket.
s, err := socket.Socket(unix.AF_NETLINK, unix.SOCK_RAW, family, "netlink")
if err != nil {
return nil, 0, err
}
return newConn(s, config)
}
// newConn binds a connection to netlink using the input *socket.Conn.
func newConn(s *socket.Conn, config *Config) (*conn, uint32, error) {
if config == nil {
config = &Config{}
}
addr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: config.Groups,
}
// Socket must be closed in the event of any system call errors, to avoid
// leaking file descriptors.
if err := s.Bind(addr); err != nil {
_ = s.Close()
return nil, 0, err
}
sa, err := s.Getsockname()
if err != nil {
_ = s.Close()
return nil, 0, err
}
return &conn{
s: s,
}, sa.(*unix.SockaddrNetlink).Pid, nil
}
// SendMessages serializes multiple Messages and sends them to netlink.
func (c *conn) SendMessages(messages []Message) error {
var buf []byte
for _, m := range messages {
b, err := m.MarshalBinary()
if err != nil {
return err
}
buf = append(buf, b...)
}
sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK}
return c.s.Sendmsg(buf, nil, sa, 0)
}
// Send sends a single Message to netlink.
func (c *conn) Send(m Message) error {
b, err := m.MarshalBinary()
if err != nil {
return err
}
sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK}
return c.s.Sendmsg(b, nil, sa, 0)
}
// Receive receives one or more Messages from netlink.
func (c *conn) Receive() ([]Message, error) {
b := make([]byte, os.Getpagesize())
for {
// Peek at the buffer to see how many bytes are available.
//
// TODO(mdlayher): deal with OOB message data if available, such as
// when PacketInfo ConnOption is true.
n, _, _, _, err := c.s.Recvmsg(b, nil, unix.MSG_PEEK)
if err != nil {
return nil, err
}
// Break when we can read all messages
if n < len(b) {
break
}
// Double in size if not enough bytes
b = make([]byte, len(b)*2)
}
// Read out all available messages
n, _, _, _, err := c.s.Recvmsg(b, nil, 0)
if err != nil {
return nil, err
}
raw, err := syscall.ParseNetlinkMessage(b[:nlmsgAlign(n)])
if err != nil {
return nil, err
}
msgs := make([]Message, 0, len(raw))
for _, r := range raw {
m := Message{
Header: sysToHeader(r.Header),
Data: r.Data,
}
msgs = append(msgs, m)
}
return msgs, nil
}
// Close closes the connection.
func (c *conn) Close() error { return c.s.Close() }
// JoinGroup joins a multicast group by ID.
func (c *conn) JoinGroup(group uint32) error {
return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_ADD_MEMBERSHIP, int(group))
}
// LeaveGroup leaves a multicast group by ID.
func (c *conn) LeaveGroup(group uint32) error {
return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_DROP_MEMBERSHIP, int(group))
}
// SetBPF attaches an assembled BPF program to a conn.
func (c *conn) SetBPF(filter []bpf.RawInstruction) error { return c.s.SetBPF(filter) }
// RemoveBPF removes a BPF filter from a conn.
func (c *conn) RemoveBPF() error { return c.s.RemoveBPF() }
// SetOption enables or disables a netlink socket option for the Conn.
func (c *conn) SetOption(option ConnOption, enable bool) error {
o, ok := linuxOption(option)
if !ok {
// Return the typical Linux error for an unknown ConnOption.
return os.NewSyscallError("setsockopt", unix.ENOPROTOOPT)
}
var v int
if enable {
v = 1
}
return c.s.SetsockoptInt(unix.SOL_NETLINK, o, v)
}
func (c *conn) SetDeadline(t time.Time) error { return c.s.SetDeadline(t) }
func (c *conn) SetReadDeadline(t time.Time) error { return c.s.SetReadDeadline(t) }
func (c *conn) SetWriteDeadline(t time.Time) error { return c.s.SetWriteDeadline(t) }
// SetReadBuffer sets the size of the operating system's receive buffer
// associated with the Conn.
func (c *conn) SetReadBuffer(bytes int) error { return c.s.SetReadBuffer(bytes) }
// SetReadBuffer sets the size of the operating system's transmit buffer
// associated with the Conn.
func (c *conn) SetWriteBuffer(bytes int) error { return c.s.SetWriteBuffer(bytes) }
// SyscallConn returns a raw network connection.
func (c *conn) SyscallConn() (syscall.RawConn, error) { return c.s.SyscallConn() }
// linuxOption converts a ConnOption to its Linux value.
func linuxOption(o ConnOption) (int, bool) {
switch o {
case PacketInfo:
return unix.NETLINK_PKTINFO, true
case BroadcastError:
return unix.NETLINK_BROADCAST_ERROR, true
case NoENOBUFS:
return unix.NETLINK_NO_ENOBUFS, true
case ListenAllNSID:
return unix.NETLINK_LISTEN_ALL_NSID, true
case CapAcknowledge:
return unix.NETLINK_CAP_ACK, true
case ExtendedAcknowledge:
return unix.NETLINK_EXT_ACK, true
case GetStrictCheck:
return unix.NETLINK_GET_STRICT_CHK, true
default:
return 0, false
}
}
// sysToHeader converts a syscall.NlMsghdr to a Header.
func sysToHeader(r syscall.NlMsghdr) Header {
// NB: the memory layout of Header and syscall.NlMsgHdr must be
// exactly the same for this unsafe cast to work
return *(*Header)(unsafe.Pointer(&r))
}
// newError converts an error number from netlink into the appropriate
// system call error for Linux.
func newError(errno int) error {
return syscall.Errno(errno)
}

29
vendor/github.com/mdlayher/netlink/conn_others.go generated vendored Normal file
View File

@ -0,0 +1,29 @@
//+build !linux
package netlink
import (
"fmt"
"runtime"
)
// errUnimplemented is returned by all functions on platforms that
// cannot make use of netlink sockets.
var errUnimplemented = fmt.Errorf("netlink: not implemented on %s/%s",
runtime.GOOS, runtime.GOARCH)
var _ Socket = &conn{}
// A conn is the no-op implementation of a netlink sockets connection.
type conn struct{}
// All cross-platform functions and Socket methods are unimplemented outside
// of Linux.
func dial(_ int, _ *Config) (*conn, uint32, error) { return nil, 0, errUnimplemented }
func newError(_ int) error { return errUnimplemented }
func (c *conn) Send(_ Message) error { return errUnimplemented }
func (c *conn) SendMessages(_ []Message) error { return errUnimplemented }
func (c *conn) Receive() ([]Message, error) { return nil, errUnimplemented }
func (c *conn) Close() error { return errUnimplemented }

69
vendor/github.com/mdlayher/netlink/debug.go generated vendored Normal file
View File

@ -0,0 +1,69 @@
package netlink
import (
"fmt"
"log"
"os"
"strconv"
"strings"
)
// Arguments used to create a debugger.
var debugArgs []string
func init() {
// Is netlink debugging enabled?
s := os.Getenv("NLDEBUG")
if s == "" {
return
}
debugArgs = strings.Split(s, ",")
}
// A debugger is used to provide debugging information about a netlink connection.
type debugger struct {
Log *log.Logger
Level int
}
// newDebugger creates a debugger by parsing key=value arguments.
func newDebugger(args []string) *debugger {
d := &debugger{
Log: log.New(os.Stderr, "nl: ", 0),
Level: 1,
}
for _, a := range args {
kv := strings.Split(a, "=")
if len(kv) != 2 {
// Ignore malformed pairs and assume callers wants defaults.
continue
}
switch kv[0] {
// Select the log level for the debugger.
case "level":
level, err := strconv.Atoi(kv[1])
if err != nil {
panicf("netlink: invalid NLDEBUG level: %q", a)
}
d.Level = level
}
}
return d
}
// debugf prints debugging information at the specified level, if d.Level is
// high enough to print the message.
func (d *debugger) debugf(level int, format string, v ...interface{}) {
if d.Level >= level {
d.Log.Printf(format, v...)
}
}
func panicf(format string, a ...interface{}) {
panic(fmt.Sprintf(format, a...))
}

36
vendor/github.com/mdlayher/netlink/doc.go generated vendored Normal file
View File

@ -0,0 +1,36 @@
// Package netlink provides low-level access to Linux netlink sockets.
//
// If you have any questions or you'd like some guidance, please join us on
// Gophers Slack (https://invite.slack.golangbridge.org) in the #networking
// channel!
//
//
// Network namespaces
//
// This package is aware of Linux network namespaces, and can enter different
// network namespaces either implicitly or explicitly, depending on
// configuration. The Config structure passed to Dial to create a Conn controls
// these behaviors. See the documentation of Config.NetNS for details.
//
//
// Debugging
//
// This package supports rudimentary netlink connection debugging support.
// To enable this, run your binary with the NLDEBUG environment variable set.
// Debugging information will be output to stderr with a prefix of "nl:".
//
// To use the debugging defaults, use:
//
// $ NLDEBUG=1 ./nlctl
//
// To configure individual aspects of the debugger, pass key/value options such
// as:
//
// $ NLDEBUG=level=1 ./nlctl
//
// Available key/value debugger options include:
//
// level=N: specify the debugging level (only "1" is currently supported)
package netlink
//go:generate dot netlink.dot -T svg -o netlink.svg

143
vendor/github.com/mdlayher/netlink/errors.go generated vendored Normal file
View File

@ -0,0 +1,143 @@
package netlink
import (
"errors"
"fmt"
"net"
"os"
"strings"
)
// Error messages which can be returned by Validate.
var (
errMismatchedSequence = errors.New("mismatched sequence in netlink reply")
errMismatchedPID = errors.New("mismatched PID in netlink reply")
errShortErrorMessage = errors.New("not enough data for netlink error code")
)
// Errors which can be returned by a Socket that does not implement
// all exposed methods of Conn.
var errNotSupported = errors.New("operation not supported")
// notSupported provides a concise constructor for "not supported" errors.
func notSupported(op string) error {
return newOpError(op, errNotSupported)
}
// IsNotExist determines if an error is produced as the result of querying some
// file, object, resource, etc. which does not exist. Users of this package
// should always use netlink.IsNotExist, rather than os.IsNotExist, when
// checking for specific netlink-related errors.
//
// Errors types created by this package, such as OpError, can be used with
// IsNotExist, but this function also defers to the behavior of os.IsNotExist
// for unrecognized error types.
//
// Deprecated: make use of errors.Unwrap and errors.Is in Go 1.13+.
func IsNotExist(err error) bool {
switch err := err.(type) {
case *OpError:
// Unwrap the inner error and use the stdlib's logic.
return os.IsNotExist(err.Err)
default:
return os.IsNotExist(err)
}
}
var (
_ error = &OpError{}
_ net.Error = &OpError{}
// Ensure compatibility with Go 1.13+ errors package.
_ interface{ Unwrap() error } = &OpError{}
)
// An OpError is an error produced as the result of a failed netlink operation.
type OpError struct {
// Op is the operation which caused this OpError, such as "send"
// or "receive".
Op string
// Err is the underlying error which caused this OpError.
//
// If Err was produced by a system call error, Err will be of type
// *os.SyscallError. If Err was produced by an error code in a netlink
// message, Err will contain a raw error value type such as a unix.Errno.
//
// Most callers should inspect Err using errors.Is from the standard
// library.
Err error
// Message and Offset contain additional error information provided by the
// kernel when the ExtendedAcknowledge option is set on a Conn and the
// kernel indicates the AcknowledgeTLVs flag in a response. If this option
// is not set, both of these fields will be empty.
Message string
Offset int
}
// newOpError is a small wrapper for creating an OpError. As a convenience, it
// returns nil if the input err is nil: akin to os.NewSyscallError.
func newOpError(op string, err error) error {
if err == nil {
return nil
}
return &OpError{
Op: op,
Err: err,
}
}
func (e *OpError) Error() string {
if e == nil {
return "<nil>"
}
var sb strings.Builder
_, _ = sb.WriteString(fmt.Sprintf("netlink %s: %v", e.Op, e.Err))
if e.Message != "" || e.Offset != 0 {
_, _ = sb.WriteString(fmt.Sprintf(", offset: %d, message: %q",
e.Offset, e.Message))
}
return sb.String()
}
// Unwrap unwraps the internal Err field for use with errors.Unwrap.
func (e *OpError) Unwrap() error { return e.Err }
// Portions of this code taken from the Go standard library:
//
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
type timeout interface {
Timeout() bool
}
// Timeout reports whether the error was caused by an I/O timeout.
func (e *OpError) Timeout() bool {
if ne, ok := e.Err.(*os.SyscallError); ok {
t, ok := ne.Err.(timeout)
return ok && t.Timeout()
}
t, ok := e.Err.(timeout)
return ok && t.Timeout()
}
type temporary interface {
Temporary() bool
}
// Temporary reports whether an operation may succeed if retried.
func (e *OpError) Temporary() bool {
if ne, ok := e.Err.(*os.SyscallError); ok {
t, ok := ne.Err.(temporary)
return ok && t.Temporary()
}
t, ok := e.Err.(temporary)
return ok && t.Temporary()
}

81
vendor/github.com/mdlayher/netlink/fuzz.go generated vendored Normal file
View File

@ -0,0 +1,81 @@
//+build gofuzz
package netlink
import "github.com/google/go-cmp/cmp"
func fuzz(b1 []byte) int {
// 1. unmarshal, marshal, unmarshal again to check m1 and m2 for equality
// after a round trip. checkMessage is also used because there is a fair
// amount of tricky logic around testing for presence of error headers and
// extended acknowledgement attributes.
var m1 Message
if err := m1.UnmarshalBinary(b1); err != nil {
return 0
}
if err := checkMessage(m1); err != nil {
return 0
}
b2, err := m1.MarshalBinary()
if err != nil {
panicf("failed to marshal m1: %v", err)
}
var m2 Message
if err := m2.UnmarshalBinary(b2); err != nil {
panicf("failed to unmarshal m2: %v", err)
}
if err := checkMessage(m2); err != nil {
panicf("failed to check m2: %v", err)
}
if diff := cmp.Diff(m1, m2); diff != "" {
panicf("unexpected Message (-want +got):\n%s", diff)
}
// 2. marshal again and compare b2 and b3 (b1 may have reserved bytes set
// which we ignore and fill with zeros when marshaling) for equality.
b3, err := m2.MarshalBinary()
if err != nil {
panicf("failed to marshal m2: %v", err)
}
if diff := cmp.Diff(b2, b3); diff != "" {
panicf("unexpected message bytes (-want +got):\n%s", diff)
}
// 3. unmarshal any possible attributes from m1's data and marshal them
// again for comparison.
a1, err := UnmarshalAttributes(m1.Data)
if err != nil {
return 0
}
ab1, err := MarshalAttributes(a1)
if err != nil {
panicf("failed to marshal a1: %v", err)
}
a2, err := UnmarshalAttributes(ab1)
if err != nil {
panicf("failed to unmarshal a2: %v", err)
}
if diff := cmp.Diff(a1, a2); diff != "" {
panicf("unexpected Attributes (-want +got):\n%s", diff)
}
ab2, err := MarshalAttributes(a2)
if err != nil {
panicf("failed to marshal a2: %v", err)
}
if diff := cmp.Diff(ab1, ab2); diff != "" {
panicf("unexpected attribute bytes (-want +got):\n%s", diff)
}
return 1
}

347
vendor/github.com/mdlayher/netlink/message.go generated vendored Normal file
View File

@ -0,0 +1,347 @@
package netlink
import (
"errors"
"fmt"
"unsafe"
"github.com/mdlayher/netlink/nlenc"
)
// Flags which may apply to netlink attribute types when communicating with
// certain netlink families.
const (
Nested uint16 = 0x8000
NetByteOrder uint16 = 0x4000
// attrTypeMask masks off Type bits used for the above flags.
attrTypeMask uint16 = 0x3fff
)
// Various errors which may occur when attempting to marshal or unmarshal
// a Message to and from its binary form.
var (
errIncorrectMessageLength = errors.New("netlink message header length incorrect")
errShortMessage = errors.New("not enough data to create a netlink message")
errUnalignedMessage = errors.New("input data is not properly aligned for netlink message")
)
// HeaderFlags specify flags which may be present in a Header.
type HeaderFlags uint16
const (
// General netlink communication flags.
// Request indicates a request to netlink.
Request HeaderFlags = 1
// Multi indicates a multi-part message, terminated by Done on the
// last message.
Multi HeaderFlags = 2
// Acknowledge requests that netlink reply with an acknowledgement
// using Error and, if needed, an error code.
Acknowledge HeaderFlags = 4
// Echo requests that netlink echo this request back to the sender.
Echo HeaderFlags = 8
// DumpInterrupted indicates that a dump was inconsistent due to a
// sequence change.
DumpInterrupted HeaderFlags = 16
// DumpFiltered indicates that a dump was filtered as requested.
DumpFiltered HeaderFlags = 32
// Flags used to retrieve data from netlink.
// Root requests that netlink return a complete table instead of a
// single entry.
Root HeaderFlags = 0x100
// Match requests that netlink return a list of all matching entries.
Match HeaderFlags = 0x200
// Atomic requests that netlink send an atomic snapshot of its entries.
// Requires CAP_NET_ADMIN or an effective UID of 0.
Atomic HeaderFlags = 0x400
// Dump requests that netlink return a complete list of all entries.
Dump HeaderFlags = Root | Match
// Flags used to create objects.
// Replace indicates request replaces an existing matching object.
Replace HeaderFlags = 0x100
// Excl indicates request does not replace the object if it already exists.
Excl HeaderFlags = 0x200
// Create indicates request creates an object if it doesn't already exist.
Create HeaderFlags = 0x400
// Append indicates request adds to the end of the object list.
Append HeaderFlags = 0x800
// Flags for extended acknowledgements.
// Capped indicates the size of a request was capped in an extended
// acknowledgement.
Capped HeaderFlags = 0x100
// AcknowledgeTLVs indicates the presence of netlink extended
// acknowledgement TLVs in a response.
AcknowledgeTLVs HeaderFlags = 0x200
)
// String returns the string representation of a HeaderFlags.
func (f HeaderFlags) String() string {
names := []string{
"request",
"multi",
"acknowledge",
"echo",
"dumpinterrupted",
"dumpfiltered",
}
var s string
left := uint(f)
for i, name := range names {
if f&(1<<uint(i)) != 0 {
if s != "" {
s += "|"
}
s += name
left ^= (1 << uint(i))
}
}
if s == "" && left == 0 {
s = "0"
}
if left > 0 {
if s != "" {
s += "|"
}
s += fmt.Sprintf("%#x", left)
}
return s
}
// HeaderType specifies the type of a Header.
type HeaderType uint16
const (
// Noop indicates that no action was taken.
Noop HeaderType = 0x1
// Error indicates an error code is present, which is also used to indicate
// success when the code is 0.
Error HeaderType = 0x2
// Done indicates the end of a multi-part message.
Done HeaderType = 0x3
// Overrun indicates that data was lost from this message.
Overrun HeaderType = 0x4
)
// String returns the string representation of a HeaderType.
func (t HeaderType) String() string {
switch t {
case Noop:
return "noop"
case Error:
return "error"
case Done:
return "done"
case Overrun:
return "overrun"
default:
return fmt.Sprintf("unknown(%d)", t)
}
}
// NB: the memory layout of Header and Linux's syscall.NlMsgHdr must be
// exactly the same. Cannot reorder, change data type, add, or remove fields.
// Named types of the same size (e.g. HeaderFlags is a uint16) are okay.
// A Header is a netlink header. A Header is sent and received with each
// Message to indicate metadata regarding a Message.
type Header struct {
// Length of a Message, including this Header.
Length uint32
// Contents of a Message.
Type HeaderType
// Flags which may be used to modify a request or response.
Flags HeaderFlags
// The sequence number of a Message.
Sequence uint32
// The process ID of the sending process.
PID uint32
}
// A Message is a netlink message. It contains a Header and an arbitrary
// byte payload, which may be decoded using information from the Header.
//
// Data is often populated with netlink attributes. For easy encoding and
// decoding of attributes, see the AttributeDecoder and AttributeEncoder types.
type Message struct {
Header Header
Data []byte
}
// MarshalBinary marshals a Message into a byte slice.
func (m Message) MarshalBinary() ([]byte, error) {
ml := nlmsgAlign(int(m.Header.Length))
if ml < nlmsgHeaderLen || ml != int(m.Header.Length) {
return nil, errIncorrectMessageLength
}
b := make([]byte, ml)
nlenc.PutUint32(b[0:4], m.Header.Length)
nlenc.PutUint16(b[4:6], uint16(m.Header.Type))
nlenc.PutUint16(b[6:8], uint16(m.Header.Flags))
nlenc.PutUint32(b[8:12], m.Header.Sequence)
nlenc.PutUint32(b[12:16], m.Header.PID)
copy(b[16:], m.Data)
return b, nil
}
// UnmarshalBinary unmarshals the contents of a byte slice into a Message.
func (m *Message) UnmarshalBinary(b []byte) error {
if len(b) < nlmsgHeaderLen {
return errShortMessage
}
if len(b) != nlmsgAlign(len(b)) {
return errUnalignedMessage
}
// Don't allow misleading length
m.Header.Length = nlenc.Uint32(b[0:4])
if int(m.Header.Length) != len(b) {
return errShortMessage
}
m.Header.Type = HeaderType(nlenc.Uint16(b[4:6]))
m.Header.Flags = HeaderFlags(nlenc.Uint16(b[6:8]))
m.Header.Sequence = nlenc.Uint32(b[8:12])
m.Header.PID = nlenc.Uint32(b[12:16])
m.Data = b[16:]
return nil
}
// checkMessage checks a single Message for netlink errors.
func checkMessage(m Message) error {
// NB: All non-nil errors returned from this function *must* be of type
// OpError in order to maintain the appropriate contract with callers of
// this package.
// The libnl documentation indicates that type error can
// contain error codes:
// https://www.infradead.org/~tgr/libnl/doc/core.html#core_errmsg.
//
// However, rtnetlink at least seems to also allow errors to occur at the
// end of a multipart message with done/multi and an error number.
var hasHeader bool
switch {
case m.Header.Type == Error:
// Error code followed by nlmsghdr/ext ack attributes.
hasHeader = true
case m.Header.Type == Done && m.Header.Flags&Multi != 0:
// If no data, there must be no error number so just exit early. Some
// of the unit tests hard-coded this but I don't actually know if this
// case occurs in the wild.
if len(m.Data) == 0 {
return nil
}
// Done|Multi potentially followed by ext ack attributes.
default:
// Neither, nothing to do.
return nil
}
// Errno occupies 4 bytes.
const endErrno = 4
if len(m.Data) < endErrno {
return newOpError("receive", errShortErrorMessage)
}
c := nlenc.Int32(m.Data[:endErrno])
if c == 0 {
// 0 indicates no error.
return nil
}
oerr := &OpError{
Op: "receive",
// Error code is a negative integer, convert it into an OS-specific raw
// system call error, but do not wrap with os.NewSyscallError to signify
// that this error was produced by a netlink message; not a system call.
Err: newError(-1 * int(c)),
}
// TODO(mdlayher): investigate the Capped flag.
if m.Header.Flags&AcknowledgeTLVs == 0 {
// No extended acknowledgement.
return oerr
}
// Flags indicate an extended acknowledgement. The type/flags combination
// checked above determines the offset where the TLVs occur.
var off int
if hasHeader {
// There is an nlmsghdr preceding the TLVs.
if len(m.Data) < endErrno+nlmsgHeaderLen {
return newOpError("receive", errShortErrorMessage)
}
// The TLVs should be at the offset indicated by the nlmsghdr.length,
// plus the offset where the header began. But make sure the calculated
// offset is still in-bounds.
h := *(*Header)(unsafe.Pointer(&m.Data[endErrno : endErrno+nlmsgHeaderLen][0]))
off = endErrno + int(h.Length)
if len(m.Data) < off {
return newOpError("receive", errShortErrorMessage)
}
} else {
// There is no nlmsghdr preceding the TLVs, parse them directly.
off = endErrno
}
ad, err := NewAttributeDecoder(m.Data[off:])
if err != nil {
// Malformed TLVs, just return the OpError with the info we have.
return oerr
}
for ad.Next() {
switch ad.Type() {
case 1: // unix.NLMSGERR_ATTR_MSG
oerr.Message = ad.String()
case 2: // unix.NLMSGERR_ATTR_OFFS
oerr.Offset = int(ad.Uint32())
}
}
// Explicitly ignore ad.Err: malformed TLVs, just return the OpError with
// the info we have.
return oerr
}

87
vendor/github.com/mdlayher/netlink/netlink.dot generated vendored Normal file
View File

@ -0,0 +1,87 @@
digraph {
rankdir = LR
subgraph cluster_netlink {
"github.com/mdlayher/netlink" [URL="https://github.com/mdlayher/netlink"]
}
subgraph cluster_connector {
label = "NETLINK_CONNECTOR";
{
"github.com/fearful-symmetry/garlic" [URL="https://github.com/fearful-symmetry/garlic"]
} -> "github.com/mdlayher/netlink"
}
subgraph cluster_crypto {
label = "NETLINK_CRYPTO";
{
"github.com/mdlayher/cryptonl" [URL="https://github.com/mdlayher/cryptonl"]
} -> "github.com/mdlayher/netlink"
}
subgraph cluster_generic {
label = "NETLINK_GENERIC (genetlink)";
"github.com/mdlayher/genetlink" [URL="https://github.com/mdlayher/genetlink"]
"github.com/mdlayher/genetlink" -> "github.com/mdlayher/netlink"
{
"github.com/axatrax/l2tp" [URL="https://github.com/axatrax/l2tp"]
"github.com/digitalocean/go-openvswitch" [URL="https://github.com/digitalocean/go-openvswitch"]
"github.com/mdlayher/devlink" [URL="https://github.com/mdlayher/devlink"]
"github.com/mdlayher/ethtool" [URL="https://github.com/mdlayher/ethtool"]
"github.com/mdlayher/quota" [URL="https://github.com/mdlayher/quota"]
"github.com/mdlayher/taskstats" [URL="https://github.com/mdlayher/taskstats"]
"github.com/mdlayher/wifi" [URL="https://github.com/mdlayher/wifi"]
"github.com/Merovius/nbd" [URL="https://github.com/Merovius/nbd"]
"github.com/rtr7/router7" [URL="https://github.com/rtr7/router7"]
"github.com/u-root/u-bmc" [URL="https://github.com/u-root/u-bmc"]
"golang.zx2c4.com/wireguard/wgctrl" [URL="https://golang.zx2c4.com/wireguard/wgctrl"]
} -> "github.com/mdlayher/genetlink"
}
subgraph cluster_kobject_uevent {
label = "NETLINK_KOBJECT_UEVENT";
{
"github.com/mdlayher/kobject" [URL="https://github.com/mdlayher/kobject"]
} -> "github.com/mdlayher/netlink"
}
subgraph cluster_netfilter {
label = "NETLINK_NETFILTER (nfnetlink)";
{
"github.com/florianl/go-conntrack" [URL="https://github.com/florianl/go-conntrack"]
"github.com/florianl/go-nflog" [URL="https://github.com/florianl/go-nflog"]
"github.com/florianl/go-nfqueue" [URL="https://github.com/florianl/go-nfqueue"]
"github.com/google/nftables" [URL="https://github.com/google/nftables"]
"github.com/ti-mo/netfilter" [URL="https://github.com/ti-mo/netfilter"]
} -> "github.com/mdlayher/netlink"
{
"github.com/ti-mo/conntrack" [URL="https://github.com/ti-mo/conntrack"]
} -> "github.com/ti-mo/netfilter"
}
subgraph cluster_route {
label = "NETLINK_ROUTE (rtnetlink)";
{
"github.com/ema/qdisc" [URL="https://github.com/ema/qdisc"]
"github.com/florianl/go-tc" [URL="https://github.com/florianl/go-tc"]
"github.com/jsimonetti/rtnetlink" [URL="https://github.com/jsimonetti/rtnetlink"]
"gitlab.com/mergetb/tech/rtnl" [URL="https://gitlab.com/mergetb/tech/rtnl"]
} -> "github.com/mdlayher/netlink"
}
subgraph cluster_w1 {
label = "NETLINK_W1";
{
"github.com/SpComb/go-onewire" [URL="https://github.com/SpComb/go-onewire"]
} -> "github.com/mdlayher/netlink"
}
}

451
vendor/github.com/mdlayher/netlink/netlink.svg generated vendored Normal file
View File

@ -0,0 +1,451 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<!-- Generated by graphviz version 2.43.0 (0)
-->
<!-- Title: %3 Pages: 1 -->
<svg width="1148pt" height="1515pt"
viewBox="0.00 0.00 1148.01 1515.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1511)">
<title>%3</title>
<polygon fill="white" stroke="transparent" points="-4,4 -4,-1511 1144.01,-1511 1144.01,4 -4,4"/>
<g id="clust1" class="cluster">
<title>cluster_netlink</title>
<polygon fill="none" stroke="black" points="826.13,-417 826.13,-469 1132.01,-469 1132.01,-417 826.13,-417"/>
</g>
<g id="clust2" class="cluster">
<title>cluster_connector</title>
<polygon fill="none" stroke="black" points="439.16,-1424 439.16,-1499 806.13,-1499 806.13,-1424 439.16,-1424"/>
<text text-anchor="middle" x="622.65" y="-1483.8" font-family="Times,serif" font-size="14.00">NETLINK_CONNECTOR</text>
</g>
<g id="clust4" class="cluster">
<title>cluster_crypto</title>
<polygon fill="none" stroke="black" points="463.21,-1341 463.21,-1416 782.09,-1416 782.09,-1341 463.21,-1341"/>
<text text-anchor="middle" x="622.65" y="-1400.8" font-family="Times,serif" font-size="14.00">NETLINK_CRYPTO</text>
</g>
<g id="clust6" class="cluster">
<title>cluster_generic</title>
<polygon fill="none" stroke="black" points="8,-718 8,-1333 786.64,-1333 786.64,-718 8,-718"/>
<text text-anchor="middle" x="397.32" y="-1317.8" font-family="Times,serif" font-size="14.00">NETLINK_GENERIC (genetlink)</text>
</g>
<g id="clust8" class="cluster">
<title>cluster_kobject_uevent</title>
<polygon fill="none" stroke="black" points="468.41,-635 468.41,-710 776.89,-710 776.89,-635 468.41,-635"/>
<text text-anchor="middle" x="622.65" y="-694.8" font-family="Times,serif" font-size="14.00">NETLINK_KOBJECT_UEVENT</text>
</g>
<g id="clust10" class="cluster">
<title>cluster_netfilter</title>
<polygon fill="none" stroke="black" points="44.4,-336 44.4,-627 760.64,-627 760.64,-336 44.4,-336"/>
<text text-anchor="middle" x="402.52" y="-611.8" font-family="Times,serif" font-size="14.00">NETLINK_NETFILTER (nfnetlink)</text>
</g>
<g id="clust13" class="cluster">
<title>cluster_route</title>
<polygon fill="none" stroke="black" points="458.66,-91 458.66,-328 786.64,-328 786.64,-91 458.66,-91"/>
<text text-anchor="middle" x="622.65" y="-312.8" font-family="Times,serif" font-size="14.00">NETLINK_ROUTE (rtnetlink)</text>
</g>
<g id="clust15" class="cluster">
<title>cluster_w1</title>
<polygon fill="none" stroke="black" points="455.41,-8 455.41,-83 789.89,-83 789.89,-8 455.41,-8"/>
<text text-anchor="middle" x="622.65" y="-67.8" font-family="Times,serif" font-size="14.00">NETLINK_W1</text>
</g>
<!-- github.com/mdlayher/netlink -->
<g id="node1" class="node">
<title>github.com/mdlayher/netlink</title>
<g id="a_node1"><a xlink:href="https://github.com/mdlayher/netlink" xlink:title="github.com/mdlayher/netlink">
<ellipse fill="none" stroke="black" cx="979.07" cy="-443" rx="144.87" ry="18"/>
<text text-anchor="middle" x="979.07" y="-439.3" font-family="Times,serif" font-size="14.00">github.com/mdlayher/netlink</text>
</a>
</g>
</g>
<!-- github.com/fearful&#45;symmetry/garlic -->
<g id="node2" class="node">
<title>github.com/fearful&#45;symmetry/garlic</title>
<g id="a_node2"><a xlink:href="https://github.com/fearful-symmetry/garlic" xlink:title="github.com/fearful&#45;symmetry/garlic">
<ellipse fill="none" stroke="black" cx="622.65" cy="-1450" rx="175.47" ry="18"/>
<text text-anchor="middle" x="622.65" y="-1446.3" font-family="Times,serif" font-size="14.00">github.com/fearful&#45;symmetry/garlic</text>
</a>
</g>
</g>
<!-- github.com/fearful&#45;symmetry/garlic&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge1" class="edge">
<title>github.com/fearful&#45;symmetry/garlic&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M775.09,-1441.05C786.54,-1435.8 797.12,-1428.92 806.13,-1420 945.48,-1282.21 972.95,-617.34 977.34,-471.37"/>
<polygon fill="black" stroke="black" points="980.84,-471.13 977.63,-461.04 973.85,-470.93 980.84,-471.13"/>
</g>
<!-- github.com/mdlayher/cryptonl -->
<g id="node3" class="node">
<title>github.com/mdlayher/cryptonl</title>
<g id="a_node3"><a xlink:href="https://github.com/mdlayher/cryptonl" xlink:title="github.com/mdlayher/cryptonl">
<ellipse fill="none" stroke="black" cx="622.65" cy="-1367" rx="151.37" ry="18"/>
<text text-anchor="middle" x="622.65" y="-1363.3" font-family="Times,serif" font-size="14.00">github.com/mdlayher/cryptonl</text>
</a>
</g>
</g>
<!-- github.com/mdlayher/cryptonl&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge2" class="edge">
<title>github.com/mdlayher/cryptonl&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M766.97,-1361.23C781.49,-1355.68 794.97,-1347.86 806.13,-1337 934.19,-1212.36 970.17,-610.39 976.82,-471.59"/>
<polygon fill="black" stroke="black" points="980.33,-471.5 977.3,-461.35 973.34,-471.18 980.33,-471.5"/>
</g>
<!-- github.com/mdlayher/genetlink -->
<g id="node4" class="node">
<title>github.com/mdlayher/genetlink</title>
<g id="a_node4"><a xlink:href="https://github.com/mdlayher/genetlink" xlink:title="github.com/mdlayher/genetlink">
<ellipse fill="none" stroke="black" cx="622.65" cy="-987" rx="155.97" ry="18"/>
<text text-anchor="middle" x="622.65" y="-983.3" font-family="Times,serif" font-size="14.00">github.com/mdlayher/genetlink</text>
</a>
</g>
</g>
<!-- github.com/mdlayher/genetlink&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge3" class="edge">
<title>github.com/mdlayher/genetlink&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M635.97,-968.89C665.62,-925.25 743.13,-810.74 806.13,-714 863.98,-625.18 930.8,-518.71 961.32,-469.88"/>
<polygon fill="black" stroke="black" points="964.42,-471.51 966.75,-461.17 958.48,-467.8 964.42,-471.51"/>
</g>
<!-- github.com/axatrax/l2tp -->
<g id="node5" class="node">
<title>github.com/axatrax/l2tp</title>
<g id="a_node5"><a xlink:href="https://github.com/axatrax/l2tp" xlink:title="github.com/axatrax/l2tp">
<ellipse fill="none" stroke="black" cx="213.58" cy="-1284" rx="122.38" ry="18"/>
<text text-anchor="middle" x="213.58" y="-1280.3" font-family="Times,serif" font-size="14.00">github.com/axatrax/l2tp</text>
</a>
</g>
</g>
<!-- github.com/axatrax/l2tp&#45;&gt;github.com/mdlayher/genetlink -->
<g id="edge4" class="edge">
<title>github.com/axatrax/l2tp&#45;&gt;github.com/mdlayher/genetlink</title>
<path fill="none" stroke="black" d="M335.31,-1282.19C361.59,-1277.86 388.33,-1270.15 411.16,-1257 512.47,-1198.65 582.36,-1070.58 609.3,-1014.37"/>
<polygon fill="black" stroke="black" points="612.6,-1015.59 613.69,-1005.05 606.26,-1012.61 612.6,-1015.59"/>
</g>
<!-- github.com/digitalocean/go&#45;openvswitch -->
<g id="node6" class="node">
<title>github.com/digitalocean/go&#45;openvswitch</title>
<g id="a_node6"><a xlink:href="https://github.com/digitalocean/go-openvswitch" xlink:title="github.com/digitalocean/go&#45;openvswitch">
<ellipse fill="none" stroke="black" cx="213.58" cy="-1230" rx="197.66" ry="18"/>
<text text-anchor="middle" x="213.58" y="-1226.3" font-family="Times,serif" font-size="14.00">github.com/digitalocean/go&#45;openvswitch</text>
</a>
</g>
</g>
<!-- github.com/digitalocean/go&#45;openvswitch&#45;&gt;github.com/mdlayher/genetlink -->
<g id="edge5" class="edge">
<title>github.com/digitalocean/go&#45;openvswitch&#45;&gt;github.com/mdlayher/genetlink</title>
<path fill="none" stroke="black" d="M369.18,-1218.82C383.78,-1214.81 398.02,-1209.64 411.16,-1203 500.09,-1158.12 573.09,-1060.66 604.72,-1013.55"/>
<polygon fill="black" stroke="black" points="607.72,-1015.37 610.32,-1005.1 601.88,-1011.51 607.72,-1015.37"/>
</g>
<!-- github.com/mdlayher/devlink -->
<g id="node7" class="node">
<title>github.com/mdlayher/devlink</title>
<g id="a_node7"><a xlink:href="https://github.com/mdlayher/devlink" xlink:title="github.com/mdlayher/devlink">
<ellipse fill="none" stroke="black" cx="213.58" cy="-1176" rx="146.47" ry="18"/>
<text text-anchor="middle" x="213.58" y="-1172.3" font-family="Times,serif" font-size="14.00">github.com/mdlayher/devlink</text>
</a>
</g>
</g>
<!-- github.com/mdlayher/devlink&#45;&gt;github.com/mdlayher/genetlink -->
<g id="edge6" class="edge">
<title>github.com/mdlayher/devlink&#45;&gt;github.com/mdlayher/genetlink</title>
<path fill="none" stroke="black" d="M344.3,-1167.91C367.14,-1163.77 390.27,-1157.74 411.16,-1149 488.01,-1116.85 561.22,-1049.4 597.65,-1012.49"/>
<polygon fill="black" stroke="black" points="600.38,-1014.7 604.86,-1005.1 595.37,-1009.81 600.38,-1014.7"/>
</g>
<!-- github.com/mdlayher/ethtool -->
<g id="node8" class="node">
<title>github.com/mdlayher/ethtool</title>
<g id="a_node8"><a xlink:href="https://github.com/mdlayher/ethtool" xlink:title="github.com/mdlayher/ethtool">
<ellipse fill="none" stroke="black" cx="213.58" cy="-1122" rx="145.67" ry="18"/>
<text text-anchor="middle" x="213.58" y="-1118.3" font-family="Times,serif" font-size="14.00">github.com/mdlayher/ethtool</text>
</a>
</g>
</g>
<!-- github.com/mdlayher/ethtool&#45;&gt;github.com/mdlayher/genetlink -->
<g id="edge7" class="edge">
<title>github.com/mdlayher/ethtool&#45;&gt;github.com/mdlayher/genetlink</title>
<path fill="none" stroke="black" d="M335.04,-1112C360.53,-1108.08 387.02,-1102.61 411.16,-1095 475.99,-1074.57 544.96,-1035.42 585.63,-1010.3"/>
<polygon fill="black" stroke="black" points="587.73,-1013.11 594.36,-1004.85 584.02,-1007.17 587.73,-1013.11"/>
</g>
<!-- github.com/mdlayher/quota -->
<g id="node9" class="node">
<title>github.com/mdlayher/quota</title>
<g id="a_node9"><a xlink:href="https://github.com/mdlayher/quota" xlink:title="github.com/mdlayher/quota">
<ellipse fill="none" stroke="black" cx="213.58" cy="-1068" rx="139.18" ry="18"/>
<text text-anchor="middle" x="213.58" y="-1064.3" font-family="Times,serif" font-size="14.00">github.com/mdlayher/quota</text>
</a>
</g>
</g>
<!-- github.com/mdlayher/quota&#45;&gt;github.com/mdlayher/genetlink -->
<g id="edge8" class="edge">
<title>github.com/mdlayher/quota&#45;&gt;github.com/mdlayher/genetlink</title>
<path fill="none" stroke="black" d="M318,-1056.06C348.15,-1051.97 381.08,-1046.91 411.16,-1041 459.39,-1031.54 513.03,-1017.72 554.09,-1006.43"/>
<polygon fill="black" stroke="black" points="555.28,-1009.74 563.98,-1003.7 553.41,-1002.99 555.28,-1009.74"/>
</g>
<!-- github.com/mdlayher/taskstats -->
<g id="node10" class="node">
<title>github.com/mdlayher/taskstats</title>
<g id="a_node10"><a xlink:href="https://github.com/mdlayher/taskstats" xlink:title="github.com/mdlayher/taskstats">
<ellipse fill="none" stroke="black" cx="213.58" cy="-1014" rx="155.17" ry="18"/>
<text text-anchor="middle" x="213.58" y="-1010.3" font-family="Times,serif" font-size="14.00">github.com/mdlayher/taskstats</text>
</a>
</g>
</g>
<!-- github.com/mdlayher/taskstats&#45;&gt;github.com/mdlayher/genetlink -->
<g id="edge9" class="edge">
<title>github.com/mdlayher/taskstats&#45;&gt;github.com/mdlayher/genetlink</title>
<path fill="none" stroke="black" d="M348.72,-1005.1C389.7,-1002.38 434.95,-999.38 476.62,-996.62"/>
<polygon fill="black" stroke="black" points="477.1,-1000.1 486.85,-995.94 476.64,-993.11 477.1,-1000.1"/>
</g>
<!-- github.com/mdlayher/wifi -->
<g id="node11" class="node">
<title>github.com/mdlayher/wifi</title>
<g id="a_node11"><a xlink:href="https://github.com/mdlayher/wifi" xlink:title="github.com/mdlayher/wifi">
<ellipse fill="none" stroke="black" cx="213.58" cy="-960" rx="129.18" ry="18"/>
<text text-anchor="middle" x="213.58" y="-956.3" font-family="Times,serif" font-size="14.00">github.com/mdlayher/wifi</text>
</a>
</g>
</g>
<!-- github.com/mdlayher/wifi&#45;&gt;github.com/mdlayher/genetlink -->
<g id="edge10" class="edge">
<title>github.com/mdlayher/wifi&#45;&gt;github.com/mdlayher/genetlink</title>
<path fill="none" stroke="black" d="M330.69,-967.7C376.21,-970.72 428.94,-974.22 476.9,-977.4"/>
<polygon fill="black" stroke="black" points="476.85,-980.9 487.06,-978.07 477.31,-973.92 476.85,-980.9"/>
</g>
<!-- github.com/Merovius/nbd -->
<g id="node12" class="node">
<title>github.com/Merovius/nbd</title>
<g id="a_node12"><a xlink:href="https://github.com/Merovius/nbd" xlink:title="github.com/Merovius/nbd">
<ellipse fill="none" stroke="black" cx="213.58" cy="-906" rx="129.98" ry="18"/>
<text text-anchor="middle" x="213.58" y="-902.3" font-family="Times,serif" font-size="14.00">github.com/Merovius/nbd</text>
</a>
</g>
</g>
<!-- github.com/Merovius/nbd&#45;&gt;github.com/mdlayher/genetlink -->
<g id="edge11" class="edge">
<title>github.com/Merovius/nbd&#45;&gt;github.com/mdlayher/genetlink</title>
<path fill="none" stroke="black" d="M314.2,-917.43C345.37,-921.61 379.81,-926.85 411.16,-933 459.39,-942.46 513.03,-956.28 554.09,-967.57"/>
<polygon fill="black" stroke="black" points="553.41,-971.01 563.98,-970.3 555.28,-964.26 553.41,-971.01"/>
</g>
<!-- github.com/rtr7/router7 -->
<g id="node13" class="node">
<title>github.com/rtr7/router7</title>
<g id="a_node13"><a xlink:href="https://github.com/rtr7/router7" xlink:title="github.com/rtr7/router7">
<ellipse fill="none" stroke="black" cx="213.58" cy="-852" rx="122.38" ry="18"/>
<text text-anchor="middle" x="213.58" y="-848.3" font-family="Times,serif" font-size="14.00">github.com/rtr7/router7</text>
</a>
</g>
</g>
<!-- github.com/rtr7/router7&#45;&gt;github.com/mdlayher/genetlink -->
<g id="edge12" class="edge">
<title>github.com/rtr7/router7&#45;&gt;github.com/mdlayher/genetlink</title>
<path fill="none" stroke="black" d="M322.65,-860.2C351.82,-864.2 383.07,-870.15 411.16,-879 475.99,-899.43 544.96,-938.58 585.63,-963.7"/>
<polygon fill="black" stroke="black" points="584.02,-966.83 594.36,-969.15 587.73,-960.89 584.02,-966.83"/>
</g>
<!-- github.com/u&#45;root/u&#45;bmc -->
<g id="node14" class="node">
<title>github.com/u&#45;root/u&#45;bmc</title>
<g id="a_node14"><a xlink:href="https://github.com/u-root/u-bmc" xlink:title="github.com/u&#45;root/u&#45;bmc">
<ellipse fill="none" stroke="black" cx="213.58" cy="-798" rx="124.58" ry="18"/>
<text text-anchor="middle" x="213.58" y="-794.3" font-family="Times,serif" font-size="14.00">github.com/u&#45;root/u&#45;bmc</text>
</a>
</g>
</g>
<!-- github.com/u&#45;root/u&#45;bmc&#45;&gt;github.com/mdlayher/genetlink -->
<g id="edge13" class="edge">
<title>github.com/u&#45;root/u&#45;bmc&#45;&gt;github.com/mdlayher/genetlink</title>
<path fill="none" stroke="black" d="M331.58,-803.97C358.44,-808.07 386.38,-814.63 411.16,-825 488.01,-857.15 561.22,-924.6 597.65,-961.51"/>
<polygon fill="black" stroke="black" points="595.37,-964.19 604.86,-968.9 600.38,-959.3 595.37,-964.19"/>
</g>
<!-- golang.zx2c4.com/wireguard/wgctrl -->
<g id="node15" class="node">
<title>golang.zx2c4.com/wireguard/wgctrl</title>
<g id="a_node15"><a xlink:href="https://golang.zx2c4.com/wireguard/wgctrl" xlink:title="golang.zx2c4.com/wireguard/wgctrl">
<ellipse fill="none" stroke="black" cx="213.58" cy="-744" rx="176.57" ry="18"/>
<text text-anchor="middle" x="213.58" y="-740.3" font-family="Times,serif" font-size="14.00">golang.zx2c4.com/wireguard/wgctrl</text>
</a>
</g>
</g>
<!-- golang.zx2c4.com/wireguard/wgctrl&#45;&gt;github.com/mdlayher/genetlink -->
<g id="edge14" class="edge">
<title>golang.zx2c4.com/wireguard/wgctrl&#45;&gt;github.com/mdlayher/genetlink</title>
<path fill="none" stroke="black" d="M363.27,-753.62C379.93,-757.83 396.25,-763.47 411.16,-771 500.09,-815.88 573.09,-913.34 604.72,-960.45"/>
<polygon fill="black" stroke="black" points="601.88,-962.49 610.32,-968.9 607.72,-958.63 601.88,-962.49"/>
</g>
<!-- github.com/mdlayher/kobject -->
<g id="node16" class="node">
<title>github.com/mdlayher/kobject</title>
<g id="a_node16"><a xlink:href="https://github.com/mdlayher/kobject" xlink:title="github.com/mdlayher/kobject">
<ellipse fill="none" stroke="black" cx="622.65" cy="-661" rx="146.47" ry="18"/>
<text text-anchor="middle" x="622.65" y="-657.3" font-family="Times,serif" font-size="14.00">github.com/mdlayher/kobject</text>
</a>
</g>
</g>
<!-- github.com/mdlayher/kobject&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge15" class="edge">
<title>github.com/mdlayher/kobject&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M749.05,-651.76C768.9,-647.13 788.63,-640.48 806.13,-631 877.87,-592.14 935.5,-511.67 962.1,-469.73"/>
<polygon fill="black" stroke="black" points="965.13,-471.49 967.45,-461.15 959.19,-467.79 965.13,-471.49"/>
</g>
<!-- github.com/florianl/go&#45;conntrack -->
<g id="node17" class="node">
<title>github.com/florianl/go&#45;conntrack</title>
<g id="a_node17"><a xlink:href="https://github.com/florianl/go-conntrack" xlink:title="github.com/florianl/go&#45;conntrack">
<ellipse fill="none" stroke="black" cx="213.58" cy="-524" rx="161.37" ry="18"/>
<text text-anchor="middle" x="213.58" y="-520.3" font-family="Times,serif" font-size="14.00">github.com/florianl/go&#45;conntrack</text>
</a>
</g>
</g>
<!-- github.com/florianl/go&#45;conntrack&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge16" class="edge">
<title>github.com/florianl/go&#45;conntrack&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M346.97,-513.88C467.01,-504.18 648.72,-488.2 806.13,-469 830.23,-466.06 856.16,-462.4 880.39,-458.77"/>
<polygon fill="black" stroke="black" points="880.93,-462.22 890.29,-457.27 879.88,-455.3 880.93,-462.22"/>
</g>
<!-- github.com/florianl/go&#45;nflog -->
<g id="node18" class="node">
<title>github.com/florianl/go&#45;nflog</title>
<g id="a_node18"><a xlink:href="https://github.com/florianl/go-nflog" xlink:title="github.com/florianl/go&#45;nflog">
<ellipse fill="none" stroke="black" cx="213.58" cy="-470" rx="138.38" ry="18"/>
<text text-anchor="middle" x="213.58" y="-466.3" font-family="Times,serif" font-size="14.00">github.com/florianl/go&#45;nflog</text>
</a>
</g>
</g>
<!-- github.com/florianl/go&#45;nflog&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge17" class="edge">
<title>github.com/florianl/go&#45;nflog&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M347.57,-465.3C481.99,-460.54 689.29,-453.21 829.15,-448.27"/>
<polygon fill="black" stroke="black" points="829.65,-451.75 839.52,-447.9 829.4,-444.76 829.65,-451.75"/>
</g>
<!-- github.com/florianl/go&#45;nfqueue -->
<g id="node19" class="node">
<title>github.com/florianl/go&#45;nfqueue</title>
<g id="a_node19"><a xlink:href="https://github.com/florianl/go-nfqueue" xlink:title="github.com/florianl/go&#45;nfqueue">
<ellipse fill="none" stroke="black" cx="213.58" cy="-416" rx="153.27" ry="18"/>
<text text-anchor="middle" x="213.58" y="-412.3" font-family="Times,serif" font-size="14.00">github.com/florianl/go&#45;nfqueue</text>
</a>
</g>
</g>
<!-- github.com/florianl/go&#45;nfqueue&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge18" class="edge">
<title>github.com/florianl/go&#45;nfqueue&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M360.89,-421.17C495.43,-425.93 693.98,-432.95 829.3,-437.74"/>
<polygon fill="black" stroke="black" points="829.22,-441.24 839.34,-438.09 829.47,-434.24 829.22,-441.24"/>
</g>
<!-- github.com/google/nftables -->
<g id="node20" class="node">
<title>github.com/google/nftables</title>
<g id="a_node20"><a xlink:href="https://github.com/google/nftables" xlink:title="github.com/google/nftables">
<ellipse fill="none" stroke="black" cx="213.58" cy="-362" rx="137.28" ry="18"/>
<text text-anchor="middle" x="213.58" y="-358.3" font-family="Times,serif" font-size="14.00">github.com/google/nftables</text>
</a>
</g>
</g>
<!-- github.com/google/nftables&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge19" class="edge">
<title>github.com/google/nftables&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M329.52,-371.61C448.83,-381.94 640.78,-399.5 806.13,-419 828.43,-421.63 852.3,-424.76 874.95,-427.88"/>
<polygon fill="black" stroke="black" points="874.53,-431.35 884.92,-429.26 875.49,-424.42 874.53,-431.35"/>
</g>
<!-- github.com/ti&#45;mo/netfilter -->
<g id="node21" class="node">
<title>github.com/ti&#45;mo/netfilter</title>
<g id="a_node21"><a xlink:href="https://github.com/ti-mo/netfilter" xlink:title="github.com/ti&#45;mo/netfilter">
<ellipse fill="none" stroke="black" cx="622.65" cy="-572" rx="129.98" ry="18"/>
<text text-anchor="middle" x="622.65" y="-568.3" font-family="Times,serif" font-size="14.00">github.com/ti&#45;mo/netfilter</text>
</a>
</g>
</g>
<!-- github.com/ti&#45;mo/netfilter&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge20" class="edge">
<title>github.com/ti&#45;mo/netfilter&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M703.9,-557.91C736.18,-551.12 773.43,-541.84 806.13,-530 854.07,-512.65 905.86,-485.13 939.93,-465.69"/>
<polygon fill="black" stroke="black" points="941.7,-468.71 948.62,-460.69 938.21,-462.64 941.7,-468.71"/>
</g>
<!-- github.com/ti&#45;mo/conntrack -->
<g id="node22" class="node">
<title>github.com/ti&#45;mo/conntrack</title>
<g id="a_node22"><a xlink:href="https://github.com/ti-mo/conntrack" xlink:title="github.com/ti&#45;mo/conntrack">
<ellipse fill="none" stroke="black" cx="213.58" cy="-578" rx="138.38" ry="18"/>
<text text-anchor="middle" x="213.58" y="-574.3" font-family="Times,serif" font-size="14.00">github.com/ti&#45;mo/conntrack</text>
</a>
</g>
</g>
<!-- github.com/ti&#45;mo/conntrack&#45;&gt;github.com/ti&#45;mo/netfilter -->
<g id="edge21" class="edge">
<title>github.com/ti&#45;mo/conntrack&#45;&gt;github.com/ti&#45;mo/netfilter</title>
<path fill="none" stroke="black" d="M351.26,-575.99C393.71,-575.36 440.52,-574.67 483.09,-574.04"/>
<polygon fill="black" stroke="black" points="483.29,-577.54 493.24,-573.89 483.19,-570.54 483.29,-577.54"/>
</g>
<!-- github.com/ema/qdisc -->
<g id="node23" class="node">
<title>github.com/ema/qdisc</title>
<g id="a_node23"><a xlink:href="https://github.com/ema/qdisc" xlink:title="github.com/ema/qdisc">
<ellipse fill="none" stroke="black" cx="622.65" cy="-279" rx="112.38" ry="18"/>
<text text-anchor="middle" x="622.65" y="-275.3" font-family="Times,serif" font-size="14.00">github.com/ema/qdisc</text>
</a>
</g>
</g>
<!-- github.com/ema/qdisc&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge22" class="edge">
<title>github.com/ema/qdisc&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M690.34,-293.42C725.65,-302.27 769.22,-315.16 806.13,-332 858.51,-355.89 913.46,-394 946.74,-418.77"/>
<polygon fill="black" stroke="black" points="945.03,-421.86 955.13,-425.06 949.24,-416.26 945.03,-421.86"/>
</g>
<!-- github.com/florianl/go&#45;tc -->
<g id="node24" class="node">
<title>github.com/florianl/go&#45;tc</title>
<g id="a_node24"><a xlink:href="https://github.com/florianl/go-tc" xlink:title="github.com/florianl/go&#45;tc">
<ellipse fill="none" stroke="black" cx="622.65" cy="-225" rx="124.28" ry="18"/>
<text text-anchor="middle" x="622.65" y="-221.3" font-family="Times,serif" font-size="14.00">github.com/florianl/go&#45;tc</text>
</a>
</g>
</g>
<!-- github.com/florianl/go&#45;tc&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge23" class="edge">
<title>github.com/florianl/go&#45;tc&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M741.38,-230.26C763.78,-234.58 786.38,-241.4 806.13,-252 879.02,-291.1 936.52,-373.74 962.68,-416.35"/>
<polygon fill="black" stroke="black" points="959.77,-418.3 967.93,-425.05 965.76,-414.68 959.77,-418.3"/>
</g>
<!-- github.com/jsimonetti/rtnetlink -->
<g id="node25" class="node">
<title>github.com/jsimonetti/rtnetlink</title>
<g id="a_node25"><a xlink:href="https://github.com/jsimonetti/rtnetlink" xlink:title="github.com/jsimonetti/rtnetlink">
<ellipse fill="none" stroke="black" cx="622.65" cy="-171" rx="155.97" ry="18"/>
<text text-anchor="middle" x="622.65" y="-167.3" font-family="Times,serif" font-size="14.00">github.com/jsimonetti/rtnetlink</text>
</a>
</g>
</g>
<!-- github.com/jsimonetti/rtnetlink&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge24" class="edge">
<title>github.com/jsimonetti/rtnetlink&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M761.97,-179.11C777.55,-183.56 792.64,-189.67 806.13,-198 891.51,-250.68 945.9,-363.24 967.48,-415.41"/>
<polygon fill="black" stroke="black" points="964.25,-416.75 971.25,-424.71 970.74,-414.12 964.25,-416.75"/>
</g>
<!-- gitlab.com/mergetb/tech/rtnl -->
<g id="node26" class="node">
<title>gitlab.com/mergetb/tech/rtnl</title>
<g id="a_node26"><a xlink:href="https://gitlab.com/mergetb/tech/rtnl" xlink:title="gitlab.com/mergetb/tech/rtnl">
<ellipse fill="none" stroke="black" cx="622.65" cy="-117" rx="144.87" ry="18"/>
<text text-anchor="middle" x="622.65" y="-113.3" font-family="Times,serif" font-size="14.00">gitlab.com/mergetb/tech/rtnl</text>
</a>
</g>
</g>
<!-- gitlab.com/mergetb/tech/rtnl&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge25" class="edge">
<title>gitlab.com/mergetb/tech/rtnl&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M759.44,-122.98C776.06,-127.69 792.1,-134.45 806.13,-144 904.62,-211.07 953.88,-354.53 970.94,-415.1"/>
<polygon fill="black" stroke="black" points="967.56,-416.02 973.58,-424.75 974.32,-414.18 967.56,-416.02"/>
</g>
<!-- github.com/SpComb/go&#45;onewire -->
<g id="node27" class="node">
<title>github.com/SpComb/go&#45;onewire</title>
<g id="a_node27"><a xlink:href="https://github.com/SpComb/go-onewire" xlink:title="github.com/SpComb/go&#45;onewire">
<ellipse fill="none" stroke="black" cx="622.65" cy="-34" rx="159.47" ry="18"/>
<text text-anchor="middle" x="622.65" y="-30.3" font-family="Times,serif" font-size="14.00">github.com/SpComb/go&#45;onewire</text>
</a>
</g>
</g>
<!-- github.com/SpComb/go&#45;onewire&#45;&gt;github.com/mdlayher/netlink -->
<g id="edge26" class="edge">
<title>github.com/SpComb/go&#45;onewire&#45;&gt;github.com/mdlayher/netlink</title>
<path fill="none" stroke="black" d="M725.73,-47.78C754.17,-55.59 783.46,-67.83 806.13,-87 912.36,-176.8 958.15,-347.52 972.62,-414.77"/>
<polygon fill="black" stroke="black" points="969.21,-415.57 974.68,-424.64 976.07,-414.14 969.21,-415.57"/>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 26 KiB

93
vendor/github.com/mdlayher/netlink/netns_linux.go generated vendored Normal file
View File

@ -0,0 +1,93 @@
//+build linux
package netlink
import (
"errors"
"fmt"
"os"
"golang.org/x/sys/unix"
)
// errNetNSDisabled is returned when network namespaces are unavailable on
// a given system.
var errNetNSDisabled = errors.New("netlink: network namespaces are not enabled on this system")
// A netNS is a handle that can manipulate network namespaces.
//
// Operations performed on a netNS must use runtime.LockOSThread before
// manipulating any network namespaces.
type netNS struct {
// The handle to a network namespace.
f *os.File
// Indicates if network namespaces are disabled on this system, and thus
// operations should become a no-op or return errors.
disabled bool
}
// threadNetNS constructs a netNS using the network namespace of the calling
// thread. If the namespace is not the default namespace, runtime.LockOSThread
// should be invoked first.
func threadNetNS() (*netNS, error) {
return fileNetNS(fmt.Sprintf("/proc/self/task/%d/ns/net", unix.Gettid()))
}
// fileNetNS opens file and creates a netNS. fileNetNS should only be called
// directly in tests.
func fileNetNS(file string) (*netNS, error) {
f, err := os.Open(file)
switch {
case err == nil:
return &netNS{f: f}, nil
case os.IsNotExist(err):
// Network namespaces are not enabled on this system. Use this signal
// to return errors elsewhere if the caller explicitly asks for a
// network namespace to be set.
return &netNS{disabled: true}, nil
default:
return nil, err
}
}
// Close releases the handle to a network namespace.
func (n *netNS) Close() error {
return n.do(func() error {
return n.f.Close()
})
}
// FD returns a file descriptor which represents the network namespace.
func (n *netNS) FD() int {
if n.disabled {
// No reasonable file descriptor value in this case, so specify a
// non-existent one.
return -1
}
return int(n.f.Fd())
}
// Restore restores the original network namespace for the calling thread.
func (n *netNS) Restore() error {
return n.do(func() error {
return n.Set(n.FD())
})
}
// Set sets a new network namespace for the current thread using fd.
func (n *netNS) Set(fd int) error {
return n.do(func() error {
return os.NewSyscallError("setns", unix.Setns(fd, unix.CLONE_NEWNET))
})
}
// do runs fn if network namespaces are enabled on this system.
func (n *netNS) do(fn func() error) error {
if n.disabled {
return errNetNSDisabled
}
return fn()
}

15
vendor/github.com/mdlayher/netlink/nlenc/doc.go generated vendored Normal file
View File

@ -0,0 +1,15 @@
// Package nlenc implements encoding and decoding functions for netlink
// messages and attributes.
package nlenc
import (
"encoding/binary"
"github.com/josharian/native"
)
// NativeEndian returns the native byte order of this system.
func NativeEndian() binary.ByteOrder {
// TODO(mdlayher): consider deprecating and removing this function for v2.
return native.Endian
}

150
vendor/github.com/mdlayher/netlink/nlenc/int.go generated vendored Normal file
View File

@ -0,0 +1,150 @@
package nlenc
import (
"fmt"
"unsafe"
)
// PutUint8 encodes a uint8 into b.
// If b is not exactly 1 byte in length, PutUint8 will panic.
func PutUint8(b []byte, v uint8) {
if l := len(b); l != 1 {
panic(fmt.Sprintf("PutUint8: unexpected byte slice length: %d", l))
}
b[0] = v
}
// PutUint16 encodes a uint16 into b using the host machine's native endianness.
// If b is not exactly 2 bytes in length, PutUint16 will panic.
func PutUint16(b []byte, v uint16) {
if l := len(b); l != 2 {
panic(fmt.Sprintf("PutUint16: unexpected byte slice length: %d", l))
}
*(*uint16)(unsafe.Pointer(&b[0])) = v
}
// PutUint32 encodes a uint32 into b using the host machine's native endianness.
// If b is not exactly 4 bytes in length, PutUint32 will panic.
func PutUint32(b []byte, v uint32) {
if l := len(b); l != 4 {
panic(fmt.Sprintf("PutUint32: unexpected byte slice length: %d", l))
}
*(*uint32)(unsafe.Pointer(&b[0])) = v
}
// PutUint64 encodes a uint64 into b using the host machine's native endianness.
// If b is not exactly 8 bytes in length, PutUint64 will panic.
func PutUint64(b []byte, v uint64) {
if l := len(b); l != 8 {
panic(fmt.Sprintf("PutUint64: unexpected byte slice length: %d", l))
}
*(*uint64)(unsafe.Pointer(&b[0])) = v
}
// PutInt32 encodes a int32 into b using the host machine's native endianness.
// If b is not exactly 4 bytes in length, PutInt32 will panic.
func PutInt32(b []byte, v int32) {
if l := len(b); l != 4 {
panic(fmt.Sprintf("PutInt32: unexpected byte slice length: %d", l))
}
*(*int32)(unsafe.Pointer(&b[0])) = v
}
// Uint8 decodes a uint8 from b.
// If b is not exactly 1 byte in length, Uint8 will panic.
func Uint8(b []byte) uint8 {
if l := len(b); l != 1 {
panic(fmt.Sprintf("Uint8: unexpected byte slice length: %d", l))
}
return b[0]
}
// Uint16 decodes a uint16 from b using the host machine's native endianness.
// If b is not exactly 2 bytes in length, Uint16 will panic.
func Uint16(b []byte) uint16 {
if l := len(b); l != 2 {
panic(fmt.Sprintf("Uint16: unexpected byte slice length: %d", l))
}
return *(*uint16)(unsafe.Pointer(&b[0]))
}
// Uint32 decodes a uint32 from b using the host machine's native endianness.
// If b is not exactly 4 bytes in length, Uint32 will panic.
func Uint32(b []byte) uint32 {
if l := len(b); l != 4 {
panic(fmt.Sprintf("Uint32: unexpected byte slice length: %d", l))
}
return *(*uint32)(unsafe.Pointer(&b[0]))
}
// Uint64 decodes a uint64 from b using the host machine's native endianness.
// If b is not exactly 8 bytes in length, Uint64 will panic.
func Uint64(b []byte) uint64 {
if l := len(b); l != 8 {
panic(fmt.Sprintf("Uint64: unexpected byte slice length: %d", l))
}
return *(*uint64)(unsafe.Pointer(&b[0]))
}
// Int32 decodes an int32 from b using the host machine's native endianness.
// If b is not exactly 4 bytes in length, Int32 will panic.
func Int32(b []byte) int32 {
if l := len(b); l != 4 {
panic(fmt.Sprintf("Int32: unexpected byte slice length: %d", l))
}
return *(*int32)(unsafe.Pointer(&b[0]))
}
// Uint8Bytes encodes a uint8 into a newly-allocated byte slice. It is a
// shortcut for allocating a new byte slice and filling it using PutUint8.
func Uint8Bytes(v uint8) []byte {
b := make([]byte, 1)
PutUint8(b, v)
return b
}
// Uint16Bytes encodes a uint16 into a newly-allocated byte slice using the
// host machine's native endianness. It is a shortcut for allocating a new
// byte slice and filling it using PutUint16.
func Uint16Bytes(v uint16) []byte {
b := make([]byte, 2)
PutUint16(b, v)
return b
}
// Uint32Bytes encodes a uint32 into a newly-allocated byte slice using the
// host machine's native endianness. It is a shortcut for allocating a new
// byte slice and filling it using PutUint32.
func Uint32Bytes(v uint32) []byte {
b := make([]byte, 4)
PutUint32(b, v)
return b
}
// Uint64Bytes encodes a uint64 into a newly-allocated byte slice using the
// host machine's native endianness. It is a shortcut for allocating a new
// byte slice and filling it using PutUint64.
func Uint64Bytes(v uint64) []byte {
b := make([]byte, 8)
PutUint64(b, v)
return b
}
// Int32Bytes encodes a int32 into a newly-allocated byte slice using the
// host machine's native endianness. It is a shortcut for allocating a new
// byte slice and filling it using PutInt32.
func Int32Bytes(v int32) []byte {
b := make([]byte, 4)
PutInt32(b, v)
return b
}

18
vendor/github.com/mdlayher/netlink/nlenc/string.go generated vendored Normal file
View File

@ -0,0 +1,18 @@
package nlenc
import "bytes"
// Bytes returns a null-terminated byte slice with the contents of s.
func Bytes(s string) []byte {
return append([]byte(s), 0x00)
}
// String returns a string with the contents of b from a null-terminated
// byte slice.
func String(b []byte) string {
// If the string has more than one NULL terminator byte, we want to remove
// all of them before returning the string to the caller; hence the use of
// strings.TrimRight instead of strings.TrimSuffix (which previously only
// removed a single NULL).
return string(bytes.TrimRight(b, "\x00"))
}

9
vendor/github.com/mdlayher/socket/LICENSE.md generated vendored Normal file
View File

@ -0,0 +1,9 @@
# MIT License
Copyright (C) 2021 Matt Layher
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

14
vendor/github.com/mdlayher/socket/README.md generated vendored Normal file
View File

@ -0,0 +1,14 @@
# socket [![Test Status](https://github.com/mdlayher/socket/workflows/Test/badge.svg)](https://github.com/mdlayher/socket/actions) [![Go Reference](https://pkg.go.dev/badge/github.com/mdlayher/socket.svg)](https://pkg.go.dev/github.com/mdlayher/socket) [![Go Report Card](https://goreportcard.com/badge/github.com/mdlayher/socket)](https://goreportcard.com/report/github.com/mdlayher/socket)
Package `socket` provides a low-level network connection type which integrates
with Go's runtime network poller to provide asynchronous I/O and deadline
support. MIT Licensed.
This package focuses on UNIX-like operating systems which make use of BSD
sockets system call APIs. It is meant to be used as a foundation for the
creation of operating system-specific socket packages, for socket families such
as Linux's `AF_NETLINK`, `AF_PACKET`, or `AF_VSOCK`. This package should not be
used directly in end user applications.
Any use of package socket should be guarded by build tags, as one would also
use when importing the `syscall` or `golang.org/x/sys` packages.

23
vendor/github.com/mdlayher/socket/accept.go generated vendored Normal file
View File

@ -0,0 +1,23 @@
//go:build !dragonfly && !freebsd && !illumos && !linux
// +build !dragonfly,!freebsd,!illumos,!linux
package socket
import (
"fmt"
"runtime"
"golang.org/x/sys/unix"
)
const sysAccept = "accept"
// accept wraps accept(2).
func accept(fd, flags int) (int, unix.Sockaddr, error) {
if flags != 0 {
// These operating systems have no support for flags to accept(2).
return 0, nil, fmt.Errorf("socket: Conn.Accept flags are ineffective on %s", runtime.GOOS)
}
return unix.Accept(fd)
}

15
vendor/github.com/mdlayher/socket/accept4.go generated vendored Normal file
View File

@ -0,0 +1,15 @@
//go:build dragonfly || freebsd || illumos || linux
// +build dragonfly freebsd illumos linux
package socket
import (
"golang.org/x/sys/unix"
)
const sysAccept = "accept4"
// accept wraps accept4(2).
func accept(fd, flags int) (int, unix.Sockaddr, error) {
return unix.Accept4(fd, flags)
}

496
vendor/github.com/mdlayher/socket/conn.go generated vendored Normal file
View File

@ -0,0 +1,496 @@
package socket
import (
"os"
"sync/atomic"
"syscall"
"time"
"golang.org/x/sys/unix"
)
// A Conn is a low-level network connection which integrates with Go's runtime
// network poller to provide asynchronous I/O and deadline support.
type Conn struct {
// Indicates whether or not Conn.Close has been called. Must be accessed
// atomically. Atomics definitions must come first in the Conn struct.
closed uint32
// A unique name for the Conn which is also associated with derived file
// descriptors such as those created by accept(2).
name string
// Provides access to the underlying file registered with the runtime
// network poller, and arbitrary raw I/O calls.
fd *os.File
rc syscall.RawConn
}
// High-level methods which provide convenience over raw system calls.
// Close closes the underlying file descriptor for the Conn, which also causes
// all in-flight I/O operations to immediately unblock and return errors. Any
// subsequent uses of Conn will result in EBADF.
func (c *Conn) Close() error {
// The caller has expressed an intent to close the socket, so immediately
// increment s.closed to force further calls to result in EBADF before also
// closing the file descriptor to unblock any outstanding operations.
//
// Because other operations simply check for s.closed != 0, we will permit
// double Close, which would increment s.closed beyond 1.
if atomic.AddUint32(&c.closed, 1) != 1 {
// Multiple Close calls.
return nil
}
return os.NewSyscallError("close", c.fd.Close())
}
// Read implements io.Reader by reading directly from the underlying file
// descriptor.
func (c *Conn) Read(b []byte) (int, error) { return c.fd.Read(b) }
// Write implements io.Writer by writing directly to the underlying file
// descriptor.
func (c *Conn) Write(b []byte) (int, error) { return c.fd.Write(b) }
// SetDeadline sets both the read and write deadlines associated with the Conn.
func (c *Conn) SetDeadline(t time.Time) error { return c.fd.SetDeadline(t) }
// SetReadDeadline sets the read deadline associated with the Conn.
func (c *Conn) SetReadDeadline(t time.Time) error { return c.fd.SetReadDeadline(t) }
// SetWriteDeadline sets the write deadline associated with the Conn.
func (c *Conn) SetWriteDeadline(t time.Time) error { return c.fd.SetWriteDeadline(t) }
// ReadBuffer gets the size of the operating system's receive buffer associated
// with the Conn.
func (c *Conn) ReadBuffer() (int, error) {
return c.GetsockoptInt(unix.SOL_SOCKET, unix.SO_RCVBUF)
}
// WriteBuffer gets the size of the operating system's transmit buffer
// associated with the Conn.
func (c *Conn) WriteBuffer() (int, error) {
return c.GetsockoptInt(unix.SOL_SOCKET, unix.SO_SNDBUF)
}
// SetReadBuffer sets the size of the operating system's receive buffer
// associated with the Conn.
//
// When called with elevated privileges on Linux, the SO_RCVBUFFORCE option will
// be used to override operating system limits. Otherwise SO_RCVBUF is used
// (which obeys operating system limits).
func (c *Conn) SetReadBuffer(bytes int) error { return c.setReadBuffer(bytes) }
// SetWriteBuffer sets the size of the operating system's transmit buffer
// associated with the Conn.
//
// When called with elevated privileges on Linux, the SO_SNDBUFFORCE option will
// be used to override operating system limits. Otherwise SO_SNDBUF is used
// (which obeys operating system limits).
func (c *Conn) SetWriteBuffer(bytes int) error { return c.setWriteBuffer(bytes) }
// SyscallConn returns a raw network connection. This implements the
// syscall.Conn interface.
//
// SyscallConn is intended for advanced use cases, such as getting and setting
// arbitrary socket options using the socket's file descriptor. If possible,
// those operations should be performed using methods on Conn instead.
//
// Once invoked, it is the caller's responsibility to ensure that operations
// performed using Conn and the syscall.RawConn do not conflict with each other.
func (c *Conn) SyscallConn() (syscall.RawConn, error) {
if atomic.LoadUint32(&c.closed) != 0 {
return nil, os.NewSyscallError("syscallconn", unix.EBADF)
}
// TODO(mdlayher): mutex or similar to enforce syscall.RawConn contract of
// FD remaining valid for duration of calls?
return c.rc, nil
}
// Socket wraps the socket(2) system call to produce a Conn. domain, typ, and
// proto are passed directly to socket(2), and name should be a unique name for
// the socket type such as "netlink" or "vsock".
//
// If the operating system supports SOCK_CLOEXEC and SOCK_NONBLOCK, they are
// automatically applied to typ to mirror the standard library's socket flag
// behaviors.
func Socket(domain, typ, proto int, name string) (*Conn, error) {
var (
fd int
err error
)
for {
fd, err = unix.Socket(domain, typ|socketFlags, proto)
switch {
case err == nil:
// Some OSes already set CLOEXEC with typ.
if !flagCLOEXEC {
unix.CloseOnExec(fd)
}
// No error, prepare the Conn.
return newConn(fd, name)
case !ready(err):
// System call interrupted or not ready, try again.
continue
case err == unix.EINVAL, err == unix.EPROTONOSUPPORT:
// On Linux, SOCK_NONBLOCK and SOCK_CLOEXEC were introduced in
// 2.6.27. On FreeBSD, both flags were introduced in FreeBSD 10.
// EINVAL and EPROTONOSUPPORT check for earlier versions of these
// OSes respectively.
//
// Mirror what the standard library does when creating file
// descriptors: avoid racing a fork/exec with the creation of new
// file descriptors, so that child processes do not inherit socket
// file descriptors unexpectedly.
//
// For a more thorough explanation, see similar work in the Go tree:
// func sysSocket in net/sock_cloexec.go, as well as the detailed
// comment in syscall/exec_unix.go.
syscall.ForkLock.RLock()
fd, err = unix.Socket(domain, typ, proto)
if err == nil {
unix.CloseOnExec(fd)
}
syscall.ForkLock.RUnlock()
return newConn(fd, name)
default:
// Unhandled error.
return nil, os.NewSyscallError("socket", err)
}
}
}
// TODO(mdlayher): consider exporting newConn as New?
// newConn wraps an existing file descriptor to create a Conn. name should be a
// unique name for the socket type such as "netlink" or "vsock".
func newConn(fd int, name string) (*Conn, error) {
// All Conn I/O is nonblocking for integration with Go's runtime network
// poller. Depending on the OS this might already be set but it can't hurt
// to set it again.
if err := unix.SetNonblock(fd, true); err != nil {
return nil, os.NewSyscallError("setnonblock", err)
}
// os.NewFile registers the non-blocking file descriptor with the runtime
// poller, which is then used for most subsequent operations except those
// that require raw I/O via SyscallConn.
//
// See also: https://golang.org/pkg/os/#NewFile
f := os.NewFile(uintptr(fd), name)
rc, err := f.SyscallConn()
if err != nil {
return nil, err
}
return &Conn{
name: name,
fd: f,
rc: rc,
}, nil
}
// Low-level methods which provide raw system call access.
// Accept wraps accept(2) or accept4(2) depending on the operating system, but
// returns a Conn for the accepted connection rather than a raw file descriptor.
//
// If the operating system supports accept4(2) (which allows flags),
// SOCK_CLOEXEC and SOCK_NONBLOCK are automatically applied to flags to mirror
// the standard library's socket flag behaviors.
//
// If the operating system only supports accept(2) (which does not allow flags)
// and flags is not zero, an error will be returned.
func (c *Conn) Accept(flags int) (*Conn, unix.Sockaddr, error) {
var (
nfd int
sa unix.Sockaddr
err error
)
doErr := c.read(sysAccept, func(fd int) error {
// Either accept(2) or accept4(2) depending on the OS.
nfd, sa, err = accept(fd, flags|socketFlags)
return err
})
if doErr != nil {
return nil, nil, doErr
}
if err != nil {
// sysAccept is either "accept" or "accept4" depending on the OS.
return nil, nil, os.NewSyscallError(sysAccept, err)
}
// Successfully accepted a connection, wrap it in a Conn for use by the
// caller.
ac, err := newConn(nfd, c.name)
if err != nil {
return nil, nil, err
}
return ac, sa, nil
}
// Bind wraps bind(2).
func (c *Conn) Bind(sa unix.Sockaddr) error {
const op = "bind"
var err error
doErr := c.control(op, func(fd int) error {
err = unix.Bind(fd, sa)
return err
})
if doErr != nil {
return doErr
}
return os.NewSyscallError(op, err)
}
// Connect wraps connect(2).
func (c *Conn) Connect(sa unix.Sockaddr) error {
const op = "connect"
var err error
doErr := c.write(op, func(fd int) error {
err = unix.Connect(fd, sa)
return err
})
if doErr != nil {
return doErr
}
if err == unix.EISCONN {
// Darwin reports EISCONN if already connected, but the socket is
// established and we don't need to report an error.
return nil
}
return os.NewSyscallError(op, err)
}
// Getsockname wraps getsockname(2).
func (c *Conn) Getsockname() (unix.Sockaddr, error) {
const op = "getsockname"
var (
sa unix.Sockaddr
err error
)
doErr := c.control(op, func(fd int) error {
sa, err = unix.Getsockname(fd)
return err
})
if doErr != nil {
return nil, doErr
}
return sa, os.NewSyscallError(op, err)
}
// GetsockoptInt wraps getsockopt(2) for integer values.
func (c *Conn) GetsockoptInt(level, opt int) (int, error) {
const op = "getsockopt"
var (
value int
err error
)
doErr := c.control(op, func(fd int) error {
value, err = unix.GetsockoptInt(fd, level, opt)
return err
})
if doErr != nil {
return 0, doErr
}
return value, os.NewSyscallError(op, err)
}
// Listen wraps listen(2).
func (c *Conn) Listen(n int) error {
const op = "listen"
var err error
doErr := c.control(op, func(fd int) error {
err = unix.Listen(fd, n)
return err
})
if doErr != nil {
return doErr
}
return os.NewSyscallError(op, err)
}
// Recvmsg wraps recvmsg(2).
func (c *Conn) Recvmsg(p, oob []byte, flags int) (int, int, int, unix.Sockaddr, error) {
const op = "recvmsg"
var (
n, oobn, recvflags int
from unix.Sockaddr
err error
)
doErr := c.read(op, func(fd int) error {
n, oobn, recvflags, from, err = unix.Recvmsg(fd, p, oob, flags)
return err
})
if doErr != nil {
return 0, 0, 0, nil, doErr
}
return n, oobn, recvflags, from, os.NewSyscallError(op, err)
}
// Recvfrom wraps recvfrom(2)
func (c *Conn) Recvfrom(p []byte, flags int) (int, unix.Sockaddr, error) {
const op = "recvfrom"
var (
n int
addr unix.Sockaddr
err error
)
doErr := c.read(op, func(fd int) error {
n, addr, err = unix.Recvfrom(fd, p, flags)
return err
})
if doErr != nil {
return 0, nil, doErr
}
return n, addr, os.NewSyscallError(op, err)
}
// Sendmsg wraps sendmsg(2).
func (c *Conn) Sendmsg(p, oob []byte, to unix.Sockaddr, flags int) error {
const op = "sendmsg"
var err error
doErr := c.write(op, func(fd int) error {
err = unix.Sendmsg(fd, p, oob, to, flags)
return err
})
if doErr != nil {
return doErr
}
return os.NewSyscallError(op, err)
}
// Sendto wraps Sendto(2).
func (c *Conn) Sendto(b []byte, to unix.Sockaddr, flags int) error {
const op = "sendto"
var err error
doErr := c.write(op, func(fd int) error {
err = unix.Sendto(fd, b, flags, to)
return err
})
if doErr != nil {
return doErr
}
return os.NewSyscallError(op, err)
}
// SetsockoptInt wraps setsockopt(2) for integer values.
func (c *Conn) SetsockoptInt(level, opt, value int) error {
const op = "setsockopt"
var err error
doErr := c.control(op, func(fd int) error {
err = unix.SetsockoptInt(fd, level, opt, value)
return err
})
if doErr != nil {
return doErr
}
return os.NewSyscallError(op, err)
}
// Conn low-level read/write/control functions. These functions mirror the
// syscall.RawConn APIs but the input closures return errors rather than
// booleans. Any syscalls invoked within f should return their error to allow
// the Conn to check for readiness with the runtime network poller, or to retry
// operations which may have been interrupted by EINTR or similar.
//
// Note that errors from the input closure functions are not propagated to the
// error return values of read/write/control, and the caller is still
// responsible for error handling.
// read executes f, a read function, against the associated file descriptor.
// op is used to create an *os.SyscallError if the file descriptor is closed.
func (c *Conn) read(op string, f func(fd int) error) error {
if atomic.LoadUint32(&c.closed) != 0 {
return os.NewSyscallError(op, unix.EBADF)
}
return c.rc.Read(func(fd uintptr) bool {
return ready(f(int(fd)))
})
}
// write executes f, a write function, against the associated file descriptor.
// op is used to create an *os.SyscallError if the file descriptor is closed.
func (c *Conn) write(op string, f func(fd int) error) error {
if atomic.LoadUint32(&c.closed) != 0 {
return os.NewSyscallError(op, unix.EBADF)
}
return c.rc.Write(func(fd uintptr) bool {
return ready(f(int(fd)))
})
}
// control executes f, a control function, against the associated file
// descriptor. op is used to create an *os.SyscallError if the file descriptor
// is closed.
func (c *Conn) control(op string, f func(fd int) error) error {
if atomic.LoadUint32(&c.closed) != 0 {
return os.NewSyscallError(op, unix.EBADF)
}
return c.rc.Control(func(fd uintptr) {
// Repeatedly attempt the syscall(s) invoked by f until completion is
// indicated by the return value of ready.
for {
if ready(f(int(fd))) {
return
}
}
})
}
// ready indicates readiness based on the value of err.
func ready(err error) bool {
// When a socket is in non-blocking mode, we might see EAGAIN or
// EINPROGRESS. In that case, return false to let the poller wait for
// readiness. See the source code for internal/poll.FD.RawRead for more
// details.
//
// Starting in Go 1.14, goroutines are asynchronously preemptible. The 1.14
// release notes indicate that applications should expect to see EINTR more
// often on slow system calls (like recvmsg while waiting for input), so we
// must handle that case as well.
switch err {
case unix.EAGAIN, unix.EINTR, unix.EINPROGRESS:
// Not ready.
return false
default:
// Ready regardless of whether there was an error or no error.
return true
}
}

88
vendor/github.com/mdlayher/socket/conn_linux.go generated vendored Normal file
View File

@ -0,0 +1,88 @@
//go:build linux
// +build linux
package socket
import (
"os"
"unsafe"
"golang.org/x/net/bpf"
"golang.org/x/sys/unix"
)
// SetBPF attaches an assembled BPF program to a Conn.
func (c *Conn) SetBPF(filter []bpf.RawInstruction) error {
// We can't point to the first instruction in the array if no instructions
// are present.
if len(filter) == 0 {
return os.NewSyscallError("setsockopt", unix.EINVAL)
}
prog := unix.SockFprog{
Len: uint16(len(filter)),
Filter: (*unix.SockFilter)(unsafe.Pointer(&filter[0])),
}
return c.SetsockoptSockFprog(unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, &prog)
}
// RemoveBPF removes a BPF filter from a Conn.
func (c *Conn) RemoveBPF() error {
// 0 argument is ignored.
return c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_DETACH_FILTER, 0)
}
// SetsockoptSockFprog wraps setsockopt(2) for unix.SockFprog values.
func (c *Conn) SetsockoptSockFprog(level, opt int, fprog *unix.SockFprog) error {
const op = "setsockopt"
var err error
doErr := c.control(op, func(fd int) error {
err = unix.SetsockoptSockFprog(fd, level, opt, fprog)
return err
})
if doErr != nil {
return doErr
}
return os.NewSyscallError(op, err)
}
// GetSockoptTpacketStats wraps getsockopt(2) for getting TpacketStats
func (c *Conn) GetSockoptTpacketStats(level, name int) (*unix.TpacketStats, error) {
const op = "getsockopt"
var (
stats *unix.TpacketStats
err error
)
doErr := c.control(op, func(fd int) error {
stats, err = unix.GetsockoptTpacketStats(fd, level, name)
return err
})
if doErr != nil {
return stats, doErr
}
return stats, os.NewSyscallError(op, err)
}
// GetSockoptTpacketStatsV3 wraps getsockopt(2) for getting TpacketStatsV3
func (c *Conn) GetSockoptTpacketStatsV3(level, name int) (*unix.TpacketStatsV3, error) {
const op = "getsockopt"
var (
stats *unix.TpacketStatsV3
err error
)
doErr := c.control(op, func(fd int) error {
stats, err = unix.GetsockoptTpacketStatsV3(fd, level, name)
return err
})
if doErr != nil {
return stats, doErr
}
return stats, os.NewSyscallError(op, err)
}

13
vendor/github.com/mdlayher/socket/doc.go generated vendored Normal file
View File

@ -0,0 +1,13 @@
// Package socket provides a low-level network connection type which integrates
// with Go's runtime network poller to provide asynchronous I/O and deadline
// support.
//
// This package focuses on UNIX-like operating systems which make use of BSD
// sockets system call APIs. It is meant to be used as a foundation for the
// creation of operating system-specific socket packages, for socket families
// such as Linux's AF_NETLINK, AF_PACKET, or AF_VSOCK. This package should not
// be used directly in end user applications.
//
// Any use of package socket should be guarded by build tags, as one would also
// use when importing the syscall or golang.org/x/sys packages.
package socket

24
vendor/github.com/mdlayher/socket/setbuffer_linux.go generated vendored Normal file
View File

@ -0,0 +1,24 @@
//go:build linux
// +build linux
package socket
import "golang.org/x/sys/unix"
// setReadBuffer wraps the SO_RCVBUF{,FORCE} setsockopt(2) options.
func (c *Conn) setReadBuffer(bytes int) error {
err := c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, bytes)
if err != nil {
err = c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_RCVBUF, bytes)
}
return err
}
// setWriteBuffer wraps the SO_SNDBUF{,FORCE} setsockopt(2) options.
func (c *Conn) setWriteBuffer(bytes int) error {
err := c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, bytes)
if err != nil {
err = c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_SNDBUF, bytes)
}
return err
}

16
vendor/github.com/mdlayher/socket/setbuffer_others.go generated vendored Normal file
View File

@ -0,0 +1,16 @@
//go:build !linux
// +build !linux
package socket
import "golang.org/x/sys/unix"
// setReadBuffer wraps the SO_RCVBUF setsockopt(2) option.
func (c *Conn) setReadBuffer(bytes int) error {
return c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_RCVBUF, bytes)
}
// setWriteBuffer wraps the SO_SNDBUF setsockopt(2) option.
func (c *Conn) setWriteBuffer(bytes int) error {
return c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_SNDBUF, bytes)
}

View File

@ -0,0 +1,12 @@
//go:build !darwin
// +build !darwin
package socket
import "golang.org/x/sys/unix"
const (
// These operating systems support CLOEXEC and NONBLOCK socket options.
flagCLOEXEC = true
socketFlags = unix.SOCK_CLOEXEC | unix.SOCK_NONBLOCK
)

11
vendor/github.com/mdlayher/socket/typ_none.go generated vendored Normal file
View File

@ -0,0 +1,11 @@
//go:build darwin
// +build darwin
package socket
const (
// These operating systems do not support CLOEXEC and NONBLOCK socket
// options.
flagCLOEXEC = false
socketFlags = 0
)

3
vendor/golang.org/x/crypto/AUTHORS generated vendored Normal file
View File

@ -0,0 +1,3 @@
# This source code refers to The Go Authors for copyright purposes.
# The master list of authors is in the main Go distribution,
# visible at https://tip.golang.org/AUTHORS.

3
vendor/golang.org/x/crypto/CONTRIBUTORS generated vendored Normal file
View File

@ -0,0 +1,3 @@
# This source code was written by the Go contributors.
# The master list of contributors is in the main Go distribution,
# visible at https://tip.golang.org/CONTRIBUTORS.

27
vendor/golang.org/x/crypto/LICENSE generated vendored Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

22
vendor/golang.org/x/crypto/PATENTS generated vendored Normal file
View File

@ -0,0 +1,22 @@
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.

145
vendor/golang.org/x/crypto/curve25519/curve25519.go generated vendored Normal file
View File

@ -0,0 +1,145 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package curve25519 provides an implementation of the X25519 function, which
// performs scalar multiplication on the elliptic curve known as Curve25519.
// See RFC 7748.
package curve25519 // import "golang.org/x/crypto/curve25519"
import (
"crypto/subtle"
"fmt"
"golang.org/x/crypto/curve25519/internal/field"
)
// ScalarMult sets dst to the product scalar * point.
//
// Deprecated: when provided a low-order point, ScalarMult will set dst to all
// zeroes, irrespective of the scalar. Instead, use the X25519 function, which
// will return an error.
func ScalarMult(dst, scalar, point *[32]byte) {
var e [32]byte
copy(e[:], scalar[:])
e[0] &= 248
e[31] &= 127
e[31] |= 64
var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
x1.SetBytes(point[:])
x2.One()
x3.Set(&x1)
z3.One()
swap := 0
for pos := 254; pos >= 0; pos-- {
b := e[pos/8] >> uint(pos&7)
b &= 1
swap ^= int(b)
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
swap = int(b)
tmp0.Subtract(&x3, &z3)
tmp1.Subtract(&x2, &z2)
x2.Add(&x2, &z2)
z2.Add(&x3, &z3)
z3.Multiply(&tmp0, &x2)
z2.Multiply(&z2, &tmp1)
tmp0.Square(&tmp1)
tmp1.Square(&x2)
x3.Add(&z3, &z2)
z2.Subtract(&z3, &z2)
x2.Multiply(&tmp1, &tmp0)
tmp1.Subtract(&tmp1, &tmp0)
z2.Square(&z2)
z3.Mult32(&tmp1, 121666)
x3.Square(&x3)
tmp0.Add(&tmp0, &z3)
z3.Multiply(&x1, &z2)
z2.Multiply(&tmp1, &tmp0)
}
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
z2.Invert(&z2)
x2.Multiply(&x2, &z2)
copy(dst[:], x2.Bytes())
}
// ScalarBaseMult sets dst to the product scalar * base where base is the
// standard generator.
//
// It is recommended to use the X25519 function with Basepoint instead, as
// copying into fixed size arrays can lead to unexpected bugs.
func ScalarBaseMult(dst, scalar *[32]byte) {
ScalarMult(dst, scalar, &basePoint)
}
const (
// ScalarSize is the size of the scalar input to X25519.
ScalarSize = 32
// PointSize is the size of the point input to X25519.
PointSize = 32
)
// Basepoint is the canonical Curve25519 generator.
var Basepoint []byte
var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
func init() { Basepoint = basePoint[:] }
func checkBasepoint() {
if subtle.ConstantTimeCompare(Basepoint, []byte{
0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}) != 1 {
panic("curve25519: global Basepoint value was modified")
}
}
// X25519 returns the result of the scalar multiplication (scalar * point),
// according to RFC 7748, Section 5. scalar, point and the return value are
// slices of 32 bytes.
//
// scalar can be generated at random, for example with crypto/rand. point should
// be either Basepoint or the output of another X25519 call.
//
// If point is Basepoint (but not if it's a different slice with the same
// contents) a precomputed implementation might be used for performance.
func X25519(scalar, point []byte) ([]byte, error) {
// Outline the body of function, to let the allocation be inlined in the
// caller, and possibly avoid escaping to the heap.
var dst [32]byte
return x25519(&dst, scalar, point)
}
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
var in [32]byte
if l := len(scalar); l != 32 {
return nil, fmt.Errorf("bad scalar length: %d, expected %d", l, 32)
}
if l := len(point); l != 32 {
return nil, fmt.Errorf("bad point length: %d, expected %d", l, 32)
}
copy(in[:], scalar)
if &point[0] == &Basepoint[0] {
checkBasepoint()
ScalarBaseMult(dst, &in)
} else {
var base, zero [32]byte
copy(base[:], point)
ScalarMult(dst, &in, &base)
if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
return nil, fmt.Errorf("bad input point: low order point")
}
}
return dst[:], nil
}

View File

@ -0,0 +1,7 @@
This package is kept in sync with crypto/ed25519/internal/edwards25519/field in
the standard library.
If there are any changes in the standard library that need to be synced to this
package, run sync.sh. It will not overwrite any local changes made since the
previous sync, so it's ok to land changes in this package first, and then sync
to the standard library later.

View File

@ -0,0 +1,416 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package field implements fast arithmetic modulo 2^255-19.
package field
import (
"crypto/subtle"
"encoding/binary"
"math/bits"
)
// Element represents an element of the field GF(2^255-19). Note that this
// is not a cryptographically secure group, and should only be used to interact
// with edwards25519.Point coordinates.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is a valid zero element.
type Element struct {
// An element t represents the integer
// t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204
//
// Between operations, all limbs are expected to be lower than 2^52.
l0 uint64
l1 uint64
l2 uint64
l3 uint64
l4 uint64
}
const maskLow51Bits uint64 = (1 << 51) - 1
var feZero = &Element{0, 0, 0, 0, 0}
// Zero sets v = 0, and returns v.
func (v *Element) Zero() *Element {
*v = *feZero
return v
}
var feOne = &Element{1, 0, 0, 0, 0}
// One sets v = 1, and returns v.
func (v *Element) One() *Element {
*v = *feOne
return v
}
// reduce reduces v modulo 2^255 - 19 and returns it.
func (v *Element) reduce() *Element {
v.carryPropagate()
// After the light reduction we now have a field element representation
// v < 2^255 + 2^13 * 19, but need v < 2^255 - 19.
// If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1,
// generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise.
c := (v.l0 + 19) >> 51
c = (v.l1 + c) >> 51
c = (v.l2 + c) >> 51
c = (v.l3 + c) >> 51
c = (v.l4 + c) >> 51
// If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's
// effectively applying the reduction identity to the carry.
v.l0 += 19 * c
v.l1 += v.l0 >> 51
v.l0 = v.l0 & maskLow51Bits
v.l2 += v.l1 >> 51
v.l1 = v.l1 & maskLow51Bits
v.l3 += v.l2 >> 51
v.l2 = v.l2 & maskLow51Bits
v.l4 += v.l3 >> 51
v.l3 = v.l3 & maskLow51Bits
// no additional carry
v.l4 = v.l4 & maskLow51Bits
return v
}
// Add sets v = a + b, and returns v.
func (v *Element) Add(a, b *Element) *Element {
v.l0 = a.l0 + b.l0
v.l1 = a.l1 + b.l1
v.l2 = a.l2 + b.l2
v.l3 = a.l3 + b.l3
v.l4 = a.l4 + b.l4
// Using the generic implementation here is actually faster than the
// assembly. Probably because the body of this function is so simple that
// the compiler can figure out better optimizations by inlining the carry
// propagation. TODO
return v.carryPropagateGeneric()
}
// Subtract sets v = a - b, and returns v.
func (v *Element) Subtract(a, b *Element) *Element {
// We first add 2 * p, to guarantee the subtraction won't underflow, and
// then subtract b (which can be up to 2^255 + 2^13 * 19).
v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0
v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1
v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2
v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3
v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4
return v.carryPropagate()
}
// Negate sets v = -a, and returns v.
func (v *Element) Negate(a *Element) *Element {
return v.Subtract(feZero, a)
}
// Invert sets v = 1/z mod p, and returns v.
//
// If z == 0, Invert returns v = 0.
func (v *Element) Invert(z *Element) *Element {
// Inversion is implemented as exponentiation with exponent p 2. It uses the
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element
z2.Square(z) // 2
t.Square(&z2) // 4
t.Square(&t) // 8
z9.Multiply(&t, z) // 9
z11.Multiply(&z9, &z2) // 11
t.Square(&z11) // 22
z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0
t.Square(&z2_5_0) // 2^6 - 2^1
for i := 0; i < 4; i++ {
t.Square(&t) // 2^10 - 2^5
}
z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0
t.Square(&z2_10_0) // 2^11 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^20 - 2^10
}
z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0
t.Square(&z2_20_0) // 2^21 - 2^1
for i := 0; i < 19; i++ {
t.Square(&t) // 2^40 - 2^20
}
t.Multiply(&t, &z2_20_0) // 2^40 - 2^0
t.Square(&t) // 2^41 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^50 - 2^10
}
z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0
t.Square(&z2_50_0) // 2^51 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^100 - 2^50
}
z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0
t.Square(&z2_100_0) // 2^101 - 2^1
for i := 0; i < 99; i++ {
t.Square(&t) // 2^200 - 2^100
}
t.Multiply(&t, &z2_100_0) // 2^200 - 2^0
t.Square(&t) // 2^201 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^250 - 2^50
}
t.Multiply(&t, &z2_50_0) // 2^250 - 2^0
t.Square(&t) // 2^251 - 2^1
t.Square(&t) // 2^252 - 2^2
t.Square(&t) // 2^253 - 2^3
t.Square(&t) // 2^254 - 2^4
t.Square(&t) // 2^255 - 2^5
return v.Multiply(&t, &z11) // 2^255 - 21
}
// Set sets v = a, and returns v.
func (v *Element) Set(a *Element) *Element {
*v = *a
return v
}
// SetBytes sets v to x, which must be a 32-byte little-endian encoding.
//
// Consistent with RFC 7748, the most significant bit (the high bit of the
// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1)
// are accepted. Note that this is laxer than specified by RFC 8032.
func (v *Element) SetBytes(x []byte) *Element {
if len(x) != 32 {
panic("edwards25519: invalid field element input size")
}
// Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51).
v.l0 = binary.LittleEndian.Uint64(x[0:8])
v.l0 &= maskLow51Bits
// Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51).
v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3
v.l1 &= maskLow51Bits
// Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51).
v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6
v.l2 &= maskLow51Bits
// Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51).
v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1
v.l3 &= maskLow51Bits
// Bits 204:251 (bytes 24:32, bits 192:256, shift 12, mask 51).
// Note: not bytes 25:33, shift 4, to avoid overread.
v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12
v.l4 &= maskLow51Bits
return v
}
// Bytes returns the canonical 32-byte little-endian encoding of v.
func (v *Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [32]byte
return v.bytes(&out)
}
func (v *Element) bytes(out *[32]byte) []byte {
t := *v
t.reduce()
var buf [8]byte
for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} {
bitsOffset := i * 51
binary.LittleEndian.PutUint64(buf[:], l<<uint(bitsOffset%8))
for i, bb := range buf {
off := bitsOffset/8 + i
if off >= len(out) {
break
}
out[off] |= bb
}
}
return out[:]
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
func (v *Element) Equal(u *Element) int {
sa, sv := u.Bytes(), v.Bytes()
return subtle.ConstantTimeCompare(sa, sv)
}
// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise.
func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) }
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *Element) Select(a, b *Element, cond int) *Element {
m := mask64Bits(cond)
v.l0 = (m & a.l0) | (^m & b.l0)
v.l1 = (m & a.l1) | (^m & b.l1)
v.l2 = (m & a.l2) | (^m & b.l2)
v.l3 = (m & a.l3) | (^m & b.l3)
v.l4 = (m & a.l4) | (^m & b.l4)
return v
}
// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v.
func (v *Element) Swap(u *Element, cond int) {
m := mask64Bits(cond)
t := m & (v.l0 ^ u.l0)
v.l0 ^= t
u.l0 ^= t
t = m & (v.l1 ^ u.l1)
v.l1 ^= t
u.l1 ^= t
t = m & (v.l2 ^ u.l2)
v.l2 ^= t
u.l2 ^= t
t = m & (v.l3 ^ u.l3)
v.l3 ^= t
u.l3 ^= t
t = m & (v.l4 ^ u.l4)
v.l4 ^= t
u.l4 ^= t
}
// IsNegative returns 1 if v is negative, and 0 otherwise.
func (v *Element) IsNegative() int {
return int(v.Bytes()[0] & 1)
}
// Absolute sets v to |u|, and returns v.
func (v *Element) Absolute(u *Element) *Element {
return v.Select(new(Element).Negate(u), u, u.IsNegative())
}
// Multiply sets v = x * y, and returns v.
func (v *Element) Multiply(x, y *Element) *Element {
feMul(v, x, y)
return v
}
// Square sets v = x * x, and returns v.
func (v *Element) Square(x *Element) *Element {
feSquare(v, x)
return v
}
// Mult32 sets v = x * y, and returns v.
func (v *Element) Mult32(x *Element, y uint32) *Element {
x0lo, x0hi := mul51(x.l0, y)
x1lo, x1hi := mul51(x.l1, y)
x2lo, x2hi := mul51(x.l2, y)
x3lo, x3hi := mul51(x.l3, y)
x4lo, x4hi := mul51(x.l4, y)
v.l0 = x0lo + 19*x4hi // carried over per the reduction identity
v.l1 = x1lo + x0hi
v.l2 = x2lo + x1hi
v.l3 = x3lo + x2hi
v.l4 = x4lo + x3hi
// The hi portions are going to be only 32 bits, plus any previous excess,
// so we can skip the carry propagation.
return v
}
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
mh, ml := bits.Mul64(a, uint64(b))
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}
// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3.
func (v *Element) Pow22523(x *Element) *Element {
var t0, t1, t2 Element
t0.Square(x) // x^2
t1.Square(&t0) // x^4
t1.Square(&t1) // x^8
t1.Multiply(x, &t1) // x^9
t0.Multiply(&t0, &t1) // x^11
t0.Square(&t0) // x^22
t0.Multiply(&t1, &t0) // x^31
t1.Square(&t0) // x^62
for i := 1; i < 5; i++ { // x^992
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1
t1.Square(&t0) // 2^11 - 2
for i := 1; i < 10; i++ { // 2^20 - 2^10
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^20 - 1
t2.Square(&t1) // 2^21 - 2
for i := 1; i < 20; i++ { // 2^40 - 2^20
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^40 - 1
t1.Square(&t1) // 2^41 - 2
for i := 1; i < 10; i++ { // 2^50 - 2^10
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^50 - 1
t1.Square(&t0) // 2^51 - 2
for i := 1; i < 50; i++ { // 2^100 - 2^50
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^100 - 1
t2.Square(&t1) // 2^101 - 2
for i := 1; i < 100; i++ { // 2^200 - 2^100
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^200 - 1
t1.Square(&t1) // 2^201 - 2
for i := 1; i < 50; i++ { // 2^250 - 2^50
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^250 - 1
t0.Square(&t0) // 2^251 - 2
t0.Square(&t0) // 2^252 - 4
return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3)
}
// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion.
var sqrtM1 = &Element{1718705420411056, 234908883556509,
2233514472574048, 2117202627021982, 765476049583133}
// SqrtRatio sets r to the non-negative square root of the ratio of u and v.
//
// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio
// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00,
// and returns r and 0.
func (r *Element) SqrtRatio(u, v *Element) (rr *Element, wasSquare int) {
var a, b Element
// r = (u * v3) * (u * v7)^((p-5)/8)
v2 := a.Square(v)
uv3 := b.Multiply(u, b.Multiply(v2, v))
uv7 := a.Multiply(uv3, a.Square(v2))
r.Multiply(uv3, r.Pow22523(uv7))
check := a.Multiply(v, a.Square(r)) // check = v * r^2
uNeg := b.Negate(u)
correctSignSqrt := check.Equal(u)
flippedSignSqrt := check.Equal(uNeg)
flippedSignSqrtI := check.Equal(uNeg.Multiply(uNeg, sqrtM1))
rPrime := b.Multiply(r, sqrtM1) // r_prime = SQRT_M1 * r
// r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r)
r.Select(rPrime, r, flippedSignSqrt|flippedSignSqrtI)
r.Absolute(r) // Choose the nonnegative square root.
return r, correctSignSqrt | flippedSignSqrt
}

View File

@ -0,0 +1,13 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
// +build amd64,gc,!purego
package field
// feMul sets out = a * b. It works like feMulGeneric.
//go:noescape
func feMul(out *Element, a *Element, b *Element)
// feSquare sets out = a * a. It works like feSquareGeneric.
//go:noescape
func feSquare(out *Element, a *Element)

View File

@ -0,0 +1,379 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
#include "textflag.h"
// func feMul(out *Element, a *Element, b *Element)
TEXT ·feMul(SB), NOSPLIT, $0-24
MOVQ a+8(FP), CX
MOVQ b+16(FP), BX
// r0 = a0×b0
MOVQ (CX), AX
MULQ (BX)
MOVQ AX, DI
MOVQ DX, SI
// r0 += 19×a1×b4
MOVQ 8(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a2×b3
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a3×b2
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a4×b1
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 8(BX)
ADDQ AX, DI
ADCQ DX, SI
// r1 = a0×b1
MOVQ (CX), AX
MULQ 8(BX)
MOVQ AX, R9
MOVQ DX, R8
// r1 += a1×b0
MOVQ 8(CX), AX
MULQ (BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a2×b4
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a3×b3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a4×b2
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, R9
ADCQ DX, R8
// r2 = a0×b2
MOVQ (CX), AX
MULQ 16(BX)
MOVQ AX, R11
MOVQ DX, R10
// r2 += a1×b1
MOVQ 8(CX), AX
MULQ 8(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += a2×b0
MOVQ 16(CX), AX
MULQ (BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a3×b4
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a4×b3
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R11
ADCQ DX, R10
// r3 = a0×b3
MOVQ (CX), AX
MULQ 24(BX)
MOVQ AX, R13
MOVQ DX, R12
// r3 += a1×b2
MOVQ 8(CX), AX
MULQ 16(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a2×b1
MOVQ 16(CX), AX
MULQ 8(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a3×b0
MOVQ 24(CX), AX
MULQ (BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += 19×a4×b4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R13
ADCQ DX, R12
// r4 = a0×b4
MOVQ (CX), AX
MULQ 32(BX)
MOVQ AX, R15
MOVQ DX, R14
// r4 += a1×b3
MOVQ 8(CX), AX
MULQ 24(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a2×b2
MOVQ 16(CX), AX
MULQ 16(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a3×b1
MOVQ 24(CX), AX
MULQ 8(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a4×b0
MOVQ 32(CX), AX
MULQ (BX)
ADDQ AX, R15
ADCQ DX, R14
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, DI, SI
SHLQ $0x0d, R9, R8
SHLQ $0x0d, R11, R10
SHLQ $0x0d, R13, R12
SHLQ $0x0d, R15, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Second reduction chain (carryPropagate)
MOVQ DI, SI
SHRQ $0x33, SI
MOVQ R9, R8
SHRQ $0x33, R8
MOVQ R11, R10
SHRQ $0x33, R10
MOVQ R13, R12
SHRQ $0x33, R12
MOVQ R15, R14
SHRQ $0x33, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Store output
MOVQ out+0(FP), AX
MOVQ DI, (AX)
MOVQ R9, 8(AX)
MOVQ R11, 16(AX)
MOVQ R13, 24(AX)
MOVQ R15, 32(AX)
RET
// func feSquare(out *Element, a *Element)
TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ a+8(FP), CX
// r0 = l0×l0
MOVQ (CX), AX
MULQ (CX)
MOVQ AX, SI
MOVQ DX, BX
// r0 += 38×l1×l4
MOVQ 8(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, SI
ADCQ DX, BX
// r0 += 38×l2×l3
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 24(CX)
ADDQ AX, SI
ADCQ DX, BX
// r1 = 2×l0×l1
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 8(CX)
MOVQ AX, R8
MOVQ DX, DI
// r1 += 38×l2×l4
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R8
ADCQ DX, DI
// r1 += 19×l3×l3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(CX)
ADDQ AX, R8
ADCQ DX, DI
// r2 = 2×l0×l2
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 16(CX)
MOVQ AX, R10
MOVQ DX, R9
// r2 += l1×l1
MOVQ 8(CX), AX
MULQ 8(CX)
ADDQ AX, R10
ADCQ DX, R9
// r2 += 38×l3×l4
MOVQ 24(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R10
ADCQ DX, R9
// r3 = 2×l0×l3
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 24(CX)
MOVQ AX, R12
MOVQ DX, R11
// r3 += 2×l1×l2
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 16(CX)
ADDQ AX, R12
ADCQ DX, R11
// r3 += 19×l4×l4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(CX)
ADDQ AX, R12
ADCQ DX, R11
// r4 = 2×l0×l4
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 32(CX)
MOVQ AX, R14
MOVQ DX, R13
// r4 += 2×l1×l3
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 24(CX)
ADDQ AX, R14
ADCQ DX, R13
// r4 += l2×l2
MOVQ 16(CX), AX
MULQ 16(CX)
ADDQ AX, R14
ADCQ DX, R13
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, SI, BX
SHLQ $0x0d, R8, DI
SHLQ $0x0d, R10, R9
SHLQ $0x0d, R12, R11
SHLQ $0x0d, R14, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Second reduction chain (carryPropagate)
MOVQ SI, BX
SHRQ $0x33, BX
MOVQ R8, DI
SHRQ $0x33, DI
MOVQ R10, R9
SHRQ $0x33, R9
MOVQ R12, R11
SHRQ $0x33, R11
MOVQ R14, R13
SHRQ $0x33, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Store output
MOVQ out+0(FP), AX
MOVQ SI, (AX)
MOVQ R8, 8(AX)
MOVQ R10, 16(AX)
MOVQ R12, 24(AX)
MOVQ R14, 32(AX)
RET

View File

@ -0,0 +1,12 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
// +build !amd64 !gc purego
package field
func feMul(v, x, y *Element) { feMulGeneric(v, x, y) }
func feSquare(v, x *Element) { feSquareGeneric(v, x) }

View File

@ -0,0 +1,16 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
// +build arm64,gc,!purego
package field
//go:noescape
func carryPropagate(v *Element)
func (v *Element) carryPropagate() *Element {
carryPropagate(v)
return v
}

View File

@ -0,0 +1,43 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
// +build arm64,gc,!purego
#include "textflag.h"
// carryPropagate works exactly like carryPropagateGeneric and uses the
// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but
// avoids loading R0-R4 twice and uses LDP and STP.
//
// See https://golang.org/issues/43145 for the main compiler issue.
//
// func carryPropagate(v *Element)
TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8
MOVD v+0(FP), R20
LDP 0(R20), (R0, R1)
LDP 16(R20), (R2, R3)
MOVD 32(R20), R4
AND $0x7ffffffffffff, R0, R10
AND $0x7ffffffffffff, R1, R11
AND $0x7ffffffffffff, R2, R12
AND $0x7ffffffffffff, R3, R13
AND $0x7ffffffffffff, R4, R14
ADD R0>>51, R11, R11
ADD R1>>51, R12, R12
ADD R2>>51, R13, R13
ADD R3>>51, R14, R14
// R4>>51 * 19 + R10 -> R10
LSR $51, R4, R21
MOVD $19, R22
MADD R22, R10, R21, R10
STP (R10, R11), 0(R20)
STP (R12, R13), 16(R20)
MOVD R14, 32(R20)
RET

View File

@ -0,0 +1,12 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64 || !gc || purego
// +build !arm64 !gc purego
package field
func (v *Element) carryPropagate() *Element {
return v.carryPropagateGeneric()
}

View File

@ -0,0 +1,264 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "math/bits"
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
// bits.Mul64 and bits.Add64 intrinsics.
type uint128 struct {
lo, hi uint64
}
// mul64 returns a * b.
func mul64(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
// addMul64 returns v + a * b.
func addMul64(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
func shiftRightBy51(a uint128) uint64 {
return (a.hi << (64 - 51)) | (a.lo >> 51)
}
func feMulGeneric(v, a, b *Element) {
a0 := a.l0
a1 := a.l1
a2 := a.l2
a3 := a.l3
a4 := a.l4
b0 := b.l0
b1 := b.l1
b2 := b.l2
b3 := b.l3
b4 := b.l4
// Limb multiplication works like pen-and-paper columnar multiplication, but
// with 51-bit limbs instead of digits.
//
// a4 a3 a2 a1 a0 x
// b4 b3 b2 b1 b0 =
// ------------------------
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a4b1 a3b1 a2b1 a1b1 a0b1 +
// a4b2 a3b2 a2b2 a1b2 a0b2 +
// a4b3 a3b3 a2b3 a1b3 a0b3 +
// a4b4 a3b4 a2b4 a1b4 a0b4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to
// reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5,
// r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc.
//
// Reduction can be carried out simultaneously to multiplication. For
// example, we do not compute r5: whenever the result of a multiplication
// belongs to r5, like a1b4, we multiply it by 19 and add the result to r0.
//
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a3b1 a2b1 a1b1 a0b1 19×a4b1 +
// a2b2 a1b2 a0b2 19×a4b2 19×a3b2 +
// a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 +
// a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// Finally we add up the columns into wide, overlapping limbs.
a1_19 := a1 * 19
a2_19 := a2 * 19
a3_19 := a3 * 19
a4_19 := a4 * 19
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
r0 := mul64(a0, b0)
r0 = addMul64(r0, a1_19, b4)
r0 = addMul64(r0, a2_19, b3)
r0 = addMul64(r0, a3_19, b2)
r0 = addMul64(r0, a4_19, b1)
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
r1 := mul64(a0, b1)
r1 = addMul64(r1, a1, b0)
r1 = addMul64(r1, a2_19, b4)
r1 = addMul64(r1, a3_19, b3)
r1 = addMul64(r1, a4_19, b2)
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
r2 := mul64(a0, b2)
r2 = addMul64(r2, a1, b1)
r2 = addMul64(r2, a2, b0)
r2 = addMul64(r2, a3_19, b4)
r2 = addMul64(r2, a4_19, b3)
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
r3 := mul64(a0, b3)
r3 = addMul64(r3, a1, b2)
r3 = addMul64(r3, a2, b1)
r3 = addMul64(r3, a3, b0)
r3 = addMul64(r3, a4_19, b4)
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
r4 := mul64(a0, b4)
r4 = addMul64(r4, a1, b3)
r4 = addMul64(r4, a2, b2)
r4 = addMul64(r4, a3, b1)
r4 = addMul64(r4, a4, b0)
// After the multiplication, we need to reduce (carry) the five coefficients
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
// to respect the Element invariant.
//
// Overall, the reduction works the same as carryPropagate, except with
// wider inputs: we take the carry for each coefficient by shifting it right
// by 51, and add it to the limb above it. The top carry is multiplied by 19
// according to the reduction identity and added to the lowest limb.
//
// The largest coefficient (r0) will be at most 111 bits, which guarantees
// that all carries are at most 111 - 51 = 60 bits, which fits in a uint64.
//
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
// r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²)
// r0 < (1 + 19 × 4) × 2⁵² × 2⁵²
// r0 < 2⁷ × 2⁵² × 2⁵²
// r0 < 2¹¹¹
//
// Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most
// 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and
// allows us to easily apply the reduction identity.
//
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
// r4 < 5 × 2⁵² × 2⁵²
// r4 < 2¹⁰⁷
//
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
// Now all coefficients fit into 64-bit registers but are still too large to
// be passed around as a Element. We therefore do one last carry chain,
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
func feSquareGeneric(v, a *Element) {
l0 := a.l0
l1 := a.l1
l2 := a.l2
l3 := a.l3
l4 := a.l4
// Squaring works precisely like multiplication above, but thanks to its
// symmetry we get to group a few terms together.
//
// l4 l3 l2 l1 l0 x
// l4 l3 l2 l1 l0 =
// ------------------------
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l4l1 l3l1 l2l1 l1l1 l0l1 +
// l4l2 l3l2 l2l2 l1l2 l0l2 +
// l4l3 l3l3 l2l3 l1l3 l0l3 +
// l4l4 l3l4 l2l4 l1l4 l0l4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l3l1 l2l1 l1l1 l0l1 19×l4l1 +
// l2l2 l1l2 l0l2 19×l4l2 19×l3l2 +
// l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 +
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
// only three Mul64 and four Add64, instead of five and eight.
l0_2 := l0 * 2
l1_2 := l1 * 2
l1_38 := l1 * 38
l2_38 := l2 * 38
l3_38 := l3 * 38
l3_19 := l3 * 19
l4_19 := l4 * 19
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
r0 := mul64(l0, l0)
r0 = addMul64(r0, l1_38, l4)
r0 = addMul64(r0, l2_38, l3)
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
r1 := mul64(l0_2, l1)
r1 = addMul64(r1, l2_38, l4)
r1 = addMul64(r1, l3_19, l3)
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
r2 := mul64(l0_2, l2)
r2 = addMul64(r2, l1, l1)
r2 = addMul64(r2, l3_38, l4)
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
r3 := mul64(l0_2, l3)
r3 = addMul64(r3, l1_2, l2)
r3 = addMul64(r3, l4_19, l4)
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
r4 := mul64(l0_2, l4)
r4 = addMul64(r4, l1_2, l3)
r4 = addMul64(r4, l2, l2)
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
// carryPropagate brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry. TODO inline
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51
c1 := v.l1 >> 51
c2 := v.l2 >> 51
c3 := v.l3 >> 51
c4 := v.l4 >> 51
v.l0 = v.l0&maskLow51Bits + c4*19
v.l1 = v.l1&maskLow51Bits + c0
v.l2 = v.l2&maskLow51Bits + c1
v.l3 = v.l3&maskLow51Bits + c2
v.l4 = v.l4&maskLow51Bits + c3
return v
}

View File

@ -0,0 +1 @@
b0c49ae9f59d233526f8934262c5bbbe14d4358d

View File

@ -0,0 +1,19 @@
#! /bin/bash
set -euo pipefail
cd "$(git rev-parse --show-toplevel)"
STD_PATH=src/crypto/ed25519/internal/edwards25519/field
LOCAL_PATH=curve25519/internal/field
LAST_SYNC_REF=$(cat $LOCAL_PATH/sync.checkpoint)
git fetch https://go.googlesource.com/go master
if git diff --quiet $LAST_SYNC_REF:$STD_PATH FETCH_HEAD:$STD_PATH; then
echo "No changes."
else
NEW_REF=$(git rev-parse FETCH_HEAD | tee $LOCAL_PATH/sync.checkpoint)
echo "Applying changes from $LAST_SYNC_REF to $NEW_REF..."
git diff $LAST_SYNC_REF:$STD_PATH FETCH_HEAD:$STD_PATH | \
git apply -3 --directory=$LOCAL_PATH
fi

41
vendor/golang.org/x/net/bpf/asm.go generated vendored Normal file
View File

@ -0,0 +1,41 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bpf
import "fmt"
// Assemble converts insts into raw instructions suitable for loading
// into a BPF virtual machine.
//
// Currently, no optimization is attempted, the assembled program flow
// is exactly as provided.
func Assemble(insts []Instruction) ([]RawInstruction, error) {
ret := make([]RawInstruction, len(insts))
var err error
for i, inst := range insts {
ret[i], err = inst.Assemble()
if err != nil {
return nil, fmt.Errorf("assembling instruction %d: %s", i+1, err)
}
}
return ret, nil
}
// Disassemble attempts to parse raw back into
// Instructions. Unrecognized RawInstructions are assumed to be an
// extension not implemented by this package, and are passed through
// unchanged to the output. The allDecoded value reports whether insts
// contains no RawInstructions.
func Disassemble(raw []RawInstruction) (insts []Instruction, allDecoded bool) {
insts = make([]Instruction, len(raw))
allDecoded = true
for i, r := range raw {
insts[i] = r.Disassemble()
if _, ok := insts[i].(RawInstruction); ok {
allDecoded = false
}
}
return insts, allDecoded
}

222
vendor/golang.org/x/net/bpf/constants.go generated vendored Normal file
View File

@ -0,0 +1,222 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bpf
// A Register is a register of the BPF virtual machine.
type Register uint16
const (
// RegA is the accumulator register. RegA is always the
// destination register of ALU operations.
RegA Register = iota
// RegX is the indirection register, used by LoadIndirect
// operations.
RegX
)
// An ALUOp is an arithmetic or logic operation.
type ALUOp uint16
// ALU binary operation types.
const (
ALUOpAdd ALUOp = iota << 4
ALUOpSub
ALUOpMul
ALUOpDiv
ALUOpOr
ALUOpAnd
ALUOpShiftLeft
ALUOpShiftRight
aluOpNeg // Not exported because it's the only unary ALU operation, and gets its own instruction type.
ALUOpMod
ALUOpXor
)
// A JumpTest is a comparison operator used in conditional jumps.
type JumpTest uint16
// Supported operators for conditional jumps.
// K can be RegX for JumpIfX
const (
// K == A
JumpEqual JumpTest = iota
// K != A
JumpNotEqual
// K > A
JumpGreaterThan
// K < A
JumpLessThan
// K >= A
JumpGreaterOrEqual
// K <= A
JumpLessOrEqual
// K & A != 0
JumpBitsSet
// K & A == 0
JumpBitsNotSet
)
// An Extension is a function call provided by the kernel that
// performs advanced operations that are expensive or impossible
// within the BPF virtual machine.
//
// Extensions are only implemented by the Linux kernel.
//
// TODO: should we prune this list? Some of these extensions seem
// either broken or near-impossible to use correctly, whereas other
// (len, random, ifindex) are quite useful.
type Extension int
// Extension functions available in the Linux kernel.
const (
// extOffset is the negative maximum number of instructions used
// to load instructions by overloading the K argument.
extOffset = -0x1000
// ExtLen returns the length of the packet.
ExtLen Extension = 1
// ExtProto returns the packet's L3 protocol type.
ExtProto Extension = 0
// ExtType returns the packet's type (skb->pkt_type in the kernel)
//
// TODO: better documentation. How nice an API do we want to
// provide for these esoteric extensions?
ExtType Extension = 4
// ExtPayloadOffset returns the offset of the packet payload, or
// the first protocol header that the kernel does not know how to
// parse.
ExtPayloadOffset Extension = 52
// ExtInterfaceIndex returns the index of the interface on which
// the packet was received.
ExtInterfaceIndex Extension = 8
// ExtNetlinkAttr returns the netlink attribute of type X at
// offset A.
ExtNetlinkAttr Extension = 12
// ExtNetlinkAttrNested returns the nested netlink attribute of
// type X at offset A.
ExtNetlinkAttrNested Extension = 16
// ExtMark returns the packet's mark value.
ExtMark Extension = 20
// ExtQueue returns the packet's assigned hardware queue.
ExtQueue Extension = 24
// ExtLinkLayerType returns the packet's hardware address type
// (e.g. Ethernet, Infiniband).
ExtLinkLayerType Extension = 28
// ExtRXHash returns the packets receive hash.
//
// TODO: figure out what this rxhash actually is.
ExtRXHash Extension = 32
// ExtCPUID returns the ID of the CPU processing the current
// packet.
ExtCPUID Extension = 36
// ExtVLANTag returns the packet's VLAN tag.
ExtVLANTag Extension = 44
// ExtVLANTagPresent returns non-zero if the packet has a VLAN
// tag.
//
// TODO: I think this might be a lie: it reads bit 0x1000 of the
// VLAN header, which changed meaning in recent revisions of the
// spec - this extension may now return meaningless information.
ExtVLANTagPresent Extension = 48
// ExtVLANProto returns 0x8100 if the frame has a VLAN header,
// 0x88a8 if the frame has a "Q-in-Q" double VLAN header, or some
// other value if no VLAN information is present.
ExtVLANProto Extension = 60
// ExtRand returns a uniformly random uint32.
ExtRand Extension = 56
)
// The following gives names to various bit patterns used in opcode construction.
const (
opMaskCls uint16 = 0x7
// opClsLoad masks
opMaskLoadDest = 0x01
opMaskLoadWidth = 0x18
opMaskLoadMode = 0xe0
// opClsALU & opClsJump
opMaskOperand = 0x08
opMaskOperator = 0xf0
)
const (
// +---------------+-----------------+---+---+---+
// | AddrMode (3b) | LoadWidth (2b) | 0 | 0 | 0 |
// +---------------+-----------------+---+---+---+
opClsLoadA uint16 = iota
// +---------------+-----------------+---+---+---+
// | AddrMode (3b) | LoadWidth (2b) | 0 | 0 | 1 |
// +---------------+-----------------+---+---+---+
opClsLoadX
// +---+---+---+---+---+---+---+---+
// | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
// +---+---+---+---+---+---+---+---+
opClsStoreA
// +---+---+---+---+---+---+---+---+
// | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 |
// +---+---+---+---+---+---+---+---+
opClsStoreX
// +---------------+-----------------+---+---+---+
// | Operator (4b) | OperandSrc (1b) | 1 | 0 | 0 |
// +---------------+-----------------+---+---+---+
opClsALU
// +-----------------------------+---+---+---+---+
// | TestOperator (4b) | 0 | 1 | 0 | 1 |
// +-----------------------------+---+---+---+---+
opClsJump
// +---+-------------------------+---+---+---+---+
// | 0 | 0 | 0 | RetSrc (1b) | 0 | 1 | 1 | 0 |
// +---+-------------------------+---+---+---+---+
opClsReturn
// +---+-------------------------+---+---+---+---+
// | 0 | 0 | 0 | TXAorTAX (1b) | 0 | 1 | 1 | 1 |
// +---+-------------------------+---+---+---+---+
opClsMisc
)
const (
opAddrModeImmediate uint16 = iota << 5
opAddrModeAbsolute
opAddrModeIndirect
opAddrModeScratch
opAddrModePacketLen // actually an extension, not an addressing mode.
opAddrModeMemShift
)
const (
opLoadWidth4 uint16 = iota << 3
opLoadWidth2
opLoadWidth1
)
// Operand for ALU and Jump instructions
type opOperand uint16
// Supported operand sources.
const (
opOperandConstant opOperand = iota << 3
opOperandX
)
// An jumpOp is a conditional jump condition.
type jumpOp uint16
// Supported jump conditions.
const (
opJumpAlways jumpOp = iota << 4
opJumpEqual
opJumpGT
opJumpGE
opJumpSet
)
const (
opRetSrcConstant uint16 = iota << 4
opRetSrcA
)
const (
opMiscTAX = 0x00
opMiscTXA = 0x80
)

82
vendor/golang.org/x/net/bpf/doc.go generated vendored Normal file
View File

@ -0,0 +1,82 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package bpf implements marshaling and unmarshaling of programs for the
Berkeley Packet Filter virtual machine, and provides a Go implementation
of the virtual machine.
BPF's main use is to specify a packet filter for network taps, so that
the kernel doesn't have to expensively copy every packet it sees to
userspace. However, it's been repurposed to other areas where running
user code in-kernel is needed. For example, Linux's seccomp uses BPF
to apply security policies to system calls. For simplicity, this
documentation refers only to packets, but other uses of BPF have their
own data payloads.
BPF programs run in a restricted virtual machine. It has almost no
access to kernel functions, and while conditional branches are
allowed, they can only jump forwards, to guarantee that there are no
infinite loops.
The virtual machine
The BPF VM is an accumulator machine. Its main register, called
register A, is an implicit source and destination in all arithmetic
and logic operations. The machine also has 16 scratch registers for
temporary storage, and an indirection register (register X) for
indirect memory access. All registers are 32 bits wide.
Each run of a BPF program is given one packet, which is placed in the
VM's read-only "main memory". LoadAbsolute and LoadIndirect
instructions can fetch up to 32 bits at a time into register A for
examination.
The goal of a BPF program is to produce and return a verdict (uint32),
which tells the kernel what to do with the packet. In the context of
packet filtering, the returned value is the number of bytes of the
packet to forward to userspace, or 0 to ignore the packet. Other
contexts like seccomp define their own return values.
In order to simplify programs, attempts to read past the end of the
packet terminate the program execution with a verdict of 0 (ignore
packet). This means that the vast majority of BPF programs don't need
to do any explicit bounds checking.
In addition to the bytes of the packet, some BPF programs have access
to extensions, which are essentially calls to kernel utility
functions. Currently, the only extensions supported by this package
are the Linux packet filter extensions.
Examples
This packet filter selects all ARP packets.
bpf.Assemble([]bpf.Instruction{
// Load "EtherType" field from the ethernet header.
bpf.LoadAbsolute{Off: 12, Size: 2},
// Skip over the next instruction if EtherType is not ARP.
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: 0x0806, SkipTrue: 1},
// Verdict is "send up to 4k of the packet to userspace."
bpf.RetConstant{Val: 4096},
// Verdict is "ignore packet."
bpf.RetConstant{Val: 0},
})
This packet filter captures a random 1% sample of traffic.
bpf.Assemble([]bpf.Instruction{
// Get a 32-bit random number from the Linux kernel.
bpf.LoadExtension{Num: bpf.ExtRand},
// 1% dice roll?
bpf.JumpIf{Cond: bpf.JumpLessThan, Val: 2^32/100, SkipFalse: 1},
// Capture.
bpf.RetConstant{Val: 4096},
// Ignore.
bpf.RetConstant{Val: 0},
})
*/
package bpf // import "golang.org/x/net/bpf"

726
vendor/golang.org/x/net/bpf/instructions.go generated vendored Normal file
View File

@ -0,0 +1,726 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bpf
import "fmt"
// An Instruction is one instruction executed by the BPF virtual
// machine.
type Instruction interface {
// Assemble assembles the Instruction into a RawInstruction.
Assemble() (RawInstruction, error)
}
// A RawInstruction is a raw BPF virtual machine instruction.
type RawInstruction struct {
// Operation to execute.
Op uint16
// For conditional jump instructions, the number of instructions
// to skip if the condition is true/false.
Jt uint8
Jf uint8
// Constant parameter. The meaning depends on the Op.
K uint32
}
// Assemble implements the Instruction Assemble method.
func (ri RawInstruction) Assemble() (RawInstruction, error) { return ri, nil }
// Disassemble parses ri into an Instruction and returns it. If ri is
// not recognized by this package, ri itself is returned.
func (ri RawInstruction) Disassemble() Instruction {
switch ri.Op & opMaskCls {
case opClsLoadA, opClsLoadX:
reg := Register(ri.Op & opMaskLoadDest)
sz := 0
switch ri.Op & opMaskLoadWidth {
case opLoadWidth4:
sz = 4
case opLoadWidth2:
sz = 2
case opLoadWidth1:
sz = 1
default:
return ri
}
switch ri.Op & opMaskLoadMode {
case opAddrModeImmediate:
if sz != 4 {
return ri
}
return LoadConstant{Dst: reg, Val: ri.K}
case opAddrModeScratch:
if sz != 4 || ri.K > 15 {
return ri
}
return LoadScratch{Dst: reg, N: int(ri.K)}
case opAddrModeAbsolute:
if ri.K > extOffset+0xffffffff {
return LoadExtension{Num: Extension(-extOffset + ri.K)}
}
return LoadAbsolute{Size: sz, Off: ri.K}
case opAddrModeIndirect:
return LoadIndirect{Size: sz, Off: ri.K}
case opAddrModePacketLen:
if sz != 4 {
return ri
}
return LoadExtension{Num: ExtLen}
case opAddrModeMemShift:
return LoadMemShift{Off: ri.K}
default:
return ri
}
case opClsStoreA:
if ri.Op != opClsStoreA || ri.K > 15 {
return ri
}
return StoreScratch{Src: RegA, N: int(ri.K)}
case opClsStoreX:
if ri.Op != opClsStoreX || ri.K > 15 {
return ri
}
return StoreScratch{Src: RegX, N: int(ri.K)}
case opClsALU:
switch op := ALUOp(ri.Op & opMaskOperator); op {
case ALUOpAdd, ALUOpSub, ALUOpMul, ALUOpDiv, ALUOpOr, ALUOpAnd, ALUOpShiftLeft, ALUOpShiftRight, ALUOpMod, ALUOpXor:
switch operand := opOperand(ri.Op & opMaskOperand); operand {
case opOperandX:
return ALUOpX{Op: op}
case opOperandConstant:
return ALUOpConstant{Op: op, Val: ri.K}
default:
return ri
}
case aluOpNeg:
return NegateA{}
default:
return ri
}
case opClsJump:
switch op := jumpOp(ri.Op & opMaskOperator); op {
case opJumpAlways:
return Jump{Skip: ri.K}
case opJumpEqual, opJumpGT, opJumpGE, opJumpSet:
cond, skipTrue, skipFalse := jumpOpToTest(op, ri.Jt, ri.Jf)
switch operand := opOperand(ri.Op & opMaskOperand); operand {
case opOperandX:
return JumpIfX{Cond: cond, SkipTrue: skipTrue, SkipFalse: skipFalse}
case opOperandConstant:
return JumpIf{Cond: cond, Val: ri.K, SkipTrue: skipTrue, SkipFalse: skipFalse}
default:
return ri
}
default:
return ri
}
case opClsReturn:
switch ri.Op {
case opClsReturn | opRetSrcA:
return RetA{}
case opClsReturn | opRetSrcConstant:
return RetConstant{Val: ri.K}
default:
return ri
}
case opClsMisc:
switch ri.Op {
case opClsMisc | opMiscTAX:
return TAX{}
case opClsMisc | opMiscTXA:
return TXA{}
default:
return ri
}
default:
panic("unreachable") // switch is exhaustive on the bit pattern
}
}
func jumpOpToTest(op jumpOp, skipTrue uint8, skipFalse uint8) (JumpTest, uint8, uint8) {
var test JumpTest
// Decode "fake" jump conditions that don't appear in machine code
// Ensures the Assemble -> Disassemble stage recreates the same instructions
// See https://github.com/golang/go/issues/18470
if skipTrue == 0 {
switch op {
case opJumpEqual:
test = JumpNotEqual
case opJumpGT:
test = JumpLessOrEqual
case opJumpGE:
test = JumpLessThan
case opJumpSet:
test = JumpBitsNotSet
}
return test, skipFalse, 0
}
switch op {
case opJumpEqual:
test = JumpEqual
case opJumpGT:
test = JumpGreaterThan
case opJumpGE:
test = JumpGreaterOrEqual
case opJumpSet:
test = JumpBitsSet
}
return test, skipTrue, skipFalse
}
// LoadConstant loads Val into register Dst.
type LoadConstant struct {
Dst Register
Val uint32
}
// Assemble implements the Instruction Assemble method.
func (a LoadConstant) Assemble() (RawInstruction, error) {
return assembleLoad(a.Dst, 4, opAddrModeImmediate, a.Val)
}
// String returns the instruction in assembler notation.
func (a LoadConstant) String() string {
switch a.Dst {
case RegA:
return fmt.Sprintf("ld #%d", a.Val)
case RegX:
return fmt.Sprintf("ldx #%d", a.Val)
default:
return fmt.Sprintf("unknown instruction: %#v", a)
}
}
// LoadScratch loads scratch[N] into register Dst.
type LoadScratch struct {
Dst Register
N int // 0-15
}
// Assemble implements the Instruction Assemble method.
func (a LoadScratch) Assemble() (RawInstruction, error) {
if a.N < 0 || a.N > 15 {
return RawInstruction{}, fmt.Errorf("invalid scratch slot %d", a.N)
}
return assembleLoad(a.Dst, 4, opAddrModeScratch, uint32(a.N))
}
// String returns the instruction in assembler notation.
func (a LoadScratch) String() string {
switch a.Dst {
case RegA:
return fmt.Sprintf("ld M[%d]", a.N)
case RegX:
return fmt.Sprintf("ldx M[%d]", a.N)
default:
return fmt.Sprintf("unknown instruction: %#v", a)
}
}
// LoadAbsolute loads packet[Off:Off+Size] as an integer value into
// register A.
type LoadAbsolute struct {
Off uint32
Size int // 1, 2 or 4
}
// Assemble implements the Instruction Assemble method.
func (a LoadAbsolute) Assemble() (RawInstruction, error) {
return assembleLoad(RegA, a.Size, opAddrModeAbsolute, a.Off)
}
// String returns the instruction in assembler notation.
func (a LoadAbsolute) String() string {
switch a.Size {
case 1: // byte
return fmt.Sprintf("ldb [%d]", a.Off)
case 2: // half word
return fmt.Sprintf("ldh [%d]", a.Off)
case 4: // word
if a.Off > extOffset+0xffffffff {
return LoadExtension{Num: Extension(a.Off + 0x1000)}.String()
}
return fmt.Sprintf("ld [%d]", a.Off)
default:
return fmt.Sprintf("unknown instruction: %#v", a)
}
}
// LoadIndirect loads packet[X+Off:X+Off+Size] as an integer value
// into register A.
type LoadIndirect struct {
Off uint32
Size int // 1, 2 or 4
}
// Assemble implements the Instruction Assemble method.
func (a LoadIndirect) Assemble() (RawInstruction, error) {
return assembleLoad(RegA, a.Size, opAddrModeIndirect, a.Off)
}
// String returns the instruction in assembler notation.
func (a LoadIndirect) String() string {
switch a.Size {
case 1: // byte
return fmt.Sprintf("ldb [x + %d]", a.Off)
case 2: // half word
return fmt.Sprintf("ldh [x + %d]", a.Off)
case 4: // word
return fmt.Sprintf("ld [x + %d]", a.Off)
default:
return fmt.Sprintf("unknown instruction: %#v", a)
}
}
// LoadMemShift multiplies the first 4 bits of the byte at packet[Off]
// by 4 and stores the result in register X.
//
// This instruction is mainly useful to load into X the length of an
// IPv4 packet header in a single instruction, rather than have to do
// the arithmetic on the header's first byte by hand.
type LoadMemShift struct {
Off uint32
}
// Assemble implements the Instruction Assemble method.
func (a LoadMemShift) Assemble() (RawInstruction, error) {
return assembleLoad(RegX, 1, opAddrModeMemShift, a.Off)
}
// String returns the instruction in assembler notation.
func (a LoadMemShift) String() string {
return fmt.Sprintf("ldx 4*([%d]&0xf)", a.Off)
}
// LoadExtension invokes a linux-specific extension and stores the
// result in register A.
type LoadExtension struct {
Num Extension
}
// Assemble implements the Instruction Assemble method.
func (a LoadExtension) Assemble() (RawInstruction, error) {
if a.Num == ExtLen {
return assembleLoad(RegA, 4, opAddrModePacketLen, 0)
}
return assembleLoad(RegA, 4, opAddrModeAbsolute, uint32(extOffset+a.Num))
}
// String returns the instruction in assembler notation.
func (a LoadExtension) String() string {
switch a.Num {
case ExtLen:
return "ld #len"
case ExtProto:
return "ld #proto"
case ExtType:
return "ld #type"
case ExtPayloadOffset:
return "ld #poff"
case ExtInterfaceIndex:
return "ld #ifidx"
case ExtNetlinkAttr:
return "ld #nla"
case ExtNetlinkAttrNested:
return "ld #nlan"
case ExtMark:
return "ld #mark"
case ExtQueue:
return "ld #queue"
case ExtLinkLayerType:
return "ld #hatype"
case ExtRXHash:
return "ld #rxhash"
case ExtCPUID:
return "ld #cpu"
case ExtVLANTag:
return "ld #vlan_tci"
case ExtVLANTagPresent:
return "ld #vlan_avail"
case ExtVLANProto:
return "ld #vlan_tpid"
case ExtRand:
return "ld #rand"
default:
return fmt.Sprintf("unknown instruction: %#v", a)
}
}
// StoreScratch stores register Src into scratch[N].
type StoreScratch struct {
Src Register
N int // 0-15
}
// Assemble implements the Instruction Assemble method.
func (a StoreScratch) Assemble() (RawInstruction, error) {
if a.N < 0 || a.N > 15 {
return RawInstruction{}, fmt.Errorf("invalid scratch slot %d", a.N)
}
var op uint16
switch a.Src {
case RegA:
op = opClsStoreA
case RegX:
op = opClsStoreX
default:
return RawInstruction{}, fmt.Errorf("invalid source register %v", a.Src)
}
return RawInstruction{
Op: op,
K: uint32(a.N),
}, nil
}
// String returns the instruction in assembler notation.
func (a StoreScratch) String() string {
switch a.Src {
case RegA:
return fmt.Sprintf("st M[%d]", a.N)
case RegX:
return fmt.Sprintf("stx M[%d]", a.N)
default:
return fmt.Sprintf("unknown instruction: %#v", a)
}
}
// ALUOpConstant executes A = A <Op> Val.
type ALUOpConstant struct {
Op ALUOp
Val uint32
}
// Assemble implements the Instruction Assemble method.
func (a ALUOpConstant) Assemble() (RawInstruction, error) {
return RawInstruction{
Op: opClsALU | uint16(opOperandConstant) | uint16(a.Op),
K: a.Val,
}, nil
}
// String returns the instruction in assembler notation.
func (a ALUOpConstant) String() string {
switch a.Op {
case ALUOpAdd:
return fmt.Sprintf("add #%d", a.Val)
case ALUOpSub:
return fmt.Sprintf("sub #%d", a.Val)
case ALUOpMul:
return fmt.Sprintf("mul #%d", a.Val)
case ALUOpDiv:
return fmt.Sprintf("div #%d", a.Val)
case ALUOpMod:
return fmt.Sprintf("mod #%d", a.Val)
case ALUOpAnd:
return fmt.Sprintf("and #%d", a.Val)
case ALUOpOr:
return fmt.Sprintf("or #%d", a.Val)
case ALUOpXor:
return fmt.Sprintf("xor #%d", a.Val)
case ALUOpShiftLeft:
return fmt.Sprintf("lsh #%d", a.Val)
case ALUOpShiftRight:
return fmt.Sprintf("rsh #%d", a.Val)
default:
return fmt.Sprintf("unknown instruction: %#v", a)
}
}
// ALUOpX executes A = A <Op> X
type ALUOpX struct {
Op ALUOp
}
// Assemble implements the Instruction Assemble method.
func (a ALUOpX) Assemble() (RawInstruction, error) {
return RawInstruction{
Op: opClsALU | uint16(opOperandX) | uint16(a.Op),
}, nil
}
// String returns the instruction in assembler notation.
func (a ALUOpX) String() string {
switch a.Op {
case ALUOpAdd:
return "add x"
case ALUOpSub:
return "sub x"
case ALUOpMul:
return "mul x"
case ALUOpDiv:
return "div x"
case ALUOpMod:
return "mod x"
case ALUOpAnd:
return "and x"
case ALUOpOr:
return "or x"
case ALUOpXor:
return "xor x"
case ALUOpShiftLeft:
return "lsh x"
case ALUOpShiftRight:
return "rsh x"
default:
return fmt.Sprintf("unknown instruction: %#v", a)
}
}
// NegateA executes A = -A.
type NegateA struct{}
// Assemble implements the Instruction Assemble method.
func (a NegateA) Assemble() (RawInstruction, error) {
return RawInstruction{
Op: opClsALU | uint16(aluOpNeg),
}, nil
}
// String returns the instruction in assembler notation.
func (a NegateA) String() string {
return fmt.Sprintf("neg")
}
// Jump skips the following Skip instructions in the program.
type Jump struct {
Skip uint32
}
// Assemble implements the Instruction Assemble method.
func (a Jump) Assemble() (RawInstruction, error) {
return RawInstruction{
Op: opClsJump | uint16(opJumpAlways),
K: a.Skip,
}, nil
}
// String returns the instruction in assembler notation.
func (a Jump) String() string {
return fmt.Sprintf("ja %d", a.Skip)
}
// JumpIf skips the following Skip instructions in the program if A
// <Cond> Val is true.
type JumpIf struct {
Cond JumpTest
Val uint32
SkipTrue uint8
SkipFalse uint8
}
// Assemble implements the Instruction Assemble method.
func (a JumpIf) Assemble() (RawInstruction, error) {
return jumpToRaw(a.Cond, opOperandConstant, a.Val, a.SkipTrue, a.SkipFalse)
}
// String returns the instruction in assembler notation.
func (a JumpIf) String() string {
return jumpToString(a.Cond, fmt.Sprintf("#%d", a.Val), a.SkipTrue, a.SkipFalse)
}
// JumpIfX skips the following Skip instructions in the program if A
// <Cond> X is true.
type JumpIfX struct {
Cond JumpTest
SkipTrue uint8
SkipFalse uint8
}
// Assemble implements the Instruction Assemble method.
func (a JumpIfX) Assemble() (RawInstruction, error) {
return jumpToRaw(a.Cond, opOperandX, 0, a.SkipTrue, a.SkipFalse)
}
// String returns the instruction in assembler notation.
func (a JumpIfX) String() string {
return jumpToString(a.Cond, "x", a.SkipTrue, a.SkipFalse)
}
// jumpToRaw assembles a jump instruction into a RawInstruction
func jumpToRaw(test JumpTest, operand opOperand, k uint32, skipTrue, skipFalse uint8) (RawInstruction, error) {
var (
cond jumpOp
flip bool
)
switch test {
case JumpEqual:
cond = opJumpEqual
case JumpNotEqual:
cond, flip = opJumpEqual, true
case JumpGreaterThan:
cond = opJumpGT
case JumpLessThan:
cond, flip = opJumpGE, true
case JumpGreaterOrEqual:
cond = opJumpGE
case JumpLessOrEqual:
cond, flip = opJumpGT, true
case JumpBitsSet:
cond = opJumpSet
case JumpBitsNotSet:
cond, flip = opJumpSet, true
default:
return RawInstruction{}, fmt.Errorf("unknown JumpTest %v", test)
}
jt, jf := skipTrue, skipFalse
if flip {
jt, jf = jf, jt
}
return RawInstruction{
Op: opClsJump | uint16(cond) | uint16(operand),
Jt: jt,
Jf: jf,
K: k,
}, nil
}
// jumpToString converts a jump instruction to assembler notation
func jumpToString(cond JumpTest, operand string, skipTrue, skipFalse uint8) string {
switch cond {
// K == A
case JumpEqual:
return conditionalJump(operand, skipTrue, skipFalse, "jeq", "jneq")
// K != A
case JumpNotEqual:
return fmt.Sprintf("jneq %s,%d", operand, skipTrue)
// K > A
case JumpGreaterThan:
return conditionalJump(operand, skipTrue, skipFalse, "jgt", "jle")
// K < A
case JumpLessThan:
return fmt.Sprintf("jlt %s,%d", operand, skipTrue)
// K >= A
case JumpGreaterOrEqual:
return conditionalJump(operand, skipTrue, skipFalse, "jge", "jlt")
// K <= A
case JumpLessOrEqual:
return fmt.Sprintf("jle %s,%d", operand, skipTrue)
// K & A != 0
case JumpBitsSet:
if skipFalse > 0 {
return fmt.Sprintf("jset %s,%d,%d", operand, skipTrue, skipFalse)
}
return fmt.Sprintf("jset %s,%d", operand, skipTrue)
// K & A == 0, there is no assembler instruction for JumpBitNotSet, use JumpBitSet and invert skips
case JumpBitsNotSet:
return jumpToString(JumpBitsSet, operand, skipFalse, skipTrue)
default:
return fmt.Sprintf("unknown JumpTest %#v", cond)
}
}
func conditionalJump(operand string, skipTrue, skipFalse uint8, positiveJump, negativeJump string) string {
if skipTrue > 0 {
if skipFalse > 0 {
return fmt.Sprintf("%s %s,%d,%d", positiveJump, operand, skipTrue, skipFalse)
}
return fmt.Sprintf("%s %s,%d", positiveJump, operand, skipTrue)
}
return fmt.Sprintf("%s %s,%d", negativeJump, operand, skipFalse)
}
// RetA exits the BPF program, returning the value of register A.
type RetA struct{}
// Assemble implements the Instruction Assemble method.
func (a RetA) Assemble() (RawInstruction, error) {
return RawInstruction{
Op: opClsReturn | opRetSrcA,
}, nil
}
// String returns the instruction in assembler notation.
func (a RetA) String() string {
return fmt.Sprintf("ret a")
}
// RetConstant exits the BPF program, returning a constant value.
type RetConstant struct {
Val uint32
}
// Assemble implements the Instruction Assemble method.
func (a RetConstant) Assemble() (RawInstruction, error) {
return RawInstruction{
Op: opClsReturn | opRetSrcConstant,
K: a.Val,
}, nil
}
// String returns the instruction in assembler notation.
func (a RetConstant) String() string {
return fmt.Sprintf("ret #%d", a.Val)
}
// TXA copies the value of register X to register A.
type TXA struct{}
// Assemble implements the Instruction Assemble method.
func (a TXA) Assemble() (RawInstruction, error) {
return RawInstruction{
Op: opClsMisc | opMiscTXA,
}, nil
}
// String returns the instruction in assembler notation.
func (a TXA) String() string {
return fmt.Sprintf("txa")
}
// TAX copies the value of register A to register X.
type TAX struct{}
// Assemble implements the Instruction Assemble method.
func (a TAX) Assemble() (RawInstruction, error) {
return RawInstruction{
Op: opClsMisc | opMiscTAX,
}, nil
}
// String returns the instruction in assembler notation.
func (a TAX) String() string {
return fmt.Sprintf("tax")
}
func assembleLoad(dst Register, loadSize int, mode uint16, k uint32) (RawInstruction, error) {
var (
cls uint16
sz uint16
)
switch dst {
case RegA:
cls = opClsLoadA
case RegX:
cls = opClsLoadX
default:
return RawInstruction{}, fmt.Errorf("invalid target register %v", dst)
}
switch loadSize {
case 1:
sz = opLoadWidth1
case 2:
sz = opLoadWidth2
case 4:
sz = opLoadWidth4
default:
return RawInstruction{}, fmt.Errorf("invalid load byte length %d", sz)
}
return RawInstruction{
Op: cls | sz | mode,
K: k,
}, nil
}

10
vendor/golang.org/x/net/bpf/setter.go generated vendored Normal file
View File

@ -0,0 +1,10 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bpf
// A Setter is a type which can attach a compiled BPF filter to itself.
type Setter interface {
SetBPF(filter []RawInstruction) error
}

150
vendor/golang.org/x/net/bpf/vm.go generated vendored Normal file
View File

@ -0,0 +1,150 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bpf
import (
"errors"
"fmt"
)
// A VM is an emulated BPF virtual machine.
type VM struct {
filter []Instruction
}
// NewVM returns a new VM using the input BPF program.
func NewVM(filter []Instruction) (*VM, error) {
if len(filter) == 0 {
return nil, errors.New("one or more Instructions must be specified")
}
for i, ins := range filter {
check := len(filter) - (i + 1)
switch ins := ins.(type) {
// Check for out-of-bounds jumps in instructions
case Jump:
if check <= int(ins.Skip) {
return nil, fmt.Errorf("cannot jump %d instructions; jumping past program bounds", ins.Skip)
}
case JumpIf:
if check <= int(ins.SkipTrue) {
return nil, fmt.Errorf("cannot jump %d instructions in true case; jumping past program bounds", ins.SkipTrue)
}
if check <= int(ins.SkipFalse) {
return nil, fmt.Errorf("cannot jump %d instructions in false case; jumping past program bounds", ins.SkipFalse)
}
case JumpIfX:
if check <= int(ins.SkipTrue) {
return nil, fmt.Errorf("cannot jump %d instructions in true case; jumping past program bounds", ins.SkipTrue)
}
if check <= int(ins.SkipFalse) {
return nil, fmt.Errorf("cannot jump %d instructions in false case; jumping past program bounds", ins.SkipFalse)
}
// Check for division or modulus by zero
case ALUOpConstant:
if ins.Val != 0 {
break
}
switch ins.Op {
case ALUOpDiv, ALUOpMod:
return nil, errors.New("cannot divide by zero using ALUOpConstant")
}
// Check for unknown extensions
case LoadExtension:
switch ins.Num {
case ExtLen:
default:
return nil, fmt.Errorf("extension %d not implemented", ins.Num)
}
}
}
// Make sure last instruction is a return instruction
switch filter[len(filter)-1].(type) {
case RetA, RetConstant:
default:
return nil, errors.New("BPF program must end with RetA or RetConstant")
}
// Though our VM works using disassembled instructions, we
// attempt to assemble the input filter anyway to ensure it is compatible
// with an operating system VM.
_, err := Assemble(filter)
return &VM{
filter: filter,
}, err
}
// Run runs the VM's BPF program against the input bytes.
// Run returns the number of bytes accepted by the BPF program, and any errors
// which occurred while processing the program.
func (v *VM) Run(in []byte) (int, error) {
var (
// Registers of the virtual machine
regA uint32
regX uint32
regScratch [16]uint32
// OK is true if the program should continue processing the next
// instruction, or false if not, causing the loop to break
ok = true
)
// TODO(mdlayher): implement:
// - NegateA:
// - would require a change from uint32 registers to int32
// registers
// TODO(mdlayher): add interop tests that check signedness of ALU
// operations against kernel implementation, and make sure Go
// implementation matches behavior
for i := 0; i < len(v.filter) && ok; i++ {
ins := v.filter[i]
switch ins := ins.(type) {
case ALUOpConstant:
regA = aluOpConstant(ins, regA)
case ALUOpX:
regA, ok = aluOpX(ins, regA, regX)
case Jump:
i += int(ins.Skip)
case JumpIf:
jump := jumpIf(ins, regA)
i += jump
case JumpIfX:
jump := jumpIfX(ins, regA, regX)
i += jump
case LoadAbsolute:
regA, ok = loadAbsolute(ins, in)
case LoadConstant:
regA, regX = loadConstant(ins, regA, regX)
case LoadExtension:
regA = loadExtension(ins, in)
case LoadIndirect:
regA, ok = loadIndirect(ins, in, regX)
case LoadMemShift:
regX, ok = loadMemShift(ins, in)
case LoadScratch:
regA, regX = loadScratch(ins, regScratch, regA, regX)
case RetA:
return int(regA), nil
case RetConstant:
return int(ins.Val), nil
case StoreScratch:
regScratch = storeScratch(ins, regScratch, regA, regX)
case TAX:
regX = regA
case TXA:
regA = regX
default:
return 0, fmt.Errorf("unknown Instruction at index %d: %T", i, ins)
}
}
return 0, nil
}

182
vendor/golang.org/x/net/bpf/vm_instructions.go generated vendored Normal file
View File

@ -0,0 +1,182 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bpf
import (
"encoding/binary"
"fmt"
)
func aluOpConstant(ins ALUOpConstant, regA uint32) uint32 {
return aluOpCommon(ins.Op, regA, ins.Val)
}
func aluOpX(ins ALUOpX, regA uint32, regX uint32) (uint32, bool) {
// Guard against division or modulus by zero by terminating
// the program, as the OS BPF VM does
if regX == 0 {
switch ins.Op {
case ALUOpDiv, ALUOpMod:
return 0, false
}
}
return aluOpCommon(ins.Op, regA, regX), true
}
func aluOpCommon(op ALUOp, regA uint32, value uint32) uint32 {
switch op {
case ALUOpAdd:
return regA + value
case ALUOpSub:
return regA - value
case ALUOpMul:
return regA * value
case ALUOpDiv:
// Division by zero not permitted by NewVM and aluOpX checks
return regA / value
case ALUOpOr:
return regA | value
case ALUOpAnd:
return regA & value
case ALUOpShiftLeft:
return regA << value
case ALUOpShiftRight:
return regA >> value
case ALUOpMod:
// Modulus by zero not permitted by NewVM and aluOpX checks
return regA % value
case ALUOpXor:
return regA ^ value
default:
return regA
}
}
func jumpIf(ins JumpIf, regA uint32) int {
return jumpIfCommon(ins.Cond, ins.SkipTrue, ins.SkipFalse, regA, ins.Val)
}
func jumpIfX(ins JumpIfX, regA uint32, regX uint32) int {
return jumpIfCommon(ins.Cond, ins.SkipTrue, ins.SkipFalse, regA, regX)
}
func jumpIfCommon(cond JumpTest, skipTrue, skipFalse uint8, regA uint32, value uint32) int {
var ok bool
switch cond {
case JumpEqual:
ok = regA == value
case JumpNotEqual:
ok = regA != value
case JumpGreaterThan:
ok = regA > value
case JumpLessThan:
ok = regA < value
case JumpGreaterOrEqual:
ok = regA >= value
case JumpLessOrEqual:
ok = regA <= value
case JumpBitsSet:
ok = (regA & value) != 0
case JumpBitsNotSet:
ok = (regA & value) == 0
}
if ok {
return int(skipTrue)
}
return int(skipFalse)
}
func loadAbsolute(ins LoadAbsolute, in []byte) (uint32, bool) {
offset := int(ins.Off)
size := int(ins.Size)
return loadCommon(in, offset, size)
}
func loadConstant(ins LoadConstant, regA uint32, regX uint32) (uint32, uint32) {
switch ins.Dst {
case RegA:
regA = ins.Val
case RegX:
regX = ins.Val
}
return regA, regX
}
func loadExtension(ins LoadExtension, in []byte) uint32 {
switch ins.Num {
case ExtLen:
return uint32(len(in))
default:
panic(fmt.Sprintf("unimplemented extension: %d", ins.Num))
}
}
func loadIndirect(ins LoadIndirect, in []byte, regX uint32) (uint32, bool) {
offset := int(ins.Off) + int(regX)
size := int(ins.Size)
return loadCommon(in, offset, size)
}
func loadMemShift(ins LoadMemShift, in []byte) (uint32, bool) {
offset := int(ins.Off)
// Size of LoadMemShift is always 1 byte
if !inBounds(len(in), offset, 1) {
return 0, false
}
// Mask off high 4 bits and multiply low 4 bits by 4
return uint32(in[offset]&0x0f) * 4, true
}
func inBounds(inLen int, offset int, size int) bool {
return offset+size <= inLen
}
func loadCommon(in []byte, offset int, size int) (uint32, bool) {
if !inBounds(len(in), offset, size) {
return 0, false
}
switch size {
case 1:
return uint32(in[offset]), true
case 2:
return uint32(binary.BigEndian.Uint16(in[offset : offset+size])), true
case 4:
return uint32(binary.BigEndian.Uint32(in[offset : offset+size])), true
default:
panic(fmt.Sprintf("invalid load size: %d", size))
}
}
func loadScratch(ins LoadScratch, regScratch [16]uint32, regA uint32, regX uint32) (uint32, uint32) {
switch ins.Dst {
case RegA:
regA = regScratch[ins.N]
case RegX:
regX = regScratch[ins.N]
}
return regA, regX
}
func storeScratch(ins StoreScratch, regScratch [16]uint32, regA uint32, regX uint32) [16]uint32 {
switch ins.Src {
case RegA:
regScratch[ins.N] = regA
case RegX:
regScratch[ins.N] = regX
}
return regScratch
}

20
vendor/golang.org/x/net/http2/README generated vendored
View File

@ -1,20 +0,0 @@
This is a work-in-progress HTTP/2 implementation for Go.
It will eventually live in the Go standard library and won't require
any changes to your code to use. It will just be automatic.
Status:
* The server support is pretty good. A few things are missing
but are being worked on.
* The client work has just started but shares a lot of code
is coming along much quicker.
Docs are at https://godoc.org/golang.org/x/net/http2
Demo test server at https://http2.golang.org/
Help & bug reports welcome!
Contributing: https://golang.org/doc/contribute.html
Bugs: https://golang.org/issue/new?title=x/net/http2:+

53
vendor/golang.org/x/net/http2/ascii.go generated vendored Normal file
View File

@ -0,0 +1,53 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import "strings"
// The HTTP protocols are defined in terms of ASCII, not Unicode. This file
// contains helper functions which may use Unicode-aware functions which would
// otherwise be unsafe and could introduce vulnerabilities if used improperly.
// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
// are equal, ASCII-case-insensitively.
func asciiEqualFold(s, t string) bool {
if len(s) != len(t) {
return false
}
for i := 0; i < len(s); i++ {
if lower(s[i]) != lower(t[i]) {
return false
}
}
return true
}
// lower returns the ASCII lowercase version of b.
func lower(b byte) byte {
if 'A' <= b && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
// isASCIIPrint returns whether s is ASCII and printable according to
// https://tools.ietf.org/html/rfc20#section-4.2.
func isASCIIPrint(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < ' ' || s[i] > '~' {
return false
}
}
return true
}
// asciiToLower returns the lowercase version of s if s is ASCII and printable,
// and whether or not it was.
func asciiToLower(s string) (lower string, ok bool) {
if !isASCIIPrint(s) {
return "", false
}
return strings.ToLower(s), true
}

View File

@ -7,13 +7,21 @@
package http2
import (
"context"
"crypto/tls"
"errors"
"net/http"
"sync"
)
// ClientConnPool manages a pool of HTTP/2 client connections.
type ClientConnPool interface {
// GetClientConn returns a specific HTTP/2 connection (usually
// a TLS-TCP connection) to an HTTP/2 server. On success, the
// returned ClientConn accounts for the upcoming RoundTrip
// call, so the caller should not omit it. If the caller needs
// to, ClientConn.RoundTrip can be called with a bogus
// new(http.Request) to release the stream reservation.
GetClientConn(req *http.Request, addr string) (*ClientConn, error)
MarkDead(*ClientConn)
}
@ -40,7 +48,7 @@ type clientConnPool struct {
conns map[string][]*ClientConn // key is host:port
dialing map[string]*dialCall // currently in-flight dials
keys map[*ClientConn][]string
addConnCalls map[string]*addConnCall // in-flight addConnIfNeede calls
addConnCalls map[string]*addConnCall // in-flight addConnIfNeeded calls
}
func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
@ -52,87 +60,85 @@ const (
noDialOnMiss = false
)
// shouldTraceGetConn reports whether getClientConn should call any
// ClientTrace.GetConn hook associated with the http.Request.
//
// This complexity is needed to avoid double calls of the GetConn hook
// during the back-and-forth between net/http and x/net/http2 (when the
// net/http.Transport is upgraded to also speak http2), as well as support
// the case where x/net/http2 is being used directly.
func (p *clientConnPool) shouldTraceGetConn(st clientConnIdleState) bool {
// If our Transport wasn't made via ConfigureTransport, always
// trace the GetConn hook if provided, because that means the
// http2 package is being used directly and it's the one
// dialing, as opposed to net/http.
if _, ok := p.t.ConnPool.(noDialClientConnPool); !ok {
return true
}
// Otherwise, only use the GetConn hook if this connection has
// been used previously for other requests. For fresh
// connections, the net/http package does the dialing.
return !st.freshConn
}
func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
// TODO(dneil): Dial a new connection when t.DisableKeepAlives is set?
if isConnectionCloseRequest(req) && dialOnMiss {
// It gets its own connection.
traceGetConn(req, addr)
const singleUse = true
cc, err := p.t.dialClientConn(addr, singleUse)
cc, err := p.t.dialClientConn(req.Context(), addr, singleUse)
if err != nil {
return nil, err
}
return cc, nil
}
p.mu.Lock()
for _, cc := range p.conns[addr] {
if st := cc.idleState(); st.canTakeNewRequest {
if p.shouldTraceGetConn(st) {
traceGetConn(req, addr)
for {
p.mu.Lock()
for _, cc := range p.conns[addr] {
if cc.ReserveNewRequest() {
// When a connection is presented to us by the net/http package,
// the GetConn hook has already been called.
// Don't call it a second time here.
if !cc.getConnCalled {
traceGetConn(req, addr)
}
cc.getConnCalled = false
p.mu.Unlock()
return cc, nil
}
}
if !dialOnMiss {
p.mu.Unlock()
return nil, ErrNoCachedConn
}
traceGetConn(req, addr)
call := p.getStartDialLocked(req.Context(), addr)
p.mu.Unlock()
<-call.done
if shouldRetryDial(call, req) {
continue
}
cc, err := call.res, call.err
if err != nil {
return nil, err
}
if cc.ReserveNewRequest() {
return cc, nil
}
}
if !dialOnMiss {
p.mu.Unlock()
return nil, ErrNoCachedConn
}
traceGetConn(req, addr)
call := p.getStartDialLocked(addr)
p.mu.Unlock()
<-call.done
return call.res, call.err
}
// dialCall is an in-flight Transport dial call to a host.
type dialCall struct {
_ incomparable
p *clientConnPool
_ incomparable
p *clientConnPool
// the context associated with the request
// that created this dialCall
ctx context.Context
done chan struct{} // closed when done
res *ClientConn // valid after done is closed
err error // valid after done is closed
}
// requires p.mu is held.
func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
if call, ok := p.dialing[addr]; ok {
// A dial is already in-flight. Don't start another.
return call
}
call := &dialCall{p: p, done: make(chan struct{})}
call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
if p.dialing == nil {
p.dialing = make(map[string]*dialCall)
}
p.dialing[addr] = call
go call.dial(addr)
go call.dial(call.ctx, addr)
return call
}
// run in its own goroutine.
func (c *dialCall) dial(addr string) {
func (c *dialCall) dial(ctx context.Context, addr string) {
const singleUse = false // shared conn
c.res, c.err = c.p.t.dialClientConn(addr, singleUse)
c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
close(c.done)
c.p.mu.Lock()
@ -195,6 +201,7 @@ func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) {
if err != nil {
c.err = err
} else {
cc.getConnCalled = true // already called by the net/http package
p.addConnLocked(key, cc)
}
delete(p.addConnCalls, key)
@ -276,3 +283,28 @@ type noDialClientConnPool struct{ *clientConnPool }
func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
return p.getClientConn(req, addr, noDialOnMiss)
}
// shouldRetryDial reports whether the current request should
// retry dialing after the call finished unsuccessfully, for example
// if the dial was canceled because of a context cancellation or
// deadline expiry.
func shouldRetryDial(call *dialCall, req *http.Request) bool {
if call.err == nil {
// No error, no need to retry
return false
}
if call.ctx == req.Context() {
// If the call has the same context as the request, the dial
// should not be retried, since any cancellation will have come
// from this request.
return false
}
if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
// If the call error is not because of a context cancellation or a deadline expiry,
// the dial should not be retried.
return false
}
// Only retry if the error is a context cancellation error or deadline expiry
// and the context associated with the call was canceled or expired.
return call.ctx.Err() != nil
}

View File

@ -53,6 +53,13 @@ func (e ErrCode) String() string {
return fmt.Sprintf("unknown error code 0x%x", uint32(e))
}
func (e ErrCode) stringToken() string {
if s, ok := errCodeName[e]; ok {
return s
}
return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e))
}
// ConnectionError is an error that results in the termination of the
// entire connection.
type ConnectionError ErrCode
@ -67,6 +74,11 @@ type StreamError struct {
Cause error // optional additional detail
}
// errFromPeer is a sentinel error value for StreamError.Cause to
// indicate that the StreamError was sent from the peer over the wire
// and wasn't locally generated in the Transport.
var errFromPeer = errors.New("received from peer")
func streamError(id uint32, code ErrCode) StreamError {
return StreamError{StreamID: id, Code: code}
}

Some files were not shown because too many files have changed in this diff Show More