kilo/pkg/mesh/mesh.go

787 lines
22 KiB
Go
Raw Normal View History

2019-01-18 01:50:10 +00:00
// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build linux
2019-01-18 01:50:10 +00:00
package mesh
import (
"bytes"
2019-01-18 01:50:10 +00:00
"fmt"
"io/ioutil"
"net"
"os"
"sync"
"time"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/prometheus/client_golang/prometheus"
"github.com/vishvananda/netlink"
"github.com/squat/kilo/pkg/encapsulation"
2019-01-18 01:50:10 +00:00
"github.com/squat/kilo/pkg/iproute"
"github.com/squat/kilo/pkg/iptables"
"github.com/squat/kilo/pkg/route"
"github.com/squat/kilo/pkg/wireguard"
)
const (
// kiloPath is the directory where Kilo stores its configuration.
kiloPath = "/var/lib/kilo"
// privateKeyPath is the filepath where the WireGuard private key is stored.
privateKeyPath = kiloPath + "/key"
// confPath is the filepath where the WireGuard configuration is stored.
confPath = kiloPath + "/conf"
2019-01-18 01:50:10 +00:00
)
// Mesh is able to create Kilo network meshes.
type Mesh struct {
Backend
cleanUpIface bool
cni bool
cniPath string
enc encapsulation.Encapsulator
externalIP *net.IPNet
granularity Granularity
hostname string
internalIP *net.IPNet
ipTables *iptables.Controller
kiloIface int
key []byte
local bool
port uint32
priv []byte
privIface int
pub []byte
resyncPeriod time.Duration
stop chan struct{}
subnet *net.IPNet
table *route.Table
wireGuardIP *net.IPNet
2019-01-18 01:50:10 +00:00
// nodes and peers are mutable fields in the struct
// and need to be guarded.
2019-01-18 01:50:10 +00:00
nodes map[string]*Node
peers map[string]*Peer
2019-01-18 01:50:10 +00:00
mu sync.Mutex
errorCounter *prometheus.CounterVec
leaderGuage prometheus.Gauge
nodesGuage prometheus.Gauge
peersGuage prometheus.Gauge
reconcileCounter prometheus.Counter
logger log.Logger
2019-01-18 01:50:10 +00:00
}
// New returns a new Mesh instance.
func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port uint32, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanUpIface bool, createIface bool, resyncPeriod time.Duration, logger log.Logger) (*Mesh, error) {
if err := os.MkdirAll(kiloPath, 0700); err != nil {
2019-01-18 01:50:10 +00:00
return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
}
private, err := ioutil.ReadFile(privateKeyPath)
private = bytes.Trim(private, "\n")
2019-01-18 01:50:10 +00:00
if err != nil {
level.Warn(logger).Log("msg", "no private key found on disk; generating one now")
if private, err = wireguard.GenKey(); err != nil {
return nil, err
}
}
public, err := wireguard.PubKey(private)
if err != nil {
return nil, err
}
if err := ioutil.WriteFile(privateKeyPath, private, 0600); err != nil {
2019-01-18 01:50:10 +00:00
return nil, fmt.Errorf("failed to write private key to disk: %v", err)
}
cniIndex, err := cniDeviceIndex()
if err != nil {
return nil, fmt.Errorf("failed to query netlink for CNI device: %v", err)
}
var kiloIface int
if createIface {
kiloIface, _, err = wireguard.New(iface)
if err != nil {
return nil, fmt.Errorf("failed to create WireGuard interface: %v", err)
}
} else {
link, err := netlink.LinkByName(iface)
if err != nil {
return nil, fmt.Errorf("failed to get interface index: %v", err)
}
kiloIface = link.Attrs().Index
2019-01-18 01:50:10 +00:00
}
privateIP, publicIP, err := getIP(hostname, kiloIface, enc.Index(), cniIndex)
if err != nil {
return nil, fmt.Errorf("failed to find public IP: %v", err)
}
var privIface int
if privateIP != nil {
ifaces, err := interfacesForIP(privateIP)
if err != nil {
return nil, fmt.Errorf("failed to find interface for private IP: %v", err)
2019-01-18 01:50:10 +00:00
}
privIface = ifaces[0].Index
if enc.Strategy() != encapsulation.Never {
if err := enc.Init(privIface); err != nil {
return nil, fmt.Errorf("failed to initialize encapsulator: %v", err)
}
}
level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the private IP address", privateIP.String()))
} else {
enc = encapsulation.Noop(enc.Strategy())
level.Debug(logger).Log("msg", "running without a private IP address")
2019-01-18 01:50:10 +00:00
}
level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the public IP address", publicIP.String()))
ipTables, err := iptables.New(iptables.WithLogger(log.With(logger, "component", "iptables")), iptables.WithResyncPeriod(resyncPeriod))
2019-01-18 01:50:10 +00:00
if err != nil {
return nil, fmt.Errorf("failed to IP tables controller: %v", err)
}
return &Mesh{
Backend: backend,
cleanUpIface: cleanUpIface,
cni: cni,
cniPath: cniPath,
enc: enc,
externalIP: publicIP,
granularity: granularity,
hostname: hostname,
internalIP: privateIP,
ipTables: ipTables,
kiloIface: kiloIface,
nodes: make(map[string]*Node),
peers: make(map[string]*Peer),
port: port,
priv: private,
privIface: privIface,
pub: public,
local: local,
stop: make(chan struct{}),
subnet: subnet,
table: route.NewTable(),
2019-01-18 01:50:10 +00:00
errorCounter: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "kilo_errors_total",
Help: "Number of errors that occurred while administering the mesh.",
}, []string{"event"}),
leaderGuage: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "kilo_leader",
Help: "Leadership status of the node.",
}),
2019-01-18 01:50:10 +00:00
nodesGuage: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "kilo_nodes",
Help: "Number of nodes in the mesh.",
}),
peersGuage: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "kilo_peers",
Help: "Number of peers in the mesh.",
2019-01-18 01:50:10 +00:00
}),
reconcileCounter: prometheus.NewCounter(prometheus.CounterOpts{
Name: "kilo_reconciles_total",
Help: "Number of reconciliation attempts.",
}),
2019-01-18 01:50:10 +00:00
logger: logger,
}, nil
}
// Run starts the mesh.
func (m *Mesh) Run() error {
if err := m.Nodes().Init(m.stop); err != nil {
return fmt.Errorf("failed to initialize node backend: %v", err)
}
// Try to set the CNI config quickly.
if m.cni {
if n, err := m.Nodes().Get(m.hostname); err == nil {
m.nodes[m.hostname] = n
m.updateCNIConfig()
} else {
level.Warn(m.logger).Log("error", fmt.Errorf("failed to get node %q: %v", m.hostname, err))
}
}
if err := m.Peers().Init(m.stop); err != nil {
return fmt.Errorf("failed to initialize peer backend: %v", err)
2019-01-18 01:50:10 +00:00
}
ipTablesErrors, err := m.ipTables.Run(m.stop)
if err != nil {
return fmt.Errorf("failed to watch for IP tables updates: %v", err)
}
routeErrors, err := m.table.Run(m.stop)
if err != nil {
return fmt.Errorf("failed to watch for route table updates: %v", err)
}
go func() {
for {
var err error
select {
case err = <-ipTablesErrors:
case err = <-routeErrors:
case <-m.stop:
return
}
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("run").Inc()
}
}
}()
defer m.cleanUp()
resync := time.NewTimer(m.resyncPeriod)
checkIn := time.NewTimer(checkInPeriod)
nw := m.Nodes().Watch()
pw := m.Peers().Watch()
var ne *NodeEvent
var pe *PeerEvent
2019-01-18 01:50:10 +00:00
for {
select {
case ne = <-nw:
m.syncNodes(ne)
case pe = <-pw:
m.syncPeers(pe)
case <-checkIn.C:
m.checkIn()
checkIn.Reset(checkInPeriod)
case <-resync.C:
if m.cni {
m.updateCNIConfig()
}
2019-01-18 01:50:10 +00:00
m.applyTopology()
resync.Reset(m.resyncPeriod)
2019-01-18 01:50:10 +00:00
case <-m.stop:
return nil
}
}
}
func (m *Mesh) syncNodes(e *NodeEvent) {
2019-01-18 01:50:10 +00:00
logger := log.With(m.logger, "event", e.Type)
level.Debug(logger).Log("msg", "syncing nodes", "event", e.Type)
2019-01-18 01:50:10 +00:00
if isSelf(m.hostname, e.Node) {
level.Debug(logger).Log("msg", "processing local node", "node", e.Node)
m.handleLocal(e.Node)
return
}
var diff bool
m.mu.Lock()
if !e.Node.Ready() {
level.Debug(logger).Log("msg", "received incomplete node", "node", e.Node)
// An existing node is no longer valid
// so remove it from the mesh.
if _, ok := m.nodes[e.Node.Name]; ok {
2019-05-08 15:10:33 +00:00
level.Info(logger).Log("msg", "node is no longer ready", "node", e.Node)
2019-01-18 01:50:10 +00:00
diff = true
}
} else {
switch e.Type {
case AddEvent:
fallthrough
case UpdateEvent:
if !nodesAreEqual(m.nodes[e.Node.Name], e.Node) {
diff = true
}
// Even if the nodes are the same,
// overwrite the old node to update the timestamp.
m.nodes[e.Node.Name] = e.Node
2019-01-18 01:50:10 +00:00
case DeleteEvent:
delete(m.nodes, e.Node.Name)
diff = true
}
}
m.mu.Unlock()
if diff {
level.Info(logger).Log("node", e.Node)
m.applyTopology()
}
}
func (m *Mesh) syncPeers(e *PeerEvent) {
logger := log.With(m.logger, "event", e.Type)
level.Debug(logger).Log("msg", "syncing peers", "event", e.Type)
var diff bool
m.mu.Lock()
// Peers are indexed by public key.
key := string(e.Peer.PublicKey)
if !e.Peer.Ready() {
level.Debug(logger).Log("msg", "received incomplete peer", "peer", e.Peer)
// An existing peer is no longer valid
// so remove it from the mesh.
if _, ok := m.peers[key]; ok {
2019-05-08 15:10:33 +00:00
level.Info(logger).Log("msg", "peer is no longer ready", "peer", e.Peer)
diff = true
}
} else {
switch e.Type {
case AddEvent:
fallthrough
case UpdateEvent:
2019-05-08 15:10:33 +00:00
if e.Old != nil && key != string(e.Old.PublicKey) {
delete(m.peers, string(e.Old.PublicKey))
diff = true
}
if !peersAreEqual(m.peers[key], e.Peer) {
m.peers[key] = e.Peer
diff = true
}
case DeleteEvent:
delete(m.peers, key)
diff = true
}
}
m.mu.Unlock()
if diff {
level.Info(logger).Log("peer", e.Peer)
m.applyTopology()
}
}
// checkIn will try to update the local node's LastSeen timestamp
// in the backend.
func (m *Mesh) checkIn() {
m.mu.Lock()
2019-05-08 15:10:33 +00:00
defer m.mu.Unlock()
n := m.nodes[m.hostname]
if n == nil {
level.Debug(m.logger).Log("msg", "no local node found in backend")
return
}
2019-05-08 15:10:33 +00:00
oldTime := n.LastSeen
n.LastSeen = time.Now().Unix()
if err := m.Nodes().Set(m.hostname, n); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", n)
m.errorCounter.WithLabelValues("checkin").Inc()
2019-05-08 15:10:33 +00:00
// Revert time.
n.LastSeen = oldTime
return
}
level.Debug(m.logger).Log("msg", "successfully checked in local node in backend")
}
2019-01-18 01:50:10 +00:00
func (m *Mesh) handleLocal(n *Node) {
// Allow the IPs to be overridden.
if n.Endpoint == nil || (n.Endpoint.DNS == "" && n.Endpoint.IP == nil) {
n.Endpoint = &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: m.externalIP.IP}, Port: m.port}
2019-01-18 01:50:10 +00:00
}
if n.InternalIP == nil {
n.InternalIP = m.internalIP
}
2019-01-18 01:50:10 +00:00
// Compare the given node to the calculated local node.
// Take leader, location, and subnet from the argument, as these
// are not determined by kilo.
local := &Node{
Endpoint: n.Endpoint,
Key: m.pub,
InternalIP: n.InternalIP,
LastSeen: time.Now().Unix(),
Leader: n.Leader,
Location: n.Location,
Name: m.hostname,
PersistentKeepalive: n.PersistentKeepalive,
Subnet: n.Subnet,
WireGuardIP: m.wireGuardIP,
}
2019-01-18 01:50:10 +00:00
if !nodesAreEqual(n, local) {
level.Debug(m.logger).Log("msg", "local node differs from backend")
if err := m.Nodes().Set(m.hostname, local); err != nil {
2019-01-18 01:50:10 +00:00
level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", local)
m.errorCounter.WithLabelValues("local").Inc()
return
}
level.Debug(m.logger).Log("msg", "successfully reconciled local node against backend")
}
m.mu.Lock()
2019-05-08 15:10:33 +00:00
2019-01-18 01:50:10 +00:00
n = m.nodes[m.hostname]
if n == nil {
n = &Node{}
}
m.mu.Unlock()
if !nodesAreEqual(n, local) {
m.mu.Lock()
m.nodes[local.Name] = local
m.mu.Unlock()
m.applyTopology()
}
}
func (m *Mesh) applyTopology() {
m.reconcileCounter.Inc()
2019-01-18 01:50:10 +00:00
m.mu.Lock()
defer m.mu.Unlock()
// If we can't resolve an endpoint, then fail and retry later.
if err := m.resolveEndpoints(); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
2019-05-08 15:10:33 +00:00
// Ensure only ready nodes are considered.
nodes := make(map[string]*Node)
var readyNodes float64
for k := range m.nodes {
if !m.nodes[k].Ready() {
continue
}
// Make a shallow copy of the node.
node := *m.nodes[k]
nodes[k] = &node
readyNodes++
}
2019-05-08 15:10:33 +00:00
// Ensure only ready nodes are considered.
peers := make(map[string]*Peer)
var readyPeers float64
for k := range m.peers {
if !m.peers[k].Ready() {
2019-01-18 01:50:10 +00:00
continue
}
// Make a shallow copy of the peer.
peer := *m.peers[k]
peers[k] = &peer
readyPeers++
2019-01-18 01:50:10 +00:00
}
m.nodesGuage.Set(readyNodes)
m.peersGuage.Set(readyPeers)
2019-01-18 01:50:10 +00:00
// We cannot do anything with the topology until the local node is available.
2019-05-08 15:10:33 +00:00
if nodes[m.hostname] == nil {
2019-01-18 01:50:10 +00:00
return
}
// Find the Kilo interface name.
link, err := linkByIndex(m.kiloIface)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
// Find the old configuration.
oldConfRaw, err := wireguard.ShowConf(link.Attrs().Name)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
oldConf := wireguard.Parse(oldConfRaw)
updateNATEndpoints(nodes, peers, oldConf)
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive)
2019-01-18 01:50:10 +00:00
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
// Update the node's WireGuard IP.
m.wireGuardIP = t.wireGuardCIDR
conf := t.Conf()
buf, err := conf.Bytes()
2019-01-18 01:50:10 +00:00
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
2019-01-18 01:50:10 +00:00
}
if err := ioutil.WriteFile(confPath, buf, 0600); err != nil {
2019-01-18 01:50:10 +00:00
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
ipRules := t.Rules(m.cni)
// If we are handling local routes, ensure the local
// tunnel has an IP address and IPIP traffic is allowed.
if m.enc.Strategy() != encapsulation.Never && m.local {
var cidrs []*net.IPNet
for _, s := range t.segments {
// If the location prefix is not logicalLocation, but nodeLocation,
// we don't need to set any extra rules for encapsulation anyways
// because traffic will go over WireGuard.
if s.location == logicalLocationPrefix+nodes[m.hostname].Location {
for i := range s.privateIPs {
cidrs = append(cidrs, oneAddressCIDR(s.privateIPs[i]))
}
2019-01-18 01:50:10 +00:00
break
}
}
ipRules = append(ipRules, m.enc.Rules(cidrs)...)
// If we are handling local routes, ensure the local
// tunnel has an IP address.
if err := m.enc.Set(oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil {
2019-01-18 01:50:10 +00:00
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
}
if err := m.ipTables.Set(ipRules); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
2019-01-18 01:50:10 +00:00
if t.leader {
m.leaderGuage.Set(1)
2019-01-18 01:50:10 +00:00
if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
// Setting the WireGuard configuration interrupts existing connections
// so only set the configuration if it has changed.
equal := conf.Equal(oldConf)
2019-01-18 01:50:10 +00:00
if !equal {
level.Info(m.logger).Log("msg", "WireGuard configurations are different")
if err := wireguard.SetConf(link.Attrs().Name, confPath); err != nil {
2019-01-18 01:50:10 +00:00
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
}
if err := iproute.Set(m.kiloIface, true); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
} else {
m.leaderGuage.Set(0)
2019-01-18 01:50:10 +00:00
level.Debug(m.logger).Log("msg", "local node is not the leader")
if err := iproute.Set(m.kiloIface, false); err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
}
// We need to add routes last since they may depend
// on the WireGuard interface.
routes, rules := t.Routes(link.Attrs().Name, m.kiloIface, m.privIface, m.enc.Index(), m.local, m.enc)
if err := m.table.Set(routes, rules); err != nil {
2019-01-18 01:50:10 +00:00
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
}
}
// RegisterMetrics registers Prometheus metrics on the given Prometheus
// registerer.
func (m *Mesh) RegisterMetrics(r prometheus.Registerer) {
r.MustRegister(
m.errorCounter,
m.leaderGuage,
2019-01-18 01:50:10 +00:00
m.nodesGuage,
m.peersGuage,
m.reconcileCounter,
2019-01-18 01:50:10 +00:00
)
}
// Stop stops the mesh.
func (m *Mesh) Stop() {
close(m.stop)
}
func (m *Mesh) cleanUp() {
if err := m.ipTables.CleanUp(); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up IP tables: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := m.table.CleanUp(); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up routes: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := os.Remove(confPath); err != nil {
2019-01-18 01:50:10 +00:00
level.Error(m.logger).Log("error", fmt.Sprintf("failed to delete configuration file: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if m.cleanUpIface {
if err := iproute.RemoveInterface(m.kiloIface); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
2019-01-18 01:50:10 +00:00
}
if err := m.Nodes().CleanUp(m.hostname); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up node backend: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := m.Peers().CleanUp(m.hostname); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err))
2019-01-18 01:50:10 +00:00
m.errorCounter.WithLabelValues("cleanUp").Inc()
}
if err := m.enc.CleanUp(); err != nil {
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up encapsulator: %v", err))
m.errorCounter.WithLabelValues("cleanUp").Inc()
2019-01-18 01:50:10 +00:00
}
}
func (m *Mesh) resolveEndpoints() error {
for k := range m.nodes {
// Skip unready nodes, since they will not be used
// in the topology anyways.
if !m.nodes[k].Ready() {
continue
}
// If the node is ready, then the endpoint is not nil
// but it may not have a DNS name.
if m.nodes[k].Endpoint.DNS == "" {
continue
}
if err := resolveEndpoint(m.nodes[k].Endpoint); err != nil {
return err
}
}
for k := range m.peers {
// Skip unready peers, since they will not be used
// in the topology anyways.
if !m.peers[k].Ready() {
continue
}
// Peers may have nil endpoints.
if m.peers[k].Endpoint == nil || m.peers[k].Endpoint.DNS == "" {
continue
}
if err := resolveEndpoint(m.peers[k].Endpoint); err != nil {
return err
}
}
return nil
}
func resolveEndpoint(endpoint *wireguard.Endpoint) error {
ips, err := net.LookupIP(endpoint.DNS)
if err != nil {
return fmt.Errorf("failed to look up DNS name %q: %v", endpoint.DNS, err)
}
nets := make([]*net.IPNet, len(ips), len(ips))
for i := range ips {
nets[i] = oneAddressCIDR(ips[i])
}
sortIPs(nets)
if len(nets) == 0 {
return fmt.Errorf("did not find any addresses for DNS name %q", endpoint.DNS)
}
endpoint.IP = nets[0].IP
return nil
}
2019-01-18 01:50:10 +00:00
func isSelf(hostname string, node *Node) bool {
return node != nil && node.Name == hostname
}
func nodesAreEqual(a, b *Node) bool {
if (a != nil) != (b != nil) {
2019-01-18 01:50:10 +00:00
return false
}
if a == b {
return true
}
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
return false
}
if a.Endpoint != nil {
if a.Endpoint.Port != b.Endpoint.Port {
return false
}
// Check the DNS name first since this package
// is doing the DNS resolution.
if a.Endpoint.DNS != b.Endpoint.DNS {
return false
}
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
return false
}
}
2019-05-08 15:10:33 +00:00
// Ignore LastSeen when comparing equality we want to check if the nodes are
// equivalent. However, we do want to check if LastSeen has transitioned
// between valid and invalid.
return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive
2019-01-18 01:50:10 +00:00
}
func peersAreEqual(a, b *Peer) bool {
if !(a != nil) == (b != nil) {
return false
}
if a == b {
return true
}
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
return false
}
if a.Endpoint != nil {
if a.Endpoint.Port != b.Endpoint.Port {
return false
}
// Check the DNS name first since this package
// is doing the DNS resolution.
if a.Endpoint.DNS != b.Endpoint.DNS {
return false
}
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
return false
}
}
if len(a.AllowedIPs) != len(b.AllowedIPs) {
return false
}
for i := range a.AllowedIPs {
if !ipNetsEqual(a.AllowedIPs[i], b.AllowedIPs[i]) {
return false
}
}
return string(a.PublicKey) == string(b.PublicKey) && string(a.PresharedKey) == string(b.PresharedKey) && a.PersistentKeepalive == b.PersistentKeepalive
}
2019-01-18 01:50:10 +00:00
func ipNetsEqual(a, b *net.IPNet) bool {
if a == nil && b == nil {
return true
}
if (a != nil) != (b != nil) {
return false
}
if a.Mask.String() != b.Mask.String() {
return false
}
return a.IP.Equal(b.IP)
}
func subnetsEqual(a, b *net.IPNet) bool {
if a == nil && b == nil {
return true
}
if (a != nil) != (b != nil) {
return false
}
2019-01-18 01:50:10 +00:00
if a.Mask.String() != b.Mask.String() {
return false
}
if !a.Contains(b.IP) {
return false
}
if !b.Contains(a.IP) {
return false
}
return true
}
func linkByIndex(index int) (netlink.Link, error) {
link, err := netlink.LinkByIndex(index)
if err != nil {
return nil, fmt.Errorf("failed to get interface: %v", err)
}
return link, nil
}
// updateNATEndpoints ensures that nodes and peers behind NAT update
// their endpoints from the WireGuard configuration so they can roam.
func updateNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf) {
keys := make(map[string]*wireguard.Peer)
for i := range conf.Peers {
keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
}
for _, n := range nodes {
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
n.Endpoint = peer.Endpoint
}
}
for _, p := range peers {
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
p.Endpoint = peer.Endpoint
}
}
}