pkg/*: use wireguard.Enpoint
This commit introduces the wireguard.Enpoint struct. It encapsulates a DN name with port and a net.UPDAddr. The fields are private and only accessible over exported Methods to avoid accidental modification. Also iptables.GetProtocol is improved to avoid ipv4 rules being applied by `ip6tables`. Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
parent
b370ed3511
commit
08eea4f3c1
@ -1,473 +0,0 @@
|
||||
// +build linux
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
"github.com/oklog/run"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/client-go/tools/clientcmd"
|
||||
|
||||
"github.com/squat/kilo/pkg/iproute"
|
||||
"github.com/squat/kilo/pkg/k8s"
|
||||
"github.com/squat/kilo/pkg/k8s/apis/kilo/v1alpha1"
|
||||
kiloclient "github.com/squat/kilo/pkg/k8s/clientset/versioned"
|
||||
"github.com/squat/kilo/pkg/mesh"
|
||||
"github.com/squat/kilo/pkg/route"
|
||||
"github.com/squat/kilo/pkg/wireguard"
|
||||
)
|
||||
|
||||
func takeIPNet(_ net.IP, i *net.IPNet, _ error) *net.IPNet {
|
||||
return i
|
||||
}
|
||||
|
||||
func connect() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "connect",
|
||||
RunE: connectAsPeer,
|
||||
Short: "connect to a Kilo cluster as a peer over WireGuard",
|
||||
}
|
||||
cmd.Flags().IPNetP("allowed-ip", "a", *takeIPNet(net.ParseCIDR("10.10.10.10/32")), "Allowed IP of the peer")
|
||||
cmd.Flags().IPNetP("service-cidr", "c", *takeIPNet(net.ParseCIDR("10.43.0.0/16")), "service CIDR of the cluster")
|
||||
cmd.Flags().String("log-level", logLevelInfo, fmt.Sprintf("Log level to use. Possible values: %s", availableLogLevels))
|
||||
cmd.Flags().String("config-path", "/tmp/wg.ini", "path to WireGuard configuation file")
|
||||
cmd.Flags().Bool("clean-up", true, "clean up routes and interface")
|
||||
cmd.Flags().Uint("mtu", uint(1420), "clean up routes and interface")
|
||||
cmd.Flags().Duration("resync-period", 30*time.Second, "How often should Kilo reconcile?")
|
||||
|
||||
availableLogLevels = strings.Join([]string{
|
||||
logLevelAll,
|
||||
logLevelDebug,
|
||||
logLevelInfo,
|
||||
logLevelWarn,
|
||||
logLevelError,
|
||||
logLevelNone,
|
||||
}, ", ")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func connectAsPeer(cmd *cobra.Command, args []string) error {
|
||||
var cancel context.CancelFunc
|
||||
resyncPersiod, err := cmd.Flags().GetDuration("resync-period")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mtu, err := cmd.Flags().GetUint("mtu")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
configPath, err := cmd.Flags().GetString("config-path")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serviceCIDR, err := cmd.Flags().GetIPNet("service-cidr")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
allowedIP, err := cmd.Flags().GetIPNet("allowed-ip")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger := log.NewJSONLogger(log.NewSyncWriter(os.Stdout))
|
||||
logLevel, err := cmd.Flags().GetString("log-level")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch logLevel {
|
||||
case logLevelAll:
|
||||
logger = level.NewFilter(logger, level.AllowAll())
|
||||
case logLevelDebug:
|
||||
logger = level.NewFilter(logger, level.AllowDebug())
|
||||
case logLevelInfo:
|
||||
logger = level.NewFilter(logger, level.AllowInfo())
|
||||
case logLevelWarn:
|
||||
logger = level.NewFilter(logger, level.AllowWarn())
|
||||
case logLevelError:
|
||||
logger = level.NewFilter(logger, level.AllowError())
|
||||
case logLevelNone:
|
||||
logger = level.NewFilter(logger, level.AllowNone())
|
||||
default:
|
||||
return fmt.Errorf("log level %s unknown; possible values are: %s", logLevel, availableLogLevels)
|
||||
}
|
||||
logger = log.With(logger, "ts", log.DefaultTimestampUTC)
|
||||
logger = log.With(logger, "caller", log.DefaultCaller)
|
||||
peername := "random"
|
||||
if len(args) > 0 {
|
||||
peername = args[0]
|
||||
}
|
||||
|
||||
var kiloClient *kiloclient.Clientset
|
||||
switch backend {
|
||||
case k8s.Backend:
|
||||
config, err := clientcmd.BuildConfigFromFlags("", kubeconfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Kubernetes config: %v", err)
|
||||
}
|
||||
kiloClient = kiloclient.NewForConfigOrDie(config)
|
||||
default:
|
||||
return fmt.Errorf("backend %v unknown; posible values are: %s", backend, availableBackends)
|
||||
}
|
||||
privateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
publicKey := privateKey.PublicKey()
|
||||
level.Info(logger).Log("msg", "generated public key", "key", publicKey)
|
||||
|
||||
peer := &v1alpha1.Peer{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: peername,
|
||||
},
|
||||
Spec: v1alpha1.PeerSpec{
|
||||
AllowedIPs: []string{allowedIP.String()},
|
||||
PersistentKeepalive: 10,
|
||||
PublicKey: publicKey.String(),
|
||||
},
|
||||
}
|
||||
if p, err := kiloClient.KiloV1alpha1().Peers().Get(context.TODO(), peername, metav1.GetOptions{}); err != nil || p == nil {
|
||||
peer, err = kiloClient.KiloV1alpha1().Peers().Create(context.TODO(), peer, metav1.CreateOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create peer: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
kiloIfaceName := "kilo0"
|
||||
|
||||
iface, _, err := wireguard.New(kiloIfaceName, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create wg interface: %w", err)
|
||||
} else {
|
||||
level.Info(logger).Log("msg", "successfully created wg interface", "name", kiloIfaceName, "no", iface)
|
||||
}
|
||||
if err := iproute.Set(iface, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := iproute.SetAddress(iface, &allowedIP); err != nil {
|
||||
return err
|
||||
} else {
|
||||
level.Info(logger).Log("mag", "successfully set IP address of wg interface", "IP", allowedIP.String())
|
||||
}
|
||||
|
||||
if err := iproute.Set(iface, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var g run.Group
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
g.Add(run.SignalHandler(ctx, syscall.SIGINT, syscall.SIGTERM))
|
||||
|
||||
table := route.NewTable()
|
||||
stop := make(chan struct{}, 1)
|
||||
errCh := make(<-chan error, 1)
|
||||
{
|
||||
ch := make(chan struct{})
|
||||
g.Add(
|
||||
func() error {
|
||||
for {
|
||||
select {
|
||||
case err, ok := <-errCh:
|
||||
if ok {
|
||||
level.Error(logger).Log("err", err.Error())
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
case <-ch:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
},
|
||||
func(err error) {
|
||||
ch <- struct{}{}
|
||||
close(ch)
|
||||
stop <- struct{}{}
|
||||
close(stop)
|
||||
level.Error(logger).Log("msg", "stopped ip routes table", "err", err.Error())
|
||||
},
|
||||
)
|
||||
}
|
||||
{
|
||||
ch := make(chan struct{})
|
||||
g.Add(
|
||||
func() error {
|
||||
for {
|
||||
ns, err := opts.backend.Nodes().List()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list nodes: %v", err)
|
||||
}
|
||||
ps, err := opts.backend.Peers().List()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list peers: %v", err)
|
||||
}
|
||||
// Obtain the Granularity by looking at the annotation of the first node.
|
||||
if opts.granularity, err = optainGranularity(opts.granularity, ns); err != nil {
|
||||
return fmt.Errorf("failed to obtain granularity: %w", err)
|
||||
}
|
||||
var hostname string
|
||||
subnet := mesh.DefaultKiloSubnet
|
||||
nodes := make(map[string]*mesh.Node)
|
||||
for _, n := range ns {
|
||||
if n.Ready() {
|
||||
nodes[n.Name] = n
|
||||
hostname = n.Name
|
||||
}
|
||||
if n.WireGuardIP != nil {
|
||||
subnet = n.WireGuardIP
|
||||
}
|
||||
}
|
||||
subnet.IP = subnet.IP.Mask(subnet.Mask)
|
||||
if len(nodes) == 0 {
|
||||
return errors.New("did not find any valid Kilo nodes in the cluster")
|
||||
}
|
||||
peers := make(map[string]*mesh.Peer)
|
||||
for _, p := range ps {
|
||||
if p.Ready() {
|
||||
peers[p.Name] = p
|
||||
}
|
||||
}
|
||||
if _, ok := peers[peername]; !ok {
|
||||
return fmt.Errorf("did not find any peer named %q in the cluster", peername)
|
||||
}
|
||||
|
||||
t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, opts.port, []byte{}, subnet, peers[peername].PersistentKeepalive, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create topology: %v", err)
|
||||
}
|
||||
conf := t.PeerConf(peername)
|
||||
conf.Interface = &wireguard.Interface{
|
||||
PrivateKey: []byte(privateKey.String()),
|
||||
ListenPort: uint32(55555),
|
||||
}
|
||||
buf, err := conf.Bytes()
|
||||
if err != nil {
|
||||
//level.Error(m.logger).Log("error", err)
|
||||
//m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return err
|
||||
}
|
||||
if err := ioutil.WriteFile("/tmp/wg.ini", buf, 0600); err != nil {
|
||||
//level.Error(m.logger).Log("error", err)
|
||||
//m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return err
|
||||
}
|
||||
if err := wireguard.SetConf(kiloIfaceName, "/tmp/wg.ini"); err != nil {
|
||||
return err
|
||||
}
|
||||
wgClient, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize wg Client: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
wgClient.Close()
|
||||
}()
|
||||
wgConf := wgtypes.Config{
|
||||
PrivateKey: &privateKey,
|
||||
// Peers: []wgtypes.PeerConfig{},
|
||||
}
|
||||
if err := wgClient.ConfigureDevice(kiloIfaceName, wgConf); err != nil {
|
||||
return fmt.Errorf("failed to configure wg interface: %w", err)
|
||||
}
|
||||
|
||||
var routes []*netlink.Route
|
||||
for _, segment := range t.Segments() {
|
||||
for i := range segment.CIDRS() {
|
||||
// Add routes to the Pod CIDRs of nodes in other segments.
|
||||
routes = append(routes, &netlink.Route{
|
||||
Dst: segment.CIDRS()[i],
|
||||
//Flags: int(netlink.FLAG_ONLINK),
|
||||
LinkIndex: iface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
})
|
||||
}
|
||||
for i := range segment.PrivateIPs() {
|
||||
// Add routes to the private IPs of nodes in other segments.
|
||||
routes = append(routes, &netlink.Route{
|
||||
Dst: mesh.OneAddressCIDR(segment.PrivateIPs()[i]),
|
||||
//Flags: int(netlink.FLAG_ONLINK),
|
||||
LinkIndex: iface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
})
|
||||
}
|
||||
// For segments / locations other than the location of this instance of kg,
|
||||
// we need to set routes for allowed location IPs over the leader in the current location.
|
||||
for i := range segment.AllowedLocationIPs() {
|
||||
routes = append(routes, &netlink.Route{
|
||||
Dst: segment.AllowedLocationIPs()[i],
|
||||
//Flags: int(netlink.FLAG_ONLINK),
|
||||
LinkIndex: iface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
})
|
||||
}
|
||||
routes = append(routes, &netlink.Route{
|
||||
Dst: mesh.OneAddressCIDR(segment.WireGuardIP()),
|
||||
//Flags: int(netlink.FLAG_ONLINK),
|
||||
LinkIndex: iface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
})
|
||||
}
|
||||
// Add routes for the allowed IPs of peers.
|
||||
for _, peer := range t.Peers() {
|
||||
for i := range peer.AllowedIPs {
|
||||
routes = append(routes, &netlink.Route{
|
||||
Dst: peer.AllowedIPs[i],
|
||||
//Flags: int(netlink.FLAG_ONLINK),
|
||||
LinkIndex: iface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
})
|
||||
}
|
||||
}
|
||||
routes = append(routes, &netlink.Route{
|
||||
Dst: &serviceCIDR,
|
||||
//Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: nil,
|
||||
LinkIndex: iface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
})
|
||||
for _, r := range routes {
|
||||
fmt.Println(r)
|
||||
}
|
||||
|
||||
level.Debug(logger).Log("routes", routes)
|
||||
if err := table.Set(routes, []*netlink.Rule{}); err != nil {
|
||||
return fmt.Errorf("failed to set ip routes table: %w", err)
|
||||
}
|
||||
errCh, err = table.Run(stop)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start ip routes tables: %w", err)
|
||||
}
|
||||
select {
|
||||
case <-time.After(resyncPersiod):
|
||||
case <-ch:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}, func(err error) {
|
||||
// Cancel the root context in the very end.
|
||||
defer func() {
|
||||
cancel()
|
||||
level.Debug(logger).Log("msg", "canceled parent context")
|
||||
}()
|
||||
ch <- struct{}{}
|
||||
var serr run.SignalError
|
||||
if ok := errors.As(err, &serr); ok {
|
||||
level.Info(logger).Log("msg", "received signal", "signal", serr.Signal.String(), "err", err.Error())
|
||||
} else {
|
||||
level.Error(logger).Log("msg", "received error", "err", err.Error())
|
||||
}
|
||||
level.Debug(logger).Log("msg", "stoped ip routes table")
|
||||
ctxWithTimeOut, cancelWithTimeOut := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer func() {
|
||||
cancelWithTimeOut()
|
||||
level.Debug(logger).Log("msg", "canceled timed context")
|
||||
}()
|
||||
if err := kiloClient.KiloV1alpha1().Peers().Delete(ctxWithTimeOut, peername, metav1.DeleteOptions{}); err != nil {
|
||||
level.Error(logger).Log("failed to delete peer: %w", err)
|
||||
} else {
|
||||
level.Info(logger).Log("msg", "deleted peer", "peer", peername)
|
||||
}
|
||||
if ok, err := cmd.Flags().GetBool("clean-up"); err != nil {
|
||||
level.Error(logger).Log("err", err.Error(), "msg", "failed to get value from clean-up flag")
|
||||
} else if ok {
|
||||
cleanUp(iface, table, configPath, logger)
|
||||
}
|
||||
})
|
||||
}
|
||||
err = g.Run()
|
||||
var serr run.SignalError
|
||||
if ok := errors.As(err, &serr); ok {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func cleanUp(iface int, t *route.Table, configPath string, logger log.Logger) {
|
||||
if err := iproute.Set(iface, false); err != nil {
|
||||
level.Error(logger).Log("err", err.Error(), "msg", "failed to set down wg interface")
|
||||
}
|
||||
if err := os.Remove(configPath); err != nil {
|
||||
level.Error(logger).Log("error", fmt.Sprintf("failed to delete configuration file: %v", err))
|
||||
}
|
||||
if err := iproute.RemoveInterface(iface); err != nil {
|
||||
level.Error(logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err))
|
||||
}
|
||||
if err := t.CleanUp(); err != nil {
|
||||
level.Error(logger).Log("failed to clean up routes: %w", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//func nodeToWGPeer(n *mesh.Node) (ret wgtypes.PeerConfig, err error) {
|
||||
// pubKey, err := wgtypes.ParseKey(string(n.Key))
|
||||
// if err != nil {
|
||||
// return ret, err
|
||||
// }
|
||||
// ret.PublicKey = pubKey
|
||||
// if err != nil {
|
||||
// return ret, err
|
||||
// }
|
||||
// aIPs := []net.IPNet{*n.WireGuardIP, *n.InternalIP, *n.Subnet}
|
||||
// for _, a := range n.AllowedLocationIPs {
|
||||
// aIPs = append(aIPs, *a)
|
||||
// }
|
||||
// ret.AllowedIPs = aIPs
|
||||
//
|
||||
// dur := time.Second * time.Duration(n.PersistentKeepalive)
|
||||
// ret.PersistentKeepaliveInterval = &dur
|
||||
//
|
||||
// udpEndpoint, err := net.ResolveUDPAddr("udp", n.Endpoint.String())
|
||||
// if err != nil {
|
||||
// return ret, err
|
||||
// } else if udpEndpoint.IP == nil {
|
||||
// udpEndpoint = nil
|
||||
// }
|
||||
// ret.Endpoint = udpEndpoint
|
||||
// return ret, nil
|
||||
//}
|
||||
//
|
||||
//func toWGPeer(p wireguard.Peer) (ret wgtypes.PeerConfig, err error) {
|
||||
// pubKey, err := wgtypes.ParseKey(string(p.PublicKey))
|
||||
// if err != nil {
|
||||
// return ret, err
|
||||
// }
|
||||
// ret.PublicKey = pubKey
|
||||
// aIPs := make([]net.IPNet, len(p.AllowedIPs))
|
||||
// for i, a := range p.AllowedIPs {
|
||||
// aIPs[i] = *a
|
||||
// }
|
||||
// ret.AllowedIPs = aIPs
|
||||
//
|
||||
// if preSharedKey, err := wgtypes.ParseKey(string(p.PresharedKey)); len(p.PresharedKey) > 0 && err != nil {
|
||||
// return ret, err
|
||||
// } else {
|
||||
// ret.PresharedKey = &preSharedKey
|
||||
// }
|
||||
// dur := time.Second * time.Duration(p.PersistentKeepalive)
|
||||
// ret.PersistentKeepaliveInterval = &dur
|
||||
//
|
||||
// udpEndpoint, err := net.ResolveUDPAddr("udp", p.Endpoint.String())
|
||||
// if err != nil {
|
||||
// return ret, err
|
||||
// } else if udpEndpoint.IP == nil {
|
||||
// udpEndpoint = nil
|
||||
// }
|
||||
// fmt.Println("peer")
|
||||
// fmt.Println(aIPs)
|
||||
// ret.Endpoint = udpEndpoint
|
||||
// return ret, nil
|
||||
//}
|
@ -1,20 +0,0 @@
|
||||
// +build !linux
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func connect() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "connect",
|
||||
Short: "not supporred on you OS",
|
||||
RunE: func(_ *cobra.Command, _ []string) error {
|
||||
return errors.New("this command is not supported on your OS")
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
@ -289,6 +289,7 @@ func runShowConfPeer(_ *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
// translatePeer translates a wireguard.Peer to a Peer CRD.
|
||||
// TODO this function has many similarities to peerBackend.Set(name, peer)
|
||||
func translatePeer(peer *wireguard.Peer) *v1alpha1.Peer {
|
||||
if peer == nil {
|
||||
return &v1alpha1.Peer{}
|
||||
@ -303,21 +304,13 @@ func translatePeer(peer *wireguard.Peer) *v1alpha1.Peer {
|
||||
aips = append(aips, aip.String())
|
||||
}
|
||||
var endpoint *v1alpha1.PeerEndpoint
|
||||
if (peer.Endpoint != nil && peer.Endpoint.Port > 0) || peer.Addr != "" {
|
||||
var ip string
|
||||
if peer.Endpoint.IP != nil {
|
||||
ip = peer.Endpoint.IP.String()
|
||||
}
|
||||
var dns string
|
||||
if strs := strings.Split(peer.Addr, ":"); len(strs) == 2 && strs[0] != "" {
|
||||
dns = strs[0]
|
||||
}
|
||||
if peer.Endpoint.Port() > 0 || !peer.Endpoint.HasDNS() {
|
||||
endpoint = &v1alpha1.PeerEndpoint{
|
||||
DNSOrIP: v1alpha1.DNSOrIP{
|
||||
DNS: dns,
|
||||
IP: ip,
|
||||
IP: peer.Endpoint.IP().String(),
|
||||
DNS: peer.Endpoint.DNS(),
|
||||
},
|
||||
Port: uint32(peer.Endpoint.Port),
|
||||
Port: uint32(peer.Endpoint.Port()),
|
||||
}
|
||||
}
|
||||
var key string
|
||||
|
@ -18,7 +18,7 @@ test_full_mesh_connectivity() {
|
||||
}
|
||||
|
||||
test_full_mesh_peer() {
|
||||
check_peer wg1 e2e 10.5.0.1/32 full
|
||||
check_peer wg99 e2e 10.5.0.1/32 full
|
||||
}
|
||||
|
||||
test_full_mesh_allowed_location_ips() {
|
||||
|
@ -18,7 +18,7 @@ test_location_mesh_connectivity() {
|
||||
}
|
||||
|
||||
test_location_mesh_peer() {
|
||||
check_peer wg1 e2e 10.5.0.1/32 location
|
||||
check_peer wg99 e2e 10.5.0.1/32 location
|
||||
}
|
||||
|
||||
test_mesh_granularity_auto_detect() {
|
||||
|
@ -74,7 +74,7 @@ func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule {
|
||||
rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
|
||||
for _, n := range nodes {
|
||||
// Accept encapsulated traffic from peers.
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(n.IP)), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
|
||||
}
|
||||
// Drop all other IPIP traffic.
|
||||
rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
|
||||
|
@ -53,11 +53,11 @@ const (
|
||||
)
|
||||
|
||||
// GetProtocol will return a protocol from the length of an IP address.
|
||||
func GetProtocol(length int) Protocol {
|
||||
if length == net.IPv6len {
|
||||
return ProtocolIPv6
|
||||
func GetProtocol(ip net.IP) Protocol {
|
||||
if len(ip) == net.IPv4len || ip.To4() != nil {
|
||||
return ProtocolIPv4
|
||||
}
|
||||
return ProtocolIPv4
|
||||
return ProtocolIPv6
|
||||
}
|
||||
|
||||
// Client represents any type that can administer iptables rules.
|
||||
|
@ -32,7 +32,6 @@ import (
|
||||
"k8s.io/apimachinery/pkg/labels"
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
"k8s.io/apimachinery/pkg/util/strategicpatch"
|
||||
"k8s.io/apimachinery/pkg/util/validation"
|
||||
v1informers "k8s.io/client-go/informers/core/v1"
|
||||
"k8s.io/client-go/kubernetes"
|
||||
v1listers "k8s.io/client-go/listers/core/v1"
|
||||
@ -277,9 +276,9 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
|
||||
location = node.ObjectMeta.Labels[topologyLabel]
|
||||
}
|
||||
// Allow the endpoint to be overridden.
|
||||
endpoint, addr := parseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey])
|
||||
if endpoint == nil && addr == "" {
|
||||
endpoint, addr = parseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey])
|
||||
endpoint := wireguard.ParseEndpoint(node.ObjectMeta.Annotations[forceEndpointAnnotationKey])
|
||||
if endpoint == nil {
|
||||
endpoint = wireguard.ParseEndpoint(node.ObjectMeta.Annotations[endpointAnnotationKey])
|
||||
}
|
||||
// Allow the internal IP to be overridden.
|
||||
internalIP := normalizeIP(node.ObjectMeta.Annotations[forceInternalIPAnnotationKey])
|
||||
@ -345,7 +344,6 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
|
||||
// It is valid for the InternalIP to be nil,
|
||||
// if the given node only has public IP addresses.
|
||||
Endpoint: endpoint,
|
||||
Addr: addr,
|
||||
NoInternalIP: noInternalIP,
|
||||
InternalIP: internalIP,
|
||||
Key: key,
|
||||
@ -379,8 +377,7 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer {
|
||||
}
|
||||
aips = append(aips, *aip)
|
||||
}
|
||||
var endpoint *net.UDPAddr
|
||||
var addr string
|
||||
var endpoint *wireguard.Endpoint
|
||||
if peer.Spec.Endpoint != nil {
|
||||
ip := net.ParseIP(peer.Spec.Endpoint.IP)
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
@ -390,10 +387,10 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer {
|
||||
}
|
||||
if peer.Spec.Endpoint.Port > 0 {
|
||||
if ip != nil {
|
||||
endpoint = &net.UDPAddr{IP: ip, Port: int(peer.Spec.Endpoint.Port)}
|
||||
endpoint = wireguard.NewEndpoint(ip, int(peer.Spec.Endpoint.Port))
|
||||
}
|
||||
if peer.Spec.Endpoint.DNS != "" {
|
||||
addr = fmt.Sprintf("%s:%d", peer.Spec.Endpoint.DNS, peer.Spec.Endpoint.Port)
|
||||
endpoint = wireguard.ParseEndpoint(fmt.Sprintf("%s:%d", peer.Spec.Endpoint.DNS, peer.Spec.Endpoint.Port))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -415,12 +412,11 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer {
|
||||
Peer: wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: aips,
|
||||
Endpoint: endpoint, // applyTopology will resolve this endpoint from the KiloEndpoint.
|
||||
PersistentKeepaliveInterval: &pka,
|
||||
PresharedKey: psk,
|
||||
PublicKey: key,
|
||||
},
|
||||
Addr: addr,
|
||||
Endpoint: endpoint,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -518,22 +514,12 @@ func (pb *peerBackend) Set(name string, peer *mesh.Peer) error {
|
||||
p.Spec.AllowedIPs[i] = peer.AllowedIPs[i].String()
|
||||
}
|
||||
if peer.Endpoint != nil {
|
||||
var ip string
|
||||
if peer.Endpoint.IP != nil {
|
||||
ip = peer.Endpoint.IP.String()
|
||||
}
|
||||
var dns string
|
||||
if peer.Addr != "" {
|
||||
if strs := strings.Split(peer.Addr, ":"); len(strs) == 2 && strs[0] != "" {
|
||||
dns = strs[0]
|
||||
}
|
||||
}
|
||||
p.Spec.Endpoint = &v1alpha1.PeerEndpoint{
|
||||
DNSOrIP: v1alpha1.DNSOrIP{
|
||||
IP: ip,
|
||||
DNS: dns,
|
||||
IP: peer.Endpoint.IP().String(),
|
||||
DNS: peer.Endpoint.DNS(),
|
||||
},
|
||||
Port: uint32(peer.Endpoint.Port),
|
||||
Port: uint32(peer.Endpoint.Port()),
|
||||
}
|
||||
}
|
||||
if peer.PersistentKeepaliveInterval == nil {
|
||||
@ -570,34 +556,3 @@ func normalizeIP(ip string) *net.IPNet {
|
||||
ipNet.IP = i.To16()
|
||||
return ipNet
|
||||
}
|
||||
|
||||
func parseEndpoint(endpoint string) (*net.UDPAddr, string) {
|
||||
if len(endpoint) == 0 {
|
||||
return nil, ""
|
||||
}
|
||||
parts := strings.Split(endpoint, ":")
|
||||
if len(parts) < 2 {
|
||||
return nil, ""
|
||||
}
|
||||
portRaw := parts[len(parts)-1]
|
||||
hostRaw := strings.Trim(strings.Join(parts[:len(parts)-1], ":"), "[]")
|
||||
port, err := strconv.ParseUint(portRaw, 10, 32)
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
}
|
||||
if len(validation.IsValidPortNum(int(port))) != 0 {
|
||||
return nil, ""
|
||||
}
|
||||
ip := net.ParseIP(hostRaw)
|
||||
if ip == nil {
|
||||
if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 {
|
||||
return nil, endpoint
|
||||
}
|
||||
return nil, ""
|
||||
}
|
||||
u, err := net.ResolveUDPAddr("udp", endpoint)
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
}
|
||||
return u, ""
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
// Copyright 2021 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -80,8 +80,8 @@ func TestTranslateNode(t *testing.T) {
|
||||
internalIPAnnotationKey: "10.0.0.2/32",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort},
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2"), Mask: net.CIDRMask(32, 32)},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort),
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.0.0.2").To4(), Mask: net.CIDRMask(32, 32)},
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -91,8 +91,8 @@ func TestTranslateNode(t *testing.T) {
|
||||
internalIPAnnotationKey: "ff60::10/64",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &net.UDPAddr{IP: net.ParseIP("ff10::10"), Port: mesh.DefaultKiloPort},
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("ff60::10"), Mask: net.CIDRMask(64, 128)},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("ff10::10").To16(), mesh.DefaultKiloPort),
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("ff60::10").To16(), Mask: net.CIDRMask(64, 128)},
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -105,7 +105,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
name: "normalize subnet",
|
||||
annotations: map[string]string{},
|
||||
out: &mesh.Node{
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(24, 32)},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0").To4(), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
subnet: "10.2.0.1/24",
|
||||
},
|
||||
@ -113,7 +113,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
name: "valid subnet",
|
||||
annotations: map[string]string{},
|
||||
out: &mesh.Node{
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0").To4(), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
subnet: "10.2.1.0/24",
|
||||
},
|
||||
@ -145,7 +145,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
forceEndpointAnnotationKey: "-10.0.0.2:51821",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort),
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -155,7 +155,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
forceEndpointAnnotationKey: "10.0.0.2:51821",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.2"), Port: 51821},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.2").To4(), 51821),
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -174,7 +174,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
forceInternalIPAnnotationKey: "-10.1.0.2/24",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.1"), Mask: net.CIDRMask(24, 32)},
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.1").To4(), Mask: net.CIDRMask(24, 32)},
|
||||
NoInternalIP: false,
|
||||
},
|
||||
},
|
||||
@ -185,7 +185,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
forceInternalIPAnnotationKey: "10.1.0.2/24",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(24, 32)},
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(24, 32)},
|
||||
NoInternalIP: false,
|
||||
},
|
||||
},
|
||||
@ -214,16 +214,16 @@ func TestTranslateNode(t *testing.T) {
|
||||
RegionLabelKey: "a",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.2"), Port: 51821},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.2").To4(), 51821),
|
||||
NoInternalIP: false,
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)},
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2").To4(), Mask: net.CIDRMask(32, 32)},
|
||||
Key: fooKey,
|
||||
LastSeen: 1000000000,
|
||||
Leader: true,
|
||||
Location: "b",
|
||||
PersistentKeepalive: 25 * time.Second,
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
|
||||
WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1"), Mask: net.CIDRMask(16, 32)},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0").To4(), Mask: net.CIDRMask(24, 32)},
|
||||
WireGuardIP: &net.IPNet{IP: net.ParseIP("10.4.0.1").To4(), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
subnet: "10.2.1.0/24",
|
||||
},
|
||||
@ -245,7 +245,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
RegionLabelKey: "a",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &net.UDPAddr{IP: net.ParseIP("1100::10"), Port: 51821},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("1100::10"), 51821),
|
||||
NoInternalIP: false,
|
||||
InternalIP: &net.IPNet{IP: net.ParseIP("10.1.0.2"), Mask: net.CIDRMask(32, 32)},
|
||||
Key: fooKey,
|
||||
@ -273,7 +273,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
RegionLabelKey: "a",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 51820},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1"), 51820),
|
||||
InternalIP: nil,
|
||||
Key: fooKey,
|
||||
LastSeen: 1000000000,
|
||||
@ -301,7 +301,7 @@ func TestTranslateNode(t *testing.T) {
|
||||
RegionLabelKey: "a",
|
||||
},
|
||||
out: &mesh.Node{
|
||||
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 51820},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1"), 51820),
|
||||
NoInternalIP: true,
|
||||
InternalIP: nil,
|
||||
Key: fooKey,
|
||||
@ -424,9 +424,9 @@ func TestTranslatePeer(t *testing.T) {
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
Endpoint: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort},
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("10.0.0.1").To4(), mesh.DefaultKiloPort),
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -443,9 +443,9 @@ func TestTranslatePeer(t *testing.T) {
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
Endpoint: &net.UDPAddr{IP: net.ParseIP("ff60::2"), Port: mesh.DefaultKiloPort},
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("ff60::2").To16(), mesh.DefaultKiloPort),
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -461,7 +461,7 @@ func TestTranslatePeer(t *testing.T) {
|
||||
},
|
||||
out: &mesh.Peer{
|
||||
Peer: wireguard.Peer{
|
||||
Addr: "example.com:51820",
|
||||
Endpoint: wireguard.ParseEndpoint("example.com:51820"),
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
PersistentKeepaliveInterval: &zero,
|
||||
},
|
||||
@ -544,64 +544,3 @@ func TestTranslatePeer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEndpoint(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
endpoint string
|
||||
udp *net.UDPAddr
|
||||
addr string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
endpoint: "",
|
||||
udp: nil,
|
||||
addr: "",
|
||||
},
|
||||
{
|
||||
name: "invalid IP",
|
||||
endpoint: "10.0.0.:51820",
|
||||
udp: nil,
|
||||
addr: "",
|
||||
},
|
||||
{
|
||||
name: "invalid hostname",
|
||||
endpoint: "foo-:51820",
|
||||
udp: nil,
|
||||
addr: "",
|
||||
},
|
||||
{
|
||||
name: "invalid port",
|
||||
endpoint: "10.0.0.1:100000000",
|
||||
udp: nil,
|
||||
addr: "",
|
||||
},
|
||||
{
|
||||
name: "valid IP",
|
||||
endpoint: "10.0.0.1:51820",
|
||||
udp: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: mesh.DefaultKiloPort},
|
||||
addr: "",
|
||||
},
|
||||
{
|
||||
name: "valid IPv6",
|
||||
endpoint: "[ff02::114]:51820",
|
||||
udp: &net.UDPAddr{IP: net.ParseIP("ff02::114"), Port: mesh.DefaultKiloPort},
|
||||
addr: "",
|
||||
},
|
||||
{
|
||||
name: "valid hostname",
|
||||
endpoint: "foo:51821",
|
||||
udp: nil,
|
||||
addr: "foo:51821",
|
||||
},
|
||||
} {
|
||||
udp, addr := parseEndpoint(tc.endpoint)
|
||||
if diff := pretty.Compare(udp, tc.udp); diff != "" {
|
||||
t.Errorf("test case %q: got diff: %v", tc.name, diff)
|
||||
}
|
||||
if addr != tc.addr {
|
||||
t.Errorf("test case %q: got: %q, wants: %q", tc.name, addr, tc.addr)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -56,8 +56,7 @@ const (
|
||||
|
||||
// Node represents a node in the network.
|
||||
type Node struct {
|
||||
Endpoint *net.UDPAddr
|
||||
Addr string // eg. dnsname:port
|
||||
Endpoint *wireguard.Endpoint
|
||||
Key wgtypes.Key
|
||||
NoInternalIP bool
|
||||
InternalIP *net.IPNet
|
||||
@ -82,7 +81,7 @@ type Node struct {
|
||||
func (n *Node) Ready() bool {
|
||||
// Nodes that are not leaders will not have WireGuardIPs, so it is not required.
|
||||
return n != nil &&
|
||||
(n.Endpoint != nil || n.Addr != "") &&
|
||||
n.Endpoint.Ready() &&
|
||||
n.Key != wgtypes.Key{} &&
|
||||
n.Subnet != nil &&
|
||||
time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second)
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2019 the Kilo authors
|
||||
// Copyright 2021 the Kilo authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -20,6 +20,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/awalterschulze/gographviz"
|
||||
|
||||
"github.com/squat/kilo/pkg/wireguard"
|
||||
)
|
||||
|
||||
// Dot generates a Graphviz graph of the Topology in DOT fomat.
|
||||
@ -61,7 +63,7 @@ func (t *Topology) Dot() (string, error) {
|
||||
return "", fmt.Errorf("failed to add node to subgraph")
|
||||
}
|
||||
var wg net.IP
|
||||
var endpoint *net.UDPAddr
|
||||
var endpoint *wireguard.Endpoint
|
||||
if j == s.leader {
|
||||
wg = s.wireGuardIP
|
||||
endpoint = s.endpoint
|
||||
@ -73,7 +75,7 @@ func (t *Topology) Dot() (string, error) {
|
||||
if s.privateIPs != nil {
|
||||
priv = s.privateIPs[j]
|
||||
}
|
||||
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], priv, wg, endpoint, s.addr)); err != nil {
|
||||
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], priv, wg, endpoint)); err != nil {
|
||||
return "", fmt.Errorf("failed to add label to node")
|
||||
}
|
||||
}
|
||||
@ -153,7 +155,7 @@ func subGraphName(name string) string {
|
||||
return graphEscape(fmt.Sprintf("cluster_location_%s", name))
|
||||
}
|
||||
|
||||
func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoint *net.UDPAddr, addr string) string {
|
||||
func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoint *wireguard.Endpoint) string {
|
||||
label := []string{
|
||||
location,
|
||||
name,
|
||||
@ -165,12 +167,7 @@ func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoi
|
||||
if wgIP != nil {
|
||||
label = append(label, wgIP.String())
|
||||
}
|
||||
var str string
|
||||
if addr != "" {
|
||||
str = addr
|
||||
} else if endpoint != nil {
|
||||
str = endpoint.String()
|
||||
}
|
||||
str := endpoint.String()
|
||||
if str != "" {
|
||||
label = append(label, str)
|
||||
}
|
||||
|
@ -370,8 +370,10 @@ func (m *Mesh) checkIn() {
|
||||
|
||||
func (m *Mesh) handleLocal(n *Node) {
|
||||
// Allow the IPs to be overridden.
|
||||
if n.Endpoint == nil || n.Addr == "" {
|
||||
n.Endpoint = &net.UDPAddr{IP: m.externalIP.IP, Port: m.port}
|
||||
if !n.Endpoint.Ready() {
|
||||
e := wireguard.NewEndpoint(m.externalIP.IP, m.port)
|
||||
level.Info(m.logger).Log("msg", "overriding endpoint", "node", m.hostname, "old endpoint", n.Endpoint.String(), "new endpoint", e.String())
|
||||
n.Endpoint = e
|
||||
}
|
||||
if n.InternalIP == nil && !n.NoInternalIP {
|
||||
n.InternalIP = m.internalIP
|
||||
@ -484,7 +486,7 @@ func (m *Mesh) applyTopology() {
|
||||
|
||||
natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger)
|
||||
nodes[m.hostname].DiscoveredEndpoints = natEndpoints
|
||||
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
|
||||
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port(), m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
@ -623,15 +625,8 @@ func (m *Mesh) resolveEndpoints() error {
|
||||
if !m.nodes[k].Ready() {
|
||||
continue
|
||||
}
|
||||
// If the node is ready, then the endpoint is not nil
|
||||
// but it may not have a DNS name.
|
||||
if m.nodes[k].Addr == "" {
|
||||
continue
|
||||
}
|
||||
if u, err := net.ResolveUDPAddr("udp", m.nodes[k].Addr); err == nil {
|
||||
m.nodes[k].Endpoint = u
|
||||
m.nodes[k].Endpoint.IP = u.IP
|
||||
} else {
|
||||
// Resolve the Endpoint
|
||||
if _, err := m.nodes[k].Endpoint.UDPAddr(true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -642,12 +637,10 @@ func (m *Mesh) resolveEndpoints() error {
|
||||
continue
|
||||
}
|
||||
// Peers may have nil endpoints.
|
||||
if m.peers[k].Addr == "" {
|
||||
if !m.peers[k].Endpoint.Ready() {
|
||||
continue
|
||||
}
|
||||
if u, err := net.ResolveUDPAddr("udp", m.peers[k].Addr); err == nil {
|
||||
m.peers[k].Endpoint = u
|
||||
} else {
|
||||
if _, err := m.peers[k].Endpoint.UDPAddr(true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -667,7 +660,7 @@ func nodesAreEqual(a, b *Node) bool {
|
||||
}
|
||||
// Check the DNS name first since this package
|
||||
// is doing the DNS resolution.
|
||||
if a.Addr != b.Addr || a.Endpoint.String() != b.Endpoint.String() {
|
||||
if a.Endpoint.StringOpt(false) != b.Endpoint.StringOpt(false) {
|
||||
return false
|
||||
}
|
||||
// Ignore LastSeen when comparing equality we want to check if the nodes are
|
||||
@ -696,7 +689,7 @@ func peersAreEqual(a, b *Peer) bool {
|
||||
}
|
||||
// Check the DNS name first since this package
|
||||
// is doing the DNS resolution.
|
||||
if a.Addr != b.Addr || a.Endpoint.String() != b.Endpoint.String() {
|
||||
if a.Endpoint.StringOpt(false) != b.Endpoint.StringOpt(false) {
|
||||
return false
|
||||
}
|
||||
if len(a.AllowedIPs) != len(b.AllowedIPs) {
|
||||
|
@ -20,6 +20,8 @@ import (
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/squat/kilo/pkg/wireguard"
|
||||
)
|
||||
|
||||
func mustKey() wgtypes.Key {
|
||||
@ -62,7 +64,7 @@ func TestReady(t *testing.T) {
|
||||
{
|
||||
name: "empty endpoint IP",
|
||||
node: &Node{
|
||||
Endpoint: &net.UDPAddr{Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(nil, DefaultKiloPort),
|
||||
InternalIP: internalIP,
|
||||
Key: wgtypes.Key{},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
@ -72,7 +74,7 @@ func TestReady(t *testing.T) {
|
||||
{
|
||||
name: "empty endpoint port",
|
||||
node: &Node{
|
||||
Endpoint: &net.UDPAddr{IP: externalIP.IP},
|
||||
Endpoint: wireguard.NewEndpoint(externalIP.IP, 0),
|
||||
InternalIP: internalIP,
|
||||
Key: wgtypes.Key{},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
@ -82,7 +84,7 @@ func TestReady(t *testing.T) {
|
||||
{
|
||||
name: "empty internal IP",
|
||||
node: &Node{
|
||||
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
|
||||
Key: wgtypes.Key{},
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
@ -91,7 +93,7 @@ func TestReady(t *testing.T) {
|
||||
{
|
||||
name: "empty key",
|
||||
node: &Node{
|
||||
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
|
||||
InternalIP: internalIP,
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
},
|
||||
@ -100,7 +102,7 @@ func TestReady(t *testing.T) {
|
||||
{
|
||||
name: "empty subnet",
|
||||
node: &Node{
|
||||
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
|
||||
InternalIP: internalIP,
|
||||
Key: wgtypes.Key{},
|
||||
},
|
||||
@ -109,7 +111,7 @@ func TestReady(t *testing.T) {
|
||||
{
|
||||
name: "valid",
|
||||
node: &Node{
|
||||
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
|
||||
InternalIP: internalIP,
|
||||
Key: key,
|
||||
LastSeen: time.Now().Unix(),
|
||||
|
@ -40,7 +40,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
|
||||
var gw net.IP
|
||||
for _, segment := range t.segments {
|
||||
if segment.location == t.location {
|
||||
gw = enc.Gw(segment.endpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
|
||||
gw = enc.Gw(segment.endpoint.IP(), segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -196,7 +196,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
|
||||
// equals the external IP. This means that the node
|
||||
// is only accessible through an external IP and we
|
||||
// cannot encapsulate traffic to an IP through the IP.
|
||||
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP) {
|
||||
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP()) {
|
||||
continue
|
||||
}
|
||||
// Add routes to the private IPs of nodes in other segments.
|
||||
@ -248,7 +248,7 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
|
||||
rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
|
||||
rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
|
||||
if cni {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
|
||||
// Some linux distros or docker will set forward DROP in the filter table.
|
||||
// To still be able to have pod to pod communication we need to ALLOW packets from and to pod CIDRs within a location.
|
||||
// Leader nodes will forward packets from all nodes within a location because they act as a gateway for them.
|
||||
@ -258,30 +258,30 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
|
||||
if s.location == t.location {
|
||||
// Make sure packets to and from pod cidrs are not dropped in the forward chain.
|
||||
for _, c := range s.cidrs {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
|
||||
}
|
||||
// Make sure packets to and from allowed location IPs are not dropped in the forward chain.
|
||||
for _, c := range s.allowedLocationIPs {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
|
||||
}
|
||||
// Make sure packets to and from private IPs are not dropped in the forward chain.
|
||||
for _, c := range s.privateIPs {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if iptablesForwardRule {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
|
||||
}
|
||||
}
|
||||
for _, s := range t.segments {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(s.wireGuardIP)), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
|
||||
for _, aip := range s.allowedIPs {
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
|
||||
rules = append(rules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
|
||||
}
|
||||
// Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd,
|
||||
// Otherwise packets to these destinations will reach the destination, but never find their way back.
|
||||
@ -289,7 +289,7 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
|
||||
if t.location == s.location {
|
||||
for _, alip := range s.allowedLocationIPs {
|
||||
rules = append(rules,
|
||||
iptables.NewRule(iptables.GetProtocol(len(alip.IP)), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
|
||||
iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -297,8 +297,8 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
|
||||
for _, p := range t.peers {
|
||||
for _, aip := range p.AllowedIPs {
|
||||
rules = append(rules,
|
||||
iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
|
||||
iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
|
||||
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
|
||||
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -67,8 +67,7 @@ type Topology struct {
|
||||
|
||||
type segment struct {
|
||||
allowedIPs []net.IPNet
|
||||
addr string
|
||||
endpoint *net.UDPAddr
|
||||
endpoint *wireguard.Endpoint
|
||||
key wgtypes.Key
|
||||
persistentKeepalive time.Duration
|
||||
// Location is the logical location of this segment.
|
||||
@ -178,7 +177,6 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
|
||||
})
|
||||
t.segments = append(t.segments, &segment{
|
||||
allowedIPs: allowedIPs,
|
||||
addr: topoMap[location][leader].Addr,
|
||||
endpoint: topoMap[location][leader].Endpoint,
|
||||
key: topoMap[location][leader].Key,
|
||||
persistentKeepalive: topoMap[location][leader].PersistentKeepalive,
|
||||
@ -287,14 +285,14 @@ CheckIPs:
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Topology) updateEndpoint(endpoint *net.UDPAddr, key wgtypes.Key, persistentKeepalive *time.Duration) *net.UDPAddr {
|
||||
func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *wireguard.Endpoint {
|
||||
// Do not update non-nat peers
|
||||
if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) {
|
||||
return endpoint
|
||||
}
|
||||
e, ok := t.discoveredEndpoints[key.String()]
|
||||
if ok {
|
||||
return e
|
||||
return wireguard.NewEndpointFromUDPAddr(e)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -315,12 +313,11 @@ func (t *Topology) Conf() *wireguard.Conf {
|
||||
peer := wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
|
||||
Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive),
|
||||
PersistentKeepaliveInterval: &t.persistentKeepalive,
|
||||
PublicKey: s.key,
|
||||
ReplaceAllowedIPs: true,
|
||||
},
|
||||
Addr: s.addr,
|
||||
Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive),
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
@ -328,13 +325,12 @@ func (t *Topology) Conf() *wireguard.Conf {
|
||||
peer := wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: p.AllowedIPs,
|
||||
Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval),
|
||||
PersistentKeepaliveInterval: &t.persistentKeepalive,
|
||||
PresharedKey: p.PresharedKey,
|
||||
PublicKey: p.PublicKey,
|
||||
ReplaceAllowedIPs: true,
|
||||
},
|
||||
Addr: p.Addr,
|
||||
Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval),
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
@ -352,9 +348,8 @@ func (t *Topology) AsPeer() wireguard.Peer {
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: s.allowedIPs,
|
||||
PublicKey: s.key,
|
||||
Endpoint: s.endpoint,
|
||||
},
|
||||
Addr: s.addr,
|
||||
Endpoint: s.endpoint,
|
||||
}
|
||||
return p
|
||||
}
|
||||
@ -377,12 +372,11 @@ func (t *Topology) PeerConf(name string) wireguard.Conf {
|
||||
peer := wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
AllowedIPs: s.allowedIPs,
|
||||
Endpoint: s.endpoint,
|
||||
PersistentKeepaliveInterval: pka,
|
||||
PresharedKey: psk,
|
||||
PublicKey: s.key,
|
||||
},
|
||||
Addr: s.addr,
|
||||
Endpoint: s.endpoint,
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
@ -395,8 +389,8 @@ func (t *Topology) PeerConf(name string) wireguard.Conf {
|
||||
AllowedIPs: t.peers[i].AllowedIPs,
|
||||
PersistentKeepaliveInterval: pka,
|
||||
PublicKey: t.peers[i].PublicKey,
|
||||
Endpoint: t.peers[i].Endpoint,
|
||||
},
|
||||
Endpoint: t.peers[i].Endpoint,
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
@ -417,13 +411,13 @@ func findLeader(nodes []*Node) int {
|
||||
var leaders, public []int
|
||||
for i := range nodes {
|
||||
if nodes[i].Leader {
|
||||
if isPublic(nodes[i].Endpoint.IP) {
|
||||
if isPublic(nodes[i].Endpoint.IP()) {
|
||||
return i
|
||||
}
|
||||
leaders = append(leaders, i)
|
||||
|
||||
}
|
||||
if nodes[i].Endpoint != nil && isPublic(nodes[i].Endpoint.IP) {
|
||||
if nodes[i].Endpoint != nil && isPublic(nodes[i].Endpoint.IP()) {
|
||||
public = append(public, i)
|
||||
}
|
||||
}
|
||||
@ -444,12 +438,11 @@ func deduplicatePeerIPs(peers []*Peer) []*Peer {
|
||||
Name: peer.Name,
|
||||
Peer: wireguard.Peer{
|
||||
PeerConfig: wgtypes.PeerConfig{
|
||||
Endpoint: peer.Endpoint,
|
||||
PersistentKeepaliveInterval: peer.PersistentKeepaliveInterval,
|
||||
PresharedKey: peer.PresharedKey,
|
||||
PublicKey: peer.PublicKey,
|
||||
},
|
||||
Addr: peer.Addr,
|
||||
Endpoint: peer.Endpoint,
|
||||
},
|
||||
}
|
||||
for _, ip := range peer.AllowedIPs {
|
||||
|
@ -60,7 +60,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
|
||||
nodes := map[string]*Node{
|
||||
"a": {
|
||||
Name: "a",
|
||||
Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
|
||||
InternalIP: i1,
|
||||
Location: "1",
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
|
||||
@ -69,7 +69,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
|
||||
},
|
||||
"b": {
|
||||
Name: "b",
|
||||
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
|
||||
InternalIP: i1,
|
||||
Location: "2",
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)},
|
||||
@ -78,7 +78,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
|
||||
},
|
||||
"c": {
|
||||
Name: "c",
|
||||
Endpoint: &net.UDPAddr{IP: e3.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e3.IP, DefaultKiloPort),
|
||||
InternalIP: i2,
|
||||
// Same location as node b.
|
||||
Location: "2",
|
||||
@ -87,7 +87,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
|
||||
},
|
||||
"d": {
|
||||
Name: "d",
|
||||
Endpoint: &net.UDPAddr{IP: e4.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e4.IP, DefaultKiloPort),
|
||||
// Same location as node a, but without private IP
|
||||
Location: "1",
|
||||
Subnet: &net.IPNet{IP: net.ParseIP("10.2.4.0"), Mask: net.CIDRMask(24, 32)},
|
||||
@ -115,11 +115,8 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
|
||||
{IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)},
|
||||
},
|
||||
PublicKey: key5,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.1"),
|
||||
Port: DefaultKiloPort,
|
||||
},
|
||||
},
|
||||
Endpoint: wireguard.NewEndpoint(net.ParseIP("192.168.0.1"), DefaultKiloPort),
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -576,24 +573,24 @@ func TestFindLeader(t *testing.T) {
|
||||
nodes := []*Node{
|
||||
{
|
||||
Name: "a",
|
||||
Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
|
||||
},
|
||||
{
|
||||
Name: "b",
|
||||
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
|
||||
},
|
||||
{
|
||||
Name: "c",
|
||||
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
|
||||
},
|
||||
{
|
||||
Name: "d",
|
||||
Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
|
||||
Leader: true,
|
||||
},
|
||||
{
|
||||
Name: "2",
|
||||
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
|
||||
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
|
||||
Leader: true,
|
||||
},
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ package wireguard
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
@ -23,6 +24,7 @@ import (
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"k8s.io/apimachinery/pkg/util/validation"
|
||||
)
|
||||
|
||||
type section string
|
||||
@ -53,6 +55,10 @@ func (c Conf) WGConfig() wgtypes.Config {
|
||||
wgPs := make([]wgtypes.PeerConfig, len(c.Peers))
|
||||
for i, p := range c.Peers {
|
||||
wgPs[i] = p.PeerConfig
|
||||
if p.Endpoint.Resolved() {
|
||||
// We can ingore the error because we already checked if the Endpoint was resolved in the above line.
|
||||
wgPs[i].Endpoint, _ = p.Endpoint.UDPAddr(false)
|
||||
}
|
||||
wgPs[i].ReplaceAllowedIPs = true
|
||||
}
|
||||
r.Peers = wgPs
|
||||
@ -60,10 +66,169 @@ func (c Conf) WGConfig() wgtypes.Config {
|
||||
return r
|
||||
}
|
||||
|
||||
// Endpoint represents a WireGuard endpoint.
|
||||
type Endpoint struct {
|
||||
udpAddr *net.UDPAddr
|
||||
addr string
|
||||
}
|
||||
|
||||
// ParseEndpoint returns an Endpoint from a string.
|
||||
// The input should look like "10.0.0.0:100", "[ff10::10]:100"
|
||||
// or "example.com:100".
|
||||
func ParseEndpoint(endpoint string) *Endpoint {
|
||||
if len(endpoint) == 0 {
|
||||
return nil
|
||||
}
|
||||
hostRaw, portRaw, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
port, err := strconv.ParseUint(portRaw, 10, 32)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(validation.IsValidPortNum(int(port))) != 0 {
|
||||
return nil
|
||||
}
|
||||
ip := net.ParseIP(hostRaw)
|
||||
if ip == nil {
|
||||
if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 {
|
||||
return &Endpoint{
|
||||
addr: endpoint,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// ResolveUDPAddr will not resolve the endpoint as long as a valid IP and port is given.
|
||||
// This should be the case here.
|
||||
u, err := net.ResolveUDPAddr("udp", endpoint)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
u.IP = cutIP(u.IP)
|
||||
return &Endpoint{
|
||||
udpAddr: u,
|
||||
}
|
||||
}
|
||||
|
||||
// NewEndpointFromUDPAddr returns an Endpoint from a net.UDPAddr.
|
||||
func NewEndpointFromUDPAddr(u *net.UDPAddr) *Endpoint {
|
||||
if u != nil {
|
||||
u.IP = cutIP(u.IP)
|
||||
}
|
||||
return &Endpoint{
|
||||
udpAddr: u,
|
||||
}
|
||||
}
|
||||
|
||||
// NewEndpoint returns an Endpoint from a net.IP and port.
|
||||
func NewEndpoint(ip net.IP, port int) *Endpoint {
|
||||
return &Endpoint{
|
||||
udpAddr: &net.UDPAddr{
|
||||
IP: cutIP(ip),
|
||||
Port: port,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Ready return true, if the Enpoint is ready.
|
||||
// Ready means that an IP or DN and port exists.
|
||||
func (e *Endpoint) Ready() bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
return (e.udpAddr != nil && e.udpAddr.IP != nil && e.udpAddr.Port > 0) || len(e.addr) > 0
|
||||
}
|
||||
|
||||
// Port returns the port of the Endpoint.
|
||||
func (e *Endpoint) Port() int {
|
||||
if !e.Ready() {
|
||||
return 0
|
||||
}
|
||||
if e.udpAddr != nil {
|
||||
return e.udpAddr.Port
|
||||
}
|
||||
// We can ignore the errors here bacause the returned port will be "".
|
||||
// This will result to Port 0 after the conversion to and int.
|
||||
_, p, _ := net.SplitHostPort(e.addr)
|
||||
port, _ := strconv.ParseUint(p, 10, 32)
|
||||
return int(port)
|
||||
}
|
||||
|
||||
// HasDNS returns true if the endpoint has a DN.
|
||||
func (e *Endpoint) HasDNS() bool {
|
||||
return e != nil && e.addr != ""
|
||||
}
|
||||
|
||||
// DNS returns the DN of the Endpoint.
|
||||
func (e *Endpoint) DNS() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
_, s, _ := net.SplitHostPort(e.addr)
|
||||
return s
|
||||
}
|
||||
|
||||
// Resolved returns true, if the DN of the Endpoint was resolved
|
||||
// or if the Endpoint has a resolved endpoint.
|
||||
func (e *Endpoint) Resolved() bool {
|
||||
return e != nil && e.udpAddr != nil
|
||||
}
|
||||
|
||||
// UDPAddr returns the UDPAddr of the Endpoint. If resolve is false,
|
||||
// UDPAddr() will not try to resolve a DN name, if the Endpoint is not yet resolved.
|
||||
func (e *Endpoint) UDPAddr(resolve bool) (*net.UDPAddr, error) {
|
||||
if !e.Ready() {
|
||||
return nil, errors.New("Enpoint is not ready")
|
||||
}
|
||||
if e.udpAddr != nil {
|
||||
// Make a copy of the UDPAddr to protect it from modification outside this package.
|
||||
h := *e.udpAddr
|
||||
return &h, nil
|
||||
}
|
||||
if !resolve {
|
||||
return nil, errors.New("Endpoint is not resolved")
|
||||
}
|
||||
var err error
|
||||
if e.udpAddr, err = net.ResolveUDPAddr("udp", e.addr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Make a copy of the UDPAddr to protect it from modification outside this package.
|
||||
h := *e.udpAddr
|
||||
return &h, nil
|
||||
}
|
||||
|
||||
// IP returns the IP address of the Enpoint or nil.
|
||||
func (e *Endpoint) IP() net.IP {
|
||||
if !e.Resolved() {
|
||||
return nil
|
||||
}
|
||||
return e.udpAddr.IP
|
||||
}
|
||||
|
||||
// String will return the endpoint as a string.
|
||||
// If a DN exists, it will take prcedence over the resolved endpoint.
|
||||
func (e *Endpoint) String() string {
|
||||
return e.StringOpt(true)
|
||||
}
|
||||
|
||||
// StringOpt will return string of the Endpoint.
|
||||
// If dnsFirst is false, the resolved Endpoint will
|
||||
// take precedence over the DN.
|
||||
func (e *Endpoint) StringOpt(dnsFirst bool) string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.udpAddr != nil && (!dnsFirst || e.addr == "") {
|
||||
return e.udpAddr.String()
|
||||
}
|
||||
return e.addr
|
||||
}
|
||||
|
||||
// Peer represents a `peer` section of a WireGuard configuration.
|
||||
type Peer struct {
|
||||
wgtypes.PeerConfig
|
||||
Addr string // eg: dnsname:port
|
||||
Endpoint *Endpoint
|
||||
}
|
||||
|
||||
// DeduplicateIPs eliminates duplicate allowed IPs.
|
||||
@ -109,7 +274,7 @@ func (c Conf) Bytes() ([]byte, error) {
|
||||
if err = writeAllowedIPs(buf, p.AllowedIPs); err != nil {
|
||||
return nil, fmt.Errorf("failed to write allowed IPs: %v", err)
|
||||
}
|
||||
if err = writeEndpoint(buf, p.Endpoint, p.Addr); err != nil {
|
||||
if err = writeEndpoint(buf, p.Endpoint); err != nil {
|
||||
return nil, fmt.Errorf("failed to write endpoint: %v", err)
|
||||
}
|
||||
if p.PersistentKeepaliveInterval == nil {
|
||||
@ -158,8 +323,8 @@ func (c *Conf) Equal(d *wgtypes.Device) (bool, string) {
|
||||
if c.Peers[i].Endpoint == nil || d.Peers[i].Endpoint == nil {
|
||||
return c.Peers[i].Endpoint == nil && d.Peers[i].Endpoint == nil, "peer endpoints: nil value"
|
||||
}
|
||||
if !c.Peers[i].Endpoint.IP.Equal(d.Peers[i].Endpoint.IP) || c.Peers[i].Endpoint.Port != d.Peers[i].Endpoint.Port {
|
||||
return false, fmt.Sprintf("Peer %d endpoint: old=%q, new=%q", i, d.Peers[i].Endpoint.String(), c.Peers[i].Endpoint.String())
|
||||
if c.Peers[i].Endpoint.StringOpt(false) != d.Peers[i].Endpoint.String() {
|
||||
return false, fmt.Sprintf("Peer %d endpoint: old=%q, new=%q", i, d.Peers[i].Endpoint.String(), c.Peers[i].Endpoint.StringOpt(false))
|
||||
}
|
||||
|
||||
pki := time.Duration(0)
|
||||
@ -201,6 +366,13 @@ func sortCIDRs(cidrs []net.IPNet) {
|
||||
})
|
||||
}
|
||||
|
||||
func cutIP(ip net.IP) net.IP {
|
||||
if i4 := ip.To4(); i4 != nil {
|
||||
return i4
|
||||
}
|
||||
return ip.To16()
|
||||
}
|
||||
|
||||
func writeAllowedIPs(buf *bytes.Buffer, ais []net.IPNet) error {
|
||||
if len(ais) == 0 {
|
||||
return nil
|
||||
@ -248,13 +420,9 @@ func writeValue(buf *bytes.Buffer, k key, v string) error {
|
||||
return buf.WriteByte('\n')
|
||||
}
|
||||
|
||||
func writeEndpoint(buf *bytes.Buffer, e *net.UDPAddr, d string) error {
|
||||
str := ""
|
||||
if d != "" {
|
||||
str = d
|
||||
} else if e != nil {
|
||||
str = e.String()
|
||||
} else {
|
||||
func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error {
|
||||
str := e.String()
|
||||
if str == "" {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
|
Loading…
Reference in New Issue
Block a user