From 08eea4f3c1adb416d2ff254989311c91a8d38803 Mon Sep 17 00:00:00 2001 From: leonnicolas Date: Wed, 29 Sep 2021 22:30:32 +0200 Subject: [PATCH] pkg/*: use wireguard.Enpoint This commit introduces the wireguard.Enpoint struct. It encapsulates a DN name with port and a net.UPDAddr. The fields are private and only accessible over exported Methods to avoid accidental modification. Also iptables.GetProtocol is improved to avoid ipv4 rules being applied by `ip6tables`. Signed-off-by: leonnicolas --- cmd/kgctl/connect_linux.nogo | 473 ----------------------------------- cmd/kgctl/connect_other.nogo | 20 -- cmd/kgctl/showconf.go | 17 +- e2e/full-mesh.sh | 2 +- e2e/location-mesh.sh | 2 +- pkg/encapsulation/ipip.go | 2 +- pkg/iptables/iptables.go | 8 +- pkg/k8s/backend.go | 65 +---- pkg/k8s/backend_test.go | 103 ++------ pkg/mesh/backend.go | 5 +- pkg/mesh/graph.go | 17 +- pkg/mesh/mesh.go | 29 +-- pkg/mesh/mesh_test.go | 14 +- pkg/mesh/routes.go | 32 +-- pkg/mesh/topology.go | 29 +-- pkg/mesh/topology_test.go | 23 +- pkg/wireguard/conf.go | 190 +++++++++++++- 17 files changed, 287 insertions(+), 744 deletions(-) delete mode 100644 cmd/kgctl/connect_linux.nogo delete mode 100644 cmd/kgctl/connect_other.nogo diff --git a/cmd/kgctl/connect_linux.nogo b/cmd/kgctl/connect_linux.nogo deleted file mode 100644 index fa066dd..0000000 --- a/cmd/kgctl/connect_linux.nogo +++ /dev/null @@ -1,473 +0,0 @@ -// +build linux - -package main - -import ( - "context" - "errors" - "fmt" - "io/ioutil" - "net" - "os" - "strings" - "syscall" - "time" - - "github.com/go-kit/kit/log" - "github.com/go-kit/kit/log/level" - "github.com/oklog/run" - "github.com/spf13/cobra" - "github.com/vishvananda/netlink" - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/wgctrl" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/tools/clientcmd" - - "github.com/squat/kilo/pkg/iproute" - "github.com/squat/kilo/pkg/k8s" - "github.com/squat/kilo/pkg/k8s/apis/kilo/v1alpha1" - kiloclient "github.com/squat/kilo/pkg/k8s/clientset/versioned" - "github.com/squat/kilo/pkg/mesh" - "github.com/squat/kilo/pkg/route" - "github.com/squat/kilo/pkg/wireguard" -) - -func takeIPNet(_ net.IP, i *net.IPNet, _ error) *net.IPNet { - return i -} - -func connect() *cobra.Command { - cmd := &cobra.Command{ - Use: "connect", - RunE: connectAsPeer, - Short: "connect to a Kilo cluster as a peer over WireGuard", - } - cmd.Flags().IPNetP("allowed-ip", "a", *takeIPNet(net.ParseCIDR("10.10.10.10/32")), "Allowed IP of the peer") - cmd.Flags().IPNetP("service-cidr", "c", *takeIPNet(net.ParseCIDR("10.43.0.0/16")), "service CIDR of the cluster") - cmd.Flags().String("log-level", logLevelInfo, fmt.Sprintf("Log level to use. Possible values: %s", availableLogLevels)) - cmd.Flags().String("config-path", "/tmp/wg.ini", "path to WireGuard configuation file") - cmd.Flags().Bool("clean-up", true, "clean up routes and interface") - cmd.Flags().Uint("mtu", uint(1420), "clean up routes and interface") - cmd.Flags().Duration("resync-period", 30*time.Second, "How often should Kilo reconcile?") - - availableLogLevels = strings.Join([]string{ - logLevelAll, - logLevelDebug, - logLevelInfo, - logLevelWarn, - logLevelError, - logLevelNone, - }, ", ") - - return cmd -} - -func connectAsPeer(cmd *cobra.Command, args []string) error { - var cancel context.CancelFunc - resyncPersiod, err := cmd.Flags().GetDuration("resync-period") - if err != nil { - return err - } - mtu, err := cmd.Flags().GetUint("mtu") - if err != nil { - return err - } - configPath, err := cmd.Flags().GetString("config-path") - if err != nil { - return err - } - serviceCIDR, err := cmd.Flags().GetIPNet("service-cidr") - if err != nil { - return err - } - allowedIP, err := cmd.Flags().GetIPNet("allowed-ip") - if err != nil { - return err - } - logger := log.NewJSONLogger(log.NewSyncWriter(os.Stdout)) - logLevel, err := cmd.Flags().GetString("log-level") - if err != nil { - return err - } - switch logLevel { - case logLevelAll: - logger = level.NewFilter(logger, level.AllowAll()) - case logLevelDebug: - logger = level.NewFilter(logger, level.AllowDebug()) - case logLevelInfo: - logger = level.NewFilter(logger, level.AllowInfo()) - case logLevelWarn: - logger = level.NewFilter(logger, level.AllowWarn()) - case logLevelError: - logger = level.NewFilter(logger, level.AllowError()) - case logLevelNone: - logger = level.NewFilter(logger, level.AllowNone()) - default: - return fmt.Errorf("log level %s unknown; possible values are: %s", logLevel, availableLogLevels) - } - logger = log.With(logger, "ts", log.DefaultTimestampUTC) - logger = log.With(logger, "caller", log.DefaultCaller) - peername := "random" - if len(args) > 0 { - peername = args[0] - } - - var kiloClient *kiloclient.Clientset - switch backend { - case k8s.Backend: - config, err := clientcmd.BuildConfigFromFlags("", kubeconfig) - if err != nil { - return fmt.Errorf("failed to create Kubernetes config: %v", err) - } - kiloClient = kiloclient.NewForConfigOrDie(config) - default: - return fmt.Errorf("backend %v unknown; posible values are: %s", backend, availableBackends) - } - privateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - return fmt.Errorf("failed to generate private key: %w", err) - } - publicKey := privateKey.PublicKey() - level.Info(logger).Log("msg", "generated public key", "key", publicKey) - - peer := &v1alpha1.Peer{ - ObjectMeta: metav1.ObjectMeta{ - Name: peername, - }, - Spec: v1alpha1.PeerSpec{ - AllowedIPs: []string{allowedIP.String()}, - PersistentKeepalive: 10, - PublicKey: publicKey.String(), - }, - } - if p, err := kiloClient.KiloV1alpha1().Peers().Get(context.TODO(), peername, metav1.GetOptions{}); err != nil || p == nil { - peer, err = kiloClient.KiloV1alpha1().Peers().Create(context.TODO(), peer, metav1.CreateOptions{}) - if err != nil { - return fmt.Errorf("failed to create peer: %w", err) - } - } - - kiloIfaceName := "kilo0" - - iface, _, err := wireguard.New(kiloIfaceName, mtu) - if err != nil { - return fmt.Errorf("failed to create wg interface: %w", err) - } else { - level.Info(logger).Log("msg", "successfully created wg interface", "name", kiloIfaceName, "no", iface) - } - if err := iproute.Set(iface, false); err != nil { - return err - } - - if err := iproute.SetAddress(iface, &allowedIP); err != nil { - return err - } else { - level.Info(logger).Log("mag", "successfully set IP address of wg interface", "IP", allowedIP.String()) - } - - if err := iproute.Set(iface, true); err != nil { - return err - } - - var g run.Group - ctx, cancel := context.WithCancel(context.Background()) - g.Add(run.SignalHandler(ctx, syscall.SIGINT, syscall.SIGTERM)) - - table := route.NewTable() - stop := make(chan struct{}, 1) - errCh := make(<-chan error, 1) - { - ch := make(chan struct{}) - g.Add( - func() error { - for { - select { - case err, ok := <-errCh: - if ok { - level.Error(logger).Log("err", err.Error()) - } else { - return nil - } - case <-ch: - return nil - } - } - }, - func(err error) { - ch <- struct{}{} - close(ch) - stop <- struct{}{} - close(stop) - level.Error(logger).Log("msg", "stopped ip routes table", "err", err.Error()) - }, - ) - } - { - ch := make(chan struct{}) - g.Add( - func() error { - for { - ns, err := opts.backend.Nodes().List() - if err != nil { - return fmt.Errorf("failed to list nodes: %v", err) - } - ps, err := opts.backend.Peers().List() - if err != nil { - return fmt.Errorf("failed to list peers: %v", err) - } - // Obtain the Granularity by looking at the annotation of the first node. - if opts.granularity, err = optainGranularity(opts.granularity, ns); err != nil { - return fmt.Errorf("failed to obtain granularity: %w", err) - } - var hostname string - subnet := mesh.DefaultKiloSubnet - nodes := make(map[string]*mesh.Node) - for _, n := range ns { - if n.Ready() { - nodes[n.Name] = n - hostname = n.Name - } - if n.WireGuardIP != nil { - subnet = n.WireGuardIP - } - } - subnet.IP = subnet.IP.Mask(subnet.Mask) - if len(nodes) == 0 { - return errors.New("did not find any valid Kilo nodes in the cluster") - } - peers := make(map[string]*mesh.Peer) - for _, p := range ps { - if p.Ready() { - peers[p.Name] = p - } - } - if _, ok := peers[peername]; !ok { - return fmt.Errorf("did not find any peer named %q in the cluster", peername) - } - - t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, opts.port, []byte{}, subnet, peers[peername].PersistentKeepalive, logger) - if err != nil { - return fmt.Errorf("failed to create topology: %v", err) - } - conf := t.PeerConf(peername) - conf.Interface = &wireguard.Interface{ - PrivateKey: []byte(privateKey.String()), - ListenPort: uint32(55555), - } - buf, err := conf.Bytes() - if err != nil { - //level.Error(m.logger).Log("error", err) - //m.errorCounter.WithLabelValues("apply").Inc() - return err - } - if err := ioutil.WriteFile("/tmp/wg.ini", buf, 0600); err != nil { - //level.Error(m.logger).Log("error", err) - //m.errorCounter.WithLabelValues("apply").Inc() - return err - } - if err := wireguard.SetConf(kiloIfaceName, "/tmp/wg.ini"); err != nil { - return err - } - wgClient, err := wgctrl.New() - if err != nil { - return fmt.Errorf("failed to initialize wg Client: %w", err) - } - defer func() { - wgClient.Close() - }() - wgConf := wgtypes.Config{ - PrivateKey: &privateKey, - // Peers: []wgtypes.PeerConfig{}, - } - if err := wgClient.ConfigureDevice(kiloIfaceName, wgConf); err != nil { - return fmt.Errorf("failed to configure wg interface: %w", err) - } - - var routes []*netlink.Route - for _, segment := range t.Segments() { - for i := range segment.CIDRS() { - // Add routes to the Pod CIDRs of nodes in other segments. - routes = append(routes, &netlink.Route{ - Dst: segment.CIDRS()[i], - //Flags: int(netlink.FLAG_ONLINK), - LinkIndex: iface, - Protocol: unix.RTPROT_STATIC, - }) - } - for i := range segment.PrivateIPs() { - // Add routes to the private IPs of nodes in other segments. - routes = append(routes, &netlink.Route{ - Dst: mesh.OneAddressCIDR(segment.PrivateIPs()[i]), - //Flags: int(netlink.FLAG_ONLINK), - LinkIndex: iface, - Protocol: unix.RTPROT_STATIC, - }) - } - // For segments / locations other than the location of this instance of kg, - // we need to set routes for allowed location IPs over the leader in the current location. - for i := range segment.AllowedLocationIPs() { - routes = append(routes, &netlink.Route{ - Dst: segment.AllowedLocationIPs()[i], - //Flags: int(netlink.FLAG_ONLINK), - LinkIndex: iface, - Protocol: unix.RTPROT_STATIC, - }) - } - routes = append(routes, &netlink.Route{ - Dst: mesh.OneAddressCIDR(segment.WireGuardIP()), - //Flags: int(netlink.FLAG_ONLINK), - LinkIndex: iface, - Protocol: unix.RTPROT_STATIC, - }) - } - // Add routes for the allowed IPs of peers. - for _, peer := range t.Peers() { - for i := range peer.AllowedIPs { - routes = append(routes, &netlink.Route{ - Dst: peer.AllowedIPs[i], - //Flags: int(netlink.FLAG_ONLINK), - LinkIndex: iface, - Protocol: unix.RTPROT_STATIC, - }) - } - } - routes = append(routes, &netlink.Route{ - Dst: &serviceCIDR, - //Flags: int(netlink.FLAG_ONLINK), - Gw: nil, - LinkIndex: iface, - Protocol: unix.RTPROT_STATIC, - }) - for _, r := range routes { - fmt.Println(r) - } - - level.Debug(logger).Log("routes", routes) - if err := table.Set(routes, []*netlink.Rule{}); err != nil { - return fmt.Errorf("failed to set ip routes table: %w", err) - } - errCh, err = table.Run(stop) - if err != nil { - return fmt.Errorf("failed to start ip routes tables: %w", err) - } - select { - case <-time.After(resyncPersiod): - case <-ch: - return nil - } - } - }, func(err error) { - // Cancel the root context in the very end. - defer func() { - cancel() - level.Debug(logger).Log("msg", "canceled parent context") - }() - ch <- struct{}{} - var serr run.SignalError - if ok := errors.As(err, &serr); ok { - level.Info(logger).Log("msg", "received signal", "signal", serr.Signal.String(), "err", err.Error()) - } else { - level.Error(logger).Log("msg", "received error", "err", err.Error()) - } - level.Debug(logger).Log("msg", "stoped ip routes table") - ctxWithTimeOut, cancelWithTimeOut := context.WithTimeout(ctx, 10*time.Second) - defer func() { - cancelWithTimeOut() - level.Debug(logger).Log("msg", "canceled timed context") - }() - if err := kiloClient.KiloV1alpha1().Peers().Delete(ctxWithTimeOut, peername, metav1.DeleteOptions{}); err != nil { - level.Error(logger).Log("failed to delete peer: %w", err) - } else { - level.Info(logger).Log("msg", "deleted peer", "peer", peername) - } - if ok, err := cmd.Flags().GetBool("clean-up"); err != nil { - level.Error(logger).Log("err", err.Error(), "msg", "failed to get value from clean-up flag") - } else if ok { - cleanUp(iface, table, configPath, logger) - } - }) - } - err = g.Run() - var serr run.SignalError - if ok := errors.As(err, &serr); ok { - return nil - } - return err -} - -func cleanUp(iface int, t *route.Table, configPath string, logger log.Logger) { - if err := iproute.Set(iface, false); err != nil { - level.Error(logger).Log("err", err.Error(), "msg", "failed to set down wg interface") - } - if err := os.Remove(configPath); err != nil { - level.Error(logger).Log("error", fmt.Sprintf("failed to delete configuration file: %v", err)) - } - if err := iproute.RemoveInterface(iface); err != nil { - level.Error(logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err)) - } - if err := t.CleanUp(); err != nil { - level.Error(logger).Log("failed to clean up routes: %w", err) - } - return -} - -//func nodeToWGPeer(n *mesh.Node) (ret wgtypes.PeerConfig, err error) { -// pubKey, err := wgtypes.ParseKey(string(n.Key)) -// if err != nil { -// return ret, err -// } -// ret.PublicKey = pubKey -// if err != nil { -// return ret, err -// } -// aIPs := []net.IPNet{*n.WireGuardIP, *n.InternalIP, *n.Subnet} -// for _, a := range n.AllowedLocationIPs { -// aIPs = append(aIPs, *a) -// } -// ret.AllowedIPs = aIPs -// -// dur := time.Second * time.Duration(n.PersistentKeepalive) -// ret.PersistentKeepaliveInterval = &dur -// -// udpEndpoint, err := net.ResolveUDPAddr("udp", n.Endpoint.String()) -// if err != nil { -// return ret, err -// } else if udpEndpoint.IP == nil { -// udpEndpoint = nil -// } -// ret.Endpoint = udpEndpoint -// return ret, nil -//} -// -//func toWGPeer(p wireguard.Peer) (ret wgtypes.PeerConfig, err error) { -// pubKey, err := wgtypes.ParseKey(string(p.PublicKey)) -// if err != nil { -// return ret, err -// } -// ret.PublicKey = pubKey -// aIPs := make([]net.IPNet, len(p.AllowedIPs)) -// for i, a := range p.AllowedIPs { -// aIPs[i] = *a -// } -// ret.AllowedIPs = aIPs -// -// if preSharedKey, err := wgtypes.ParseKey(string(p.PresharedKey)); len(p.PresharedKey) > 0 && err != nil { -// return ret, err -// } else { -// ret.PresharedKey = &preSharedKey -// } -// dur := time.Second * time.Duration(p.PersistentKeepalive) -// ret.PersistentKeepaliveInterval = &dur -// -// udpEndpoint, err := net.ResolveUDPAddr("udp", p.Endpoint.String()) -// if err != nil { -// return ret, err -// } else if udpEndpoint.IP == nil { -// udpEndpoint = nil -// } -// fmt.Println("peer") -// fmt.Println(aIPs) -// ret.Endpoint = udpEndpoint -// return ret, nil -//} diff --git a/cmd/kgctl/connect_other.nogo b/cmd/kgctl/connect_other.nogo deleted file mode 100644 index 9e81a64..0000000 --- a/cmd/kgctl/connect_other.nogo +++ /dev/null @@ -1,20 +0,0 @@ -// +build !linux - -package main - -import ( - "errors" - - "github.com/spf13/cobra" -) - -func connect() *cobra.Command { - cmd := &cobra.Command{ - Use: "connect", - Short: "not supporred on you OS", - RunE: func(_ *cobra.Command, _ []string) error { - return errors.New("this command is not supported on your OS") - }, - } - return cmd -} diff --git a/cmd/kgctl/showconf.go b/cmd/kgctl/showconf.go index 1cd76d4..7ff2909 100644 --- a/cmd/kgctl/showconf.go +++ b/cmd/kgctl/showconf.go @@ -289,6 +289,7 @@ func runShowConfPeer(_ *cobra.Command, args []string) error { } // translatePeer translates a wireguard.Peer to a Peer CRD. +// TODO this function has many similarities to peerBackend.Set(name, peer) func translatePeer(peer *wireguard.Peer) *v1alpha1.Peer { if peer == nil { return &v1alpha1.Peer{} @@ -303,21 +304,13 @@ func translatePeer(peer *wireguard.Peer) *v1alpha1.Peer { aips = append(aips, aip.String()) } var endpoint *v1alpha1.PeerEndpoint - if (peer.Endpoint != nil && peer.Endpoint.Port > 0) || peer.Addr != "" { - var ip string - if peer.Endpoint.IP != nil { - ip = peer.Endpoint.IP.String() - } - var dns string - if strs := strings.Split(peer.Addr, ":"); len(strs) == 2 && strs[0] != "" { - dns = strs[0] - } + if peer.Endpoint.Port() > 0 || !peer.Endpoint.HasDNS() { endpoint = &v1alpha1.PeerEndpoint{ DNSOrIP: v1alpha1.DNSOrIP{ - DNS: dns, - IP: ip, + IP: peer.Endpoint.IP().String(), + DNS: peer.Endpoint.DNS(), }, - Port: uint32(peer.Endpoint.Port), + Port: uint32(peer.Endpoint.Port()), } } var key string diff --git a/e2e/full-mesh.sh b/e2e/full-mesh.sh index b306fa1..1f4310a 100644 --- a/e2e/full-mesh.sh +++ b/e2e/full-mesh.sh @@ -18,7 +18,7 @@ test_full_mesh_connectivity() { } test_full_mesh_peer() { - check_peer wg1 e2e 10.5.0.1/32 full + check_peer wg99 e2e 10.5.0.1/32 full } test_full_mesh_allowed_location_ips() { diff --git a/e2e/location-mesh.sh b/e2e/location-mesh.sh index bf42555..590d5ea 100755 --- a/e2e/location-mesh.sh +++ b/e2e/location-mesh.sh @@ -18,7 +18,7 @@ test_location_mesh_connectivity() { } test_location_mesh_peer() { - check_peer wg1 e2e 10.5.0.1/32 location + check_peer wg99 e2e 10.5.0.1/32 location } test_mesh_granularity_auto_detect() { diff --git a/pkg/encapsulation/ipip.go b/pkg/encapsulation/ipip.go index fda4d28..d92b39f 100644 --- a/pkg/encapsulation/ipip.go +++ b/pkg/encapsulation/ipip.go @@ -74,7 +74,7 @@ func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule { rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) for _, n := range nodes { // Accept encapsulated traffic from peers. - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(n.IP)), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT")) } // Drop all other IPIP traffic. rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) diff --git a/pkg/iptables/iptables.go b/pkg/iptables/iptables.go index 3a6023f..8290676 100644 --- a/pkg/iptables/iptables.go +++ b/pkg/iptables/iptables.go @@ -53,11 +53,11 @@ const ( ) // GetProtocol will return a protocol from the length of an IP address. -func GetProtocol(length int) Protocol { - if length == net.IPv6len { - return ProtocolIPv6 +func GetProtocol(ip net.IP) Protocol { + if len(ip) == net.IPv4len || ip.To4() != nil { + return ProtocolIPv4 } - return ProtocolIPv4 + return ProtocolIPv6 } // Client represents any type that can administer iptables rules. diff --git a/pkg/k8s/backend.go b/pkg/k8s/backend.go index 4c9bbc5..f7d81fc 100644 --- a/pkg/k8s/backend.go +++ b/pkg/k8s/backend.go @@ -32,7 +32,6 @@ import ( "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/strategicpatch" - "k8s.io/apimachinery/pkg/util/validation" v1informers "k8s.io/client-go/informers/core/v1" "k8s.io/client-go/kubernetes" v1listers "k8s.io/client-go/listers/core/v1" @@ -277,9 +276,9 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node { location = node.ObjectMeta.Labels[topologyLabel] } // Allow the endpoint to be overridden. - endpoint, addr := parseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey]) - if endpoint == nil && addr == "" { - endpoint, addr = parseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey]) + endpoint := wireguard.ParseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey]) + if endpoint == nil { + endpoint = wireguard.ParseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey]) } // Allow the internal IP to be overridden. internalIP := normalizeIP(node.ObjectMeta.Annotations[forceInternalIPAnnotationKey]) @@ -345,7 +344,6 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node { // It is valid for the InternalIP to be nil, // if the given node only has public IP addresses. Endpoint: endpoint, - Addr: addr, NoInternalIP: noInternalIP, InternalIP: internalIP, Key: key, @@ -379,8 +377,7 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer { } aips = append(aips, *aip) } - var endpoint *net.UDPAddr - var addr string + var endpoint *wireguard.Endpoint if peer.Spec.Endpoint != nil { ip := net.ParseIP(peer.Spec.Endpoint.IP) if ip4 := ip.To4(); ip4 != nil { @@ -390,10 +387,10 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer { } if peer.Spec.Endpoint.Port > 0 { if ip != nil { - endpoint = &net.UDPAddr{IP: ip, Port: int(peer.Spec.Endpoint.Port)} + endpoint = wireguard.NewEndpoint(ip, int(peer.Spec.Endpoint.Port)) } if peer.Spec.Endpoint.DNS != "" { - addr = fmt.Sprintf("%s:%d", peer.Spec.Endpoint.DNS, peer.Spec.Endpoint.Port) + endpoint = wireguard.ParseEndpoint(fmt.Sprintf("%s:%d", peer.Spec.Endpoint.DNS, peer.Spec.Endpoint.Port)) } } } @@ -415,12 +412,11 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer { Peer: wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: aips, - Endpoint: endpoint, // applyTopology will resolve this endpoint from the KiloEndpoint. PersistentKeepaliveInterval: &pka, PresharedKey: psk, PublicKey: key, }, - Addr: addr, + Endpoint: endpoint, }, } } @@ -518,22 +514,12 @@ func (pb *peerBackend) Set(name string, peer *mesh.Peer) error { p.Spec.AllowedIPs[i] = peer.AllowedIPs[i].String() } if peer.Endpoint != nil { - var ip string - if peer.Endpoint.IP != nil { - ip = peer.Endpoint.IP.String() - } - var dns string - if peer.Addr != "" { - if strs := strings.Split(peer.Addr, ":"); len(strs) == 2 && strs[0] != "" { - dns = strs[0] - } - } p.Spec.Endpoint = &v1alpha1.PeerEndpoint{ DNSOrIP: v1alpha1.DNSOrIP{ - IP: ip, - DNS: dns, + IP: peer.Endpoint.IP().String(), + DNS: peer.Endpoint.DNS(), }, - Port: uint32(peer.Endpoint.Port), + Port: uint32(peer.Endpoint.Port()), } } if peer.PersistentKeepaliveInterval == nil { @@ -570,34 +556,3 @@ func normalizeIP(ip string) *net.IPNet { ipNet.IP = i.To16() return ipNet } - -func parseEndpoint(endpoint string) (*net.UDPAddr, string) { - if len(endpoint) == 0 { - return nil, "" - } - parts := strings.Split(endpoint, ":") - if len(parts) < 2 { - return nil, "" - } - portRaw := parts[len(parts)-1] - hostRaw := strings.Trim(strings.Join(parts[:len(parts)-1], ":"), "[]") - port, err := strconv.ParseUint(portRaw, 10, 32) - if err != nil { - return nil, "" - } - if len(validation.IsValidPortNum(int(port))) != 0 { - return nil, "" - } - ip := net.ParseIP(hostRaw) - if ip == nil { - if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 { - return nil, endpoint - } - return nil, "" - } - u, err := net.ResolveUDPAddr("udp", endpoint) - if err != nil { - return nil, "" - } - return u, "" -} diff --git a/pkg/k8s/backend_test.go b/pkg/k8s/backend_test.go index 1cc5898..cf4eadc 100644 --- a/pkg/k8s/backend_test.go +++ b/pkg/k8s/backend_test.go @@ -1,4 +1,4 @@ -// Copyright 2019 the Kilo authors +// Copyright 2021 the Kilo authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -80,8 +80,8 @@ func TestTranslateNode(t *testing.T) { internalIPAnnotationKey: "10.0.0.2/32", }, out: &mesh.Node{ - Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort}, - InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)}, + Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort), + InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2").To4(), Mask: net.CIDRMask(32, 32)}, }, }, { @@ -91,8 +91,8 @@ func TestTranslateNode(t *testing.T) { internalIPAnnotationKey: "ff60::10/64", }, out: &mesh.Node{ - Endpoint: &net.UDPAddr{IP: net.ParseIP("ff10::10"), Port: mesh.DefaultKiloPort}, - InternalIP: &net.IPNet{IP: net.ParseIP("ff60::10"), Mask: net.CIDRMask(64, 128)}, + Endpoint: wireguard.NewEndpoint(net.ParseIP("ff10::10").To16(), mesh.DefaultKiloPort), + InternalIP: &net.IPNet{IP: net.ParseIP("ff60::10").To16(), Mask: net.CIDRMask(64, 128)}, }, }, { @@ -105,7 +105,7 @@ func TestTranslateNode(t *testing.T) { name: "normalize subnet", annotations: map[string]string{}, out: &mesh.Node{ - Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(24, 32)}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0").To4(), Mask: net.CIDRMask(24, 32)}, }, subnet: "10.2.0.1/24", }, @@ -113,7 +113,7 @@ func TestTranslateNode(t *testing.T) { name: "valid subnet", annotations: map[string]string{}, out: &mesh.Node{ - Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0").To4(), Mask: net.CIDRMask(24, 32)}, }, subnet: "10.2.1.0/24", }, @@ -145,7 +145,7 @@ func TestTranslateNode(t *testing.T) { forceEndpointAnnotationKey: "-10.0.0.2:51821", }, out: &mesh.Node{ - Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort), }, }, { @@ -155,7 +155,7 @@ func TestTranslateNode(t *testing.T) { forceEndpointAnnotationKey: "10.0.0.2:51821", }, out: &mesh.Node{ - Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.2"), Port: 51821}, + Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.2").To4(), 51821), }, }, { @@ -174,7 +174,7 @@ func TestTranslateNode(t *testing.T) { forceInternalIPAnnotationKey: "-10.1.0.2/24", }, out: &mesh.Node{ - InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.1"), Mask: net.CIDRMask(24, 32)}, + InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.1").To4(), Mask: net.CIDRMask(24, 32)}, NoInternalIP: false, }, }, @@ -185,7 +185,7 @@ func TestTranslateNode(t *testing.T) { forceInternalIPAnnotationKey: "10.1.0.2/24", }, out: &mesh.Node{ - InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(24, 32)}, + InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(24, 32)}, NoInternalIP: false, }, }, @@ -214,16 +214,16 @@ func TestTranslateNode(t *testing.T) { RegionLabelKey: "a", }, out: &mesh.Node{ - Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.2"), Port: 51821}, + Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.2").To4(), 51821), NoInternalIP: false, - InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)}, + InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(32, 32)}, Key: fooKey, LastSeen: 1000000000, Leader: true, Location: "b", PersistentKeepalive: 25 * time.Second, - Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)}, - WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)}, + Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0").To4(), Mask: net.CIDRMask(24, 32)}, + WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1").To4(), Mask: net.CIDRMask(16, 32)}, }, subnet: "10.2.1.0/24", }, @@ -245,7 +245,7 @@ func TestTranslateNode(t *testing.T) { RegionLabelKey: "a", }, out: &mesh.Node{ - Endpoint: &net.UDPAddr{IP: net.ParseIP("1100::10"), Port: 51821}, + Endpoint: wireguard.NewEndpoint(net.ParseIP("1100::10"), 51821), NoInternalIP: false, InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)}, Key: fooKey, @@ -273,7 +273,7 @@ func TestTranslateNode(t *testing.T) { RegionLabelKey: "a", }, out: &mesh.Node{ - Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 51820}, + Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1"), 51820), InternalIP: nil, Key: fooKey, LastSeen: 1000000000, @@ -301,7 +301,7 @@ func TestTranslateNode(t *testing.T) { RegionLabelKey: "a", }, out: &mesh.Node{ - Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 51820}, + Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1"), 51820), NoInternalIP: true, InternalIP: nil, Key: fooKey, @@ -424,9 +424,9 @@ func TestTranslatePeer(t *testing.T) { out: &mesh.Peer{ Peer: wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ - Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort}, PersistentKeepaliveInterval: &zero, }, + Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort), }, }, }, @@ -443,9 +443,9 @@ func TestTranslatePeer(t *testing.T) { out: &mesh.Peer{ Peer: wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ - Endpoint: &net.UDPAddr{IP: net.ParseIP("ff60::2"), Port: mesh.DefaultKiloPort}, PersistentKeepaliveInterval: &zero, }, + Endpoint: wireguard.NewEndpoint(net.ParseIP("ff60::2").To16(), mesh.DefaultKiloPort), }, }, }, @@ -461,7 +461,7 @@ func TestTranslatePeer(t *testing.T) { }, out: &mesh.Peer{ Peer: wireguard.Peer{ - Addr: "example.com:51820", + Endpoint: wireguard.ParseEndpoint("example.com:51820"), PeerConfig: wgtypes.PeerConfig{ PersistentKeepaliveInterval: &zero, }, @@ -544,64 +544,3 @@ func TestTranslatePeer(t *testing.T) { } } } - -func TestParseEndpoint(t *testing.T) { - for _, tc := range []struct { - name string - endpoint string - udp *net.UDPAddr - addr string - }{ - { - name: "empty", - endpoint: "", - udp: nil, - addr: "", - }, - { - name: "invalid IP", - endpoint: "10.0.0.:51820", - udp: nil, - addr: "", - }, - { - name: "invalid hostname", - endpoint: "foo-:51820", - udp: nil, - addr: "", - }, - { - name: "invalid port", - endpoint: "10.0.0.1:100000000", - udp: nil, - addr: "", - }, - { - name: "valid IP", - endpoint: "10.0.0.1:51820", - udp: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort}, - addr: "", - }, - { - name: "valid IPv6", - endpoint: "[ff02::114]:51820", - udp: &net.UDPAddr{IP: net.ParseIP("ff02::114"), Port: mesh.DefaultKiloPort}, - addr: "", - }, - { - name: "valid hostname", - endpoint: "foo:51821", - udp: nil, - addr: "foo:51821", - }, - } { - udp, addr := parseEndpoint(tc.endpoint) - if diff := pretty.Compare(udp, tc.udp); diff != "" { - t.Errorf("test case %q: got diff: %v", tc.name, diff) - } - if addr != tc.addr { - t.Errorf("test case %q: got: %q, wants: %q", tc.name, addr, tc.addr) - } - - } -} diff --git a/pkg/mesh/backend.go b/pkg/mesh/backend.go index f4bc152..562d32b 100644 --- a/pkg/mesh/backend.go +++ b/pkg/mesh/backend.go @@ -56,8 +56,7 @@ const ( // Node represents a node in the network. type Node struct { - Endpoint *net.UDPAddr - Addr string // eg. dnsname:port + Endpoint *wireguard.Endpoint Key wgtypes.Key NoInternalIP bool InternalIP *net.IPNet @@ -82,7 +81,7 @@ type Node struct { func (n *Node) Ready() bool { // Nodes that are not leaders will not have WireGuardIPs, so it is not required. return n != nil && - (n.Endpoint != nil || n.Addr != "") && + n.Endpoint.Ready() && n.Key != wgtypes.Key{} && n.Subnet != nil && time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second) diff --git a/pkg/mesh/graph.go b/pkg/mesh/graph.go index 1b6133f..cdf9ff3 100644 --- a/pkg/mesh/graph.go +++ b/pkg/mesh/graph.go @@ -1,4 +1,4 @@ -// Copyright 2019 the Kilo authors +// Copyright 2021 the Kilo authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ import ( "strings" "github.com/awalterschulze/gographviz" + + "github.com/squat/kilo/pkg/wireguard" ) // Dot generates a Graphviz graph of the Topology in DOT fomat. @@ -61,7 +63,7 @@ func (t *Topology) Dot() (string, error) { return "", fmt.Errorf("failed to add node to subgraph") } var wg net.IP - var endpoint *net.UDPAddr + var endpoint *wireguard.Endpoint if j == s.leader { wg = s.wireGuardIP endpoint = s.endpoint @@ -73,7 +75,7 @@ func (t *Topology) Dot() (string, error) { if s.privateIPs != nil { priv = s.privateIPs[j] } - if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], priv, wg, endpoint, s.addr)); err != nil { + if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], priv, wg, endpoint)); err != nil { return "", fmt.Errorf("failed to add label to node") } } @@ -153,7 +155,7 @@ func subGraphName(name string) string { return graphEscape(fmt.Sprintf("cluster_location_%s", name)) } -func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoint *net.UDPAddr, addr string) string { +func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoint *wireguard.Endpoint) string { label := []string{ location, name, @@ -165,12 +167,7 @@ func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoi if wgIP != nil { label = append(label, wgIP.String()) } - var str string - if addr != "" { - str = addr - } else if endpoint != nil { - str = endpoint.String() - } + str := endpoint.String() if str != "" { label = append(label, str) } diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index cbb3697..ce44c75 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -370,8 +370,10 @@ func (m *Mesh) checkIn() { func (m *Mesh) handleLocal(n *Node) { // Allow the IPs to be overridden. - if n.Endpoint == nil || n.Addr == "" { - n.Endpoint = &net.UDPAddr{IP: m.externalIP.IP, Port: m.port} + if !n.Endpoint.Ready() { + e := wireguard.NewEndpoint(m.externalIP.IP, m.port) + level.Info(m.logger).Log("msg", "overriding endpoint", "node", m.hostname, "old endpoint", n.Endpoint.String(), "new endpoint", e.String()) + n.Endpoint = e } if n.InternalIP == nil && !n.NoInternalIP { n.InternalIP = m.internalIP @@ -484,7 +486,7 @@ func (m *Mesh) applyTopology() { natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger) nodes[m.hostname].DiscoveredEndpoints = natEndpoints - t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger) + t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port(), m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger) if err != nil { level.Error(m.logger).Log("error", err) m.errorCounter.WithLabelValues("apply").Inc() @@ -623,15 +625,8 @@ func (m *Mesh) resolveEndpoints() error { if !m.nodes[k].Ready() { continue } - // If the node is ready, then the endpoint is not nil - // but it may not have a DNS name. - if m.nodes[k].Addr == "" { - continue - } - if u, err := net.ResolveUDPAddr("udp", m.nodes[k].Addr); err == nil { - m.nodes[k].Endpoint = u - m.nodes[k].Endpoint.IP = u.IP - } else { + // Resolve the Endpoint + if _, err := m.nodes[k].Endpoint.UDPAddr(true); err != nil { return err } } @@ -642,12 +637,10 @@ func (m *Mesh) resolveEndpoints() error { continue } // Peers may have nil endpoints. - if m.peers[k].Addr == "" { + if !m.peers[k].Endpoint.Ready() { continue } - if u, err := net.ResolveUDPAddr("udp", m.peers[k].Addr); err == nil { - m.peers[k].Endpoint = u - } else { + if _, err := m.peers[k].Endpoint.UDPAddr(true); err != nil { return err } } @@ -667,7 +660,7 @@ func nodesAreEqual(a, b *Node) bool { } // Check the DNS name first since this package // is doing the DNS resolution. - if a.Addr != b.Addr || a.Endpoint.String() != b.Endpoint.String() { + if a.Endpoint.StringOpt(false) != b.Endpoint.StringOpt(false) { return false } // Ignore LastSeen when comparing equality we want to check if the nodes are @@ -696,7 +689,7 @@ func peersAreEqual(a, b *Peer) bool { } // Check the DNS name first since this package // is doing the DNS resolution. - if a.Addr != b.Addr || a.Endpoint.String() != b.Endpoint.String() { + if a.Endpoint.StringOpt(false) != b.Endpoint.StringOpt(false) { return false } if len(a.AllowedIPs) != len(b.AllowedIPs) { diff --git a/pkg/mesh/mesh_test.go b/pkg/mesh/mesh_test.go index d4d80d7..f02c3af 100644 --- a/pkg/mesh/mesh_test.go +++ b/pkg/mesh/mesh_test.go @@ -20,6 +20,8 @@ import ( "time" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/squat/kilo/pkg/wireguard" ) func mustKey() wgtypes.Key { @@ -62,7 +64,7 @@ func TestReady(t *testing.T) { { name: "empty endpoint IP", node: &Node{ - Endpoint: &net.UDPAddr{Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(nil, DefaultKiloPort), InternalIP: internalIP, Key: wgtypes.Key{}, Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, @@ -72,7 +74,7 @@ func TestReady(t *testing.T) { { name: "empty endpoint port", node: &Node{ - Endpoint: &net.UDPAddr{IP: externalIP.IP}, + Endpoint: wireguard.NewEndpoint(externalIP.IP, 0), InternalIP: internalIP, Key: wgtypes.Key{}, Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, @@ -82,7 +84,7 @@ func TestReady(t *testing.T) { { name: "empty internal IP", node: &Node{ - Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort), Key: wgtypes.Key{}, Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, @@ -91,7 +93,7 @@ func TestReady(t *testing.T) { { name: "empty key", node: &Node{ - Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort), InternalIP: internalIP, Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)}, }, @@ -100,7 +102,7 @@ func TestReady(t *testing.T) { { name: "empty subnet", node: &Node{ - Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort), InternalIP: internalIP, Key: wgtypes.Key{}, }, @@ -109,7 +111,7 @@ func TestReady(t *testing.T) { { name: "valid", node: &Node{ - Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort), InternalIP: internalIP, Key: key, LastSeen: time.Now().Unix(), diff --git a/pkg/mesh/routes.go b/pkg/mesh/routes.go index 950da0c..0827ebd 100644 --- a/pkg/mesh/routes.go +++ b/pkg/mesh/routes.go @@ -40,7 +40,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface var gw net.IP for _, segment := range t.segments { if segment.location == t.location { - gw = enc.Gw(segment.endpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader]) + gw = enc.Gw(segment.endpoint.IP(), segment.privateIPs[segment.leader], segment.cidrs[segment.leader]) break } } @@ -196,7 +196,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface // equals the external IP. This means that the node // is only accessible through an external IP and we // cannot encapsulate traffic to an IP through the IP. - if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP) { + if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP()) { continue } // Add routes to the private IPs of nodes in other segments. @@ -248,7 +248,7 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule { rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT")) rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT")) if cni { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT")) // Some linux distros or docker will set forward DROP in the filter table. // To still be able to have pod to pod communication we need to ALLOW packets from and to pod CIDRs within a location. // Leader nodes will forward packets from all nodes within a location because they act as a gateway for them. @@ -258,30 +258,30 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule { if s.location == t.location { // Make sure packets to and from pod cidrs are not dropped in the forward chain. for _, c := range s.cidrs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT")) } // Make sure packets to and from allowed location IPs are not dropped in the forward chain. for _, c := range s.allowedLocationIPs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT")) } // Make sure packets to and from private IPs are not dropped in the forward chain. for _, c := range s.privateIPs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT")) } } } } else if iptablesForwardRule { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT")) } } for _, s := range t.segments { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(s.wireGuardIP)), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN")) for _, aip := range s.allowedIPs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN")) + rules = append(rules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN")) } // Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd, // Otherwise packets to these destinations will reach the destination, but never find their way back. @@ -289,7 +289,7 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule { if t.location == s.location { for _, alip := range s.allowedLocationIPs { rules = append(rules, - iptables.NewRule(iptables.GetProtocol(len(alip.IP)), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"), + iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"), ) } } @@ -297,8 +297,8 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule { for _, p := range t.peers { for _, aip := range p.AllowedIPs { rules = append(rules, - iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"), - iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"), + iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"), + iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"), ) } } diff --git a/pkg/mesh/topology.go b/pkg/mesh/topology.go index 24b56c4..63a1238 100644 --- a/pkg/mesh/topology.go +++ b/pkg/mesh/topology.go @@ -67,8 +67,7 @@ type Topology struct { type segment struct { allowedIPs []net.IPNet - addr string - endpoint *net.UDPAddr + endpoint *wireguard.Endpoint key wgtypes.Key persistentKeepalive time.Duration // Location is the logical location of this segment. @@ -178,7 +177,6 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra }) t.segments = append(t.segments, &segment{ allowedIPs: allowedIPs, - addr: topoMap[location][leader].Addr, endpoint: topoMap[location][leader].Endpoint, key: topoMap[location][leader].Key, persistentKeepalive: topoMap[location][leader].PersistentKeepalive, @@ -287,14 +285,14 @@ CheckIPs: return } -func (t *Topology) updateEndpoint(endpoint *net.UDPAddr, key wgtypes.Key, persistentKeepalive *time.Duration) *net.UDPAddr { +func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *wireguard.Endpoint { // Do not update non-nat peers if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) { return endpoint } e, ok := t.discoveredEndpoints[key.String()] if ok { - return e + return wireguard.NewEndpointFromUDPAddr(e) } return nil } @@ -315,12 +313,11 @@ func (t *Topology) Conf() *wireguard.Conf { peer := wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...), - Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive), PersistentKeepaliveInterval: &t.persistentKeepalive, PublicKey: s.key, ReplaceAllowedIPs: true, }, - Addr: s.addr, + Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive), } c.Peers = append(c.Peers, peer) } @@ -328,13 +325,12 @@ func (t *Topology) Conf() *wireguard.Conf { peer := wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: p.AllowedIPs, - Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval), PersistentKeepaliveInterval: &t.persistentKeepalive, PresharedKey: p.PresharedKey, PublicKey: p.PublicKey, ReplaceAllowedIPs: true, }, - Addr: p.Addr, + Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval), } c.Peers = append(c.Peers, peer) } @@ -352,9 +348,8 @@ func (t *Topology) AsPeer() wireguard.Peer { PeerConfig: wgtypes.PeerConfig{ AllowedIPs: s.allowedIPs, PublicKey: s.key, - Endpoint: s.endpoint, }, - Addr: s.addr, + Endpoint: s.endpoint, } return p } @@ -377,12 +372,11 @@ func (t *Topology) PeerConf(name string) wireguard.Conf { peer := wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ AllowedIPs: s.allowedIPs, - Endpoint: s.endpoint, PersistentKeepaliveInterval: pka, PresharedKey: psk, PublicKey: s.key, }, - Addr: s.addr, + Endpoint: s.endpoint, } c.Peers = append(c.Peers, peer) } @@ -395,8 +389,8 @@ func (t *Topology) PeerConf(name string) wireguard.Conf { AllowedIPs: t.peers[i].AllowedIPs, PersistentKeepaliveInterval: pka, PublicKey: t.peers[i].PublicKey, - Endpoint: t.peers[i].Endpoint, }, + Endpoint: t.peers[i].Endpoint, } c.Peers = append(c.Peers, peer) } @@ -417,13 +411,13 @@ func findLeader(nodes []*Node) int { var leaders, public []int for i := range nodes { if nodes[i].Leader { - if isPublic(nodes[i].Endpoint.IP) { + if isPublic(nodes[i].Endpoint.IP()) { return i } leaders = append(leaders, i) } - if nodes[i].Endpoint != nil && isPublic(nodes[i].Endpoint.IP) { + if nodes[i].Endpoint != nil && isPublic(nodes[i].Endpoint.IP()) { public = append(public, i) } } @@ -444,12 +438,11 @@ func deduplicatePeerIPs(peers []*Peer) []*Peer { Name: peer.Name, Peer: wireguard.Peer{ PeerConfig: wgtypes.PeerConfig{ - Endpoint: peer.Endpoint, PersistentKeepaliveInterval: peer.PersistentKeepaliveInterval, PresharedKey: peer.PresharedKey, PublicKey: peer.PublicKey, }, - Addr: peer.Addr, + Endpoint: peer.Endpoint, }, } for _, ip := range peer.AllowedIPs { diff --git a/pkg/mesh/topology_test.go b/pkg/mesh/topology_test.go index d05fea5..9b0b8a8 100644 --- a/pkg/mesh/topology_test.go +++ b/pkg/mesh/topology_test.go @@ -60,7 +60,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) nodes := map[string]*Node{ "a": { Name: "a", - Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort), InternalIP: i1, Location: "1", Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)}, @@ -69,7 +69,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) }, "b": { Name: "b", - Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort), InternalIP: i1, Location: "2", Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)}, @@ -78,7 +78,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) }, "c": { Name: "c", - Endpoint: &net.UDPAddr{IP: e3.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(e3.IP, DefaultKiloPort), InternalIP: i2, // Same location as node b. Location: "2", @@ -87,7 +87,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) }, "d": { Name: "d", - Endpoint: &net.UDPAddr{IP: e4.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(e4.IP, DefaultKiloPort), // Same location as node a, but without private IP Location: "1", Subnet: &net.IPNet{IP: net.ParseIP("10.2.4.0"), Mask: net.CIDRMask(24, 32)}, @@ -115,11 +115,8 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int) {IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)}, }, PublicKey: key5, - Endpoint: &net.UDPAddr{ - IP: net.ParseIP("192.168.0.1"), - Port: DefaultKiloPort, - }, }, + Endpoint: wireguard.NewEndpoint(net.ParseIP("192.168.0.1"), DefaultKiloPort), }, }, } @@ -576,24 +573,24 @@ func TestFindLeader(t *testing.T) { nodes := []*Node{ { Name: "a", - Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort), }, { Name: "b", - Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort), }, { Name: "c", - Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort), }, { Name: "d", - Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort), Leader: true, }, { Name: "2", - Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort}, + Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort), Leader: true, }, } diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index 2fc6f66..d844b56 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -16,6 +16,7 @@ package wireguard import ( "bytes" + "errors" "fmt" "net" "sort" @@ -23,6 +24,7 @@ import ( "time" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "k8s.io/apimachinery/pkg/util/validation" ) type section string @@ -53,6 +55,10 @@ func (c Conf) WGConfig() wgtypes.Config { wgPs := make([]wgtypes.PeerConfig, len(c.Peers)) for i, p := range c.Peers { wgPs[i] = p.PeerConfig + if p.Endpoint.Resolved() { + // We can ingore the error because we already checked if the Endpoint was resolved in the above line. + wgPs[i].Endpoint, _ = p.Endpoint.UDPAddr(false) + } wgPs[i].ReplaceAllowedIPs = true } r.Peers = wgPs @@ -60,10 +66,169 @@ func (c Conf) WGConfig() wgtypes.Config { return r } +// Endpoint represents a WireGuard endpoint. +type Endpoint struct { + udpAddr *net.UDPAddr + addr string +} + +// ParseEndpoint returns an Endpoint from a string. +// The input should look like "10.0.0.0:100", "[ff10::10]:100" +// or "example.com:100". +func ParseEndpoint(endpoint string) *Endpoint { + if len(endpoint) == 0 { + return nil + } + hostRaw, portRaw, err := net.SplitHostPort(endpoint) + if err != nil { + return nil + } + port, err := strconv.ParseUint(portRaw, 10, 32) + if err != nil { + return nil + } + if len(validation.IsValidPortNum(int(port))) != 0 { + return nil + } + ip := net.ParseIP(hostRaw) + if ip == nil { + if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 { + return &Endpoint{ + addr: endpoint, + } + } + return nil + } + // ResolveUDPAddr will not resolve the endpoint as long as a valid IP and port is given. + // This should be the case here. + u, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return nil + } + u.IP = cutIP(u.IP) + return &Endpoint{ + udpAddr: u, + } +} + +// NewEndpointFromUDPAddr returns an Endpoint from a net.UDPAddr. +func NewEndpointFromUDPAddr(u *net.UDPAddr) *Endpoint { + if u != nil { + u.IP = cutIP(u.IP) + } + return &Endpoint{ + udpAddr: u, + } +} + +// NewEndpoint returns an Endpoint from a net.IP and port. +func NewEndpoint(ip net.IP, port int) *Endpoint { + return &Endpoint{ + udpAddr: &net.UDPAddr{ + IP: cutIP(ip), + Port: port, + }, + } +} + +// Ready return true, if the Enpoint is ready. +// Ready means that an IP or DN and port exists. +func (e *Endpoint) Ready() bool { + if e == nil { + return false + } + return (e.udpAddr != nil && e.udpAddr.IP != nil && e.udpAddr.Port > 0) || len(e.addr) > 0 +} + +// Port returns the port of the Endpoint. +func (e *Endpoint) Port() int { + if !e.Ready() { + return 0 + } + if e.udpAddr != nil { + return e.udpAddr.Port + } + // We can ignore the errors here bacause the returned port will be "". + // This will result to Port 0 after the conversion to and int. + _, p, _ := net.SplitHostPort(e.addr) + port, _ := strconv.ParseUint(p, 10, 32) + return int(port) +} + +// HasDNS returns true if the endpoint has a DN. +func (e *Endpoint) HasDNS() bool { + return e != nil && e.addr != "" +} + +// DNS returns the DN of the Endpoint. +func (e *Endpoint) DNS() string { + if e == nil { + return "" + } + _, s, _ := net.SplitHostPort(e.addr) + return s +} + +// Resolved returns true, if the DN of the Endpoint was resolved +// or if the Endpoint has a resolved endpoint. +func (e *Endpoint) Resolved() bool { + return e != nil && e.udpAddr != nil +} + +// UDPAddr returns the UDPAddr of the Endpoint. If resolve is false, +// UDPAddr() will not try to resolve a DN name, if the Endpoint is not yet resolved. +func (e *Endpoint) UDPAddr(resolve bool) (*net.UDPAddr, error) { + if !e.Ready() { + return nil, errors.New("Enpoint is not ready") + } + if e.udpAddr != nil { + // Make a copy of the UDPAddr to protect it from modification outside this package. + h := *e.udpAddr + return &h, nil + } + if !resolve { + return nil, errors.New("Endpoint is not resolved") + } + var err error + if e.udpAddr, err = net.ResolveUDPAddr("udp", e.addr); err != nil { + return nil, err + } + // Make a copy of the UDPAddr to protect it from modification outside this package. + h := *e.udpAddr + return &h, nil +} + +// IP returns the IP address of the Enpoint or nil. +func (e *Endpoint) IP() net.IP { + if !e.Resolved() { + return nil + } + return e.udpAddr.IP +} + +// String will return the endpoint as a string. +// If a DN exists, it will take prcedence over the resolved endpoint. +func (e *Endpoint) String() string { + return e.StringOpt(true) +} + +// StringOpt will return string of the Endpoint. +// If dnsFirst is false, the resolved Endpoint will +// take precedence over the DN. +func (e *Endpoint) StringOpt(dnsFirst bool) string { + if e == nil { + return "" + } + if e.udpAddr != nil && (!dnsFirst || e.addr == "") { + return e.udpAddr.String() + } + return e.addr +} + // Peer represents a `peer` section of a WireGuard configuration. type Peer struct { wgtypes.PeerConfig - Addr string // eg: dnsname:port + Endpoint *Endpoint } // DeduplicateIPs eliminates duplicate allowed IPs. @@ -109,7 +274,7 @@ func (c Conf) Bytes() ([]byte, error) { if err = writeAllowedIPs(buf, p.AllowedIPs); err != nil { return nil, fmt.Errorf("failed to write allowed IPs: %v", err) } - if err = writeEndpoint(buf, p.Endpoint, p.Addr); err != nil { + if err = writeEndpoint(buf, p.Endpoint); err != nil { return nil, fmt.Errorf("failed to write endpoint: %v", err) } if p.PersistentKeepaliveInterval == nil { @@ -158,8 +323,8 @@ func (c *Conf) Equal(d *wgtypes.Device) (bool, string) { if c.Peers[i].Endpoint == nil || d.Peers[i].Endpoint == nil { return c.Peers[i].Endpoint == nil && d.Peers[i].Endpoint == nil, "peer endpoints: nil value" } - if !c.Peers[i].Endpoint.IP.Equal(d.Peers[i].Endpoint.IP) || c.Peers[i].Endpoint.Port != d.Peers[i].Endpoint.Port { - return false, fmt.Sprintf("Peer %d endpoint: old=%q, new=%q", i, d.Peers[i].Endpoint.String(), c.Peers[i].Endpoint.String()) + if c.Peers[i].Endpoint.StringOpt(false) != d.Peers[i].Endpoint.String() { + return false, fmt.Sprintf("Peer %d endpoint: old=%q, new=%q", i, d.Peers[i].Endpoint.String(), c.Peers[i].Endpoint.StringOpt(false)) } pki := time.Duration(0) @@ -201,6 +366,13 @@ func sortCIDRs(cidrs []net.IPNet) { }) } +func cutIP(ip net.IP) net.IP { + if i4 := ip.To4(); i4 != nil { + return i4 + } + return ip.To16() +} + func writeAllowedIPs(buf *bytes.Buffer, ais []net.IPNet) error { if len(ais) == 0 { return nil @@ -248,13 +420,9 @@ func writeValue(buf *bytes.Buffer, k key, v string) error { return buf.WriteByte('\n') } -func writeEndpoint(buf *bytes.Buffer, e *net.UDPAddr, d string) error { - str := "" - if d != "" { - str = d - } else if e != nil { - str = e.String() - } else { +func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error { + str := e.String() + if str == "" { return nil } var err error