Nat to nat (#146)
* wireguard: export an Endpoint comparison method * Record discovered endpoints in node * Synchronize DiscoveredEndpoints in k8s backend * Add discoveredEndpointsAreEqual * Handle discovered Endpoints in topology to enable NAT 2 NAT * Refactor to use Endpoint.Equal Compare IP first by default and compare DNS name first when we know the Endpoint was resolved. * Drop the shallow copies of nodes and peers Now that updateNATEndpoints was updated to discoverNATEndpoints and that the endpoints are overridden by topology instead of mutating the nodes and peers object, we can safely drop this copy.
This commit is contained in:
committed by
GitHub
parent
863628ffaa
commit
2ac000c68a
@@ -389,6 +389,7 @@ func (m *Mesh) handleLocal(n *Node) {
|
||||
PersistentKeepalive: n.PersistentKeepalive,
|
||||
Subnet: n.Subnet,
|
||||
WireGuardIP: m.wireGuardIP,
|
||||
DiscoveredEndpoints: n.DiscoveredEndpoints,
|
||||
}
|
||||
if !nodesAreEqual(n, local) {
|
||||
level.Debug(m.logger).Log("msg", "local node differs from backend")
|
||||
@@ -431,9 +432,8 @@ func (m *Mesh) applyTopology() {
|
||||
if !m.nodes[k].Ready() {
|
||||
continue
|
||||
}
|
||||
// Make a shallow copy of the node.
|
||||
node := *m.nodes[k]
|
||||
nodes[k] = &node
|
||||
// Make it point to the node without copy.
|
||||
nodes[k] = m.nodes[k]
|
||||
readyNodes++
|
||||
}
|
||||
// Ensure only ready nodes are considered.
|
||||
@@ -443,9 +443,8 @@ func (m *Mesh) applyTopology() {
|
||||
if !m.peers[k].Ready() {
|
||||
continue
|
||||
}
|
||||
// Make a shallow copy of the peer.
|
||||
peer := *m.peers[k]
|
||||
peers[k] = &peer
|
||||
// Make it point the peer without copy.
|
||||
peers[k] = m.peers[k]
|
||||
readyPeers++
|
||||
}
|
||||
m.nodesGuage.Set(readyNodes)
|
||||
@@ -469,7 +468,8 @@ func (m *Mesh) applyTopology() {
|
||||
return
|
||||
}
|
||||
oldConf := wireguard.Parse(oldConfRaw)
|
||||
updateNATEndpoints(nodes, peers, oldConf)
|
||||
natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
|
||||
nodes[m.hostname].DiscoveredEndpoints = natEndpoints
|
||||
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
@@ -676,26 +676,15 @@ func nodesAreEqual(a, b *Node) bool {
|
||||
if a == b {
|
||||
return true
|
||||
}
|
||||
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
|
||||
// Check the DNS name first since this package
|
||||
// is doing the DNS resolution.
|
||||
if !a.Endpoint.Equal(b.Endpoint, true) {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint != nil {
|
||||
if a.Endpoint.Port != b.Endpoint.Port {
|
||||
return false
|
||||
}
|
||||
// Check the DNS name first since this package
|
||||
// is doing the DNS resolution.
|
||||
if a.Endpoint.DNS != b.Endpoint.DNS {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
// Ignore LastSeen when comparing equality we want to check if the nodes are
|
||||
// equivalent. However, we do want to check if LastSeen has transitioned
|
||||
// between valid and invalid.
|
||||
return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive
|
||||
return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive && discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints)
|
||||
}
|
||||
|
||||
func peersAreEqual(a, b *Peer) bool {
|
||||
@@ -705,22 +694,11 @@ func peersAreEqual(a, b *Peer) bool {
|
||||
if a == b {
|
||||
return true
|
||||
}
|
||||
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
|
||||
// Check the DNS name first since this package
|
||||
// is doing the DNS resolution.
|
||||
if !a.Endpoint.Equal(b.Endpoint, true) {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint != nil {
|
||||
if a.Endpoint.Port != b.Endpoint.Port {
|
||||
return false
|
||||
}
|
||||
// Check the DNS name first since this package
|
||||
// is doing the DNS resolution.
|
||||
if a.Endpoint.DNS != b.Endpoint.DNS {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(a.AllowedIPs) != len(b.AllowedIPs) {
|
||||
return false
|
||||
}
|
||||
@@ -764,6 +742,24 @@ func subnetsEqual(a, b *net.IPNet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if (a != nil) != (b != nil) {
|
||||
return false
|
||||
}
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for k := range a {
|
||||
if !a[k].Equal(b[k], false) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func linkByIndex(index int) (netlink.Link, error) {
|
||||
link, err := netlink.LinkByIndex(index)
|
||||
if err != nil {
|
||||
@@ -772,21 +768,30 @@ func linkByIndex(index int) (netlink.Link, error) {
|
||||
return link, nil
|
||||
}
|
||||
|
||||
// updateNATEndpoints ensures that nodes and peers behind NAT update
|
||||
// their endpoints from the WireGuard configuration so they can roam.
|
||||
func updateNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf) {
|
||||
// discoverNATEndpoints uses the node's WireGuard configuration to returns a list of the most recently discovered endpoints for all nodes and peers behind NAT so that they can roam.
|
||||
func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf, logger log.Logger) map[string]*wireguard.Endpoint {
|
||||
natEndpoints := make(map[string]*wireguard.Endpoint)
|
||||
keys := make(map[string]*wireguard.Peer)
|
||||
for i := range conf.Peers {
|
||||
keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
|
||||
}
|
||||
for _, n := range nodes {
|
||||
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
|
||||
n.Endpoint = peer.Endpoint
|
||||
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false))
|
||||
// Should check location leader but only available in topology ... or have topology handle that list
|
||||
// Better check wg latest-handshake
|
||||
if !n.Endpoint.Equal(peer.Endpoint, false) {
|
||||
natEndpoints[string(n.Key)] = peer.Endpoint
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, p := range peers {
|
||||
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
|
||||
p.Endpoint = peer.Endpoint
|
||||
if !p.Endpoint.Equal(peer.Endpoint, false) {
|
||||
natEndpoints[string(p.PublicKey)] = peer.Endpoint
|
||||
}
|
||||
}
|
||||
}
|
||||
level.Debug(logger).Log("msg", "Discovered WireGuard NAT Endpoints", "DiscoveredEndpoints", natEndpoints)
|
||||
return natEndpoints
|
||||
}
|
||||
|
Reference in New Issue
Block a user