pkg/mesh,pkg/wireguard: sync NAT endpoints
This commit changes how Kilo allows nodes and peers behind NAT to roam. Rather that ignore changes to endpoints when comparing WireGuard configurations, Kilo now incorporates changes to endpoints for peers behind NAT into its configuration first and later compares the configurations. Signed-off-by: Lucas Servén Marín <lserven@gmail.com>
This commit is contained in:
parent
24d7c27901
commit
29280a987e
@ -529,7 +529,9 @@ func (m *Mesh) applyTopology() {
|
||||
if !m.nodes[k].Ready() {
|
||||
continue
|
||||
}
|
||||
nodes[k] = m.nodes[k]
|
||||
// Make a shallow copy of the node.
|
||||
node := *m.nodes[k]
|
||||
nodes[k] = &node
|
||||
readyNodes++
|
||||
}
|
||||
// Ensure only ready nodes are considered.
|
||||
@ -539,7 +541,9 @@ func (m *Mesh) applyTopology() {
|
||||
if !m.peers[k].Ready() {
|
||||
continue
|
||||
}
|
||||
peers[k] = m.peers[k]
|
||||
// Make a shallow copy of the peer.
|
||||
peer := *m.peers[k]
|
||||
peers[k] = &peer
|
||||
readyPeers++
|
||||
}
|
||||
m.nodesGuage.Set(readyNodes)
|
||||
@ -548,6 +552,22 @@ func (m *Mesh) applyTopology() {
|
||||
if nodes[m.hostname] == nil {
|
||||
return
|
||||
}
|
||||
// Find the Kilo interface name.
|
||||
link, err := linkByIndex(m.kiloIface)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
// Find the old configuration.
|
||||
oldConfRaw, err := wireguard.ShowConf(link.Attrs().Name)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
oldConf := wireguard.Parse(oldConfRaw)
|
||||
updateNATEndpoints(nodes, peers, oldConf)
|
||||
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
@ -582,7 +602,6 @@ func (m *Mesh) applyTopology() {
|
||||
}
|
||||
}
|
||||
ipRules = append(ipRules, m.enc.Rules(cidrs)...)
|
||||
|
||||
// If we are handling local routes, ensure the local
|
||||
// tunnel has an IP address.
|
||||
if err := m.enc.Set(oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil {
|
||||
@ -596,28 +615,15 @@ func (m *Mesh) applyTopology() {
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
// Find the Kilo interface name.
|
||||
link, err := linkByIndex(m.kiloIface)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
if t.leader {
|
||||
if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
oldConf, err := wireguard.ShowConf(link.Attrs().Name)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
// Setting the WireGuard configuration interrupts existing connections
|
||||
// so only set the configuration if it has changed.
|
||||
equal := conf.EqualWithPeerCheck(wireguard.Parse(oldConf), peersAreEqualIgnoreNAT)
|
||||
equal := conf.Equal(oldConf)
|
||||
if !equal {
|
||||
level.Info(m.logger).Log("msg", "WireGuard configurations are different")
|
||||
if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil {
|
||||
@ -814,41 +820,6 @@ func peersAreEqual(a, b *Peer) bool {
|
||||
return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive
|
||||
}
|
||||
|
||||
// Basic nil checks and checking the lengths of the allowed IPs is
|
||||
// done by the WireGuard package.
|
||||
func peersAreEqualIgnoreNAT(a, b *wireguard.Peer) bool {
|
||||
for j := range a.AllowedIPs {
|
||||
if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if a.PersistentKeepalive != b.PersistentKeepalive || !bytes.Equal(a.PublicKey, b.PublicKey) {
|
||||
return false
|
||||
}
|
||||
// If a persistent keepalive is set, then the peer is behind NAT
|
||||
// and we want to ignore changes in endpoints, since it may roam.
|
||||
if a.PersistentKeepalive != 0 {
|
||||
return true
|
||||
}
|
||||
if (a.Endpoint == nil) != (b.Endpoint == nil) {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint != nil {
|
||||
if a.Endpoint.Port != b.Endpoint.Port {
|
||||
return false
|
||||
}
|
||||
// IPs take priority, so check them first.
|
||||
if !a.Endpoint.IP.Equal(b.Endpoint.IP) {
|
||||
return false
|
||||
}
|
||||
// Only check the DNS name if the IP is empty.
|
||||
if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ipNetsEqual(a, b *net.IPNet) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
@ -888,3 +859,22 @@ func linkByIndex(index int) (netlink.Link, error) {
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
|
||||
// updateNATEndpoints ensures that nodes and peers behind NAT update
|
||||
// their endpoints from the WireGuard configuration so they can roam.
|
||||
func updateNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf) {
|
||||
keys := make(map[string]*wireguard.Peer)
|
||||
for i := range conf.Peers {
|
||||
keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
|
||||
}
|
||||
for _, n := range nodes {
|
||||
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
|
||||
n.Endpoint = peer.Endpoint
|
||||
}
|
||||
}
|
||||
for _, p := range peers {
|
||||
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
|
||||
p.Endpoint = peer.Endpoint
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -275,12 +275,6 @@ func (c *Conf) Bytes() ([]byte, error) {
|
||||
|
||||
// Equal checks if two WireGuard configurations are equivalent.
|
||||
func (c *Conf) Equal(b *Conf) bool {
|
||||
return c.EqualWithPeerCheck(b, strictPeerCheck)
|
||||
}
|
||||
|
||||
// EqualWithPeerCheck checks if two WireGuard configurations are equivalent
|
||||
// when their peers are compared using the given peer comparison func.
|
||||
func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool {
|
||||
if (c.Interface == nil) != (b.Interface == nil) {
|
||||
return false
|
||||
}
|
||||
@ -294,47 +288,38 @@ func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool {
|
||||
}
|
||||
sortPeers(c.Peers)
|
||||
sortPeers(b.Peers)
|
||||
var ok bool
|
||||
for i := range c.Peers {
|
||||
if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) {
|
||||
return false
|
||||
}
|
||||
sortCIDRs(c.Peers[i].AllowedIPs)
|
||||
sortCIDRs(b.Peers[i].AllowedIPs)
|
||||
if ok = pc(c.Peers[i], b.Peers[i]); !ok {
|
||||
for j := range c.Peers[i].AllowedIPs {
|
||||
if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) {
|
||||
return false
|
||||
}
|
||||
if c.Peers[i].Endpoint != nil {
|
||||
if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port {
|
||||
return false
|
||||
}
|
||||
// IPs take priority, so check them first.
|
||||
if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) {
|
||||
return false
|
||||
}
|
||||
// Only check the DNS name if the IP is empty.
|
||||
if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
// PeerCheck is a function that compares two peers.
|
||||
type PeerCheck func(a, b *Peer) bool
|
||||
|
||||
func strictPeerCheck(a, b *Peer) bool {
|
||||
for j := range a.AllowedIPs {
|
||||
if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if (a.Endpoint == nil) != (b.Endpoint == nil) {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint != nil {
|
||||
if a.Endpoint.Port != b.Endpoint.Port {
|
||||
return false
|
||||
}
|
||||
// IPs take priority, so check them first.
|
||||
if !a.Endpoint.IP.Equal(b.Endpoint.IP) {
|
||||
return false
|
||||
}
|
||||
// Only check the DNS name if the IP is empty.
|
||||
if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return a.PersistentKeepalive == b.PersistentKeepalive && bytes.Equal(a.PublicKey, b.PublicKey)
|
||||
}
|
||||
|
||||
func sortPeers(peers []*Peer) {
|
||||
|
Loading…
Reference in New Issue
Block a user