Merge pull request #43 from squat/fix_keepalive_logic

pkg/mesh,pkg/wireguard: update NAT endpoints
This commit is contained in:
Lucas Servén Marín 2020-03-04 02:15:11 +01:00 committed by GitHub
commit 6947eb4154
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 181 additions and 204 deletions

View File

@ -60,7 +60,7 @@ func runGraph(_ *cobra.Command, _ []string) error {
peers[p.Name] = p peers[p.Name] = p
} }
} }
t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, 0, []byte{}, subnet) t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, 0, []byte{}, subnet, nodes[hostname].PersistentKeepalive)
if err != nil { if err != nil {
return fmt.Errorf("failed to create topology: %v", err) return fmt.Errorf("failed to create topology: %v", err)
} }

View File

@ -147,7 +147,7 @@ func runShowConfNode(_ *cobra.Command, args []string) error {
} }
} }
t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, opts.port, []byte{}, subnet) t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, opts.port, []byte{}, subnet, nodes[hostname].PersistentKeepalive)
if err != nil { if err != nil {
return fmt.Errorf("failed to create topology: %v", err) return fmt.Errorf("failed to create topology: %v", err)
} }
@ -236,7 +236,7 @@ func runShowConfPeer(_ *cobra.Command, args []string) error {
return fmt.Errorf("did not find any peer named %q in the cluster", peer) return fmt.Errorf("did not find any peer named %q in the cluster", peer)
} }
t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, mesh.DefaultKiloPort, []byte{}, subnet) t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, mesh.DefaultKiloPort, []byte{}, subnet, peers[peer].PersistentKeepalive)
if err != nil { if err != nil {
return fmt.Errorf("failed to create topology: %v", err) return fmt.Errorf("failed to create topology: %v", err)
} }

View File

@ -4,10 +4,11 @@ The following annotations can be added to any Kubernetes Node object to configur
|Name|type|examples| |Name|type|examples|
|----|----|-------| |----|----|-------|
|[kilo.squat.ai/force-endpoint](#force-endpoint)|host:port|`55.55.55.55:51820`, `example.com:1337| |[kilo.squat.ai/force-endpoint](#force-endpoint)|host:port|`55.55.55.55:51820`, `example.com:1337`|
|[kilo.squat.ai/force-internal-ip](#force-internal-ip)|CIDR|`55.55.55.55/32`| |[kilo.squat.ai/force-internal-ip](#force-internal-ip)|CIDR|`55.55.55.55/32`|
|[kilo.squat.ai/leader](#leader)|string|`""`, `true`| |[kilo.squat.ai/leader](#leader)|string|`""`, `true`|
|[kilo.squat.ai/location](#location)|string|`gcp-east`, `lab`| |[kilo.squat.ai/location](#location)|string|`gcp-east`, `lab`|
|[kilo.squat.ai/persistent-keepalive](#persistent-keepalive)|uint|`10`|
### force-endpoint ### force-endpoint
In order to create links between locations, Kilo requires at least one node in each location to have an endpoint, ie a `host:port` combination, that is routable from the other locations. In order to create links between locations, Kilo requires at least one node in each location to have an endpoint, ie a `host:port` combination, that is routable from the other locations.
@ -42,3 +43,11 @@ Kilo will try to infer each node's location from the [topology.kubernetes.io/reg
If the label is not present for a node, for example if running a bare-metal cluster or on an unsupported cloud provider, then the location annotation should be specified. If the label is not present for a node, for example if running a bare-metal cluster or on an unsupported cloud provider, then the location annotation should be specified.
_Note_: all nodes without a defined location will be considered to be in the default location `""`. _Note_: all nodes without a defined location will be considered to be in the default location `""`.
### persistent-keepalive
In certain deployments, cluster nodes may be located behind NAT or a firewall, e.g. edge nodes located behind a commodity router.
In these scenarios, the nodes behind NAT can send packets to the nodes outside of the NATed network, however the outside nodes can only send packets into the NATed network as long as the NAT mapping remains valid.
In order for a node behind NAT to receive packets from nodes outside of the NATed network, it must maintain the NAT mapping by regularly sending packets to those nodes, ie by sending _keepalives_.
The frequency of emission of these keepalive packets can be controlled by setting the persistent-keepalive annotation on the node behind NAT.
The annotated node will use the specified value will as the persistent-keepalive interval for all of its peers.
For more background, [see the WireGuard documentation on NAT and firewall traversal](https://www.wireguard.com/quickstart/#nat-and-firewall-traversal-persistence).

View File

@ -299,7 +299,7 @@ func schema_k8s_apis_kilo_v1alpha1_PeerSpec(ref common.ReferenceCallback) common
}, },
"persistentKeepalive": { "persistentKeepalive": {
SchemaProps: spec.SchemaProps{ SchemaProps: spec.SchemaProps{
Description: "PersistentKeepalive is the interval in seconds of the emission of keepalive packets to the peer. This defaults to 0, which disables the feature.", Description: "PersistentKeepalive is the interval in seconds of the emission of keepalive packets by the peer. This defaults to 0, which disables the feature.",
Type: []string{"integer"}, Type: []string{"integer"},
Format: "int32", Format: "int32",
}, },

View File

@ -68,7 +68,7 @@ type PeerSpec struct {
// +optional // +optional
Endpoint *PeerEndpoint `json:"endpoint,omitempty"` Endpoint *PeerEndpoint `json:"endpoint,omitempty"`
// PersistentKeepalive is the interval in seconds of the emission // PersistentKeepalive is the interval in seconds of the emission
// of keepalive packets to the peer. This defaults to 0, which // of keepalive packets by the peer. This defaults to 0, which
// disables the feature. // disables the feature.
// +optional // +optional
PersistentKeepalive int `json:"persistentKeepalive,omitempty"` PersistentKeepalive int `json:"persistentKeepalive,omitempty"`

View File

@ -529,7 +529,9 @@ func (m *Mesh) applyTopology() {
if !m.nodes[k].Ready() { if !m.nodes[k].Ready() {
continue continue
} }
nodes[k] = m.nodes[k] // Make a shallow copy of the node.
node := *m.nodes[k]
nodes[k] = &node
readyNodes++ readyNodes++
} }
// Ensure only ready nodes are considered. // Ensure only ready nodes are considered.
@ -539,7 +541,9 @@ func (m *Mesh) applyTopology() {
if !m.peers[k].Ready() { if !m.peers[k].Ready() {
continue continue
} }
peers[k] = m.peers[k] // Make a shallow copy of the peer.
peer := *m.peers[k]
peers[k] = &peer
readyPeers++ readyPeers++
} }
m.nodesGuage.Set(readyNodes) m.nodesGuage.Set(readyNodes)
@ -548,7 +552,23 @@ func (m *Mesh) applyTopology() {
if nodes[m.hostname] == nil { if nodes[m.hostname] == nil {
return return
} }
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet) // Find the Kilo interface name.
link, err := linkByIndex(m.kiloIface)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
// Find the old configuration.
oldConfRaw, err := wireguard.ShowConf(link.Attrs().Name)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
oldConf := wireguard.Parse(oldConfRaw)
updateNATEndpoints(nodes, peers, oldConf)
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive)
if err != nil { if err != nil {
level.Error(m.logger).Log("error", err) level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc() m.errorCounter.WithLabelValues("apply").Inc()
@ -582,7 +602,6 @@ func (m *Mesh) applyTopology() {
} }
} }
ipRules = append(ipRules, m.enc.Rules(cidrs)...) ipRules = append(ipRules, m.enc.Rules(cidrs)...)
// If we are handling local routes, ensure the local // If we are handling local routes, ensure the local
// tunnel has an IP address. // tunnel has an IP address.
if err := m.enc.Set(oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil { if err := m.enc.Set(oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil {
@ -596,28 +615,15 @@ func (m *Mesh) applyTopology() {
m.errorCounter.WithLabelValues("apply").Inc() m.errorCounter.WithLabelValues("apply").Inc()
return return
} }
// Find the Kilo interface name.
link, err := linkByIndex(m.kiloIface)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
if t.leader { if t.leader {
if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil { if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil {
level.Error(m.logger).Log("error", err) level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc() m.errorCounter.WithLabelValues("apply").Inc()
return return
} }
oldConf, err := wireguard.ShowConf(link.Attrs().Name)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
return
}
// Setting the WireGuard configuration interrupts existing connections // Setting the WireGuard configuration interrupts existing connections
// so only set the configuration if it has changed. // so only set the configuration if it has changed.
equal := conf.EqualWithPeerCheck(wireguard.Parse(oldConf), peersAreEqualIgnoreNAT) equal := conf.Equal(oldConf)
if !equal { if !equal {
level.Info(m.logger).Log("msg", "WireGuard configurations are different") level.Info(m.logger).Log("msg", "WireGuard configurations are different")
if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil { if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil {
@ -814,41 +820,6 @@ func peersAreEqual(a, b *Peer) bool {
return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive
} }
// Basic nil checks and checking the lengths of the allowed IPs is
// done by the WireGuard package.
func peersAreEqualIgnoreNAT(a, b *wireguard.Peer) bool {
for j := range a.AllowedIPs {
if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() {
return false
}
}
if a.PersistentKeepalive != b.PersistentKeepalive || !bytes.Equal(a.PublicKey, b.PublicKey) {
return false
}
// If a persistent keepalive is set, then the peer is behind NAT
// and we want to ignore changes in endpoints, since it may roam.
if a.PersistentKeepalive != 0 {
return true
}
if (a.Endpoint == nil) != (b.Endpoint == nil) {
return false
}
if a.Endpoint != nil {
if a.Endpoint.Port != b.Endpoint.Port {
return false
}
// IPs take priority, so check them first.
if !a.Endpoint.IP.Equal(b.Endpoint.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS {
return false
}
}
return true
}
func ipNetsEqual(a, b *net.IPNet) bool { func ipNetsEqual(a, b *net.IPNet) bool {
if a == nil && b == nil { if a == nil && b == nil {
return true return true
@ -888,3 +859,22 @@ func linkByIndex(index int) (netlink.Link, error) {
} }
return link, nil return link, nil
} }
// updateNATEndpoints ensures that nodes and peers behind NAT update
// their endpoints from the WireGuard configuration so they can roam.
func updateNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf) {
keys := make(map[string]*wireguard.Peer)
for i := range conf.Peers {
keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
}
for _, n := range nodes {
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
n.Endpoint = peer.Endpoint
}
}
for _, p := range peers {
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
p.Endpoint = peer.Endpoint
}
}
}

View File

@ -42,10 +42,13 @@ type Topology struct {
// leader represents whether or not the local host // leader represents whether or not the local host
// is the segment leader. // is the segment leader.
leader bool leader bool
// persistentKeepalive is the interval in seconds of the emission
// of keepalive packets by the local node to its peers.
persistentKeepalive int
// privateIP is the private IP address of the local node.
privateIP *net.IPNet
// subnet is the Pod subnet of the local node. // subnet is the Pod subnet of the local node.
subnet *net.IPNet subnet *net.IPNet
// privateIP is the private IP address of the local node.
privateIP *net.IPNet
// wireGuardCIDR is the allocated CIDR of the WireGuard // wireGuardCIDR is the allocated CIDR of the WireGuard
// interface of the local node. If the local node is not // interface of the local node. If the local node is not
// the leader, then it is nil. // the leader, then it is nil.
@ -65,9 +68,6 @@ type segment struct {
hostnames []string hostnames []string
// leader is the index of the leader of the segment. // leader is the index of the leader of the segment.
leader int leader int
// persistentKeepalive is the interval in seconds of the emission
// of keepalive packets to the peer.
persistentKeepalive int
// privateIPs is a slice of private IPs of all peers in the segment. // privateIPs is a slice of private IPs of all peers in the segment.
privateIPs []net.IP privateIPs []net.IP
// wireGuardIP is the allocated IP address of the WireGuard // wireGuardIP is the allocated IP address of the WireGuard
@ -76,7 +76,7 @@ type segment struct {
} }
// NewTopology creates a new Topology struct from a given set of nodes and peers. // NewTopology creates a new Topology struct from a given set of nodes and peers.
func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet) (*Topology, error) { func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int) (*Topology, error) {
topoMap := make(map[string][]*Node) topoMap := make(map[string][]*Node)
for _, node := range nodes { for _, node := range nodes {
var location string var location string
@ -96,7 +96,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
localLocation = hostname localLocation = hostname
} }
t := Topology{key: key, port: port, hostname: hostname, location: localLocation, subnet: nodes[hostname].Subnet, privateIP: nodes[hostname].InternalIP} t := Topology{key: key, port: port, hostname: hostname, location: localLocation, persistentKeepalive: persistentKeepalive, privateIP: nodes[hostname].InternalIP, subnet: nodes[hostname].Subnet}
for location := range topoMap { for location := range topoMap {
// Sort the location so the result is stable. // Sort the location so the result is stable.
sort.Slice(topoMap[location], func(i, j int) bool { sort.Slice(topoMap[location], func(i, j int) bool {
@ -121,15 +121,14 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
privateIPs = append(privateIPs, node.InternalIP.IP) privateIPs = append(privateIPs, node.InternalIP.IP)
} }
t.segments = append(t.segments, &segment{ t.segments = append(t.segments, &segment{
allowedIPs: allowedIPs, allowedIPs: allowedIPs,
endpoint: topoMap[location][leader].Endpoint, endpoint: topoMap[location][leader].Endpoint,
key: topoMap[location][leader].Key, key: topoMap[location][leader].Key,
location: location, location: location,
cidrs: cidrs, cidrs: cidrs,
hostnames: hostnames, hostnames: hostnames,
leader: leader, leader: leader,
privateIPs: privateIPs, privateIPs: privateIPs,
persistentKeepalive: topoMap[location][leader].PersistentKeepalive,
}) })
} }
// Sort the Topology segments so the result is stable. // Sort the Topology segments so the result is stable.
@ -367,14 +366,14 @@ func (t *Topology) Conf() *wireguard.Conf {
AllowedIPs: s.allowedIPs, AllowedIPs: s.allowedIPs,
Endpoint: s.endpoint, Endpoint: s.endpoint,
PublicKey: s.key, PublicKey: s.key,
PersistentKeepalive: s.persistentKeepalive, PersistentKeepalive: t.persistentKeepalive,
} }
c.Peers = append(c.Peers, peer) c.Peers = append(c.Peers, peer)
} }
for _, p := range t.peers { for _, p := range t.peers {
peer := &wireguard.Peer{ peer := &wireguard.Peer{
AllowedIPs: p.AllowedIPs, AllowedIPs: p.AllowedIPs,
PersistentKeepalive: p.PersistentKeepalive, PersistentKeepalive: t.persistentKeepalive,
PublicKey: p.PublicKey, PublicKey: p.PublicKey,
Endpoint: p.Endpoint, Endpoint: p.Endpoint,
} }
@ -391,10 +390,9 @@ func (t *Topology) AsPeer() *wireguard.Peer {
continue continue
} }
return &wireguard.Peer{ return &wireguard.Peer{
AllowedIPs: s.allowedIPs, AllowedIPs: s.allowedIPs,
Endpoint: s.endpoint, Endpoint: s.endpoint,
PersistentKeepalive: s.persistentKeepalive, PublicKey: s.key,
PublicKey: s.key,
} }
} }
return nil return nil
@ -402,25 +400,35 @@ func (t *Topology) AsPeer() *wireguard.Peer {
// PeerConf generates a WireGuard configuration file for a given peer in a Topology. // PeerConf generates a WireGuard configuration file for a given peer in a Topology.
func (t *Topology) PeerConf(name string) *wireguard.Conf { func (t *Topology) PeerConf(name string) *wireguard.Conf {
var p *Peer
for i := range t.peers {
if t.peers[i].Name == name {
p = t.peers[i]
break
}
}
if p == nil {
return nil
}
c := &wireguard.Conf{} c := &wireguard.Conf{}
for _, s := range t.segments { for _, s := range t.segments {
peer := &wireguard.Peer{ peer := &wireguard.Peer{
AllowedIPs: s.allowedIPs, AllowedIPs: s.allowedIPs,
Endpoint: s.endpoint, Endpoint: s.endpoint,
PersistentKeepalive: s.persistentKeepalive, PersistentKeepalive: p.PersistentKeepalive,
PublicKey: s.key, PublicKey: s.key,
} }
c.Peers = append(c.Peers, peer) c.Peers = append(c.Peers, peer)
} }
for _, p := range t.peers { for i := range t.peers {
if p.Name == name { if t.peers[i].Name == name {
continue continue
} }
peer := &wireguard.Peer{ peer := &wireguard.Peer{
AllowedIPs: p.AllowedIPs, AllowedIPs: t.peers[i].AllowedIPs,
PersistentKeepalive: p.PersistentKeepalive, PersistentKeepalive: p.PersistentKeepalive,
PublicKey: p.PublicKey, PublicKey: t.peers[i].PublicKey,
Endpoint: p.Endpoint, Endpoint: t.peers[i].Endpoint,
} }
c.Peers = append(c.Peers, peer) c.Peers = append(c.Peers, peer)
} }

View File

@ -118,15 +118,14 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)}, wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
segments: []*segment{ segments: []*segment{
{ {
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint, endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key, key: nodes["a"].Key,
location: nodes["a"].Location, location: nodes["a"].Location,
cidrs: []*net.IPNet{nodes["a"].Subnet}, cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"}, hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP}, privateIPs: []net.IP{nodes["a"].InternalIP.IP},
persistentKeepalive: nodes["a"].PersistentKeepalive, wireGuardIP: w1,
wireGuardIP: w1,
}, },
{ {
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
@ -155,15 +154,14 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)}, wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
segments: []*segment{ segments: []*segment{
{ {
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint, endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key, key: nodes["a"].Key,
location: nodes["a"].Location, location: nodes["a"].Location,
cidrs: []*net.IPNet{nodes["a"].Subnet}, cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"}, hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP}, privateIPs: []net.IP{nodes["a"].InternalIP.IP},
persistentKeepalive: nodes["a"].PersistentKeepalive, wireGuardIP: w1,
wireGuardIP: w1,
}, },
{ {
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
@ -192,15 +190,14 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: nil, wireGuardCIDR: nil,
segments: []*segment{ segments: []*segment{
{ {
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint, endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key, key: nodes["a"].Key,
location: nodes["a"].Location, location: nodes["a"].Location,
cidrs: []*net.IPNet{nodes["a"].Subnet}, cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"}, hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP}, privateIPs: []net.IP{nodes["a"].InternalIP.IP},
persistentKeepalive: nodes["a"].PersistentKeepalive, wireGuardIP: w1,
wireGuardIP: w1,
}, },
{ {
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
@ -229,15 +226,14 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)}, wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
segments: []*segment{ segments: []*segment{
{ {
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint, endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key, key: nodes["a"].Key,
location: nodes["a"].Name, location: nodes["a"].Name,
cidrs: []*net.IPNet{nodes["a"].Subnet}, cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"}, hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP}, privateIPs: []net.IP{nodes["a"].InternalIP.IP},
persistentKeepalive: nodes["a"].PersistentKeepalive, wireGuardIP: w1,
wireGuardIP: w1,
}, },
{ {
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
@ -276,15 +272,14 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)}, wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
segments: []*segment{ segments: []*segment{
{ {
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint, endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key, key: nodes["a"].Key,
location: nodes["a"].Name, location: nodes["a"].Name,
cidrs: []*net.IPNet{nodes["a"].Subnet}, cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"}, hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP}, privateIPs: []net.IP{nodes["a"].InternalIP.IP},
persistentKeepalive: nodes["a"].PersistentKeepalive, wireGuardIP: w1,
wireGuardIP: w1,
}, },
{ {
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
@ -323,15 +318,14 @@ func TestNewTopology(t *testing.T) {
wireGuardCIDR: &net.IPNet{IP: w3, Mask: net.CIDRMask(16, 32)}, wireGuardCIDR: &net.IPNet{IP: w3, Mask: net.CIDRMask(16, 32)},
segments: []*segment{ segments: []*segment{
{ {
allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
endpoint: nodes["a"].Endpoint, endpoint: nodes["a"].Endpoint,
key: nodes["a"].Key, key: nodes["a"].Key,
location: nodes["a"].Name, location: nodes["a"].Name,
cidrs: []*net.IPNet{nodes["a"].Subnet}, cidrs: []*net.IPNet{nodes["a"].Subnet},
hostnames: []string{"a"}, hostnames: []string{"a"},
privateIPs: []net.IP{nodes["a"].InternalIP.IP}, privateIPs: []net.IP{nodes["a"].InternalIP.IP},
persistentKeepalive: nodes["a"].PersistentKeepalive, wireGuardIP: w1,
wireGuardIP: w1,
}, },
{ {
allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}}, allowedIPs: []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
@ -360,7 +354,7 @@ func TestNewTopology(t *testing.T) {
} { } {
tc.result.key = key tc.result.key = key
tc.result.port = port tc.result.port = port
topo, err := NewTopology(nodes, peers, tc.granularity, tc.hostname, port, key, DefaultKiloSubnet) topo, err := NewTopology(nodes, peers, tc.granularity, tc.hostname, port, key, DefaultKiloSubnet, 0)
if err != nil { if err != nil {
t.Errorf("test case %q: failed to generate Topology: %v", tc.name, err) t.Errorf("test case %q: failed to generate Topology: %v", tc.name, err)
} }
@ -370,8 +364,8 @@ func TestNewTopology(t *testing.T) {
} }
} }
func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet) *Topology { func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int) *Topology {
topo, err := NewTopology(nodes, peers, granularity, hostname, port, key, subnet) topo, err := NewTopology(nodes, peers, granularity, hostname, port, key, subnet, persistentKeepalive)
if err != nil { if err != nil {
t.Errorf("failed to generate Topology: %v", err) t.Errorf("failed to generate Topology: %v", err)
} }
@ -384,7 +378,7 @@ func TestRoutes(t *testing.T) {
privIface := 1 privIface := 1
tunlIface := 2 tunlIface := 2
mustTopoForGranularityAndHost := func(granularity Granularity, hostname string) *Topology { mustTopoForGranularityAndHost := func(granularity Granularity, hostname string) *Topology {
return mustTopo(t, nodes, peers, granularity, hostname, port, key, DefaultKiloSubnet) return mustTopo(t, nodes, peers, granularity, hostname, port, key, DefaultKiloSubnet, 0)
} }
for _, tc := range []struct { for _, tc := range []struct {
@ -1213,7 +1207,7 @@ func TestConf(t *testing.T) {
}{ }{
{ {
name: "logical from a", name: "logical from a",
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet), topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive),
result: `[Interface] result: `[Interface]
PrivateKey = private PrivateKey = private
ListenPort = 51820 ListenPort = 51820
@ -1222,22 +1216,23 @@ ListenPort = 51820
PublicKey = key2 PublicKey = key2
Endpoint = 10.1.0.2:51820 Endpoint = 10.1.0.2:51820
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32 AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
PersistentKeepalive = 25
[Peer] [Peer]
PublicKey = key4 PublicKey = key4
PersistentKeepalive = 0
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
PersistentKeepalive = 25
[Peer] [Peer]
PublicKey = key5 PublicKey = key5
Endpoint = 192.168.0.1:51820 Endpoint = 192.168.0.1:51820
PersistentKeepalive = 0
AllowedIPs = 10.5.0.3/24 AllowedIPs = 10.5.0.3/24
PersistentKeepalive = 25
`, `,
}, },
{ {
name: "logical from b", name: "logical from b",
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet), topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet, nodes["b"].PersistentKeepalive),
result: `[Interface] result: `[Interface]
PrivateKey = private PrivateKey = private
ListenPort = 51820 ListenPort = 51820
@ -1245,24 +1240,21 @@ AllowedIPs = 10.5.0.3/24
[Peer] [Peer]
PublicKey = key1 PublicKey = key1
Endpoint = 10.1.0.1:51820 Endpoint = 10.1.0.1:51820
PersistentKeepalive = 25
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32 AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer] [Peer]
PublicKey = key4 PublicKey = key4
PersistentKeepalive = 0
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer] [Peer]
PublicKey = key5 PublicKey = key5
Endpoint = 192.168.0.1:51820 Endpoint = 192.168.0.1:51820
PersistentKeepalive = 0
AllowedIPs = 10.5.0.3/24 AllowedIPs = 10.5.0.3/24
`, `,
}, },
{ {
name: "logical from c", name: "logical from c",
topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet), topology: mustTopo(t, nodes, peers, LogicalGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet, nodes["c"].PersistentKeepalive),
result: `[Interface] result: `[Interface]
PrivateKey = private PrivateKey = private
ListenPort = 51820 ListenPort = 51820
@ -1270,24 +1262,21 @@ AllowedIPs = 10.5.0.3/24
[Peer] [Peer]
PublicKey = key1 PublicKey = key1
Endpoint = 10.1.0.1:51820 Endpoint = 10.1.0.1:51820
PersistentKeepalive = 25
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32 AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer] [Peer]
PublicKey = key4 PublicKey = key4
PersistentKeepalive = 0
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer] [Peer]
PublicKey = key5 PublicKey = key5
Endpoint = 192.168.0.1:51820 Endpoint = 192.168.0.1:51820
PersistentKeepalive = 0
AllowedIPs = 10.5.0.3/24 AllowedIPs = 10.5.0.3/24
`, `,
}, },
{ {
name: "full from a", name: "full from a",
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet), topology: mustTopo(t, nodes, peers, FullGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive),
result: `[Interface] result: `[Interface]
PrivateKey = private PrivateKey = private
ListenPort = 51820 ListenPort = 51820
@ -1296,27 +1285,29 @@ AllowedIPs = 10.5.0.3/24
PublicKey = key2 PublicKey = key2
Endpoint = 10.1.0.2:51820 Endpoint = 10.1.0.2:51820
AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32 AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32
PersistentKeepalive = 25
[Peer] [Peer]
PublicKey = key3 PublicKey = key3
Endpoint = 10.1.0.3:51820 Endpoint = 10.1.0.3:51820
AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32 AllowedIPs = 10.2.3.0/24, 192.168.0.2/32, 10.4.0.3/32
PersistentKeepalive = 25
[Peer] [Peer]
PublicKey = key4 PublicKey = key4
PersistentKeepalive = 0
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
PersistentKeepalive = 25
[Peer] [Peer]
PublicKey = key5 PublicKey = key5
Endpoint = 192.168.0.1:51820 Endpoint = 192.168.0.1:51820
PersistentKeepalive = 0
AllowedIPs = 10.5.0.3/24 AllowedIPs = 10.5.0.3/24
PersistentKeepalive = 25
`, `,
}, },
{ {
name: "full from b", name: "full from b",
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet), topology: mustTopo(t, nodes, peers, FullGranularity, nodes["b"].Name, port, key, DefaultKiloSubnet, nodes["b"].PersistentKeepalive),
result: `[Interface] result: `[Interface]
PrivateKey = private PrivateKey = private
ListenPort = 51820 ListenPort = 51820
@ -1324,7 +1315,6 @@ AllowedIPs = 10.5.0.3/24
[Peer] [Peer]
PublicKey = key1 PublicKey = key1
Endpoint = 10.1.0.1:51820 Endpoint = 10.1.0.1:51820
PersistentKeepalive = 25
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32 AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer] [Peer]
@ -1334,19 +1324,17 @@ AllowedIPs = 10.5.0.3/24
[Peer] [Peer]
PublicKey = key4 PublicKey = key4
PersistentKeepalive = 0
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer] [Peer]
PublicKey = key5 PublicKey = key5
Endpoint = 192.168.0.1:51820 Endpoint = 192.168.0.1:51820
PersistentKeepalive = 0
AllowedIPs = 10.5.0.3/24 AllowedIPs = 10.5.0.3/24
`, `,
}, },
{ {
name: "full from c", name: "full from c",
topology: mustTopo(t, nodes, peers, FullGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet), topology: mustTopo(t, nodes, peers, FullGranularity, nodes["c"].Name, port, key, DefaultKiloSubnet, nodes["c"].PersistentKeepalive),
result: `[Interface] result: `[Interface]
PrivateKey = private PrivateKey = private
ListenPort = 51820 ListenPort = 51820
@ -1354,7 +1342,6 @@ AllowedIPs = 10.5.0.3/24
[Peer] [Peer]
PublicKey = key1 PublicKey = key1
Endpoint = 10.1.0.1:51820 Endpoint = 10.1.0.1:51820
PersistentKeepalive = 25
AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32 AllowedIPs = 10.2.1.0/24, 192.168.0.1/32, 10.4.0.1/32
[Peer] [Peer]
@ -1364,13 +1351,11 @@ AllowedIPs = 10.5.0.3/24
[Peer] [Peer]
PublicKey = key4 PublicKey = key4
PersistentKeepalive = 0
AllowedIPs = 10.5.0.1/24, 10.5.0.2/24 AllowedIPs = 10.5.0.1/24, 10.5.0.2/24
[Peer] [Peer]
PublicKey = key5 PublicKey = key5
Endpoint = 192.168.0.1:51820 Endpoint = 192.168.0.1:51820
PersistentKeepalive = 0
AllowedIPs = 10.5.0.3/24 AllowedIPs = 10.5.0.3/24
`, `,
}, },

View File

@ -275,12 +275,6 @@ func (c *Conf) Bytes() ([]byte, error) {
// Equal checks if two WireGuard configurations are equivalent. // Equal checks if two WireGuard configurations are equivalent.
func (c *Conf) Equal(b *Conf) bool { func (c *Conf) Equal(b *Conf) bool {
return c.EqualWithPeerCheck(b, strictPeerCheck)
}
// EqualWithPeerCheck checks if two WireGuard configurations are equivalent
// when their peers are compared using the given peer comparison func.
func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool {
if (c.Interface == nil) != (b.Interface == nil) { if (c.Interface == nil) != (b.Interface == nil) {
return false return false
} }
@ -294,47 +288,38 @@ func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool {
} }
sortPeers(c.Peers) sortPeers(c.Peers)
sortPeers(b.Peers) sortPeers(b.Peers)
var ok bool
for i := range c.Peers { for i := range c.Peers {
if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) { if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) {
return false return false
} }
sortCIDRs(c.Peers[i].AllowedIPs) sortCIDRs(c.Peers[i].AllowedIPs)
sortCIDRs(b.Peers[i].AllowedIPs) sortCIDRs(b.Peers[i].AllowedIPs)
if ok = pc(c.Peers[i], b.Peers[i]); !ok { for j := range c.Peers[i].AllowedIPs {
if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() {
return false
}
}
if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) {
return false
}
if c.Peers[i].Endpoint != nil {
if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port {
return false
}
// IPs take priority, so check them first.
if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS {
return false
}
}
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
return false return false
} }
} }
return true return true
}
// PeerCheck is a function that compares two peers.
type PeerCheck func(a, b *Peer) bool
func strictPeerCheck(a, b *Peer) bool {
for j := range a.AllowedIPs {
if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() {
return false
}
}
if (a.Endpoint == nil) != (b.Endpoint == nil) {
return false
}
if a.Endpoint != nil {
if a.Endpoint.Port != b.Endpoint.Port {
return false
}
// IPs take priority, so check them first.
if !a.Endpoint.IP.Equal(b.Endpoint.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS {
return false
}
}
return a.PersistentKeepalive == b.PersistentKeepalive && bytes.Equal(a.PublicKey, b.PublicKey)
} }
func sortPeers(peers []*Peer) { func sortPeers(peers []*Peer) {