From b1927990c2e0df2352875de2f4f4265151f5c42e Mon Sep 17 00:00:00 2001 From: leonnicolas Date: Mon, 20 Sep 2021 15:47:47 +0200 Subject: [PATCH] migrate to golang.zx2c4.com/wireguard/wgctrl This commit introduces the usage of wgctrl. It avoids the usage of exec calls of the wg command and parsing the output of `wg show`. Signed-off-by: leonnicolas --- Dockerfile | 2 +- cmd/kg/handlers.go | 4 +- cmd/kg/main.go | 2 +- cmd/kgctl/connect_linux.nogo | 473 +++++++++++++++++++++++++++++++++ cmd/kgctl/connect_other.nogo | 20 ++ cmd/kgctl/graph.go | 4 +- cmd/kgctl/showconf.go | 50 ++-- e2e/lib.sh | 14 +- pkg/k8s/backend.go | 94 ++++--- pkg/k8s/backend_test.go | 147 +++++++--- pkg/mesh/backend.go | 27 +- pkg/mesh/graph.go | 2 +- pkg/mesh/mesh.go | 166 ++++++------ pkg/mesh/mesh_test.go | 57 ++-- pkg/mesh/routes.go | 12 +- pkg/mesh/routes_test.go | 112 ++++---- pkg/mesh/topology.go | 170 +++++++----- pkg/mesh/topology_test.go | 502 ++++++++++++----------------------- pkg/wireguard/conf.go | 436 +++++++----------------------- pkg/wireguard/conf_test.go | 234 +--------------- pkg/wireguard/wireguard.go | 73 ----- 21 files changed, 1266 insertions(+), 1335 deletions(-) create mode 100644 cmd/kgctl/connect_linux.nogo create mode 100644 cmd/kgctl/connect_other.nogo diff --git a/Dockerfile b/Dockerfile index 3b1d418..fbc2dc8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ ARG GOARCH ARG ALPINE_VERSION=v3.12 LABEL maintainer="squat " RUN echo -e "https://alpine.global.ssl.fastly.net/alpine/$ALPINE_VERSION/main\nhttps://alpine.global.ssl.fastly.net/alpine/$ALPINE_VERSION/community" > /etc/apk/repositories && \ - apk add --no-cache ipset iptables ip6tables wireguard-tools graphviz font-noto + apk add --no-cache ipset iptables ip6tables graphviz font-noto COPY --from=cni bridge host-local loopback portmap /opt/cni/bin/ COPY bin/linux/$GOARCH/kg /opt/bin/ ENTRYPOINT ["/opt/bin/kg"] diff --git a/cmd/kg/handlers.go b/cmd/kg/handlers.go index fd71cb6..c729b37 100644 --- a/cmd/kg/handlers.go +++ b/cmd/kg/handlers.go @@ -24,6 +24,8 @@ import ( "os" "os/exec" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/squat/kilo/pkg/mesh" ) @@ -62,7 +64,7 @@ func (h *graphHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { peers[p.Name] = p } } - topo, err := mesh.NewTopology(nodes, peers, h.granularity, *h.hostname, 0, []byte{}, h.subnet, nodes[*h.hostname].PersistentKeepalive, nil) + topo, err := mesh.NewTopology(nodes, peers, h.granularity, *h.hostname, 0, wgtypes.Key{}, h.subnet, nodes[*h.hostname].PersistentKeepalive, nil) if err != nil { http.Error(w, fmt.Sprintf("failed to create topology: %v", err), http.StatusInternalServerError) return diff --git a/cmd/kg/main.go b/cmd/kg/main.go index 69acdcf..5e66e46 100644 --- a/cmd/kg/main.go +++ b/cmd/kg/main.go @@ -239,7 +239,7 @@ func runRoot(_ *cobra.Command, _ []string) error { return fmt.Errorf("backend %v unknown; possible values are: %s", backend, availableBackends) } - m, err := mesh.New(b, enc, gr, hostname, uint32(port), s, local, cni, cniPath, iface, cleanUpIface, createIface, mtu, resyncPeriod, prioritisePrivateAddr, iptablesForwardRule, log.With(logger, "component", "kilo")) + m, err := mesh.New(b, enc, gr, hostname, int(port), s, local, cni, cniPath, iface, cleanUpIface, createIface, mtu, resyncPeriod, prioritisePrivateAddr, iptablesForwardRule, log.With(logger, "component", "kilo")) if err != nil { return fmt.Errorf("failed to create Kilo mesh: %v", err) } diff --git a/cmd/kgctl/connect_linux.nogo b/cmd/kgctl/connect_linux.nogo new file mode 100644 index 0000000..fa066dd --- /dev/null +++ b/cmd/kgctl/connect_linux.nogo @@ -0,0 +1,473 @@ +// +build linux + +package main + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "net" + "os" + "strings" + "syscall" + "time" + + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/log/level" + "github.com/oklog/run" + "github.com/spf13/cobra" + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/tools/clientcmd" + + "github.com/squat/kilo/pkg/iproute" + "github.com/squat/kilo/pkg/k8s" + "github.com/squat/kilo/pkg/k8s/apis/kilo/v1alpha1" + kiloclient "github.com/squat/kilo/pkg/k8s/clientset/versioned" + "github.com/squat/kilo/pkg/mesh" + "github.com/squat/kilo/pkg/route" + "github.com/squat/kilo/pkg/wireguard" +) + +func takeIPNet(_ net.IP, i *net.IPNet, _ error) *net.IPNet { + return i +} + +func connect() *cobra.Command { + cmd := &cobra.Command{ + Use: "connect", + RunE: connectAsPeer, + Short: "connect to a Kilo cluster as a peer over WireGuard", + } + cmd.Flags().IPNetP("allowed-ip", "a", *takeIPNet(net.ParseCIDR("10.10.10.10/32")), "Allowed IP of the peer") + cmd.Flags().IPNetP("service-cidr", "c", *takeIPNet(net.ParseCIDR("10.43.0.0/16")), "service CIDR of the cluster") + cmd.Flags().String("log-level", logLevelInfo, fmt.Sprintf("Log level to use. Possible values: %s", availableLogLevels)) + cmd.Flags().String("config-path", "/tmp/wg.ini", "path to WireGuard configuation file") + cmd.Flags().Bool("clean-up", true, "clean up routes and interface") + cmd.Flags().Uint("mtu", uint(1420), "clean up routes and interface") + cmd.Flags().Duration("resync-period", 30*time.Second, "How often should Kilo reconcile?") + + availableLogLevels = strings.Join([]string{ + logLevelAll, + logLevelDebug, + logLevelInfo, + logLevelWarn, + logLevelError, + logLevelNone, + }, ", ") + + return cmd +} + +func connectAsPeer(cmd *cobra.Command, args []string) error { + var cancel context.CancelFunc + resyncPersiod, err := cmd.Flags().GetDuration("resync-period") + if err != nil { + return err + } + mtu, err := cmd.Flags().GetUint("mtu") + if err != nil { + return err + } + configPath, err := cmd.Flags().GetString("config-path") + if err != nil { + return err + } + serviceCIDR, err := cmd.Flags().GetIPNet("service-cidr") + if err != nil { + return err + } + allowedIP, err := cmd.Flags().GetIPNet("allowed-ip") + if err != nil { + return err + } + logger := log.NewJSONLogger(log.NewSyncWriter(os.Stdout)) + logLevel, err := cmd.Flags().GetString("log-level") + if err != nil { + return err + } + switch logLevel { + case logLevelAll: + logger = level.NewFilter(logger, level.AllowAll()) + case logLevelDebug: + logger = level.NewFilter(logger, level.AllowDebug()) + case logLevelInfo: + logger = level.NewFilter(logger, level.AllowInfo()) + case logLevelWarn: + logger = level.NewFilter(logger, level.AllowWarn()) + case logLevelError: + logger = level.NewFilter(logger, level.AllowError()) + case logLevelNone: + logger = level.NewFilter(logger, level.AllowNone()) + default: + return fmt.Errorf("log level %s unknown; possible values are: %s", logLevel, availableLogLevels) + } + logger = log.With(logger, "ts", log.DefaultTimestampUTC) + logger = log.With(logger, "caller", log.DefaultCaller) + peername := "random" + if len(args) > 0 { + peername = args[0] + } + + var kiloClient *kiloclient.Clientset + switch backend { + case k8s.Backend: + config, err := clientcmd.BuildConfigFromFlags("", kubeconfig) + if err != nil { + return fmt.Errorf("failed to create Kubernetes config: %v", err) + } + kiloClient = kiloclient.NewForConfigOrDie(config) + default: + return fmt.Errorf("backend %v unknown; posible values are: %s", backend, availableBackends) + } + privateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return fmt.Errorf("failed to generate private key: %w", err) + } + publicKey := privateKey.PublicKey() + level.Info(logger).Log("msg", "generated public key", "key", publicKey) + + peer := &v1alpha1.Peer{ + ObjectMeta: metav1.ObjectMeta{ + Name: peername, + }, + Spec: v1alpha1.PeerSpec{ + AllowedIPs: []string{allowedIP.String()}, + PersistentKeepalive: 10, + PublicKey: publicKey.String(), + }, + } + if p, err := kiloClient.KiloV1alpha1().Peers().Get(context.TODO(), peername, metav1.GetOptions{}); err != nil || p == nil { + peer, err = kiloClient.KiloV1alpha1().Peers().Create(context.TODO(), peer, metav1.CreateOptions{}) + if err != nil { + return fmt.Errorf("failed to create peer: %w", err) + } + } + + kiloIfaceName := "kilo0" + + iface, _, err := wireguard.New(kiloIfaceName, mtu) + if err != nil { + return fmt.Errorf("failed to create wg interface: %w", err) + } else { + level.Info(logger).Log("msg", "successfully created wg interface", "name", kiloIfaceName, "no", iface) + } + if err := iproute.Set(iface, false); err != nil { + return err + } + + if err := iproute.SetAddress(iface, &allowedIP); err != nil { + return err + } else { + level.Info(logger).Log("mag", "successfully set IP address of wg interface", "IP", allowedIP.String()) + } + + if err := iproute.Set(iface, true); err != nil { + return err + } + + var g run.Group + ctx, cancel := context.WithCancel(context.Background()) + g.Add(run.SignalHandler(ctx, syscall.SIGINT, syscall.SIGTERM)) + + table := route.NewTable() + stop := make(chan struct{}, 1) + errCh := make(<-chan error, 1) + { + ch := make(chan struct{}) + g.Add( + func() error { + for { + select { + case err, ok := <-errCh: + if ok { + level.Error(logger).Log("err", err.Error()) + } else { + return nil + } + case <-ch: + return nil + } + } + }, + func(err error) { + ch <- struct{}{} + close(ch) + stop <- struct{}{} + close(stop) + level.Error(logger).Log("msg", "stopped ip routes table", "err", err.Error()) + }, + ) + } + { + ch := make(chan struct{}) + g.Add( + func() error { + for { + ns, err := opts.backend.Nodes().List() + if err != nil { + return fmt.Errorf("failed to list nodes: %v", err) + } + ps, err := opts.backend.Peers().List() + if err != nil { + return fmt.Errorf("failed to list peers: %v", err) + } + // Obtain the Granularity by looking at the annotation of the first node. + if opts.granularity, err = optainGranularity(opts.granularity, ns); err != nil { + return fmt.Errorf("failed to obtain granularity: %w", err) + } + var hostname string + subnet := mesh.DefaultKiloSubnet + nodes := make(map[string]*mesh.Node) + for _, n := range ns { + if n.Ready() { + nodes[n.Name] = n + hostname = n.Name + } + if n.WireGuardIP != nil { + subnet = n.WireGuardIP + } + } + subnet.IP = subnet.IP.Mask(subnet.Mask) + if len(nodes) == 0 { + return errors.New("did not find any valid Kilo nodes in the cluster") + } + peers := make(map[string]*mesh.Peer) + for _, p := range ps { + if p.Ready() { + peers[p.Name] = p + } + } + if _, ok := peers[peername]; !ok { + return fmt.Errorf("did not find any peer named %q in the cluster", peername) + } + + t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, opts.port, []byte{}, subnet, peers[peername].PersistentKeepalive, logger) + if err != nil { + return fmt.Errorf("failed to create topology: %v", err) + } + conf := t.PeerConf(peername) + conf.Interface = &wireguard.Interface{ + PrivateKey: []byte(privateKey.String()), + ListenPort: uint32(55555), + } + buf, err := conf.Bytes() + if err != nil { + //level.Error(m.logger).Log("error", err) + //m.errorCounter.WithLabelValues("apply").Inc() + return err + } + if err := ioutil.WriteFile("/tmp/wg.ini", buf, 0600); err != nil { + //level.Error(m.logger).Log("error", err) + //m.errorCounter.WithLabelValues("apply").Inc() + return err + } + if err := wireguard.SetConf(kiloIfaceName, "/tmp/wg.ini"); err != nil { + return err + } + wgClient, err := wgctrl.New() + if err != nil { + return fmt.Errorf("failed to initialize wg Client: %w", err) + } + defer func() { + wgClient.Close() + }() + wgConf := wgtypes.Config{ + PrivateKey: &privateKey, + // Peers: []wgtypes.PeerConfig{}, + } + if err := wgClient.ConfigureDevice(kiloIfaceName, wgConf); err != nil { + return fmt.Errorf("failed to configure wg interface: %w", err) + } + + var routes []*netlink.Route + for _, segment := range t.Segments() { + for i := range segment.CIDRS() { + // Add routes to the Pod CIDRs of nodes in other segments. + routes = append(routes, &netlink.Route{ + Dst: segment.CIDRS()[i], + //Flags: int(netlink.FLAG_ONLINK), + LinkIndex: iface, + Protocol: unix.RTPROT_STATIC, + }) + } + for i := range segment.PrivateIPs() { + // Add routes to the private IPs of nodes in other segments. + routes = append(routes, &netlink.Route{ + Dst: mesh.OneAddressCIDR(segment.PrivateIPs()[i]), + //Flags: int(netlink.FLAG_ONLINK), + LinkIndex: iface, + Protocol: unix.RTPROT_STATIC, + }) + } + // For segments / locations other than the location of this instance of kg, + // we need to set routes for allowed location IPs over the leader in the current location. + for i := range segment.AllowedLocationIPs() { + routes = append(routes, &netlink.Route{ + Dst: segment.AllowedLocationIPs()[i], + //Flags: int(netlink.FLAG_ONLINK), + LinkIndex: iface, + Protocol: unix.RTPROT_STATIC, + }) + } + routes = append(routes, &netlink.Route{ + Dst: mesh.OneAddressCIDR(segment.WireGuardIP()), + //Flags: int(netlink.FLAG_ONLINK), + LinkIndex: iface, + Protocol: unix.RTPROT_STATIC, + }) + } + // Add routes for the allowed IPs of peers. + for _, peer := range t.Peers() { + for i := range peer.AllowedIPs { + routes = append(routes, &netlink.Route{ + Dst: peer.AllowedIPs[i], + //Flags: int(netlink.FLAG_ONLINK), + LinkIndex: iface, + Protocol: unix.RTPROT_STATIC, + }) + } + } + routes = append(routes, &netlink.Route{ + Dst: &serviceCIDR, + //Flags: int(netlink.FLAG_ONLINK), + Gw: nil, + LinkIndex: iface, + Protocol: unix.RTPROT_STATIC, + }) + for _, r := range routes { + fmt.Println(r) + } + + level.Debug(logger).Log("routes", routes) + if err := table.Set(routes, []*netlink.Rule{}); err != nil { + return fmt.Errorf("failed to set ip routes table: %w", err) + } + errCh, err = table.Run(stop) + if err != nil { + return fmt.Errorf("failed to start ip routes tables: %w", err) + } + select { + case <-time.After(resyncPersiod): + case <-ch: + return nil + } + } + }, func(err error) { + // Cancel the root context in the very end. + defer func() { + cancel() + level.Debug(logger).Log("msg", "canceled parent context") + }() + ch <- struct{}{} + var serr run.SignalError + if ok := errors.As(err, &serr); ok { + level.Info(logger).Log("msg", "received signal", "signal", serr.Signal.String(), "err", err.Error()) + } else { + level.Error(logger).Log("msg", "received error", "err", err.Error()) + } + level.Debug(logger).Log("msg", "stoped ip routes table") + ctxWithTimeOut, cancelWithTimeOut := context.WithTimeout(ctx, 10*time.Second) + defer func() { + cancelWithTimeOut() + level.Debug(logger).Log("msg", "canceled timed context") + }() + if err := kiloClient.KiloV1alpha1().Peers().Delete(ctxWithTimeOut, peername, metav1.DeleteOptions{}); err != nil { + level.Error(logger).Log("failed to delete peer: %w", err) + } else { + level.Info(logger).Log("msg", "deleted peer", "peer", peername) + } + if ok, err := cmd.Flags().GetBool("clean-up"); err != nil { + level.Error(logger).Log("err", err.Error(), "msg", "failed to get value from clean-up flag") + } else if ok { + cleanUp(iface, table, configPath, logger) + } + }) + } + err = g.Run() + var serr run.SignalError + if ok := errors.As(err, &serr); ok { + return nil + } + return err +} + +func cleanUp(iface int, t *route.Table, configPath string, logger log.Logger) { + if err := iproute.Set(iface, false); err != nil { + level.Error(logger).Log("err", err.Error(), "msg", "failed to set down wg interface") + } + if err := os.Remove(configPath); err != nil { + level.Error(logger).Log("error", fmt.Sprintf("failed to delete configuration file: %v", err)) + } + if err := iproute.RemoveInterface(iface); err != nil { + level.Error(logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err)) + } + if err := t.CleanUp(); err != nil { + level.Error(logger).Log("failed to clean up routes: %w", err) + } + return +} + +//func nodeToWGPeer(n *mesh.Node) (ret wgtypes.PeerConfig, err error) { +// pubKey, err := wgtypes.ParseKey(string(n.Key)) +// if err != nil { +// return ret, err +// } +// ret.PublicKey = pubKey +// if err != nil { +// return ret, err +// } +// aIPs := []net.IPNet{*n.WireGuardIP, *n.InternalIP, *n.Subnet} +// for _, a := range n.AllowedLocationIPs { +// aIPs = append(aIPs, *a) +// } +// ret.AllowedIPs = aIPs +// +// dur := time.Second * time.Duration(n.PersistentKeepalive) +// ret.PersistentKeepaliveInterval = &dur +// +// udpEndpoint, err := net.ResolveUDPAddr("udp", n.Endpoint.String()) +// if err != nil { +// return ret, err +// } else if udpEndpoint.IP == nil { +// udpEndpoint = nil +// } +// ret.Endpoint = udpEndpoint +// return ret, nil +//} +// +//func toWGPeer(p wireguard.Peer) (ret wgtypes.PeerConfig, err error) { +// pubKey, err := wgtypes.ParseKey(string(p.PublicKey)) +// if err != nil { +// return ret, err +// } +// ret.PublicKey = pubKey +// aIPs := make([]net.IPNet, len(p.AllowedIPs)) +// for i, a := range p.AllowedIPs { +// aIPs[i] = *a +// } +// ret.AllowedIPs = aIPs +// +// if preSharedKey, err := wgtypes.ParseKey(string(p.PresharedKey)); len(p.PresharedKey) > 0 && err != nil { +// return ret, err +// } else { +// ret.PresharedKey = &preSharedKey +// } +// dur := time.Second * time.Duration(p.PersistentKeepalive) +// ret.PersistentKeepaliveInterval = &dur +// +// udpEndpoint, err := net.ResolveUDPAddr("udp", p.Endpoint.String()) +// if err != nil { +// return ret, err +// } else if udpEndpoint.IP == nil { +// udpEndpoint = nil +// } +// fmt.Println("peer") +// fmt.Println(aIPs) +// ret.Endpoint = udpEndpoint +// return ret, nil +//} diff --git a/cmd/kgctl/connect_other.nogo b/cmd/kgctl/connect_other.nogo new file mode 100644 index 0000000..9e81a64 --- /dev/null +++ b/cmd/kgctl/connect_other.nogo @@ -0,0 +1,20 @@ +// +build !linux + +package main + +import ( + "errors" + + "github.com/spf13/cobra" +) + +func connect() *cobra.Command { + cmd := &cobra.Command{ + Use: "connect", + Short: "not supporred on you OS", + RunE: func(_ *cobra.Command, _ []string) error { + return errors.New("this command is not supported on your OS") + }, + } + return cmd +} diff --git a/cmd/kgctl/graph.go b/cmd/kgctl/graph.go index 5813113..46a6b12 100644 --- a/cmd/kgctl/graph.go +++ b/cmd/kgctl/graph.go @@ -18,6 +18,8 @@ import ( "fmt" "github.com/spf13/cobra" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/squat/kilo/pkg/mesh" ) @@ -65,7 +67,7 @@ func runGraph(_ *cobra.Command, _ []string) error { peers[p.Name] = p } } - t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, 0, []byte{}, subnet, nodes[hostname].PersistentKeepalive, nil) + t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, 0, wgtypes.Key{}, subnet, nodes[hostname].PersistentKeepalive, nil) if err != nil { return fmt.Errorf("failed to create topology: %v", err) } diff --git a/cmd/kgctl/showconf.go b/cmd/kgctl/showconf.go index 495093c..9e877d5 100644 --- a/cmd/kgctl/showconf.go +++ b/cmd/kgctl/showconf.go @@ -1,4 +1,4 @@ -// Copyright 2019 the Kilo authors +// Copyright 2021 the Kilo authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,14 +15,15 @@ package main import ( - "bytes" "errors" "fmt" "net" "os" "strings" + "time" "github.com/spf13/cobra" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -47,7 +48,7 @@ var ( }, ", ") allowedIPs []string showConfOpts struct { - allowedIPs []*net.IPNet + allowedIPs []net.IPNet serializer *json.Serializer output string asPeer bool @@ -89,7 +90,7 @@ func runShowConf(c *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("allowed-ips must contain only valid CIDRs; got %q", allowedIPs[i]) } - showConfOpts.allowedIPs = append(showConfOpts.allowedIPs, aip) + showConfOpts.allowedIPs = append(showConfOpts.allowedIPs, *aip) } return runRoot(c, args) } @@ -151,14 +152,14 @@ func runShowConfNode(_ *cobra.Command, args []string) error { } } - t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, opts.port, []byte{}, subnet, nodes[hostname].PersistentKeepalive, nil) + t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, int(opts.port), wgtypes.Key{}, subnet, nodes[hostname].PersistentKeepalive, nil) if err != nil { return fmt.Errorf("failed to create topology: %v", err) } var found bool for _, p := range t.PeerConf("").Peers { - if bytes.Equal(p.PublicKey, nodes[hostname].Key) { + if p.PublicKey == nodes[hostname].Key { found = true break } @@ -184,7 +185,7 @@ func runShowConfNode(_ *cobra.Command, args []string) error { p := t.AsPeer() p.AllowedIPs = append(p.AllowedIPs, showConfOpts.allowedIPs...) p.DeduplicateIPs() - k8sp := translatePeer(p) + k8sp := translatePeer(&p) k8sp.Name = hostname return showConfOpts.serializer.Encode(k8sp, os.Stdout) case outputFormatWireGuard: @@ -192,7 +193,7 @@ func runShowConfNode(_ *cobra.Command, args []string) error { p.AllowedIPs = append(p.AllowedIPs, showConfOpts.allowedIPs...) p.DeduplicateIPs() c, err := (&wireguard.Conf{ - Peers: []*wireguard.Peer{p}, + Peers: []wireguard.Peer{p}, }).Bytes() if err != nil { return fmt.Errorf("failed to generate configuration: %v", err) @@ -244,7 +245,11 @@ func runShowConfPeer(_ *cobra.Command, args []string) error { return fmt.Errorf("did not find any peer named %q in the cluster", peer) } - t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, mesh.DefaultKiloPort, []byte{}, subnet, peers[peer].PersistentKeepalive, nil) + pka := time.Duration(0) + if p := peers[peer].PersistentKeepaliveInterval; p != nil { + pka = *p + } + t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, mesh.DefaultKiloPort, wgtypes.Key{}, subnet, pka, nil) if err != nil { return fmt.Errorf("failed to create topology: %v", err) } @@ -272,7 +277,7 @@ func runShowConfPeer(_ *cobra.Command, args []string) error { p.AllowedIPs = append(p.AllowedIPs, showConfOpts.allowedIPs...) p.DeduplicateIPs() c, err := (&wireguard.Conf{ - Peers: []*wireguard.Peer{p}, + Peers: []wireguard.Peer{*p}, }).Bytes() if err != nil { return fmt.Errorf("failed to generate configuration: %v", err) @@ -291,36 +296,37 @@ func translatePeer(peer *wireguard.Peer) *v1alpha1.Peer { var aips []string for _, aip := range peer.AllowedIPs { // Skip any invalid IPs. - if aip == nil { + // TODO all IPs should be valid, so no need to skip here? + if aip.String() == (&net.IPNet{}).String() { continue } aips = append(aips, aip.String()) } var endpoint *v1alpha1.PeerEndpoint - if peer.Endpoint != nil && peer.Endpoint.Port > 0 && (peer.Endpoint.IP != nil || peer.Endpoint.DNS != "") { + if peer.KiloEndpoint != nil && peer.KiloEndpoint.Port > 0 && (peer.KiloEndpoint.IP != nil || peer.KiloEndpoint.DNS != "") { var ip string - if peer.Endpoint.IP != nil { - ip = peer.Endpoint.IP.String() + if peer.KiloEndpoint.IP != nil { + ip = peer.KiloEndpoint.IP.String() } endpoint = &v1alpha1.PeerEndpoint{ DNSOrIP: v1alpha1.DNSOrIP{ - DNS: peer.Endpoint.DNS, + DNS: peer.KiloEndpoint.DNS, IP: ip, }, - Port: peer.Endpoint.Port, + Port: uint32(peer.KiloEndpoint.Port), } } var key string - if len(peer.PublicKey) > 0 { - key = string(peer.PublicKey) + if peer.PublicKey != (wgtypes.Key{}) { + key = peer.PublicKey.String() } var psk string - if len(peer.PresharedKey) > 0 { - psk = string(peer.PresharedKey) + if peer.PresharedKey != nil { + psk = peer.PresharedKey.String() } var pka int - if peer.PersistentKeepalive > 0 { - pka = peer.PersistentKeepalive + if peer.PersistentKeepaliveInterval != nil && *peer.PersistentKeepaliveInterval > time.Duration(0) { + pka = int(*peer.PersistentKeepaliveInterval) } return &v1alpha1.Peer{ TypeMeta: metav1.TypeMeta{ diff --git a/e2e/lib.sh b/e2e/lib.sh index e5052a2..b8cfabc 100755 --- a/e2e/lib.sh +++ b/e2e/lib.sh @@ -184,14 +184,14 @@ check_peer() { local ALLOWED_IP=$3 local GRANULARITY=$4 create_interface "$INTERFACE" - docker run --rm --entrypoint=/usr/bin/wg "$KILO_IMAGE" genkey > "$INTERFACE" - assert "create_peer $PEER $ALLOWED_IP 10 $(docker run --rm --entrypoint=/bin/sh -v "$PWD/$INTERFACE":/key "$KILO_IMAGE" -c 'cat /key | wg pubkey')" "should be able to create Peer" + docker run --rm leonnicolas/wg-tools wg genkey > "$INTERFACE" + assert "create_peer $PEER $ALLOWED_IP 10 $(docker run --rm --entrypoint=/bin/sh -v "$PWD/$INTERFACE":/key leonnicolas/wg-tools -c 'cat /key | wg pubkey')" "should be able to create Peer" assert "_kgctl showconf peer $PEER --mesh-granularity=$GRANULARITY > $PEER.ini" "should be able to get Peer configuration" - assert "docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/usr/bin/wg -v /var/run/wireguard:/var/run/wireguard -v $PWD/$PEER.ini:/peer.ini $KILO_IMAGE setconf $INTERFACE /peer.ini" "should be able to apply configuration from kgctl" - docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/usr/bin/wg -v /var/run/wireguard:/var/run/wireguard -v "$PWD/$INTERFACE":/key "$KILO_IMAGE" set "$INTERFACE" private-key /key - docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip "$KILO_IMAGE" address add "$ALLOWED_IP" dev "$INTERFACE" - docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip "$KILO_IMAGE" link set "$INTERFACE" up - docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip "$KILO_IMAGE" route add 10.42/16 dev "$INTERFACE" + assert "docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/usr/bin/wg -v /var/run/wireguard:/var/run/wireguard -v $PWD/$PEER.ini:/peer.ini leonnicolas/wg-tools setconf $INTERFACE /peer.ini" "should be able to apply configuration from kgctl" + docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/usr/bin/wg -v /var/run/wireguard:/var/run/wireguard -v "$PWD/$INTERFACE":/key leonnicolas/wg-tools set "$INTERFACE" private-key /key + docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip leonnicolas/wg-tools address add "$ALLOWED_IP" dev "$INTERFACE" + docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip leonnicolas/wg-tools link set "$INTERFACE" up + docker run --rm --network=host --cap-add=NET_ADMIN --entrypoint=/sbin/ip leonnicolas/wg-tools route add 10.42/16 dev "$INTERFACE" assert "retry 10 5 '' check_ping --local" "should be able to ping Pods from host" assert_equals "$(_kgctl showconf peer "$PEER")" "$(_kgctl showconf peer "$PEER" --mesh-granularity="$GRANULARITY")" "kgctl should be able to auto detect the mesh granularity" rm "$INTERFACE" "$PEER".ini diff --git a/pkg/k8s/backend.go b/pkg/k8s/backend.go index 4065c31..07bc06a 100644 --- a/pkg/k8s/backend.go +++ b/pkg/k8s/backend.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" v1 "k8s.io/api/core/v1" apiextensions "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -212,13 +213,13 @@ func (nb *nodeBackend) Set(name string, node *mesh.Node) error { return fmt.Errorf("failed to find node: %v", err) } n := old.DeepCopy() - n.ObjectMeta.Annotations[endpointAnnotationKey] = node.Endpoint.String() + n.ObjectMeta.Annotations[endpointAnnotationKey] = node.KiloEndpoint.String() if node.InternalIP == nil { n.ObjectMeta.Annotations[internalIPAnnotationKey] = "" } else { n.ObjectMeta.Annotations[internalIPAnnotationKey] = node.InternalIP.String() } - n.ObjectMeta.Annotations[keyAnnotationKey] = string(node.Key) + n.ObjectMeta.Annotations[keyAnnotationKey] = node.Key.String() n.ObjectMeta.Annotations[lastSeenAnnotationKey] = strconv.FormatInt(node.LastSeen, 10) if node.WireGuardIP == nil { n.ObjectMeta.Annotations[wireGuardIPAnnotationKey] = "" @@ -292,13 +293,11 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node { internalIP = nil } // Set Wireguard PersistentKeepalive setting for the node. - var persistentKeepalive int64 - if keepAlive, ok := node.ObjectMeta.Annotations[persistentKeepaliveKey]; !ok { - persistentKeepalive = 0 - } else { - if persistentKeepalive, err = strconv.ParseInt(keepAlive, 10, 64); err != nil { - persistentKeepalive = 0 - } + var persistentKeepalive = time.Duration(0) + if keepAlive, ok := node.ObjectMeta.Annotations[persistentKeepaliveKey]; ok { + // We can ignore the error, because p will be set to 0 if an error occures. + p, _ := strconv.ParseInt(keepAlive, 10, 64) + persistentKeepalive = time.Duration(p) * time.Second } var lastSeen int64 if ls, ok := node.ObjectMeta.Annotations[lastSeenAnnotationKey]; !ok { @@ -308,7 +307,7 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node { lastSeen = 0 } } - var discoveredEndpoints map[string]*wireguard.Endpoint + var discoveredEndpoints map[string]*net.UDPAddr if de, ok := node.ObjectMeta.Annotations[discoveredEndpointsKey]; ok { err := json.Unmarshal([]byte(de), &discoveredEndpoints) if err != nil { @@ -316,11 +315,11 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node { } } // Set allowed IPs for a location. - var allowedLocationIPs []*net.IPNet + var allowedLocationIPs []net.IPNet if str, ok := node.ObjectMeta.Annotations[allowedLocationIPsKey]; ok { for _, ip := range strings.Split(str, ",") { if ipnet := normalizeIP(ip); ipnet != nil { - allowedLocationIPs = append(allowedLocationIPs, ipnet) + allowedLocationIPs = append(allowedLocationIPs, *ipnet) } } } @@ -335,6 +334,9 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node { } } + // TODO log some error or warning. + key, _ := wgtypes.ParseKey(node.ObjectMeta.Annotations[keyAnnotationKey]) + return &mesh.Node{ // Endpoint and InternalIP should only ever fail to parse if the // remote node's agent has not yet set its IP address; @@ -342,15 +344,15 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node { // the mesh can wait for the node to be updated. // It is valid for the InternalIP to be nil, // if the given node only has public IP addresses. - Endpoint: endpoint, + KiloEndpoint: endpoint, NoInternalIP: noInternalIP, InternalIP: internalIP, - Key: []byte(node.ObjectMeta.Annotations[keyAnnotationKey]), + Key: key, LastSeen: lastSeen, Leader: leader, Location: location, Name: node.Name, - PersistentKeepalive: int(persistentKeepalive), + PersistentKeepalive: persistentKeepalive, Subnet: subnet, // WireGuardIP can fail to parse if the node is not a leader or if // the node's agent has not yet reconciled. In either case, the IP @@ -367,14 +369,14 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer { if peer == nil { return nil } - var aips []*net.IPNet + var aips []net.IPNet for _, aip := range peer.Spec.AllowedIPs { aip := normalizeIP(aip) // Skip any invalid IPs. if aip == nil { continue } - aips = append(aips, aip) + aips = append(aips, *aip) } var endpoint *wireguard.Endpoint if peer.Spec.Endpoint != nil { @@ -390,30 +392,34 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer { DNS: peer.Spec.Endpoint.DNS, IP: ip, }, - Port: peer.Spec.Endpoint.Port, + Port: int(peer.Spec.Endpoint.Port), } } } - var key []byte - if len(peer.Spec.PublicKey) > 0 { - key = []byte(peer.Spec.PublicKey) + + key, _ := wgtypes.ParseKey(peer.Spec.PublicKey) + var psk *wgtypes.Key + if k, err := wgtypes.ParseKey(peer.Spec.PresharedKey); err != nil { + // Set key to nil to avoid setting a key to the zero value wgtypes.Key{} + psk = nil + } else { + psk = &k } - var psk []byte - if len(peer.Spec.PresharedKey) > 0 { - psk = []byte(peer.Spec.PresharedKey) - } - var pka int + var pka time.Duration if peer.Spec.PersistentKeepalive > 0 { - pka = peer.Spec.PersistentKeepalive + pka = time.Duration(peer.Spec.PersistentKeepalive) } return &mesh.Peer{ Name: peer.Name, Peer: wireguard.Peer{ - AllowedIPs: aips, - Endpoint: endpoint, - PersistentKeepalive: pka, - PresharedKey: psk, - PublicKey: key, + PeerConfig: wgtypes.PeerConfig{ + AllowedIPs: aips, + Endpoint: nil, // applyTopology will resolve this endpoint from the KiloEndpoint. + PersistentKeepaliveInterval: &pka, + PresharedKey: psk, + PublicKey: key, + }, + KiloEndpoint: endpoint, }, } } @@ -513,19 +519,27 @@ func (pb *peerBackend) Set(name string, peer *mesh.Peer) error { if peer.Endpoint != nil { var ip string if peer.Endpoint.IP != nil { - ip = peer.Endpoint.IP.String() + ip = peer.KiloEndpoint.IP.String() } p.Spec.Endpoint = &v1alpha1.PeerEndpoint{ DNSOrIP: v1alpha1.DNSOrIP{ IP: ip, - DNS: peer.Endpoint.DNS, + DNS: peer.KiloEndpoint.DNS, }, - Port: peer.Endpoint.Port, + Port: uint32(peer.KiloEndpoint.Port), } } - p.Spec.PersistentKeepalive = peer.PersistentKeepalive - p.Spec.PresharedKey = string(peer.PresharedKey) - p.Spec.PublicKey = string(peer.PublicKey) + if peer.PersistentKeepaliveInterval == nil { + p.Spec.PersistentKeepalive = 0 + } else { + p.Spec.PersistentKeepalive = int(*peer.PersistentKeepaliveInterval) + } + if peer.PresharedKey == nil { + p.Spec.PresharedKey = "" + } else { + p.Spec.PresharedKey = peer.PresharedKey.String() + } + p.Spec.PublicKey = peer.PublicKey.String() if _, err = pb.client.KiloV1alpha1().Peers().Update(context.TODO(), p, metav1.UpdateOptions{}); err != nil { return fmt.Errorf("failed to update peer: %v", err) } @@ -570,7 +584,7 @@ func parseEndpoint(endpoint string) *wireguard.Endpoint { ip := net.ParseIP(hostRaw) if ip == nil { if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 { - return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: hostRaw}, Port: uint32(port)} + return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{DNS: hostRaw}, Port: int(port)} } return nil } @@ -579,5 +593,5 @@ func parseEndpoint(endpoint string) *wireguard.Endpoint { } else { ip = ip.To16() } - return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: ip}, Port: uint32(port)} + return &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: ip}, Port: int(port)} } diff --git a/pkg/k8s/backend_test.go b/pkg/k8s/backend_test.go index 51d2c2b..ad18306 100644 --- a/pkg/k8s/backend_test.go +++ b/pkg/k8s/backend_test.go @@ -17,8 +17,10 @@ package k8s import ( "net" "testing" + "time" "github.com/kylelemons/godebug/pretty" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" v1 "k8s.io/api/core/v1" "github.com/squat/kilo/pkg/k8s/apis/kilo/v1alpha1" @@ -26,6 +28,30 @@ import ( "github.com/squat/kilo/pkg/wireguard" ) +func mustKey() (k wgtypes.Key) { + var err error + if k, err = wgtypes.GeneratePrivateKey(); err != nil { + panic(err.Error()) + } + return +} + +func mustPSKKey() (key *wgtypes.Key) { + if k, err := wgtypes.GenerateKey(); err != nil { + panic(err.Error()) + } else { + key = &k + } + return +} + +var ( + fooKey = mustKey() + pskKey = mustPSKKey() + second = time.Second + zero = time.Duration(0) +) + func TestTranslateNode(t *testing.T) { for _, tc := range []struct { name string @@ -54,8 +80,8 @@ func TestTranslateNode(t *testing.T) { internalIPAnnotationKey: "10.0.0.2/32", }, out: &mesh.Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort}, - InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort}, + InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)}, }, }, { @@ -108,7 +134,7 @@ func TestTranslateNode(t *testing.T) { forceEndpointAnnotationKey: "-10.0.0.2:51821", }, out: &mesh.Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort}, }, }, { @@ -118,7 +144,7 @@ func TestTranslateNode(t *testing.T) { forceEndpointAnnotationKey: "10.0.0.2:51821", }, out: &mesh.Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821}, }, }, { @@ -127,7 +153,7 @@ func TestTranslateNode(t *testing.T) { persistentKeepaliveKey: "25", }, out: &mesh.Node{ - PersistentKeepalive: 25, + PersistentKeepalive: 25 * time.Second, }, }, { @@ -166,7 +192,7 @@ func TestTranslateNode(t *testing.T) { forceEndpointAnnotationKey: "10.0.0.2:51821", forceInternalIPAnnotationKey: "10.1.0.2/32", internalIPAnnotationKey: "10.1.0.1/32", - keyAnnotationKey: "foo", + keyAnnotationKey: fooKey.String(), lastSeenAnnotationKey: "1000000000", leaderAnnotationKey: "", locationAnnotationKey: "b", @@ -177,14 +203,14 @@ func TestTranslateNode(t *testing.T) { RegionLabelKey: "a", }, out: &mesh.Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.2")}, Port: 51821}, NoInternalIP: false, InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)}, - Key: []byte("foo"), + Key: fooKey, LastSeen: 1000000000, Leader: true, Location: "b", - PersistentKeepalive: 25, + PersistentKeepalive: 25 * time.Second, Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)}, WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)}, }, @@ -195,7 +221,7 @@ func TestTranslateNode(t *testing.T) { annotations: map[string]string{ endpointAnnotationKey: "10.0.0.1:51820", internalIPAnnotationKey: "", - keyAnnotationKey: "foo", + keyAnnotationKey: fooKey.String(), lastSeenAnnotationKey: "1000000000", locationAnnotationKey: "b", persistentKeepaliveKey: "25", @@ -205,13 +231,13 @@ func TestTranslateNode(t *testing.T) { RegionLabelKey: "a", }, out: &mesh.Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820}, InternalIP: nil, - Key: []byte("foo"), + Key: fooKey, LastSeen: 1000000000, Leader: false, Location: "b", - PersistentKeepalive: 25, + PersistentKeepalive: 25 * time.Second, Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)}, WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)}, }, @@ -223,7 +249,7 @@ func TestTranslateNode(t *testing.T) { endpointAnnotationKey: "10.0.0.1:51820", internalIPAnnotationKey: "10.1.0.1/32", forceInternalIPAnnotationKey: "", - keyAnnotationKey: "foo", + keyAnnotationKey: fooKey.String(), lastSeenAnnotationKey: "1000000000", locationAnnotationKey: "b", persistentKeepaliveKey: "25", @@ -233,14 +259,14 @@ func TestTranslateNode(t *testing.T) { RegionLabelKey: "a", }, out: &mesh.Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: 51820}, NoInternalIP: true, InternalIP: nil, - Key: []byte("foo"), + Key: fooKey, LastSeen: 1000000000, Leader: false, Location: "b", - PersistentKeepalive: 25, + PersistentKeepalive: 25 * time.Second, Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)}, WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)}, }, @@ -266,7 +292,13 @@ func TestTranslatePeer(t *testing.T) { }{ { name: "empty", - out: &mesh.Peer{}, + out: &mesh.Peer{ + Peer: wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + PersistentKeepaliveInterval: &zero, + }, + }, + }, }, { name: "invalid ips", @@ -276,7 +308,13 @@ func TestTranslatePeer(t *testing.T) { "foo", }, }, - out: &mesh.Peer{}, + out: &mesh.Peer{ + Peer: wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + PersistentKeepaliveInterval: &zero, + }, + }, + }, }, { name: "valid ips", @@ -288,9 +326,12 @@ func TestTranslatePeer(t *testing.T) { }, out: &mesh.Peer{ Peer: wireguard.Peer{ - AllowedIPs: []*net.IPNet{ - {IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)}, - {IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)}, + PeerConfig: wgtypes.PeerConfig{ + AllowedIPs: []net.IPNet{ + {IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)}, + }, + PersistentKeepaliveInterval: &zero, }, }, }, @@ -305,7 +346,13 @@ func TestTranslatePeer(t *testing.T) { Port: mesh.DefaultKiloPort, }, }, - out: &mesh.Peer{}, + out: &mesh.Peer{ + Peer: wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + PersistentKeepaliveInterval: &zero, + }, + }, + }, }, { name: "only endpoint port", @@ -314,7 +361,13 @@ func TestTranslatePeer(t *testing.T) { Port: mesh.DefaultKiloPort, }, }, - out: &mesh.Peer{}, + out: &mesh.Peer{ + Peer: wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + PersistentKeepaliveInterval: &zero, + }, + }, + }, }, { name: "valid endpoint ip", @@ -328,10 +381,13 @@ func TestTranslatePeer(t *testing.T) { }, out: &mesh.Peer{ Peer: wireguard.Peer{ - Endpoint: &wireguard.Endpoint{ + KiloEndpoint: &wireguard.Endpoint{ DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("10.0.0.1")}, Port: mesh.DefaultKiloPort, }, + PeerConfig: wgtypes.PeerConfig{ + PersistentKeepaliveInterval: &zero, + }, }, }, }, @@ -347,10 +403,13 @@ func TestTranslatePeer(t *testing.T) { }, out: &mesh.Peer{ Peer: wireguard.Peer{ - Endpoint: &wireguard.Endpoint{ + KiloEndpoint: &wireguard.Endpoint{ DNSOrIP: wireguard.DNSOrIP{DNS: "example.com"}, Port: mesh.DefaultKiloPort, }, + PeerConfig: wgtypes.PeerConfig{ + PersistentKeepaliveInterval: &zero, + }, }, }, }, @@ -359,16 +418,25 @@ func TestTranslatePeer(t *testing.T) { spec: v1alpha1.PeerSpec{ PublicKey: "", }, - out: &mesh.Peer{}, + out: &mesh.Peer{ + Peer: wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + PersistentKeepaliveInterval: &zero, + }, + }, + }, }, { name: "valid key", spec: v1alpha1.PeerSpec{ - PublicKey: "foo", + PublicKey: fooKey.String(), }, out: &mesh.Peer{ Peer: wireguard.Peer{ - PublicKey: []byte("foo"), + PeerConfig: wgtypes.PeerConfig{ + PublicKey: fooKey, + PersistentKeepaliveInterval: &zero, + }, }, }, }, @@ -377,27 +445,38 @@ func TestTranslatePeer(t *testing.T) { spec: v1alpha1.PeerSpec{ PersistentKeepalive: -1, }, - out: &mesh.Peer{}, + out: &mesh.Peer{ + Peer: wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + PersistentKeepaliveInterval: &zero, + }, + }, + }, }, { name: "valid keepalive", spec: v1alpha1.PeerSpec{ - PersistentKeepalive: 1, + PersistentKeepalive: 1 * int(time.Second), }, out: &mesh.Peer{ Peer: wireguard.Peer{ - PersistentKeepalive: 1, + PeerConfig: wgtypes.PeerConfig{ + PersistentKeepaliveInterval: &second, + }, }, }, }, { name: "valid preshared key", spec: v1alpha1.PeerSpec{ - PresharedKey: "psk", + PresharedKey: pskKey.String(), }, out: &mesh.Peer{ Peer: wireguard.Peer{ - PresharedKey: []byte("psk"), + PeerConfig: wgtypes.PeerConfig{ + PersistentKeepaliveInterval: &zero, + PresharedKey: pskKey, + }, }, }, }, diff --git a/pkg/mesh/backend.go b/pkg/mesh/backend.go index db4123d..d30eb88 100644 --- a/pkg/mesh/backend.go +++ b/pkg/mesh/backend.go @@ -18,6 +18,8 @@ import ( "net" "time" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/squat/kilo/pkg/wireguard" ) @@ -54,8 +56,9 @@ const ( // Node represents a node in the network. type Node struct { - Endpoint *wireguard.Endpoint - Key []byte + KiloEndpoint *wireguard.Endpoint + Endpoint *net.UDPAddr + Key wgtypes.Key NoInternalIP bool InternalIP *net.IPNet // LastSeen is a Unix time for the last time @@ -66,18 +69,25 @@ type Node struct { Leader bool Location string Name string - PersistentKeepalive int + PersistentKeepalive time.Duration Subnet *net.IPNet WireGuardIP *net.IPNet - DiscoveredEndpoints map[string]*wireguard.Endpoint - AllowedLocationIPs []*net.IPNet + // DiscoveredEndpoints cannot be DNS endpoints, only net.UDPAddr. + DiscoveredEndpoints map[string]*net.UDPAddr + AllowedLocationIPs []net.IPNet Granularity Granularity } // Ready indicates whether or not the node is ready. func (n *Node) Ready() bool { // Nodes that are not leaders will not have WireGuardIPs, so it is not required. - return n != nil && n.Endpoint != nil && !(n.Endpoint.IP == nil && n.Endpoint.DNS == "") && n.Endpoint.Port != 0 && n.Key != nil && n.Subnet != nil && time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second) + return n != nil && + n.KiloEndpoint != nil && + !(n.KiloEndpoint.IP == nil && n.KiloEndpoint.DNS == "") && + n.KiloEndpoint.Port != 0 && + n.Key != wgtypes.Key{} && + n.Subnet != nil && + time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second) } // Peer represents a peer in the network. @@ -92,7 +102,10 @@ type Peer struct { // will not declare their endpoint and instead allow it to be // discovered. func (p *Peer) Ready() bool { - return p != nil && p.AllowedIPs != nil && len(p.AllowedIPs) != 0 && p.PublicKey != nil + return p != nil && + p.AllowedIPs != nil && + len(p.AllowedIPs) != 0 && + p.PublicKey != wgtypes.Key{} // If Key was not set, it will be wgtypes.Key{}. } // EventType describes what kind of an action an event represents. diff --git a/pkg/mesh/graph.go b/pkg/mesh/graph.go index 619f6a8..84b50a1 100644 --- a/pkg/mesh/graph.go +++ b/pkg/mesh/graph.go @@ -65,7 +65,7 @@ func (t *Topology) Dot() (string, error) { var endpoint *wireguard.Endpoint if j == s.leader { wg = s.wireGuardIP - endpoint = s.endpoint + endpoint = s.kiloEndpoint if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Rank), "1"); err != nil { return "", fmt.Errorf("failed to add rank to node") } diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index 0fcff39..34ffe2a 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -1,4 +1,4 @@ -// Copyright 2019 the Kilo authors +// Copyright 2021 the Kilo authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,6 +30,8 @@ import ( "github.com/go-kit/kit/log/level" "github.com/prometheus/client_golang/prometheus" "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/squat/kilo/pkg/encapsulation" "github.com/squat/kilo/pkg/iproute" @@ -43,8 +45,6 @@ const ( kiloPath = "/var/lib/kilo" // privateKeyPath is the filepath where the WireGuard private key is stored. privateKeyPath = kiloPath + "/key" - // confPath is the filepath where the WireGuard configuration is stored. - confPath = kiloPath + "/conf" ) // Mesh is able to create Kilo network meshes. @@ -60,12 +60,13 @@ type Mesh struct { internalIP *net.IPNet ipTables *iptables.Controller kiloIface int + kiloIfaceName string key []byte local bool - port uint32 - priv []byte + port int + priv wgtypes.Key privIface int - pub []byte + pub wgtypes.Key resyncPeriod time.Duration iptablesForwardRule bool stop chan struct{} @@ -88,23 +89,24 @@ type Mesh struct { } // New returns a new Mesh instance. -func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port uint32, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanUpIface bool, createIface bool, mtu uint, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, logger log.Logger) (*Mesh, error) { +func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port int, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanUpIface bool, createIface bool, mtu uint, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, logger log.Logger) (*Mesh, error) { if err := os.MkdirAll(kiloPath, 0700); err != nil { return nil, fmt.Errorf("failed to create directory to store configuration: %v", err) } - private, err := ioutil.ReadFile(privateKeyPath) - private = bytes.Trim(private, "\n") + privateB, err := ioutil.ReadFile(privateKeyPath) + privateB = bytes.Trim(privateB, "\n") + private, err := wgtypes.ParseKey(string(privateB)) if err != nil { level.Warn(logger).Log("msg", "no private key found on disk; generating one now") - if private, err = wireguard.GenKey(); err != nil { + if private, err = wgtypes.GenerateKey(); err != nil { return nil, err } } - public, err := wireguard.PubKey(private) + public := private.PublicKey() if err != nil { return nil, err } - if err := ioutil.WriteFile(privateKeyPath, private, 0600); err != nil { + if err := ioutil.WriteFile(privateKeyPath, []byte(private.String()), 0600); err != nil { return nil, fmt.Errorf("failed to write private key to disk: %v", err) } cniIndex, err := cniDeviceIndex() @@ -168,6 +170,7 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit internalIP: privateIP, ipTables: ipTables, kiloIface: kiloIface, + kiloIfaceName: iface, nodes: make(map[string]*Node), peers: make(map[string]*Peer), port: port, @@ -314,7 +317,7 @@ func (m *Mesh) syncPeers(e *PeerEvent) { var diff bool m.mu.Lock() // Peers are indexed by public key. - key := string(e.Peer.PublicKey) + key := e.Peer.PublicKey.String() if !e.Peer.Ready() { // Trace non ready peer with their presence in the mesh. _, ok := m.peers[key] @@ -324,8 +327,8 @@ func (m *Mesh) syncPeers(e *PeerEvent) { case AddEvent: fallthrough case UpdateEvent: - if e.Old != nil && key != string(e.Old.PublicKey) { - delete(m.peers, string(e.Old.PublicKey)) + if e.Old != nil && key != e.Old.PublicKey.String() { + delete(m.peers, e.Old.PublicKey.String()) diff = true } if !peersAreEqual(m.peers[key], e.Peer) { @@ -367,8 +370,8 @@ func (m *Mesh) checkIn() { func (m *Mesh) handleLocal(n *Node) { // Allow the IPs to be overridden. - if n.Endpoint == nil || (n.Endpoint.DNS == "" && n.Endpoint.IP == nil) { - n.Endpoint = &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: m.externalIP.IP}, Port: m.port} + if n.KiloEndpoint == nil || (n.KiloEndpoint.DNS == "" && n.KiloEndpoint.IP == nil) { + n.KiloEndpoint = &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: m.externalIP.IP}, Port: m.port} } if n.InternalIP == nil && !n.NoInternalIP { n.InternalIP = m.internalIP @@ -377,7 +380,7 @@ func (m *Mesh) handleLocal(n *Node) { // Take leader, location, and subnet from the argument, as these // are not determined by kilo. local := &Node{ - Endpoint: n.Endpoint, + KiloEndpoint: n.KiloEndpoint, Key: m.pub, NoInternalIP: n.NoInternalIP, InternalIP: n.InternalIP, @@ -462,22 +465,26 @@ func (m *Mesh) applyTopology() { m.errorCounter.WithLabelValues("apply").Inc() return } - // Find the old configuration. - oldConfDump, err := wireguard.ShowDump(link.Attrs().Name) + + wgClient, err := wgctrl.New() if err != nil { level.Error(m.logger).Log("error", err) m.errorCounter.WithLabelValues("apply").Inc() return } - oldConf, err := wireguard.ParseDump(oldConfDump) + defer wgClient.Close() + + // wgDevice is the current configuration of the wg interface. + wgDevice, err := wgClient.Device(m.kiloIfaceName) if err != nil { level.Error(m.logger).Log("error", err) m.errorCounter.WithLabelValues("apply").Inc() return } - natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger) + + natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger) nodes[m.hostname].DiscoveredEndpoints = natEndpoints - t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger) + t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].KiloEndpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger) if err != nil { level.Error(m.logger).Log("error", err) m.errorCounter.WithLabelValues("apply").Inc() @@ -489,19 +496,8 @@ func (m *Mesh) applyTopology() { } else { m.wireGuardIP = nil } - conf := t.Conf() - buf, err := conf.Bytes() - if err != nil { - level.Error(m.logger).Log("error", err) - m.errorCounter.WithLabelValues("apply").Inc() - return - } - if err := ioutil.WriteFile(confPath, buf, 0600); err != nil { - level.Error(m.logger).Log("error", err) - m.errorCounter.WithLabelValues("apply").Inc() - return - } ipRules := t.Rules(m.cni, m.iptablesForwardRule) + // If we are handling local routes, ensure the local // tunnel has an IP address and IPIP traffic is allowed. if m.enc.Strategy() != encapsulation.Never && m.local { @@ -540,10 +536,12 @@ func (m *Mesh) applyTopology() { } // Setting the WireGuard configuration interrupts existing connections // so only set the configuration if it has changed. - equal := conf.Equal(oldConf) + conf := t.Conf() + equal, diff := conf.Equal(wgDevice) if !equal { - level.Info(m.logger).Log("msg", "WireGuard configurations are different") - if err := wireguard.SetConf(link.Attrs().Name, confPath); err != nil { + level.Info(m.logger).Log("msg", "WireGuard configurations are different", "diff", diff) + level.Debug(m.logger).Log("changing wg config", "config", conf.WGConfig()) + if err := wgClient.ConfigureDevice(m.kiloIfaceName, conf.WGConfig()); err != nil { level.Error(m.logger).Log("error", err) m.errorCounter.WithLabelValues("apply").Inc() return @@ -598,10 +596,6 @@ func (m *Mesh) cleanUp() { level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up routes: %v", err)) m.errorCounter.WithLabelValues("cleanUp").Inc() } - if err := os.Remove(confPath); err != nil { - level.Error(m.logger).Log("error", fmt.Sprintf("failed to delete configuration file: %v", err)) - m.errorCounter.WithLabelValues("cleanUp").Inc() - } if m.cleanUpIface { if err := iproute.RemoveInterface(m.kiloIface); err != nil { level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err)) @@ -631,10 +625,13 @@ func (m *Mesh) resolveEndpoints() error { } // If the node is ready, then the endpoint is not nil // but it may not have a DNS name. - if m.nodes[k].Endpoint.DNS == "" { + if m.nodes[k].KiloEndpoint.DNS == "" { continue } - if err := resolveEndpoint(m.nodes[k].Endpoint); err != nil { + if u, err := net.ResolveUDPAddr("udp", m.nodes[k].KiloEndpoint.String()); err == nil { + m.nodes[k].Endpoint = u + m.nodes[k].KiloEndpoint.IP = u.IP + } else { return err } } @@ -645,33 +642,19 @@ func (m *Mesh) resolveEndpoints() error { continue } // Peers may have nil endpoints. - if m.peers[k].Endpoint == nil || m.peers[k].Endpoint.DNS == "" { + if m.peers[k].KiloEndpoint == nil || m.peers[k].KiloEndpoint.DNS == "" { continue } - if err := resolveEndpoint(m.peers[k].Endpoint); err != nil { + if u, err := net.ResolveUDPAddr("udp", m.peers[k].KiloEndpoint.String()); err == nil { + m.peers[k].Endpoint = u + m.peers[k].KiloEndpoint.IP = u.IP + } else { return err } } return nil } -func resolveEndpoint(endpoint *wireguard.Endpoint) error { - ips, err := net.LookupIP(endpoint.DNS) - if err != nil { - return fmt.Errorf("failed to look up DNS name %q: %v", endpoint.DNS, err) - } - nets := make([]*net.IPNet, len(ips), len(ips)) - for i := range ips { - nets[i] = oneAddressCIDR(ips[i]) - } - sortIPs(nets) - if len(nets) == 0 { - return fmt.Errorf("did not find any addresses for DNS name %q", endpoint.DNS) - } - endpoint.IP = nets[0].IP - return nil -} - func isSelf(hostname string, node *Node) bool { return node != nil && node.Name == hostname } @@ -685,13 +668,24 @@ func nodesAreEqual(a, b *Node) bool { } // Check the DNS name first since this package // is doing the DNS resolution. - if !a.Endpoint.Equal(b.Endpoint, true) { + if !a.KiloEndpoint.Equal(b.KiloEndpoint, true) { return false } // Ignore LastSeen when comparing equality we want to check if the nodes are // equivalent. However, we do want to check if LastSeen has transitioned // between valid and invalid. - return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive && discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints) && ipNetSlicesEqual(a.AllowedLocationIPs, b.AllowedLocationIPs) && a.Granularity == b.Granularity + return a.Key.String() == b.Key.String() && + ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && + ipNetsEqual(a.InternalIP, b.InternalIP) && + a.Leader == b.Leader && + a.Location == b.Location && + a.Name == b.Name && + subnetsEqual(a.Subnet, b.Subnet) && + a.Ready() == b.Ready() && + a.PersistentKeepalive == b.PersistentKeepalive && + discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints) && + ipNetSlicesEqual(a.AllowedLocationIPs, b.AllowedLocationIPs) && + a.Granularity == b.Granularity } func peersAreEqual(a, b *Peer) bool { @@ -703,18 +697,22 @@ func peersAreEqual(a, b *Peer) bool { } // Check the DNS name first since this package // is doing the DNS resolution. - if !a.Endpoint.Equal(b.Endpoint, true) { + if !a.KiloEndpoint.Equal(b.KiloEndpoint, true) { return false } if len(a.AllowedIPs) != len(b.AllowedIPs) { return false } for i := range a.AllowedIPs { - if !ipNetsEqual(a.AllowedIPs[i], b.AllowedIPs[i]) { + if !ipNetsEqual(&a.AllowedIPs[i], &b.AllowedIPs[i]) { return false } } - return string(a.PublicKey) == string(b.PublicKey) && string(a.PresharedKey) == string(b.PresharedKey) && a.PersistentKeepalive == b.PersistentKeepalive + return a.PublicKey.String() == b.PublicKey.String() && + (a.PresharedKey == nil) == (b.PresharedKey == nil) && + (a.PresharedKey == nil || a.PresharedKey.String() == b.PresharedKey.String()) && + (a.PersistentKeepaliveInterval == nil) == (b.PersistentKeepaliveInterval == nil) && + (a.PersistentKeepaliveInterval == nil || a.PersistentKeepaliveInterval == b.PersistentKeepaliveInterval) } func ipNetsEqual(a, b *net.IPNet) bool { @@ -730,12 +728,12 @@ func ipNetsEqual(a, b *net.IPNet) bool { return a.IP.Equal(b.IP) } -func ipNetSlicesEqual(a, b []*net.IPNet) bool { +func ipNetSlicesEqual(a, b []net.IPNet) bool { if len(a) != len(b) { return false } for i := range a { - if !ipNetsEqual(a[i], b[i]) { + if !ipNetsEqual(&a[i], &b[i]) { return false } } @@ -761,7 +759,7 @@ func subnetsEqual(a, b *net.IPNet) bool { return true } -func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool { +func discoveredEndpointsAreEqual(a, b map[string]*net.UDPAddr) bool { if a == nil && b == nil { return true } @@ -772,7 +770,7 @@ func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool { return false } for k := range a { - if !a[k].Equal(b[k], false) { + if a[k] != b[k] { return false } } @@ -788,24 +786,26 @@ func linkByIndex(index int) (netlink.Link, error) { } // discoverNATEndpoints uses the node's WireGuard configuration to returns a list of the most recently discovered endpoints for all nodes and peers behind NAT so that they can roam. -func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf, logger log.Logger) map[string]*wireguard.Endpoint { - natEndpoints := make(map[string]*wireguard.Endpoint) - keys := make(map[string]*wireguard.Peer) +// Discovered endpionts will never be DNS names, because WireGuard will always resolve them to net.UDPAddr. +func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wgtypes.Device, logger log.Logger) map[string]*net.UDPAddr { + natEndpoints := make(map[string]*net.UDPAddr) + keys := make(map[string]wgtypes.Peer) for i := range conf.Peers { - keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i] + keys[conf.Peers[i].PublicKey.String()] = conf.Peers[i] } for _, n := range nodes { - if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 { - level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false), "latest-handshake", peer.LatestHandshake) - if (peer.LatestHandshake != time.Time{}) { - natEndpoints[string(n.Key)] = peer.Endpoint + if peer, ok := keys[n.Key.String()]; ok && n.PersistentKeepalive != time.Duration(0) { + level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", peer.Endpoint.String() == n.Endpoint.String(), "latest-handshake", peer.LastHandshakeTime) + // Don't update the endpoint, if there was never any handshake. + if !peer.LastHandshakeTime.Equal(time.Time{}) { + natEndpoints[n.Key.String()] = peer.Endpoint } } } for _, p := range peers { - if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 { - if (peer.LatestHandshake != time.Time{}) { - natEndpoints[string(p.PublicKey)] = peer.Endpoint + if peer, ok := keys[p.PublicKey.String()]; ok && p.PersistentKeepaliveInterval != nil { + if !peer.LastHandshakeTime.Equal(time.Time{}) { + natEndpoints[p.PublicKey.String()] = peer.Endpoint } } } diff --git a/pkg/mesh/mesh_test.go b/pkg/mesh/mesh_test.go index 95ec0df..ea5c396 100644 --- a/pkg/mesh/mesh_test.go +++ b/pkg/mesh/mesh_test.go @@ -20,8 +20,19 @@ import ( "time" "github.com/squat/kilo/pkg/wireguard" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +func mustKey() wgtypes.Key { + if k, err := wgtypes.GeneratePrivateKey(); err != nil { + panic(err.Error()) + } else { + return k + } +} + +var key = mustKey() + func TestReady(t *testing.T) { internalIP := oneAddressCIDR(net.ParseIP("1.1.1.1")) externalIP := oneAddressCIDR(net.ParseIP("2.2.2.2")) @@ -44,7 +55,7 @@ func TestReady(t *testing.T) { name: "empty endpoint", node: &Node{ InternalIP: internalIP, - Key: []byte{}, + Key: key, Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: false, @@ -52,58 +63,58 @@ func TestReady(t *testing.T) { { name: "empty endpoint IP", node: &Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{}, Port: DefaultKiloPort}, - InternalIP: internalIP, - Key: []byte{}, - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{}, Port: DefaultKiloPort}, + InternalIP: internalIP, + Key: wgtypes.Key{}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: false, }, { name: "empty endpoint port", node: &Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}}, - InternalIP: internalIP, - Key: []byte{}, - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}}, + InternalIP: internalIP, + Key: wgtypes.Key{}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: false, }, { name: "empty internal IP", node: &Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, - Key: []byte{}, - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, + Key: wgtypes.Key{}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: false, }, { name: "empty key", node: &Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, - InternalIP: internalIP, - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, + InternalIP: internalIP, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: false, }, { name: "empty subnet", node: &Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, - InternalIP: internalIP, - Key: []byte{}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, + InternalIP: internalIP, + Key: wgtypes.Key{}, }, ready: false, }, { name: "valid", node: &Node{ - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, - InternalIP: internalIP, - Key: []byte{}, - LastSeen: time.Now().Unix(), - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: externalIP.IP}, Port: DefaultKiloPort}, + InternalIP: internalIP, + Key: key, + LastSeen: time.Now().Unix(), + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, ready: true, }, diff --git a/pkg/mesh/routes.go b/pkg/mesh/routes.go index 3e39442..e3bc4b0 100644 --- a/pkg/mesh/routes.go +++ b/pkg/mesh/routes.go @@ -40,7 +40,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface var gw net.IP for _, segment := range t.segments { if segment.location == t.location { - gw = enc.Gw(segment.endpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader]) + gw = enc.Gw(segment.kiloEndpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader]) break } } @@ -113,7 +113,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface // we need to set routes for allowed location IPs over the leader in the current location. for i := range segment.allowedLocationIPs { routes = append(routes, encapsulateRoute(&netlink.Route{ - Dst: segment.allowedLocationIPs[i], + Dst: &segment.allowedLocationIPs[i], Flags: int(netlink.FLAG_ONLINK), Gw: gw, LinkIndex: privIface, @@ -125,7 +125,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface for _, peer := range t.peers { for i := range peer.AllowedIPs { routes = append(routes, encapsulateRoute(&netlink.Route{ - Dst: peer.AllowedIPs[i], + Dst: &peer.AllowedIPs[i], Flags: int(netlink.FLAG_ONLINK), Gw: gw, LinkIndex: privIface, @@ -196,7 +196,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface // equals the external IP. This means that the node // is only accessible through an external IP and we // cannot encapsulate traffic to an IP through the IP. - if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP) { + if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.kiloEndpoint.IP) { continue } // Add routes to the private IPs of nodes in other segments. @@ -214,7 +214,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface // we need to set routes for allowed location IPs over the wg interface. for i := range segment.allowedLocationIPs { routes = append(routes, &netlink.Route{ - Dst: segment.allowedLocationIPs[i], + Dst: &segment.allowedLocationIPs[i], Flags: int(netlink.FLAG_ONLINK), Gw: segment.wireGuardIP, LinkIndex: kiloIface, @@ -226,7 +226,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface for _, peer := range t.peers { for i := range peer.AllowedIPs { routes = append(routes, &netlink.Route{ - Dst: peer.AllowedIPs[i], + Dst: &peer.AllowedIPs[i], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }) diff --git a/pkg/mesh/routes_test.go b/pkg/mesh/routes_test.go index 71fe7c5..648ce0e 100644 --- a/pkg/mesh/routes_test.go +++ b/pkg/mesh/routes_test.go @@ -75,7 +75,7 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: nodes["b"].AllowedLocationIPs[0], + Dst: &nodes["b"].AllowedLocationIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP, LinkIndex: kiloIface, @@ -89,17 +89,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -132,17 +132,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -196,21 +196,21 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: nodes["b"].InternalIP.IP, LinkIndex: privIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], Flags: int(netlink.FLAG_ONLINK), Gw: nodes["b"].InternalIP.IP, LinkIndex: privIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: nodes["b"].InternalIP.IP, LinkIndex: privIface, @@ -266,24 +266,24 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: nodes["b"].AllowedLocationIPs[0], + Dst: &nodes["b"].AllowedLocationIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["d"].Name).segments[1].wireGuardIP, LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -309,7 +309,7 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: nodes["b"].AllowedLocationIPs[0], + Dst: &nodes["b"].AllowedLocationIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[1].wireGuardIP, LinkIndex: kiloIface, @@ -337,17 +337,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -394,17 +394,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -444,7 +444,7 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: nodes["b"].AllowedLocationIPs[0], + Dst: &nodes["b"].AllowedLocationIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[1].wireGuardIP, LinkIndex: kiloIface, @@ -458,17 +458,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -509,7 +509,7 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: nodes["b"].AllowedLocationIPs[0], + Dst: &nodes["b"].AllowedLocationIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP, LinkIndex: kiloIface, @@ -523,17 +523,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -574,7 +574,7 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: nodes["b"].AllowedLocationIPs[0], + Dst: &nodes["b"].AllowedLocationIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP, LinkIndex: kiloIface, @@ -588,17 +588,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -639,17 +639,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -698,17 +698,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -782,21 +782,21 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: nodes["b"].InternalIP.IP, LinkIndex: privIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], Flags: int(netlink.FLAG_ONLINK), Gw: nodes["b"].InternalIP.IP, LinkIndex: privIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: nodes["b"].InternalIP.IP, LinkIndex: privIface, @@ -868,21 +868,21 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: nodes["b"].InternalIP.IP, LinkIndex: tunlIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], Flags: int(netlink.FLAG_ONLINK), Gw: nodes["b"].InternalIP.IP, LinkIndex: tunlIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: nodes["b"].InternalIP.IP, LinkIndex: tunlIface, @@ -918,7 +918,7 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: nodes["b"].AllowedLocationIPs[0], + Dst: &nodes["b"].AllowedLocationIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[1].wireGuardIP, LinkIndex: kiloIface, @@ -946,17 +946,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -1004,17 +1004,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, @@ -1055,7 +1055,7 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: nodes["b"].AllowedLocationIPs[0], + Dst: &nodes["b"].AllowedLocationIPs[0], Flags: int(netlink.FLAG_ONLINK), Gw: mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[1].wireGuardIP, LinkIndex: kiloIface, @@ -1069,17 +1069,17 @@ func TestRoutes(t *testing.T) { Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[0], + Dst: &peers["a"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["a"].AllowedIPs[1], + Dst: &peers["a"].AllowedIPs[1], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, { - Dst: peers["b"].AllowedIPs[0], + Dst: &peers["b"].AllowedIPs[0], LinkIndex: kiloIface, Protocol: unix.RTPROT_STATIC, }, diff --git a/pkg/mesh/topology.go b/pkg/mesh/topology.go index a0a9015..628df81 100644 --- a/pkg/mesh/topology.go +++ b/pkg/mesh/topology.go @@ -1,4 +1,4 @@ -// Copyright 2019 the Kilo authors +// Copyright 2021 the Kilo authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,9 +18,11 @@ import ( "errors" "net" "sort" + "time" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/squat/kilo/pkg/wireguard" ) @@ -33,8 +35,8 @@ const ( // Topology represents the logical structure of the overlay network. type Topology struct { // key is the private key of the node creating the topology. - key []byte - port uint32 + key wgtypes.Key + port int // Location is the logical location of the local host. location string segments []*segment @@ -47,7 +49,7 @@ type Topology struct { leader bool // persistentKeepalive is the interval in seconds of the emission // of keepalive packets by the local node to its peers. - persistentKeepalive int + persistentKeepalive time.Duration // privateIP is the private IP address of the local node. privateIP *net.IPNet // subnet is the Pod subnet of the local node. @@ -59,15 +61,16 @@ type Topology struct { // is equal to the Kilo subnet. wireGuardCIDR *net.IPNet // discoveredEndpoints is the updated map of valid discovered Endpoints - discoveredEndpoints map[string]*wireguard.Endpoint + discoveredEndpoints map[string]*net.UDPAddr logger log.Logger } type segment struct { - allowedIPs []*net.IPNet - endpoint *wireguard.Endpoint - key []byte - persistentKeepalive int + allowedIPs []net.IPNet + kiloEndpoint *wireguard.Endpoint + endpoint *net.UDPAddr + key wgtypes.Key + persistentKeepalive time.Duration // Location is the logical location of this segment. location string @@ -85,11 +88,11 @@ type segment struct { // allowedLocationIPs are not part of the cluster and are not peers. // They are directly routable from nodes within the segment. // A classic example is a printer that ought to be routable from other locations. - allowedLocationIPs []*net.IPNet + allowedLocationIPs []net.IPNet } // NewTopology creates a new Topology struct from a given set of nodes and peers. -func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int, logger log.Logger) (*Topology, error) { +func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration, logger log.Logger) (*Topology, error) { if logger == nil { logger = log.NewNopLogger() } @@ -120,7 +123,18 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra localLocation = nodeLocationPrefix + hostname } - t := Topology{key: key, port: port, hostname: hostname, location: localLocation, persistentKeepalive: persistentKeepalive, privateIP: nodes[hostname].InternalIP, subnet: nodes[hostname].Subnet, wireGuardCIDR: subnet, discoveredEndpoints: make(map[string]*wireguard.Endpoint), logger: logger} + t := Topology{ + key: key, + port: port, + hostname: hostname, + location: localLocation, + persistentKeepalive: persistentKeepalive, + privateIP: nodes[hostname].InternalIP, + subnet: nodes[hostname].Subnet, + wireGuardCIDR: subnet, + discoveredEndpoints: make(map[string]*net.UDPAddr), + logger: logger, + } for location := range topoMap { // Sort the location so the result is stable. sort.Slice(topoMap[location], func(i, j int) bool { @@ -130,9 +144,9 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra if location == localLocation && topoMap[location][leader].Name == hostname { t.leader = true } - var allowedIPs []*net.IPNet + var allowedIPs []net.IPNet allowedLocationIPsMap := make(map[string]struct{}) - var allowedLocationIPs []*net.IPNet + var allowedLocationIPs []net.IPNet var cidrs []*net.IPNet var hostnames []string var privateIPs []net.IP @@ -142,7 +156,9 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra // - the node's WireGuard IP // - the node's internal IP // - IPs that were specified by the allowed-location-ips annotation - allowedIPs = append(allowedIPs, node.Subnet) + if node.Subnet != nil { + allowedIPs = append(allowedIPs, *node.Subnet) + } for _, ip := range node.AllowedLocationIPs { if _, ok := allowedLocationIPsMap[ip.String()]; !ok { allowedLocationIPs = append(allowedLocationIPs, ip) @@ -150,7 +166,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra } } if node.InternalIP != nil { - allowedIPs = append(allowedIPs, oneAddressCIDR(node.InternalIP.IP)) + allowedIPs = append(allowedIPs, *oneAddressCIDR(node.InternalIP.IP)) privateIPs = append(privateIPs, node.InternalIP.IP) } cidrs = append(cidrs, node.Subnet) @@ -162,6 +178,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra }) t.segments = append(t.segments, &segment{ allowedIPs: allowedIPs, + kiloEndpoint: topoMap[location][leader].KiloEndpoint, endpoint: topoMap[location][leader].Endpoint, key: topoMap[location][leader].Key, persistentKeepalive: topoMap[location][leader].PersistentKeepalive, @@ -202,7 +219,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra return nil, errors.New("failed to allocate an IP address; ran out of IP addresses") } segment.wireGuardIP = ipNet.IP - segment.allowedIPs = append(segment.allowedIPs, oneAddressCIDR(ipNet.IP)) + segment.allowedIPs = append(segment.allowedIPs, *oneAddressCIDR(ipNet.IP)) if t.leader && segment.location == t.location { t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: subnet.Mask} } @@ -224,11 +241,11 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra return &t, nil } -func intersect(n1, n2 *net.IPNet) bool { +func intersect(n1, n2 net.IPNet) bool { return n1.Contains(n2.IP) || n2.Contains(n1.IP) } -func (t *Topology) filterAllowedLocationIPs(ips []*net.IPNet, location string) (ret []*net.IPNet) { +func (t *Topology) filterAllowedLocationIPs(ips []net.IPNet, location string) (ret []net.IPNet) { CheckIPs: for _, ip := range ips { for _, s := range t.segments { @@ -270,45 +287,54 @@ CheckIPs: return } -func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key []byte, persistentKeepalive int) *wireguard.Endpoint { +func (t *Topology) updateEndpoint(kiloEndpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *net.UDPAddr { // Do not update non-nat peers - if persistentKeepalive == 0 { - return endpoint + if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) { + return kiloEndpoint.UDPAddr() } - e, ok := t.discoveredEndpoints[string(key)] + e, ok := t.discoveredEndpoints[key.String()] if ok { return e } - return endpoint + return nil } // Conf generates a WireGuard configuration file for a given Topology. func (t *Topology) Conf() *wireguard.Conf { c := &wireguard.Conf{ - Interface: &wireguard.Interface{ - PrivateKey: t.key, - ListenPort: t.port, + Config: wgtypes.Config{ + PrivateKey: &t.key, + ListenPort: &t.port, + ReplacePeers: true, }, } for _, s := range t.segments { if s.location == t.location { continue } - peer := &wireguard.Peer{ - AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...), - Endpoint: t.updateEndpoint(s.endpoint, s.key, s.persistentKeepalive), - PersistentKeepalive: t.persistentKeepalive, - PublicKey: s.key, + peer := wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...), + Endpoint: t.updateEndpoint(s.kiloEndpoint, s.key, &s.persistentKeepalive), + PersistentKeepaliveInterval: &t.persistentKeepalive, + PublicKey: s.key, + ReplaceAllowedIPs: true, + }, + KiloEndpoint: s.kiloEndpoint, } c.Peers = append(c.Peers, peer) } for _, p := range t.peers { - peer := &wireguard.Peer{ - AllowedIPs: p.AllowedIPs, - Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepalive), - PersistentKeepalive: t.persistentKeepalive, - PresharedKey: p.PresharedKey, - PublicKey: p.PublicKey, + peer := wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + AllowedIPs: p.AllowedIPs, + Endpoint: t.updateEndpoint(p.KiloEndpoint, p.PublicKey, p.PersistentKeepaliveInterval), + PersistentKeepaliveInterval: &t.persistentKeepalive, + PresharedKey: p.PresharedKey, + PublicKey: p.PublicKey, + ReplaceAllowedIPs: true, + }, + KiloEndpoint: p.KiloEndpoint, } c.Peers = append(c.Peers, peer) } @@ -317,39 +343,46 @@ func (t *Topology) Conf() *wireguard.Conf { // AsPeer generates the WireGuard peer configuration for the local location of the given Topology. // This configuration can be used to configure this location as a peer of another WireGuard interface. -func (t *Topology) AsPeer() *wireguard.Peer { +func (t *Topology) AsPeer() wireguard.Peer { for _, s := range t.segments { if s.location != t.location { continue } - return &wireguard.Peer{ - AllowedIPs: s.allowedIPs, - Endpoint: s.endpoint, - PublicKey: s.key, + p := wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + AllowedIPs: s.allowedIPs, + PublicKey: s.key, + Endpoint: s.endpoint, + }, + KiloEndpoint: s.kiloEndpoint, } + return p } - return nil + return wireguard.Peer{} } // PeerConf generates a WireGuard configuration file for a given peer in a Topology. -func (t *Topology) PeerConf(name string) *wireguard.Conf { - var pka int - var psk []byte +func (t *Topology) PeerConf(name string) wireguard.Conf { + var pka *time.Duration + var psk *wgtypes.Key for i := range t.peers { if t.peers[i].Name == name { - pka = t.peers[i].PersistentKeepalive + pka = t.peers[i].PersistentKeepaliveInterval psk = t.peers[i].PresharedKey break } } - c := &wireguard.Conf{} + c := wireguard.Conf{} for _, s := range t.segments { - peer := &wireguard.Peer{ - AllowedIPs: s.allowedIPs, - Endpoint: s.endpoint, - PersistentKeepalive: pka, - PresharedKey: psk, - PublicKey: s.key, + peer := wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + AllowedIPs: s.allowedIPs, + Endpoint: s.kiloEndpoint.UDPAddr(), + PersistentKeepaliveInterval: pka, + PresharedKey: psk, + PublicKey: s.key, + }, + KiloEndpoint: s.kiloEndpoint, } c.Peers = append(c.Peers, peer) } @@ -357,11 +390,13 @@ func (t *Topology) PeerConf(name string) *wireguard.Conf { if t.peers[i].Name == name { continue } - peer := &wireguard.Peer{ - AllowedIPs: t.peers[i].AllowedIPs, - PersistentKeepalive: pka, - PublicKey: t.peers[i].PublicKey, - Endpoint: t.peers[i].Endpoint, + peer := wireguard.Peer{ + PeerConfig: wgtypes.PeerConfig{ + AllowedIPs: t.peers[i].AllowedIPs, + PersistentKeepaliveInterval: pka, + PublicKey: t.peers[i].PublicKey, + Endpoint: t.peers[i].Endpoint, + }, } c.Peers = append(c.Peers, peer) } @@ -382,13 +417,13 @@ func findLeader(nodes []*Node) int { var leaders, public []int for i := range nodes { if nodes[i].Leader { - if isPublic(nodes[i].Endpoint.IP) { + if isPublic(nodes[i].KiloEndpoint.IP) { return i } leaders = append(leaders, i) } - if isPublic(nodes[i].Endpoint.IP) { + if isPublic(nodes[i].KiloEndpoint.IP) { public = append(public, i) } } @@ -408,10 +443,13 @@ func deduplicatePeerIPs(peers []*Peer) []*Peer { p := Peer{ Name: peer.Name, Peer: wireguard.Peer{ - Endpoint: peer.Endpoint, - PersistentKeepalive: peer.PersistentKeepalive, - PresharedKey: peer.PresharedKey, - PublicKey: peer.PublicKey, + PeerConfig: wgtypes.PeerConfig{ + Endpoint: peer.Endpoint, + PersistentKeepaliveInterval: peer.PersistentKeepaliveInterval, + PresharedKey: peer.PresharedKey, + PublicKey: peer.PublicKey, + }, + KiloEndpoint: peer.KiloEndpoint, }, } for _, ip := range peer.AllowedIPs { diff --git a/pkg/mesh/topology_test.go b/pkg/mesh/topology_test.go index 33dec97..c6c5ec0 100644 --- a/pkg/mesh/topology_test.go +++ b/pkg/mesh/topology_test.go @@ -18,9 +18,11 @@ import ( "net" "strings" "testing" + "time" "github.com/go-kit/kit/log" "github.com/kylelemons/godebug/pretty" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/squat/kilo/pkg/wireguard" ) @@ -29,17 +31,25 @@ func allowedIPs(ips ...string) string { return strings.Join(ips, ", ") } -func mustParseCIDR(s string) (r *net.IPNet) { +func mustParseCIDR(s string) (r net.IPNet) { if _, ip, err := net.ParseCIDR(s); err != nil { panic("failed to parse CIDR") } else { - r = ip + r = *ip } return } -func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) { - key := []byte("private") +var ( + key1 = wgtypes.Key{'k', 'e', 'y', '1'} + key2 = wgtypes.Key{'k', 'e', 'y', '2'} + key3 = wgtypes.Key{'k', 'e', 'y', '3'} + key4 = wgtypes.Key{'k', 'e', 'y', '4'} + key5 = wgtypes.Key{'k', 'e', 'y', '5'} +) + +func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) { + key := wgtypes.Key{'p', 'r', 'i', 'v'} e1 := &net.IPNet{IP: net.ParseIP("10.1.0.1").To4(), Mask: net.CIDRMask(16, 32)} e2 := &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(16, 32)} e3 := &net.IPNet{IP: net.ParseIP("10.1.0.3").To4(), Mask: net.CIDRMask(16, 32)} @@ -50,62 +60,66 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) { nodes := map[string]*Node{ "a": { Name: "a", - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, InternalIP: i1, Location: "1", Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)}, - Key: []byte("key1"), + Key: key1, PersistentKeepalive: 25, }, "b": { Name: "b", - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, InternalIP: i1, Location: "2", Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)}, - Key: []byte("key2"), - AllowedLocationIPs: []*net.IPNet{i3}, + Key: key2, + AllowedLocationIPs: []net.IPNet{*i3}, }, "c": { - Name: "c", - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e3.IP}, Port: DefaultKiloPort}, - InternalIP: i2, + Name: "c", + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e3.IP}, Port: DefaultKiloPort}, + InternalIP: i2, // Same location as node b. Location: "2", Subnet: &net.IPNet{IP: net.ParseIP("10.2.3.0"), Mask: net.CIDRMask(24, 32)}, - Key: []byte("key3"), + Key: key3, }, "d": { - Name: "d", - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e4.IP}, Port: DefaultKiloPort}, + Name: "d", + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e4.IP}, Port: DefaultKiloPort}, // Same location as node a, but without private IP Location: "1", Subnet: &net.IPNet{IP: net.ParseIP("10.2.4.0"), Mask: net.CIDRMask(24, 32)}, - Key: []byte("key4"), + Key: key4, }, } peers := map[string]*Peer{ "a": { Name: "a", Peer: wireguard.Peer{ - AllowedIPs: []*net.IPNet{ - {IP: net.ParseIP("10.5.0.1"), Mask: net.CIDRMask(24, 32)}, - {IP: net.ParseIP("10.5.0.2"), Mask: net.CIDRMask(24, 32)}, + PeerConfig: wgtypes.PeerConfig{ + AllowedIPs: []net.IPNet{ + {IP: net.ParseIP("10.5.0.1"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("10.5.0.2"), Mask: net.CIDRMask(24, 32)}, + }, + PublicKey: key4, }, - PublicKey: []byte("key4"), }, }, "b": { Name: "b", Peer: wireguard.Peer{ - AllowedIPs: []*net.IPNet{ - {IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)}, + PeerConfig: wgtypes.PeerConfig{ + AllowedIPs: []net.IPNet{ + {IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)}, + }, + PublicKey: key5, }, - Endpoint: &wireguard.Endpoint{ + KiloEndpoint: &wireguard.Endpoint{ DNSOrIP: wireguard.DNSOrIP{IP: net.ParseIP("192.168.0.1")}, Port: DefaultKiloPort, }, - PublicKey: []byte("key5"), }, }, } @@ -138,8 +152,8 @@ func TestNewTopology(t *testing.T) { wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)}, segments: []*segment{ { - allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].Endpoint, + allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["a"].KiloEndpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: logicalLocationPrefix + nodes["a"].Location, @@ -149,8 +163,8 @@ func TestNewTopology(t *testing.T) { wireGuardIP: w1, }, { - allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].Endpoint, + allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["b"].KiloEndpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: logicalLocationPrefix + nodes["b"].Location, @@ -161,8 +175,8 @@ func TestNewTopology(t *testing.T) { allowedLocationIPs: nodes["b"].AllowedLocationIPs, }, { - allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["d"].Endpoint, + allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["d"].KiloEndpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -189,8 +203,8 @@ func TestNewTopology(t *testing.T) { wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)}, segments: []*segment{ { - allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].Endpoint, + allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["a"].KiloEndpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: logicalLocationPrefix + nodes["a"].Location, @@ -200,8 +214,8 @@ func TestNewTopology(t *testing.T) { wireGuardIP: w1, }, { - allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].Endpoint, + allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["b"].KiloEndpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: logicalLocationPrefix + nodes["b"].Location, @@ -212,8 +226,8 @@ func TestNewTopology(t *testing.T) { allowedLocationIPs: nodes["b"].AllowedLocationIPs, }, { - allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["d"].Endpoint, + allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["d"].KiloEndpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -240,8 +254,8 @@ func TestNewTopology(t *testing.T) { wireGuardCIDR: DefaultKiloSubnet, segments: []*segment{ { - allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].Endpoint, + allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["a"].KiloEndpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: logicalLocationPrefix + nodes["a"].Location, @@ -251,8 +265,8 @@ func TestNewTopology(t *testing.T) { wireGuardIP: w1, }, { - allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].Endpoint, + allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, *nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["b"].KiloEndpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: logicalLocationPrefix + nodes["b"].Location, @@ -263,8 +277,8 @@ func TestNewTopology(t *testing.T) { allowedLocationIPs: nodes["b"].AllowedLocationIPs, }, { - allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["d"].Endpoint, + allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["d"].KiloEndpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -291,8 +305,8 @@ func TestNewTopology(t *testing.T) { wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)}, segments: []*segment{ { - allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].Endpoint, + allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["a"].KiloEndpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: nodeLocationPrefix + nodes["a"].Name, @@ -302,8 +316,8 @@ func TestNewTopology(t *testing.T) { wireGuardIP: w1, }, { - allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].Endpoint, + allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["b"].KiloEndpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: nodeLocationPrefix + nodes["b"].Name, @@ -314,8 +328,8 @@ func TestNewTopology(t *testing.T) { allowedLocationIPs: nodes["b"].AllowedLocationIPs, }, { - allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["c"].Endpoint, + allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["c"].KiloEndpoint, key: nodes["c"].Key, persistentKeepalive: nodes["c"].PersistentKeepalive, location: nodeLocationPrefix + nodes["c"].Name, @@ -325,8 +339,8 @@ func TestNewTopology(t *testing.T) { wireGuardIP: w3, }, { - allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["d"].Endpoint, + allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["d"].KiloEndpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -353,8 +367,8 @@ func TestNewTopology(t *testing.T) { wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)}, segments: []*segment{ { - allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].Endpoint, + allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["a"].KiloEndpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: nodeLocationPrefix + nodes["a"].Name, @@ -364,8 +378,8 @@ func TestNewTopology(t *testing.T) { wireGuardIP: w1, }, { - allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].Endpoint, + allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["b"].KiloEndpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: nodeLocationPrefix + nodes["b"].Name, @@ -376,8 +390,8 @@ func TestNewTopology(t *testing.T) { allowedLocationIPs: nodes["b"].AllowedLocationIPs, }, { - allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["c"].Endpoint, + allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["c"].KiloEndpoint, key: nodes["c"].Key, persistentKeepalive: nodes["c"].PersistentKeepalive, location: nodeLocationPrefix + nodes["c"].Name, @@ -387,8 +401,8 @@ func TestNewTopology(t *testing.T) { wireGuardIP: w3, }, { - allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["d"].Endpoint, + allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["d"].KiloEndpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -415,8 +429,8 @@ func TestNewTopology(t *testing.T) { wireGuardCIDR: &net.IPNet{IP: w3, Mask: net.CIDRMask(16, 32)}, segments: []*segment{ { - allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].Endpoint, + allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["a"].KiloEndpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: nodeLocationPrefix + nodes["a"].Name, @@ -426,8 +440,8 @@ func TestNewTopology(t *testing.T) { wireGuardIP: w1, }, { - allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].Endpoint, + allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["b"].KiloEndpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: nodeLocationPrefix + nodes["b"].Name, @@ -438,8 +452,8 @@ func TestNewTopology(t *testing.T) { allowedLocationIPs: nodes["b"].AllowedLocationIPs, }, { - allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["c"].Endpoint, + allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["c"].KiloEndpoint, key: nodes["c"].Key, persistentKeepalive: nodes["c"].PersistentKeepalive, location: nodeLocationPrefix + nodes["c"].Name, @@ -449,8 +463,8 @@ func TestNewTopology(t *testing.T) { wireGuardIP: w3, }, { - allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["d"].Endpoint, + allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["d"].KiloEndpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -477,8 +491,8 @@ func TestNewTopology(t *testing.T) { wireGuardCIDR: &net.IPNet{IP: w4, Mask: net.CIDRMask(16, 32)}, segments: []*segment{ { - allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["a"].Endpoint, + allowedIPs: []net.IPNet{*nodes["a"].Subnet, *nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["a"].KiloEndpoint, key: nodes["a"].Key, persistentKeepalive: nodes["a"].PersistentKeepalive, location: nodeLocationPrefix + nodes["a"].Name, @@ -488,8 +502,8 @@ func TestNewTopology(t *testing.T) { wireGuardIP: w1, }, { - allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["b"].Endpoint, + allowedIPs: []net.IPNet{*nodes["b"].Subnet, *nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["b"].KiloEndpoint, key: nodes["b"].Key, persistentKeepalive: nodes["b"].PersistentKeepalive, location: nodeLocationPrefix + nodes["b"].Name, @@ -500,8 +514,8 @@ func TestNewTopology(t *testing.T) { allowedLocationIPs: nodes["b"].AllowedLocationIPs, }, { - allowedIPs: []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["c"].Endpoint, + allowedIPs: []net.IPNet{*nodes["c"].Subnet, *nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["c"].KiloEndpoint, key: nodes["c"].Key, persistentKeepalive: nodes["c"].PersistentKeepalive, location: nodeLocationPrefix + nodes["c"].Name, @@ -511,8 +525,8 @@ func TestNewTopology(t *testing.T) { wireGuardIP: w3, }, { - allowedIPs: []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, - endpoint: nodes["d"].Endpoint, + allowedIPs: []net.IPNet{*nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}}, + kiloEndpoint: nodes["d"].KiloEndpoint, key: nodes["d"].Key, persistentKeepalive: nodes["d"].PersistentKeepalive, location: nodeLocationPrefix + nodes["d"].Name, @@ -539,7 +553,7 @@ func TestNewTopology(t *testing.T) { } } -func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int) *Topology { +func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration) *Topology { topo, err := NewTopology(nodes, peers, granularity, hostname, port, key, subnet, persistentKeepalive, nil) if err != nil { t.Errorf("failed to generate Topology: %v", err) @@ -547,211 +561,6 @@ func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, gran return topo } -func TestConf(t *testing.T) { - nodes, peers, key, port := setup(t) - for _, tc := range []struct { - name string - topology *Topology - result string - }{ - { - name: "logical from a", - topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive), - result: `[Interface] -PrivateKey = private -ListenPort = 51820 - -[Peer] -PublicKey = key2 -Endpoint = 10.1.0.2:51820 -AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32, 192.168.178.3/32 -PersistentKeepalive = 25 - -[Peer] -PublicKey = key4 -Endpoint = 10.1.0.4:51820 -AllowedIPs = 10.2.4.0/24, 10.4.0.3/32 -PersistentKeepalive = 25 - -[Peer] -PublicKey = key4 -AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 -PersistentKeepalive = 25 - -[Peer] -PublicKey = key5 -Endpoint = 192.168.0.1:51820 -AllowedIPs = 10.5.0.3/24 -PersistentKeepalive = 25 -`, - }, - { - name: "logical from b", - topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet, nodes["b"].PersistentKeepalive), - result: `[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - PublicKey = key1 - Endpoint = 10.1.0.1:51820 - AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32 - - [Peer] - PublicKey = key4 - Endpoint = 10.1.0.4:51820 - AllowedIPs = 10.2.4.0/24, 10.4.0.3/32 - - [Peer] - PublicKey = key4 - AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 - - [Peer] - PublicKey = key5 - Endpoint = 192.168.0.1:51820 - AllowedIPs = 10.5.0.3/24 - `, - }, - { - name: "logical from c", - topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet, nodes["c"].PersistentKeepalive), - result: `[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - PublicKey = key1 - Endpoint = 10.1.0.1:51820 - AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32 - - [Peer] - PublicKey = key4 - Endpoint = 10.1.0.4:51820 - AllowedIPs = 10.2.4.0/24, 10.4.0.3/32 - - [Peer] - PublicKey = key4 - AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 - - [Peer] - PublicKey = key5 - Endpoint = 192.168.0.1:51820 - AllowedIPs = 10.5.0.3/24 - `, - }, - { - name: "full from a", - topology: mustTopo(t, nodes, peers, FullGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive), - result: `[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - PublicKey = key2 - Endpoint = 10.1.0.2:51820 - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32, 192.168.178.3/32 - PersistentKeepalive = 25 - - [Peer] - PublicKey = key3 - Endpoint = 10.1.0.3:51820 - AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32 - PersistentKeepalive = 25 - - [Peer] - PublicKey = key4 - Endpoint = 10.1.0.4:51820 - AllowedIPs = 10.2.4.0/24, 10.4.0.4/32 - PersistentKeepalive = 25 - - [Peer] - PublicKey = key4 - AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 - PersistentKeepalive = 25 - - [Peer] - PublicKey = key5 - Endpoint = 192.168.0.1:51820 - AllowedIPs = 10.5.0.3/24 - PersistentKeepalive = 25 - `, - }, - { - name: "full from b", - topology: mustTopo(t, nodes, peers, FullGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet, nodes["b"].PersistentKeepalive), - result: `[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - PublicKey = key1 - Endpoint = 10.1.0.1:51820 - AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32 - - [Peer] - PublicKey = key3 - Endpoint = 10.1.0.3:51820 - AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32 - - [Peer] - PublicKey = key4 - Endpoint = 10.1.0.4:51820 - AllowedIPs = 10.2.4.0/24, 10.4.0.4/32 - - [Peer] - PublicKey = key4 - AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 - - [Peer] - PublicKey = key5 - Endpoint = 192.168.0.1:51820 - AllowedIPs = 10.5.0.3/24 - `, - }, - { - name: "full from c", - topology: mustTopo(t, nodes, peers, FullGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet, nodes["c"].PersistentKeepalive), - result: `[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - PublicKey = key1 - Endpoint = 10.1.0.1:51820 - AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32 - - [Peer] - PublicKey = key2 - Endpoint = 10.1.0.2:51820 - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32, 192.168.178.3/32 - - [Peer] - PublicKey = key4 - Endpoint = 10.1.0.4:51820 - AllowedIPs = 10.2.4.0/24, 10.4.0.4/32 - - [Peer] - PublicKey = key4 - AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 - - [Peer] - PublicKey = key5 - Endpoint = 192.168.0.1:51820 - AllowedIPs = 10.5.0.3/24 - `, - }, - } { - conf := tc.topology.Conf() - if !conf.Equal(wireguard.Parse([]byte(tc.result))) { - buf, err := conf.Bytes() - if err != nil { - t.Errorf("test case %q: failed to render conf: %v", tc.name, err) - } - t.Errorf("test case %q: expected %s got %s", tc.name, tc.result, string(buf)) - } - } -} - func TestFindLeader(t *testing.T) { ip, e1, err := net.ParseCIDR("10.0.0.1/32") if err != nil { @@ -766,26 +575,26 @@ func TestFindLeader(t *testing.T) { nodes := []*Node{ { - Name: "a", - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, + Name: "a", + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, }, { - Name: "b", - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, + Name: "b", + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, }, { - Name: "c", - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, + Name: "c", + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, }, { - Name: "d", - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, - Leader: true, + Name: "d", + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e1.IP}, Port: DefaultKiloPort}, + Leader: true, }, { - Name: "2", - Endpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, - Leader: true, + Name: "2", + KiloEndpoint: &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort}, + Leader: true, }, } for _, tc := range []struct { @@ -840,31 +649,38 @@ func TestDeduplicatePeerIPs(t *testing.T) { p1 := &Peer{ Name: "1", Peer: wireguard.Peer{ - PublicKey: []byte("key1"), - AllowedIPs: []*net.IPNet{ - {IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)}, - {IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)}, + PeerConfig: wgtypes.PeerConfig{ + + PublicKey: key1, + AllowedIPs: []net.IPNet{ + {IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)}, + }, }, }, } p2 := &Peer{ Name: "2", Peer: wireguard.Peer{ - PublicKey: []byte("key2"), - AllowedIPs: []*net.IPNet{ - {IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)}, - {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, + PeerConfig: wgtypes.PeerConfig{ + PublicKey: key2, + AllowedIPs: []net.IPNet{ + {IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, + }, }, }, } p3 := &Peer{ Name: "3", Peer: wireguard.Peer{ - PublicKey: []byte("key3"), - AllowedIPs: []*net.IPNet{ - {IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)}, - {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, - {IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)}, + PeerConfig: wgtypes.PeerConfig{ + PublicKey: key3, + AllowedIPs: []net.IPNet{ + {IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)}, + }, }, }, } @@ -872,10 +688,12 @@ func TestDeduplicatePeerIPs(t *testing.T) { p4 := &Peer{ Name: "4", Peer: wireguard.Peer{ - PublicKey: []byte("key4"), - AllowedIPs: []*net.IPNet{ - {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, - {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, + PeerConfig: wgtypes.PeerConfig{ + PublicKey: key4, + AllowedIPs: []net.IPNet{ + {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, + }, }, }, } @@ -898,9 +716,11 @@ func TestDeduplicatePeerIPs(t *testing.T) { { Name: "2", Peer: wireguard.Peer{ - PublicKey: []byte("key2"), - AllowedIPs: []*net.IPNet{ - {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, + PeerConfig: wgtypes.PeerConfig{ + PublicKey: key2, + AllowedIPs: []net.IPNet{ + {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, + }, }, }, }, @@ -914,9 +734,11 @@ func TestDeduplicatePeerIPs(t *testing.T) { { Name: "1", Peer: wireguard.Peer{ - PublicKey: []byte("key1"), - AllowedIPs: []*net.IPNet{ - {IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)}, + PeerConfig: wgtypes.PeerConfig{ + PublicKey: key1, + AllowedIPs: []net.IPNet{ + {IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)}, + }, }, }, }, @@ -930,19 +752,25 @@ func TestDeduplicatePeerIPs(t *testing.T) { { Name: "2", Peer: wireguard.Peer{ - PublicKey: []byte("key2"), + PeerConfig: wgtypes.PeerConfig{ + PublicKey: key2, + }, }, }, { Name: "1", Peer: wireguard.Peer{ - PublicKey: []byte("key1"), + PeerConfig: wgtypes.PeerConfig{ + PublicKey: key1, + }, }, }, { Name: "4", Peer: wireguard.Peer{ - PublicKey: []byte("key4"), + PeerConfig: wgtypes.PeerConfig{ + PublicKey: key4, + }, }, }, }, @@ -954,19 +782,23 @@ func TestDeduplicatePeerIPs(t *testing.T) { { Name: "4", Peer: wireguard.Peer{ - PublicKey: []byte("key4"), - AllowedIPs: []*net.IPNet{ - {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, + PeerConfig: wgtypes.PeerConfig{ + PublicKey: key4, + AllowedIPs: []net.IPNet{ + {IP: net.ParseIP("10.0.0.3"), Mask: net.CIDRMask(24, 32)}, + }, }, }, }, { Name: "1", Peer: wireguard.Peer{ - PublicKey: []byte("key1"), - AllowedIPs: []*net.IPNet{ - {IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)}, - {IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)}, + PeerConfig: wgtypes.PeerConfig{ + PublicKey: key1, + AllowedIPs: []net.IPNet{ + {IP: net.ParseIP("10.0.0.1"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(24, 32)}, + }, }, }, }, @@ -985,12 +817,12 @@ func TestFilterAllowedIPs(t *testing.T) { topo := mustTopo(t, nodes, peers, LogicalGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive) for _, tc := range []struct { name string - allowedLocationIPs map[int][]*net.IPNet - result map[int][]*net.IPNet + allowedLocationIPs map[int][]net.IPNet + result map[int][]net.IPNet }{ { name: "nothing to filter", - allowedLocationIPs: map[int][]*net.IPNet{ + allowedLocationIPs: map[int][]net.IPNet{ 0: { mustParseCIDR("192.168.178.4/32"), }, @@ -1002,7 +834,7 @@ func TestFilterAllowedIPs(t *testing.T) { mustParseCIDR("192.168.178.7/32"), }, }, - result: map[int][]*net.IPNet{ + result: map[int][]net.IPNet{ 0: { mustParseCIDR("192.168.178.4/32"), }, @@ -1017,7 +849,7 @@ func TestFilterAllowedIPs(t *testing.T) { }, { name: "intersections between segments", - allowedLocationIPs: map[int][]*net.IPNet{ + allowedLocationIPs: map[int][]net.IPNet{ 0: { mustParseCIDR("192.168.178.4/32"), mustParseCIDR("192.168.178.8/32"), @@ -1031,7 +863,7 @@ func TestFilterAllowedIPs(t *testing.T) { mustParseCIDR("192.168.178.4/32"), }, }, - result: map[int][]*net.IPNet{ + result: map[int][]net.IPNet{ 0: { mustParseCIDR("192.168.178.8/32"), }, @@ -1047,7 +879,7 @@ func TestFilterAllowedIPs(t *testing.T) { }, { name: "intersections with wireGuardCIDR", - allowedLocationIPs: map[int][]*net.IPNet{ + allowedLocationIPs: map[int][]net.IPNet{ 0: { mustParseCIDR("10.4.0.1/32"), mustParseCIDR("192.168.178.8/32"), @@ -1060,7 +892,7 @@ func TestFilterAllowedIPs(t *testing.T) { mustParseCIDR("192.168.178.7/32"), }, }, - result: map[int][]*net.IPNet{ + result: map[int][]net.IPNet{ 0: { mustParseCIDR("192.168.178.8/32"), }, @@ -1075,7 +907,7 @@ func TestFilterAllowedIPs(t *testing.T) { }, { name: "intersections with more than one allowedLocationIPs", - allowedLocationIPs: map[int][]*net.IPNet{ + allowedLocationIPs: map[int][]net.IPNet{ 0: { mustParseCIDR("192.168.178.8/32"), }, @@ -1086,7 +918,7 @@ func TestFilterAllowedIPs(t *testing.T) { mustParseCIDR("192.168.178.7/24"), }, }, - result: map[int][]*net.IPNet{ + result: map[int][]net.IPNet{ 0: {}, 1: {}, 2: { diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index 7cef64d..3f624a1 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -15,27 +15,20 @@ package wireguard import ( - "bufio" "bytes" - "errors" "fmt" "net" "sort" "strconv" - "strings" "time" - "k8s.io/apimachinery/pkg/util/validation" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) type section string type key string const ( - separator = "=" - dumpSeparator = "\t" - dumpNone = "(none)" - dumpOff = "off" interfaceSection section = "Interface" peerSection section = "Peer" listenPortKey key = "ListenPort" @@ -47,34 +40,24 @@ const ( publicKeyKey key = "PublicKey" ) -type dumpInterfaceIndex int - -const ( - dumpInterfacePrivateKeyIndex = iota - dumpInterfacePublicKeyIndex - dumpInterfaceListenPortIndex - dumpInterfaceFWMarkIndex - dumpInterfaceLen -) - -type dumpPeerIndex int - -const ( - dumpPeerPublicKeyIndex = iota - dumpPeerPresharedKeyIndex - dumpPeerEndpointIndex - dumpPeerAllowedIPsIndex - dumpPeerLatestHandshakeIndex - dumpPeerTransferRXIndex - dumpPeerTransferTXIndex - dumpPeerPersistentKeepaliveIndex - dumpPeerLen -) - // Conf represents a WireGuard configuration file. type Conf struct { - Interface *Interface - Peers []*Peer + wgtypes.Config + // The Peers field is shadowed because every Peer needs the KiloEndpoint field that contains a DNS endpoint. + Peers []Peer +} + +// WGConfig returns a wgytpes.Config from a Conf. +func (c Conf) WGConfig() wgtypes.Config { + r := c.Config + wgPs := make([]wgtypes.PeerConfig, len(c.Peers)) + for i, p := range c.Peers { + wgPs[i] = p.PeerConfig + wgPs[i].ReplaceAllowedIPs = true + } + r.Peers = wgPs + r.ReplacePeers = true + return r } // Interface represents the `interface` section of a WireGuard configuration. @@ -85,18 +68,13 @@ type Interface struct { // Peer represents a `peer` section of a WireGuard configuration. type Peer struct { - AllowedIPs []*net.IPNet - Endpoint *Endpoint - PersistentKeepalive int - PresharedKey []byte - PublicKey []byte - // The following fields are part of the runtime information, not the configuration. - LatestHandshake time.Time + wgtypes.PeerConfig + KiloEndpoint *Endpoint } // DeduplicateIPs eliminates duplicate allowed IPs. func (p *Peer) DeduplicateIPs() { - var ips []*net.IPNet + var ips []net.IPNet seen := make(map[string]struct{}) for _, ip := range p.AllowedIPs { if _, ok := seen[ip.String()]; ok { @@ -111,7 +89,7 @@ func (p *Peer) DeduplicateIPs() { // Endpoint represents an `endpoint` key of a `peer` section. type Endpoint struct { DNSOrIP - Port uint32 + Port int } // String prints the string representation of the endpoint. @@ -126,6 +104,14 @@ func (e *Endpoint) String() string { return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10) } +// UDPAddr returns the corresponding net.UDPAddr of the Endpoint or nil. +func (e *Endpoint) UDPAddr() (u *net.UDPAddr) { + if a, err := net.ResolveUDPAddr("udp", e.String()); err == nil { + u = a + } + return +} + // Equal compares two endpoints. func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool { if (e == nil) != (b == nil) { @@ -173,116 +159,24 @@ func (d DNSOrIP) String() string { return d.DNS } -// Parse parses a given WireGuard configuration file and produces a Conf struct. -func Parse(buf []byte) *Conf { - var ( - active section - kv []string - c Conf - err error - iface *Interface - i int - k key - line, v string - peer *Peer - port uint64 - ) - s := bufio.NewScanner(bytes.NewBuffer(buf)) - for s.Scan() { - line = strings.TrimSpace(s.Text()) - // Skip comments. - if strings.HasPrefix(line, "#") { - continue - } - // Line is a section title. - if strings.HasPrefix(line, "[") { - if peer != nil { - c.Peers = append(c.Peers, peer) - peer = nil - } - if iface != nil { - c.Interface = iface - iface = nil - } - active = section(strings.TrimSpace(strings.Trim(line, "[]"))) - switch active { - case interfaceSection: - iface = new(Interface) - case peerSection: - peer = new(Peer) - } - continue - } - kv = strings.SplitN(line, separator, 2) - if len(kv) != 2 { - continue - } - k = key(strings.TrimSpace(kv[0])) - v = strings.TrimSpace(kv[1]) - switch active { - case interfaceSection: - switch k { - case listenPortKey: - port, err = strconv.ParseUint(v, 10, 32) - if err != nil { - continue - } - iface.ListenPort = uint32(port) - case privateKeyKey: - iface.PrivateKey = []byte(v) - } - case peerSection: - switch k { - case allowedIPsKey: - err = peer.parseAllowedIPs(v) - if err != nil { - continue - } - case endpointKey: - err = peer.parseEndpoint(v) - if err != nil { - continue - } - case persistentKeepaliveKey: - i, err = strconv.Atoi(v) - if err != nil { - continue - } - peer.PersistentKeepalive = i - case presharedKeyKey: - peer.PresharedKey = []byte(v) - case publicKeyKey: - peer.PublicKey = []byte(v) - } - } - } - if peer != nil { - c.Peers = append(c.Peers, peer) - } - if iface != nil { - c.Interface = iface - } - return &c -} - // Bytes renders a WireGuard configuration to bytes. -func (c *Conf) Bytes() ([]byte, error) { +func (c Conf) Bytes() ([]byte, error) { var err error buf := bytes.NewBuffer(make([]byte, 0, 512)) - if c.Interface != nil { + if c.PrivateKey != nil { if err = writeSection(buf, interfaceSection); err != nil { return nil, fmt.Errorf("failed to write interface: %v", err) } - if err = writePKey(buf, privateKeyKey, c.Interface.PrivateKey); err != nil { + if err = writePKey(buf, privateKeyKey, c.PrivateKey); err != nil { return nil, fmt.Errorf("failed to write private key: %v", err) } - if err = writeValue(buf, listenPortKey, strconv.FormatUint(uint64(c.Interface.ListenPort), 10)); err != nil { + if err = writeValue(buf, listenPortKey, strconv.Itoa(*c.ListenPort)); err != nil { return nil, fmt.Errorf("failed to write listen port: %v", err) } } for i, p := range c.Peers { // Add newlines to make the formatting nicer. - if i == 0 && c.Interface != nil || i != 0 { + if i == 0 && c.PrivateKey != nil || i != 0 { if err = buf.WriteByte('\n'); err != nil { return nil, err } @@ -294,74 +188,99 @@ func (c *Conf) Bytes() ([]byte, error) { if err = writeAllowedIPs(buf, p.AllowedIPs); err != nil { return nil, fmt.Errorf("failed to write allowed IPs: %v", err) } - if err = writeEndpoint(buf, p.Endpoint); err != nil { + if err = writeEndpoint(buf, p.KiloEndpoint); err != nil { return nil, fmt.Errorf("failed to write endpoint: %v", err) } - if err = writeValue(buf, persistentKeepaliveKey, strconv.Itoa(p.PersistentKeepalive)); err != nil { + if p.PersistentKeepaliveInterval == nil { + p.PersistentKeepaliveInterval = new(time.Duration) + } + if err = writeValue(buf, persistentKeepaliveKey, strconv.FormatUint(uint64(*p.PersistentKeepaliveInterval), 10)); err != nil { return nil, fmt.Errorf("failed to write persistent keepalive: %v", err) } if err = writePKey(buf, presharedKeyKey, p.PresharedKey); err != nil { return nil, fmt.Errorf("failed to write preshared key: %v", err) } - if err = writePKey(buf, publicKeyKey, p.PublicKey); err != nil { + if err = writePKey(buf, publicKeyKey, &p.PublicKey); err != nil { return nil, fmt.Errorf("failed to write public key: %v", err) } } return buf.Bytes(), nil } -// Equal checks if two WireGuard configurations are equivalent. -func (c *Conf) Equal(b *Conf) bool { - if (c.Interface == nil) != (b.Interface == nil) { - return false +// Equal returns true if the Conf and wgtypes.Device are equal. +func (c *Conf) Equal(d *wgtypes.Device) (bool, string) { + if c == nil || d == nil { + return c == nil && d == nil, "nil values" } - if c.Interface != nil { - if c.Interface.ListenPort != b.Interface.ListenPort || !bytes.Equal(c.Interface.PrivateKey, b.Interface.PrivateKey) { - return false - } + if c.ListenPort == nil || *c.ListenPort != d.ListenPort { + return false, fmt.Sprintf("port: old=%q, new=\"%v\"", d.ListenPort, c.ListenPort) } - if len(c.Peers) != len(b.Peers) { - return false + if c.PrivateKey == nil || *c.PrivateKey != d.PrivateKey { + return false, fmt.Sprintf("private key: old=\"%s...\", new=\"%s\"", d.PrivateKey.String()[0:5], c.PrivateKey.String()[0:5]) } + if len(c.Peers) != len(d.Peers) { + return false, fmt.Sprintf("number of peers: old=%d, new=%d", len(d.Peers), len(c.Peers)) + } + sortPeerConfigs(d.Peers) sortPeers(c.Peers) - sortPeers(b.Peers) for i := range c.Peers { - if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) { - return false + if len(c.Peers[i].AllowedIPs) != len(d.Peers[i].AllowedIPs) { + return false, fmt.Sprintf("Peer %d allowed IP length: old=%d, new=%d", i, len(d.Peers[i].AllowedIPs), len(c.Peers[i].AllowedIPs)) } sortCIDRs(c.Peers[i].AllowedIPs) - sortCIDRs(b.Peers[i].AllowedIPs) + sortCIDRs(d.Peers[i].AllowedIPs) for j := range c.Peers[i].AllowedIPs { - if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() { - return false + if c.Peers[i].AllowedIPs[j].String() != d.Peers[i].AllowedIPs[j].String() { + return false, fmt.Sprintf("Peer %d allowed IP: old=%q, new=%q", i, d.Peers[i].AllowedIPs[j].String(), c.Peers[i].AllowedIPs[j].String()) } } - if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint, false) { - return false + if c.Peers[i].Endpoint == nil || d.Peers[i].Endpoint == nil { + return c.Peers[i].Endpoint == nil && d.Peers[i].Endpoint == nil, "peer endpoints: nil value" } - if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) { - return false + if !c.Peers[i].Endpoint.IP.Equal(d.Peers[i].Endpoint.IP) || c.Peers[i].Endpoint.Port != d.Peers[i].Endpoint.Port { + return false, fmt.Sprintf("Peer %d endpoint: old=%q, new=%q", i, d.Peers[i].Endpoint.String(), c.Peers[i].Endpoint.String()) + } + + pki := time.Duration(0) + if p := c.Peers[i].PersistentKeepaliveInterval; p != nil { + pki = *p + } + psk := wgtypes.Key{} + if p := c.Peers[i].PresharedKey; p != nil { + psk = *p + } + if pki != d.Peers[i].PersistentKeepaliveInterval || psk != d.Peers[i].PresharedKey || c.Peers[i].PublicKey != d.Peers[i].PublicKey { + return false, "persistent keepalive or pershared key" } } - return true + return true, "" } -func sortPeers(peers []*Peer) { +func sortPeerConfigs(peers []wgtypes.Peer) { sort.Slice(peers, func(i, j int) bool { - if bytes.Compare(peers[i].PublicKey, peers[j].PublicKey) < 0 { + if peers[i].PublicKey.String() < peers[j].PublicKey.String() { return true } return false }) } -func sortCIDRs(cidrs []*net.IPNet) { +func sortPeers(peers []Peer) { + sort.Slice(peers, func(i, j int) bool { + if peers[i].PublicKey.String() < peers[j].PublicKey.String() { + return true + } + return false + }) +} + +func sortCIDRs(cidrs []net.IPNet) { sort.Slice(cidrs, func(i, j int) bool { return cidrs[i].String() < cidrs[j].String() }) } -func writeAllowedIPs(buf *bytes.Buffer, ais []*net.IPNet) error { +func writeAllowedIPs(buf *bytes.Buffer, ais []net.IPNet) error { if len(ais) == 0 { return nil } @@ -382,15 +301,16 @@ func writeAllowedIPs(buf *bytes.Buffer, ais []*net.IPNet) error { return buf.WriteByte('\n') } -func writePKey(buf *bytes.Buffer, k key, b []byte) error { - if len(b) == 0 { +func writePKey(buf *bytes.Buffer, k key, b *wgtypes.Key) error { + // Print nothing if the public key was never initialized. + if b == nil || (wgtypes.Key{}) == *b { return nil } var err error if err = writeKey(buf, k); err != nil { return err } - if _, err = buf.Write(b); err != nil { + if _, err = buf.Write([]byte(b.String())); err != nil { return err } return buf.WriteByte('\n') @@ -443,177 +363,3 @@ func writeKey(buf *bytes.Buffer, k key) error { _, err = buf.WriteString(" = ") return err } - -var ( - errParseEndpoint = errors.New("could not parse Endpoint") -) - -func (p *Peer) parseEndpoint(v string) error { - var ( - kv []string - err error - ip, ip4 net.IP - port uint64 - ) - kv = strings.Split(v, ":") - if len(kv) < 2 { - return errParseEndpoint - } - port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32) - if err != nil { - return err - } - d := DNSOrIP{} - ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]")) - if ip == nil { - if len(validation.IsDNS1123Subdomain(kv[0])) != 0 { - return errParseEndpoint - } - d.DNS = kv[0] - } else { - if ip4 = ip.To4(); ip4 != nil { - d.IP = ip4 - } else { - d.IP = ip.To16() - } - } - - p.Endpoint = &Endpoint{ - DNSOrIP: d, - Port: uint32(port), - } - return nil -} - -func (p *Peer) parseAllowedIPs(v string) error { - var ( - ai *net.IPNet - kv []string - err error - i int - ip, ip4 net.IP - ) - - kv = strings.Split(v, ",") - for i = range kv { - ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i])) - if err != nil { - return err - } - if ip4 = ip.To4(); ip4 != nil { - ip = ip4 - } else { - ip = ip.To16() - } - ai.IP = ip - p.AllowedIPs = append(p.AllowedIPs, ai) - } - return nil -} - -// ParseDump parses a given WireGuard dump and produces a Conf struct. -func ParseDump(buf []byte) (*Conf, error) { - // from man wg, show section: - // If dump is specified, then several lines are printed; - // the first contains in order separated by tab: private-key, public-key, listen-port, fw‐mark. - // Subsequent lines are printed for each peer and contain in order separated by tab: - // public-key, preshared-key, endpoint, allowed-ips, latest-handshake, transfer-rx, transfer-tx, persistent-keepalive. - var ( - active section - values []string - c Conf - err error - iface *Interface - peer *Peer - port uint64 - sec int64 - pka int - line int - ) - // First line is Interface - active = interfaceSection - s := bufio.NewScanner(bytes.NewBuffer(buf)) - for s.Scan() { - values = strings.Split(s.Text(), dumpSeparator) - - switch active { - case interfaceSection: - if len(values) < dumpInterfaceLen { - return nil, fmt.Errorf("invalid interface line: missing fields (%d < %d)", len(values), dumpInterfaceLen) - } - iface = new(Interface) - for i := range values { - switch i { - case dumpInterfacePrivateKeyIndex: - iface.PrivateKey = []byte(values[i]) - case dumpInterfaceListenPortIndex: - port, err = strconv.ParseUint(values[i], 10, 32) - if err != nil { - return nil, fmt.Errorf("invalid interface line: error parsing listen-port: %w", err) - } - iface.ListenPort = uint32(port) - } - } - c.Interface = iface - // Next lines are Peers - active = peerSection - case peerSection: - if len(values) < dumpPeerLen { - return nil, fmt.Errorf("invalid peer line %d: missing fields (%d < %d)", line, len(values), dumpPeerLen) - } - peer = new(Peer) - - for i := range values { - switch i { - case dumpPeerPublicKeyIndex: - peer.PublicKey = []byte(values[i]) - case dumpPeerPresharedKeyIndex: - if values[i] == dumpNone { - continue - } - peer.PresharedKey = []byte(values[i]) - case dumpPeerEndpointIndex: - if values[i] == dumpNone { - continue - } - err = peer.parseEndpoint(values[i]) - if err != nil { - return nil, fmt.Errorf("invalid peer line %d: error parsing endpoint: %w", line, err) - } - case dumpPeerAllowedIPsIndex: - if values[i] == dumpNone { - continue - } - err = peer.parseAllowedIPs(values[i]) - if err != nil { - return nil, fmt.Errorf("invalid peer line %d: error parsing allowed-ips: %w", line, err) - } - case dumpPeerLatestHandshakeIndex: - if values[i] == "0" { - // Use go zero value, not unix 0 timestamp. - peer.LatestHandshake = time.Time{} - continue - } - sec, err = strconv.ParseInt(values[i], 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid peer line %d: error parsing latest-handshake: %w", line, err) - } - peer.LatestHandshake = time.Unix(sec, 0) - case dumpPeerPersistentKeepaliveIndex: - if values[i] == dumpOff { - continue - } - pka, err = strconv.Atoi(values[i]) - if err != nil { - return nil, fmt.Errorf("invalid peer line %d: error parsing persistent-keepalive: %w", line, err) - } - peer.PersistentKeepalive = pka - } - } - c.Peers = append(c.Peers, peer) - peer = nil - } - line++ - } - return &c, nil -} diff --git a/pkg/wireguard/conf_test.go b/pkg/wireguard/conf_test.go index aa3e782..629c2fd 100644 --- a/pkg/wireguard/conf_test.go +++ b/pkg/wireguard/conf_test.go @@ -1,4 +1,4 @@ -// Copyright 2019 the Kilo authors +// Copyright 2021 the Kilo authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,196 +17,8 @@ package wireguard import ( "net" "testing" - - "github.com/kylelemons/godebug/pretty" ) -func TestCompareConf(t *testing.T) { - for _, tc := range []struct { - name string - a []byte - b []byte - out bool - }{ - { - name: "empty", - a: []byte{}, - b: []byte{}, - out: true, - }, - { - name: "key and value order", - a: []byte(`[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - Endpoint = 10.1.0.2:51820 - PresharedKey = psk - PublicKey = key - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - `), - b: []byte(`[Interface] - ListenPort = 51820 - PrivateKey = private - - [Peer] - PublicKey = key - AllowedIPs = 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32, 10.2.2.0/24 - PresharedKey = psk - Endpoint = 10.1.0.2:51820 - `), - out: true, - }, - { - name: "whitespace", - a: []byte(`[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - Endpoint = 10.1.0.2:51820 - PresharedKey = psk - PublicKey = key - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - `), - b: []byte(`[Interface] - PrivateKey=private - ListenPort=51820 - [Peer] - Endpoint=10.1.0.2:51820 - PresharedKey = psk - PublicKey=key - AllowedIPs=10.2.2.0/24,192.168.0.1/32,10.2.3.0/24,192.168.0.2/32,10.4.0.2/32 - `), - out: true, - }, - { - name: "missing key", - a: []byte(`[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - Endpoint = 10.1.0.2:51820 - PublicKey = key - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - `), - b: []byte(`[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - PublicKey = key - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - `), - out: false, - }, - { - name: "different value", - a: []byte(`[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - Endpoint = 10.1.0.2:51820 - PublicKey = key - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - `), - b: []byte(`[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - Endpoint = 10.1.0.2:51820 - PublicKey = key2 - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - `), - out: false, - }, - { - name: "section order", - a: []byte(`[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - Endpoint = 10.1.0.2:51820 - PresharedKey = psk - PublicKey = key - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - `), - b: []byte(`[Peer] - Endpoint = 10.1.0.2:51820 - PresharedKey = psk - PublicKey = key - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - - [Interface] - PrivateKey = private - ListenPort = 51820 - `), - out: true, - }, - { - name: "out of order peers", - a: []byte(`[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - Endpoint = 10.1.0.2:51820 - PresharedKey = psk2 - PublicKey = key2 - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - - [Peer] - Endpoint = 10.1.0.2:51820 - PresharedKey = psk1 - PublicKey = key1 - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - `), - b: []byte(`[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - Endpoint = 10.1.0.2:51820 - PresharedKey = psk1 - PublicKey = key1 - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - - [Peer] - Endpoint = 10.1.0.2:51820 - PresharedKey = psk2 - PublicKey = key2 - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - `), - out: true, - }, - { - name: "one empty", - a: []byte(`[Interface] - PrivateKey = private - ListenPort = 51820 - - [Peer] - Endpoint = 10.1.0.2:51820 - PresharedKey = psk - PublicKey = key - AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 - `), - b: []byte(``), - out: false, - }, - } { - equal := Parse(tc.a).Equal(Parse(tc.b)) - if equal != tc.out { - t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal) - } - } -} - func TestCompareEndpoint(t *testing.T) { for _, tc := range []struct { name string @@ -310,47 +122,3 @@ func TestCompareEndpoint(t *testing.T) { } } } - -func TestCompareDumpConf(t *testing.T) { - for _, tc := range []struct { - name string - d []byte - c []byte - }{ - { - name: "empty", - d: []byte{}, - c: []byte{}, - }, - { - name: "redacted copy from wg output", - d: []byte(`private B7qk8EMlob0nfado0ABM6HulUV607r4yqtBKjhap7S4= 51820 off -key1 (none) 10.254.1.1:51820 100.64.1.0/24,192.168.0.125/32,10.4.0.1/32 1619012801 67048 34952 10 -key2 (none) 10.254.2.1:51820 100.64.4.0/24,10.69.76.55/32,100.64.3.0/24,10.66.25.131/32,10.4.0.2/32 1619013058 1134456 10077852 10`), - c: []byte(`[Interface] - ListenPort = 51820 - PrivateKey = private - - [Peer] - PublicKey = key1 - AllowedIPs = 100.64.1.0/24, 192.168.0.125/32, 10.4.0.1/32 - Endpoint = 10.254.1.1:51820 - PersistentKeepalive = 10 - - [Peer] - PublicKey = key2 - AllowedIPs = 100.64.4.0/24, 10.69.76.55/32, 100.64.3.0/24, 10.66.25.131/32, 10.4.0.2/32 - Endpoint = 10.254.2.1:51820 - PersistentKeepalive = 10`), - }, - } { - - dumpConf, _ := ParseDump(tc.d) - conf := Parse(tc.c) - // Equal will ignore runtime fields and only compare configuration fields. - if !dumpConf.Equal(conf) { - diff := pretty.Compare(dumpConf, conf) - t.Errorf("test case %q: got diff: %v", tc.name, diff) - } - } -} diff --git a/pkg/wireguard/wireguard.go b/pkg/wireguard/wireguard.go index 8615b1f..3d1de77 100644 --- a/pkg/wireguard/wireguard.go +++ b/pkg/wireguard/wireguard.go @@ -18,9 +18,7 @@ package wireguard import ( - "bytes" "fmt" - "os/exec" "github.com/vishvananda/netlink" ) @@ -65,74 +63,3 @@ func New(name string, mtu uint) (int, bool, error) { } return link.Attrs().Index, true, nil } - -// Keys generates a WireGuard private and public key-pair. -func Keys() ([]byte, []byte, error) { - private, err := GenKey() - if err != nil { - return nil, nil, fmt.Errorf("failed to generate private key: %v", err) - } - public, err := PubKey(private) - return private, public, err -} - -// GenKey generates a WireGuard private key. -func GenKey() ([]byte, error) { - key, err := exec.Command("wg", "genkey").Output() - return bytes.Trim(key, "\n"), err -} - -// PubKey generates a WireGuard public key for a given private key. -func PubKey(key []byte) ([]byte, error) { - cmd := exec.Command("wg", "pubkey") - stdin, err := cmd.StdinPipe() - if err != nil { - return nil, fmt.Errorf("failed to open pipe to stdin: %v", err) - } - - go func() { - defer stdin.Close() - stdin.Write(key) - }() - - public, err := cmd.Output() - if err != nil { - return nil, fmt.Errorf("failed to generate public key: %v", err) - } - return bytes.Trim(public, "\n"), nil -} - -// SetConf applies a WireGuard configuration file to the given interface. -func SetConf(iface string, path string) error { - cmd := exec.Command("wg", "setconf", iface, path) - var stderr bytes.Buffer - cmd.Stderr = &stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to apply the WireGuard configuration: %s", stderr.String()) - } - return nil -} - -// ShowConf gets the WireGuard configuration for the given interface. -func ShowConf(iface string) ([]byte, error) { - cmd := exec.Command("wg", "showconf", iface) - var stderr, stdout bytes.Buffer - cmd.Stderr = &stderr - cmd.Stdout = &stdout - if err := cmd.Run(); err != nil { - return nil, fmt.Errorf("failed to read the WireGuard configuration: %s", stderr.String()) - } - return stdout.Bytes(), nil -} - -// ShowDump gets the WireGuard configuration and runtime information for the given interface. -func ShowDump(iface string) ([]byte, error) { - cmd := exec.Command("wg", "show", iface, "dump") - var stderr, stdout bytes.Buffer - cmd.Stderr = &stderr - cmd.Stdout = &stdout - if err := cmd.Run(); err != nil { - return nil, fmt.Errorf("failed to read the WireGuard dump output: %s", stderr.String()) - } - return stdout.Bytes(), nil -}