From 29280a987e1409257831d35919acaac2f2cfe059 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20Serv=C3=A9n=20Mar=C3=ADn?= Date: Wed, 4 Mar 2020 00:39:54 +0100 Subject: [PATCH] pkg/mesh,pkg/wireguard: sync NAT endpoints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit changes how Kilo allows nodes and peers behind NAT to roam. Rather that ignore changes to endpoints when comparing WireGuard configurations, Kilo now incorporates changes to endpoints for peers behind NAT into its configuration first and later compares the configurations. Signed-off-by: Lucas Servén Marín --- pkg/mesh/mesh.go | 94 +++++++++++++++++++------------------------ pkg/wireguard/conf.go | 59 ++++++++++----------------- 2 files changed, 64 insertions(+), 89 deletions(-) diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index 76bc049..cec9ec0 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -529,7 +529,9 @@ func (m *Mesh) applyTopology() { if !m.nodes[k].Ready() { continue } - nodes[k] = m.nodes[k] + // Make a shallow copy of the node. + node := *m.nodes[k] + nodes[k] = &node readyNodes++ } // Ensure only ready nodes are considered. @@ -539,7 +541,9 @@ func (m *Mesh) applyTopology() { if !m.peers[k].Ready() { continue } - peers[k] = m.peers[k] + // Make a shallow copy of the peer. + peer := *m.peers[k] + peers[k] = &peer readyPeers++ } m.nodesGuage.Set(readyNodes) @@ -548,6 +552,22 @@ func (m *Mesh) applyTopology() { if nodes[m.hostname] == nil { return } + // Find the Kilo interface name. + link, err := linkByIndex(m.kiloIface) + if err != nil { + level.Error(m.logger).Log("error", err) + m.errorCounter.WithLabelValues("apply").Inc() + return + } + // Find the old configuration. + oldConfRaw, err := wireguard.ShowConf(link.Attrs().Name) + if err != nil { + level.Error(m.logger).Log("error", err) + m.errorCounter.WithLabelValues("apply").Inc() + return + } + oldConf := wireguard.Parse(oldConfRaw) + updateNATEndpoints(nodes, peers, oldConf) t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive) if err != nil { level.Error(m.logger).Log("error", err) @@ -582,7 +602,6 @@ func (m *Mesh) applyTopology() { } } ipRules = append(ipRules, m.enc.Rules(cidrs)...) - // If we are handling local routes, ensure the local // tunnel has an IP address. if err := m.enc.Set(oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil { @@ -596,28 +615,15 @@ func (m *Mesh) applyTopology() { m.errorCounter.WithLabelValues("apply").Inc() return } - // Find the Kilo interface name. - link, err := linkByIndex(m.kiloIface) - if err != nil { - level.Error(m.logger).Log("error", err) - m.errorCounter.WithLabelValues("apply").Inc() - return - } if t.leader { if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil { level.Error(m.logger).Log("error", err) m.errorCounter.WithLabelValues("apply").Inc() return } - oldConf, err := wireguard.ShowConf(link.Attrs().Name) - if err != nil { - level.Error(m.logger).Log("error", err) - m.errorCounter.WithLabelValues("apply").Inc() - return - } // Setting the WireGuard configuration interrupts existing connections // so only set the configuration if it has changed. - equal := conf.EqualWithPeerCheck(wireguard.Parse(oldConf), peersAreEqualIgnoreNAT) + equal := conf.Equal(oldConf) if !equal { level.Info(m.logger).Log("msg", "WireGuard configurations are different") if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil { @@ -814,41 +820,6 @@ func peersAreEqual(a, b *Peer) bool { return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive } -// Basic nil checks and checking the lengths of the allowed IPs is -// done by the WireGuard package. -func peersAreEqualIgnoreNAT(a, b *wireguard.Peer) bool { - for j := range a.AllowedIPs { - if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() { - return false - } - } - if a.PersistentKeepalive != b.PersistentKeepalive || !bytes.Equal(a.PublicKey, b.PublicKey) { - return false - } - // If a persistent keepalive is set, then the peer is behind NAT - // and we want to ignore changes in endpoints, since it may roam. - if a.PersistentKeepalive != 0 { - return true - } - if (a.Endpoint == nil) != (b.Endpoint == nil) { - return false - } - if a.Endpoint != nil { - if a.Endpoint.Port != b.Endpoint.Port { - return false - } - // IPs take priority, so check them first. - if !a.Endpoint.IP.Equal(b.Endpoint.IP) { - return false - } - // Only check the DNS name if the IP is empty. - if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS { - return false - } - } - return true -} - func ipNetsEqual(a, b *net.IPNet) bool { if a == nil && b == nil { return true @@ -888,3 +859,22 @@ func linkByIndex(index int) (netlink.Link, error) { } return link, nil } + +// updateNATEndpoints ensures that nodes and peers behind NAT update +// their endpoints from the WireGuard configuration so they can roam. +func updateNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf) { + keys := make(map[string]*wireguard.Peer) + for i := range conf.Peers { + keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i] + } + for _, n := range nodes { + if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 { + n.Endpoint = peer.Endpoint + } + } + for _, p := range peers { + if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 { + p.Endpoint = peer.Endpoint + } + } +} diff --git a/pkg/wireguard/conf.go b/pkg/wireguard/conf.go index c1fd0fd..04e2ddd 100644 --- a/pkg/wireguard/conf.go +++ b/pkg/wireguard/conf.go @@ -275,12 +275,6 @@ func (c *Conf) Bytes() ([]byte, error) { // Equal checks if two WireGuard configurations are equivalent. func (c *Conf) Equal(b *Conf) bool { - return c.EqualWithPeerCheck(b, strictPeerCheck) -} - -// EqualWithPeerCheck checks if two WireGuard configurations are equivalent -// when their peers are compared using the given peer comparison func. -func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool { if (c.Interface == nil) != (b.Interface == nil) { return false } @@ -294,47 +288,38 @@ func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool { } sortPeers(c.Peers) sortPeers(b.Peers) - var ok bool for i := range c.Peers { if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) { return false } sortCIDRs(c.Peers[i].AllowedIPs) sortCIDRs(b.Peers[i].AllowedIPs) - if ok = pc(c.Peers[i], b.Peers[i]); !ok { + for j := range c.Peers[i].AllowedIPs { + if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() { + return false + } + } + if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) { + return false + } + if c.Peers[i].Endpoint != nil { + if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port { + return false + } + // IPs take priority, so check them first. + if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) { + return false + } + // Only check the DNS name if the IP is empty. + if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS { + return false + } + } + if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) { return false } } return true - -} - -// PeerCheck is a function that compares two peers. -type PeerCheck func(a, b *Peer) bool - -func strictPeerCheck(a, b *Peer) bool { - for j := range a.AllowedIPs { - if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() { - return false - } - } - if (a.Endpoint == nil) != (b.Endpoint == nil) { - return false - } - if a.Endpoint != nil { - if a.Endpoint.Port != b.Endpoint.Port { - return false - } - // IPs take priority, so check them first. - if !a.Endpoint.IP.Equal(b.Endpoint.IP) { - return false - } - // Only check the DNS name if the IP is empty. - if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS { - return false - } - } - return a.PersistentKeepalive == b.PersistentKeepalive && bytes.Equal(a.PublicKey, b.PublicKey) } func sortPeers(peers []*Peer) {