pkg/mesh,pkg/wireguard: sync NAT endpoints

This commit changes how Kilo allows nodes and peers behind NAT to roam.
Rather that ignore changes to endpoints when comparing WireGuard
configurations, Kilo now incorporates changes to endpoints for peers
behind NAT into its configuration first and later compares the
configurations.

Signed-off-by: Lucas Servén Marín <lserven@gmail.com>
This commit is contained in:
Lucas Servén Marín 2020-03-04 00:39:54 +01:00
parent 24d7c27901
commit 29280a987e
No known key found for this signature in database
GPG Key ID: 586FEAF680DA74AD
2 changed files with 64 additions and 89 deletions

View File

@ -529,7 +529,9 @@ func (m *Mesh) applyTopology() {
if !m.nodes[k].Ready() {
continue
}
nodes[k] = m.nodes[k]
// Make a shallow copy of the node.
node := *m.nodes[k]
nodes[k] = &node
readyNodes++
}
// Ensure only ready nodes are considered.
@ -539,7 +541,9 @@ func (m *Mesh) applyTopology() {
if !m.peers[k].Ready() {
continue
}
peers[k] = m.peers[k]
// Make a shallow copy of the peer.
peer := *m.peers[k]
peers[k] = &peer
readyPeers++
}
m.nodesGuage.Set(readyNodes)
@ -548,6 +552,22 @@ func (m *Mesh) applyTopology() {
if nodes[m.hostname] == nil {
return
}
// Find the Kilo interface name.
link, err := linkByIndex(m.kiloIface)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
// Find the old configuration.
oldConfRaw, err := wireguard.ShowConf(link.Attrs().Name)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
oldConf := wireguard.Parse(oldConfRaw)
updateNATEndpoints(nodes, peers, oldConf)
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive)
if err != nil {
level.Error(m.logger).Log("error", err)
@ -582,7 +602,6 @@ func (m *Mesh) applyTopology() {
}
}
ipRules = append(ipRules, m.enc.Rules(cidrs)...)
// If we are handling local routes, ensure the local
// tunnel has an IP address.
if err := m.enc.Set(oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil {
@ -596,28 +615,15 @@ func (m *Mesh) applyTopology() {
m.errorCounter.WithLabelValues("apply").Inc()
return
}
// Find the Kilo interface name.
link, err := linkByIndex(m.kiloIface)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
if t.leader {
if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
oldConf, err := wireguard.ShowConf(link.Attrs().Name)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
// Setting the WireGuard configuration interrupts existing connections
// so only set the configuration if it has changed.
equal := conf.EqualWithPeerCheck(wireguard.Parse(oldConf), peersAreEqualIgnoreNAT)
equal := conf.Equal(oldConf)
if !equal {
level.Info(m.logger).Log("msg", "WireGuard configurations are different")
if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil {
@ -814,41 +820,6 @@ func peersAreEqual(a, b *Peer) bool {
return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive
}
// Basic nil checks and checking the lengths of the allowed IPs is
// done by the WireGuard package.
func peersAreEqualIgnoreNAT(a, b *wireguard.Peer) bool {
for j := range a.AllowedIPs {
if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() {
return false
}
}
if a.PersistentKeepalive != b.PersistentKeepalive || !bytes.Equal(a.PublicKey, b.PublicKey) {
return false
}
// If a persistent keepalive is set, then the peer is behind NAT
// and we want to ignore changes in endpoints, since it may roam.
if a.PersistentKeepalive != 0 {
return true
}
if (a.Endpoint == nil) != (b.Endpoint == nil) {
return false
}
if a.Endpoint != nil {
if a.Endpoint.Port != b.Endpoint.Port {
return false
}
// IPs take priority, so check them first.
if !a.Endpoint.IP.Equal(b.Endpoint.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS {
return false
}
}
return true
}
func ipNetsEqual(a, b *net.IPNet) bool {
if a == nil && b == nil {
return true
@ -888,3 +859,22 @@ func linkByIndex(index int) (netlink.Link, error) {
}
return link, nil
}
// updateNATEndpoints ensures that nodes and peers behind NAT update
// their endpoints from the WireGuard configuration so they can roam.
func updateNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf) {
keys := make(map[string]*wireguard.Peer)
for i := range conf.Peers {
keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
}
for _, n := range nodes {
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
n.Endpoint = peer.Endpoint
}
}
for _, p := range peers {
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
p.Endpoint = peer.Endpoint
}
}
}

View File

@ -275,12 +275,6 @@ func (c *Conf) Bytes() ([]byte, error) {
// Equal checks if two WireGuard configurations are equivalent.
func (c *Conf) Equal(b *Conf) bool {
return c.EqualWithPeerCheck(b, strictPeerCheck)
}
// EqualWithPeerCheck checks if two WireGuard configurations are equivalent
// when their peers are compared using the given peer comparison func.
func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool {
if (c.Interface == nil) != (b.Interface == nil) {
return false
}
@ -294,47 +288,38 @@ func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool {
}
sortPeers(c.Peers)
sortPeers(b.Peers)
var ok bool
for i := range c.Peers {
if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) {
return false
}
sortCIDRs(c.Peers[i].AllowedIPs)
sortCIDRs(b.Peers[i].AllowedIPs)
if ok = pc(c.Peers[i], b.Peers[i]); !ok {
for j := range c.Peers[i].AllowedIPs {
if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() {
return false
}
}
if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) {
return false
}
if c.Peers[i].Endpoint != nil {
if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port {
return false
}
// IPs take priority, so check them first.
if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS {
return false
}
}
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
return false
}
}
return true
}
// PeerCheck is a function that compares two peers.
type PeerCheck func(a, b *Peer) bool
func strictPeerCheck(a, b *Peer) bool {
for j := range a.AllowedIPs {
if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() {
return false
}
}
if (a.Endpoint == nil) != (b.Endpoint == nil) {
return false
}
if a.Endpoint != nil {
if a.Endpoint.Port != b.Endpoint.Port {
return false
}
// IPs take priority, so check them first.
if !a.Endpoint.IP.Equal(b.Endpoint.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS {
return false
}
}
return a.PersistentKeepalive == b.PersistentKeepalive && bytes.Equal(a.PublicKey, b.PublicKey)
}
func sortPeers(peers []*Peer) {