Nat to nat (#146)

* wireguard: export an Endpoint comparison method

* Record discovered endpoints in node

* Synchronize DiscoveredEndpoints in k8s backend

* Add discoveredEndpointsAreEqual

* Handle discovered Endpoints in topology to enable NAT 2 NAT

* Refactor to use Endpoint.Equal

Compare IP first by default and compare DNS name first when we know the Endpoint was resolved.

* Drop the shallow copies of nodes and peers

Now that updateNATEndpoints was updated to discoverNATEndpoints and that
the endpoints are overridden by topology instead of mutating the nodes and
peers object, we can safely drop this copy.
This commit is contained in:
Julien Viard de Galbert
2021-04-21 19:47:29 +02:00
committed by GitHub
parent 863628ffaa
commit 2ac000c68a
7 changed files with 475 additions and 270 deletions

View File

@@ -389,6 +389,7 @@ func (m *Mesh) handleLocal(n *Node) {
PersistentKeepalive: n.PersistentKeepalive,
Subnet: n.Subnet,
WireGuardIP: m.wireGuardIP,
DiscoveredEndpoints: n.DiscoveredEndpoints,
}
if !nodesAreEqual(n, local) {
level.Debug(m.logger).Log("msg", "local node differs from backend")
@@ -431,9 +432,8 @@ func (m *Mesh) applyTopology() {
if !m.nodes[k].Ready() {
continue
}
// Make a shallow copy of the node.
node := *m.nodes[k]
nodes[k] = &node
// Make it point to the node without copy.
nodes[k] = m.nodes[k]
readyNodes++
}
// Ensure only ready nodes are considered.
@@ -443,9 +443,8 @@ func (m *Mesh) applyTopology() {
if !m.peers[k].Ready() {
continue
}
// Make a shallow copy of the peer.
peer := *m.peers[k]
peers[k] = &peer
// Make it point the peer without copy.
peers[k] = m.peers[k]
readyPeers++
}
m.nodesGuage.Set(readyNodes)
@@ -469,7 +468,8 @@ func (m *Mesh) applyTopology() {
return
}
oldConf := wireguard.Parse(oldConfRaw)
updateNATEndpoints(nodes, peers, oldConf)
natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
nodes[m.hostname].DiscoveredEndpoints = natEndpoints
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive)
if err != nil {
level.Error(m.logger).Log("error", err)
@@ -676,26 +676,15 @@ func nodesAreEqual(a, b *Node) bool {
if a == b {
return true
}
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
// Check the DNS name first since this package
// is doing the DNS resolution.
if !a.Endpoint.Equal(b.Endpoint, true) {
return false
}
if a.Endpoint != nil {
if a.Endpoint.Port != b.Endpoint.Port {
return false
}
// Check the DNS name first since this package
// is doing the DNS resolution.
if a.Endpoint.DNS != b.Endpoint.DNS {
return false
}
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
return false
}
}
// Ignore LastSeen when comparing equality we want to check if the nodes are
// equivalent. However, we do want to check if LastSeen has transitioned
// between valid and invalid.
return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive
return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive && discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints)
}
func peersAreEqual(a, b *Peer) bool {
@@ -705,22 +694,11 @@ func peersAreEqual(a, b *Peer) bool {
if a == b {
return true
}
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
// Check the DNS name first since this package
// is doing the DNS resolution.
if !a.Endpoint.Equal(b.Endpoint, true) {
return false
}
if a.Endpoint != nil {
if a.Endpoint.Port != b.Endpoint.Port {
return false
}
// Check the DNS name first since this package
// is doing the DNS resolution.
if a.Endpoint.DNS != b.Endpoint.DNS {
return false
}
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
return false
}
}
if len(a.AllowedIPs) != len(b.AllowedIPs) {
return false
}
@@ -764,6 +742,24 @@ func subnetsEqual(a, b *net.IPNet) bool {
return true
}
func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
if a == nil && b == nil {
return true
}
if (a != nil) != (b != nil) {
return false
}
if len(a) != len(b) {
return false
}
for k := range a {
if !a[k].Equal(b[k], false) {
return false
}
}
return true
}
func linkByIndex(index int) (netlink.Link, error) {
link, err := netlink.LinkByIndex(index)
if err != nil {
@@ -772,21 +768,30 @@ func linkByIndex(index int) (netlink.Link, error) {
return link, nil
}
// updateNATEndpoints ensures that nodes and peers behind NAT update
// their endpoints from the WireGuard configuration so they can roam.
func updateNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf) {
// discoverNATEndpoints uses the node's WireGuard configuration to returns a list of the most recently discovered endpoints for all nodes and peers behind NAT so that they can roam.
func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf, logger log.Logger) map[string]*wireguard.Endpoint {
natEndpoints := make(map[string]*wireguard.Endpoint)
keys := make(map[string]*wireguard.Peer)
for i := range conf.Peers {
keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
}
for _, n := range nodes {
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
n.Endpoint = peer.Endpoint
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false))
// Should check location leader but only available in topology ... or have topology handle that list
// Better check wg latest-handshake
if !n.Endpoint.Equal(peer.Endpoint, false) {
natEndpoints[string(n.Key)] = peer.Endpoint
}
}
}
for _, p := range peers {
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
p.Endpoint = peer.Endpoint
if !p.Endpoint.Equal(peer.Endpoint, false) {
natEndpoints[string(p.PublicKey)] = peer.Endpoint
}
}
}
level.Debug(logger).Log("msg", "Discovered WireGuard NAT Endpoints", "DiscoveredEndpoints", natEndpoints)
return natEndpoints
}