pkg/*: use wireguard.Enpoint

This commit introduces the wireguard.Enpoint struct.
It encapsulates a DN name with port and a net.UPDAddr.
The fields are private and only accessible over exported Methods
to avoid accidental modification.

Also iptables.GetProtocol is improved to avoid ipv4 rules being applied
by `ip6tables`.

Signed-off-by: leonnicolas <leonloechner@gmx.de>
This commit is contained in:
leonnicolas
2021-09-29 22:30:32 +02:00
parent b370ed3511
commit 08eea4f3c1
17 changed files with 287 additions and 744 deletions

View File

@@ -56,8 +56,7 @@ const (
// Node represents a node in the network.
type Node struct {
Endpoint *net.UDPAddr
Addr string // eg. dnsname:port
Endpoint *wireguard.Endpoint
Key wgtypes.Key
NoInternalIP bool
InternalIP *net.IPNet
@@ -82,7 +81,7 @@ type Node struct {
func (n *Node) Ready() bool {
// Nodes that are not leaders will not have WireGuardIPs, so it is not required.
return n != nil &&
(n.Endpoint != nil || n.Addr != "") &&
n.Endpoint.Ready() &&
n.Key != wgtypes.Key{} &&
n.Subnet != nil &&
time.Now().Unix()-n.LastSeen < int64(checkInPeriod)*2/int64(time.Second)

View File

@@ -1,4 +1,4 @@
// Copyright 2019 the Kilo authors
// Copyright 2021 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -20,6 +20,8 @@ import (
"strings"
"github.com/awalterschulze/gographviz"
"github.com/squat/kilo/pkg/wireguard"
)
// Dot generates a Graphviz graph of the Topology in DOT fomat.
@@ -61,7 +63,7 @@ func (t *Topology) Dot() (string, error) {
return "", fmt.Errorf("failed to add node to subgraph")
}
var wg net.IP
var endpoint *net.UDPAddr
var endpoint *wireguard.Endpoint
if j == s.leader {
wg = s.wireGuardIP
endpoint = s.endpoint
@@ -73,7 +75,7 @@ func (t *Topology) Dot() (string, error) {
if s.privateIPs != nil {
priv = s.privateIPs[j]
}
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], priv, wg, endpoint, s.addr)); err != nil {
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], priv, wg, endpoint)); err != nil {
return "", fmt.Errorf("failed to add label to node")
}
}
@@ -153,7 +155,7 @@ func subGraphName(name string) string {
return graphEscape(fmt.Sprintf("cluster_location_%s", name))
}
func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoint *net.UDPAddr, addr string) string {
func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoint *wireguard.Endpoint) string {
label := []string{
location,
name,
@@ -165,12 +167,7 @@ func nodeLabel(location, name string, cidr *net.IPNet, priv, wgIP net.IP, endpoi
if wgIP != nil {
label = append(label, wgIP.String())
}
var str string
if addr != "" {
str = addr
} else if endpoint != nil {
str = endpoint.String()
}
str := endpoint.String()
if str != "" {
label = append(label, str)
}

View File

@@ -370,8 +370,10 @@ func (m *Mesh) checkIn() {
func (m *Mesh) handleLocal(n *Node) {
// Allow the IPs to be overridden.
if n.Endpoint == nil || n.Addr == "" {
n.Endpoint = &net.UDPAddr{IP: m.externalIP.IP, Port: m.port}
if !n.Endpoint.Ready() {
e := wireguard.NewEndpoint(m.externalIP.IP, m.port)
level.Info(m.logger).Log("msg", "overriding endpoint", "node", m.hostname, "old endpoint", n.Endpoint.String(), "new endpoint", e.String())
n.Endpoint = e
}
if n.InternalIP == nil && !n.NoInternalIP {
n.InternalIP = m.internalIP
@@ -484,7 +486,7 @@ func (m *Mesh) applyTopology() {
natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger)
nodes[m.hostname].DiscoveredEndpoints = natEndpoints
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port(), m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
if err != nil {
level.Error(m.logger).Log("error", err)
m.errorCounter.WithLabelValues("apply").Inc()
@@ -623,15 +625,8 @@ func (m *Mesh) resolveEndpoints() error {
if !m.nodes[k].Ready() {
continue
}
// If the node is ready, then the endpoint is not nil
// but it may not have a DNS name.
if m.nodes[k].Addr == "" {
continue
}
if u, err := net.ResolveUDPAddr("udp", m.nodes[k].Addr); err == nil {
m.nodes[k].Endpoint = u
m.nodes[k].Endpoint.IP = u.IP
} else {
// Resolve the Endpoint
if _, err := m.nodes[k].Endpoint.UDPAddr(true); err != nil {
return err
}
}
@@ -642,12 +637,10 @@ func (m *Mesh) resolveEndpoints() error {
continue
}
// Peers may have nil endpoints.
if m.peers[k].Addr == "" {
if !m.peers[k].Endpoint.Ready() {
continue
}
if u, err := net.ResolveUDPAddr("udp", m.peers[k].Addr); err == nil {
m.peers[k].Endpoint = u
} else {
if _, err := m.peers[k].Endpoint.UDPAddr(true); err != nil {
return err
}
}
@@ -667,7 +660,7 @@ func nodesAreEqual(a, b *Node) bool {
}
// Check the DNS name first since this package
// is doing the DNS resolution.
if a.Addr != b.Addr || a.Endpoint.String() != b.Endpoint.String() {
if a.Endpoint.StringOpt(false) != b.Endpoint.StringOpt(false) {
return false
}
// Ignore LastSeen when comparing equality we want to check if the nodes are
@@ -696,7 +689,7 @@ func peersAreEqual(a, b *Peer) bool {
}
// Check the DNS name first since this package
// is doing the DNS resolution.
if a.Addr != b.Addr || a.Endpoint.String() != b.Endpoint.String() {
if a.Endpoint.StringOpt(false) != b.Endpoint.StringOpt(false) {
return false
}
if len(a.AllowedIPs) != len(b.AllowedIPs) {

View File

@@ -20,6 +20,8 @@ import (
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/squat/kilo/pkg/wireguard"
)
func mustKey() wgtypes.Key {
@@ -62,7 +64,7 @@ func TestReady(t *testing.T) {
{
name: "empty endpoint IP",
node: &Node{
Endpoint: &net.UDPAddr{Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(nil, DefaultKiloPort),
InternalIP: internalIP,
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
@@ -72,7 +74,7 @@ func TestReady(t *testing.T) {
{
name: "empty endpoint port",
node: &Node{
Endpoint: &net.UDPAddr{IP: externalIP.IP},
Endpoint: wireguard.NewEndpoint(externalIP.IP, 0),
InternalIP: internalIP,
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
@@ -82,7 +84,7 @@ func TestReady(t *testing.T) {
{
name: "empty internal IP",
node: &Node{
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
Key: wgtypes.Key{},
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
@@ -91,7 +93,7 @@ func TestReady(t *testing.T) {
{
name: "empty key",
node: &Node{
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
InternalIP: internalIP,
Subnet: &net.IPNet{IP: net.ParseIP("10.2.0.0"), Mask: net.CIDRMask(16, 32)},
},
@@ -100,7 +102,7 @@ func TestReady(t *testing.T) {
{
name: "empty subnet",
node: &Node{
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
InternalIP: internalIP,
Key: wgtypes.Key{},
},
@@ -109,7 +111,7 @@ func TestReady(t *testing.T) {
{
name: "valid",
node: &Node{
Endpoint: &net.UDPAddr{IP: externalIP.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(externalIP.IP, DefaultKiloPort),
InternalIP: internalIP,
Key: key,
LastSeen: time.Now().Unix(),

View File

@@ -40,7 +40,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
var gw net.IP
for _, segment := range t.segments {
if segment.location == t.location {
gw = enc.Gw(segment.endpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
gw = enc.Gw(segment.endpoint.IP(), segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
break
}
}
@@ -196,7 +196,7 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
// equals the external IP. This means that the node
// is only accessible through an external IP and we
// cannot encapsulate traffic to an IP through the IP.
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP) {
if segment.privateIPs == nil || segment.privateIPs[i].Equal(segment.endpoint.IP()) {
continue
}
// Add routes to the private IPs of nodes in other segments.
@@ -248,7 +248,7 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
if cni {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
// Some linux distros or docker will set forward DROP in the filter table.
// To still be able to have pod to pod communication we need to ALLOW packets from and to pod CIDRs within a location.
// Leader nodes will forward packets from all nodes within a location because they act as a gateway for them.
@@ -258,30 +258,30 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
if s.location == t.location {
// Make sure packets to and from pod cidrs are not dropped in the forward chain.
for _, c := range s.cidrs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
}
// Make sure packets to and from allowed location IPs are not dropped in the forward chain.
for _, c := range s.allowedLocationIPs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
}
// Make sure packets to and from private IPs are not dropped in the forward chain.
for _, c := range s.privateIPs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(c)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
}
}
}
} else if iptablesForwardRule {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
}
}
for _, s := range t.segments {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(s.wireGuardIP)), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
for _, aip := range s.allowedIPs {
rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
rules = append(rules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
}
// Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd,
// Otherwise packets to these destinations will reach the destination, but never find their way back.
@@ -289,7 +289,7 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
if t.location == s.location {
for _, alip := range s.allowedLocationIPs {
rules = append(rules,
iptables.NewRule(iptables.GetProtocol(len(alip.IP)), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
)
}
}
@@ -297,8 +297,8 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
for _, p := range t.peers {
for _, aip := range p.AllowedIPs {
rules = append(rules,
iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
)
}
}

View File

@@ -67,8 +67,7 @@ type Topology struct {
type segment struct {
allowedIPs []net.IPNet
addr string
endpoint *net.UDPAddr
endpoint *wireguard.Endpoint
key wgtypes.Key
persistentKeepalive time.Duration
// Location is the logical location of this segment.
@@ -178,7 +177,6 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
})
t.segments = append(t.segments, &segment{
allowedIPs: allowedIPs,
addr: topoMap[location][leader].Addr,
endpoint: topoMap[location][leader].Endpoint,
key: topoMap[location][leader].Key,
persistentKeepalive: topoMap[location][leader].PersistentKeepalive,
@@ -287,14 +285,14 @@ CheckIPs:
return
}
func (t *Topology) updateEndpoint(endpoint *net.UDPAddr, key wgtypes.Key, persistentKeepalive *time.Duration) *net.UDPAddr {
func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *wireguard.Endpoint {
// Do not update non-nat peers
if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) {
return endpoint
}
e, ok := t.discoveredEndpoints[key.String()]
if ok {
return e
return wireguard.NewEndpointFromUDPAddr(e)
}
return nil
}
@@ -315,12 +313,11 @@ func (t *Topology) Conf() *wireguard.Conf {
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive),
PersistentKeepaliveInterval: &t.persistentKeepalive,
PublicKey: s.key,
ReplaceAllowedIPs: true,
},
Addr: s.addr,
Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive),
}
c.Peers = append(c.Peers, peer)
}
@@ -328,13 +325,12 @@ func (t *Topology) Conf() *wireguard.Conf {
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: p.AllowedIPs,
Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval),
PersistentKeepaliveInterval: &t.persistentKeepalive,
PresharedKey: p.PresharedKey,
PublicKey: p.PublicKey,
ReplaceAllowedIPs: true,
},
Addr: p.Addr,
Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval),
}
c.Peers = append(c.Peers, peer)
}
@@ -352,9 +348,8 @@ func (t *Topology) AsPeer() wireguard.Peer {
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: s.allowedIPs,
PublicKey: s.key,
Endpoint: s.endpoint,
},
Addr: s.addr,
Endpoint: s.endpoint,
}
return p
}
@@ -377,12 +372,11 @@ func (t *Topology) PeerConf(name string) wireguard.Conf {
peer := wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
AllowedIPs: s.allowedIPs,
Endpoint: s.endpoint,
PersistentKeepaliveInterval: pka,
PresharedKey: psk,
PublicKey: s.key,
},
Addr: s.addr,
Endpoint: s.endpoint,
}
c.Peers = append(c.Peers, peer)
}
@@ -395,8 +389,8 @@ func (t *Topology) PeerConf(name string) wireguard.Conf {
AllowedIPs: t.peers[i].AllowedIPs,
PersistentKeepaliveInterval: pka,
PublicKey: t.peers[i].PublicKey,
Endpoint: t.peers[i].Endpoint,
},
Endpoint: t.peers[i].Endpoint,
}
c.Peers = append(c.Peers, peer)
}
@@ -417,13 +411,13 @@ func findLeader(nodes []*Node) int {
var leaders, public []int
for i := range nodes {
if nodes[i].Leader {
if isPublic(nodes[i].Endpoint.IP) {
if isPublic(nodes[i].Endpoint.IP()) {
return i
}
leaders = append(leaders, i)
}
if nodes[i].Endpoint != nil && isPublic(nodes[i].Endpoint.IP) {
if nodes[i].Endpoint != nil && isPublic(nodes[i].Endpoint.IP()) {
public = append(public, i)
}
}
@@ -444,12 +438,11 @@ func deduplicatePeerIPs(peers []*Peer) []*Peer {
Name: peer.Name,
Peer: wireguard.Peer{
PeerConfig: wgtypes.PeerConfig{
Endpoint: peer.Endpoint,
PersistentKeepaliveInterval: peer.PersistentKeepaliveInterval,
PresharedKey: peer.PresharedKey,
PublicKey: peer.PublicKey,
},
Addr: peer.Addr,
Endpoint: peer.Endpoint,
},
}
for _, ip := range peer.AllowedIPs {

View File

@@ -60,7 +60,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
nodes := map[string]*Node{
"a": {
Name: "a",
Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
InternalIP: i1,
Location: "1",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.1.0"), Mask: net.CIDRMask(24, 32)},
@@ -69,7 +69,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
},
"b": {
Name: "b",
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
InternalIP: i1,
Location: "2",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)},
@@ -78,7 +78,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
},
"c": {
Name: "c",
Endpoint: &net.UDPAddr{IP: e3.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e3.IP, DefaultKiloPort),
InternalIP: i2,
// Same location as node b.
Location: "2",
@@ -87,7 +87,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
},
"d": {
Name: "d",
Endpoint: &net.UDPAddr{IP: e4.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e4.IP, DefaultKiloPort),
// Same location as node a, but without private IP
Location: "1",
Subnet: &net.IPNet{IP: net.ParseIP("10.2.4.0"), Mask: net.CIDRMask(24, 32)},
@@ -115,11 +115,8 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, wgtypes.Key, int)
{IP: net.ParseIP("10.5.0.3"), Mask: net.CIDRMask(24, 32)},
},
PublicKey: key5,
Endpoint: &net.UDPAddr{
IP: net.ParseIP("192.168.0.1"),
Port: DefaultKiloPort,
},
},
Endpoint: wireguard.NewEndpoint(net.ParseIP("192.168.0.1"), DefaultKiloPort),
},
},
}
@@ -576,24 +573,24 @@ func TestFindLeader(t *testing.T) {
nodes := []*Node{
{
Name: "a",
Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
},
{
Name: "b",
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
},
{
Name: "c",
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
},
{
Name: "d",
Endpoint: &net.UDPAddr{IP: e1.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e1.IP, DefaultKiloPort),
Leader: true,
},
{
Name: "2",
Endpoint: &net.UDPAddr{IP: e2.IP, Port: DefaultKiloPort},
Endpoint: wireguard.NewEndpoint(e2.IP, DefaultKiloPort),
Leader: true,
},
}