*: add peer VPN support
This commit adds support for defining arbitrary peers that should have access to the VPN. In k8s, this is accomplished using the new Peer CRD.
This commit is contained in:
@@ -37,22 +37,22 @@ func (t *Topology) Dot() (string, error) {
|
||||
if err := g.SetDir(true); err != nil {
|
||||
return "", fmt.Errorf("failed to set direction")
|
||||
}
|
||||
leaders := make([]string, len(t.Segments))
|
||||
leaders := make([]string, len(t.segments))
|
||||
nodeAttrs := map[string]string{
|
||||
string(gographviz.Shape): "ellipse",
|
||||
}
|
||||
for i, s := range t.Segments {
|
||||
if err := g.AddSubGraph("kilo", subGraphName(s.Location), nil); err != nil {
|
||||
for i, s := range t.segments {
|
||||
if err := g.AddSubGraph("kilo", subGraphName(s.location), nil); err != nil {
|
||||
return "", fmt.Errorf("failed to add subgraph")
|
||||
}
|
||||
if err := g.AddAttr(subGraphName(s.Location), string(gographviz.Label), graphEscape(s.Location)); err != nil {
|
||||
if err := g.AddAttr(subGraphName(s.location), string(gographviz.Label), graphEscape(s.location)); err != nil {
|
||||
return "", fmt.Errorf("failed to add label to subgraph")
|
||||
}
|
||||
if err := g.AddAttr(subGraphName(s.Location), string(gographviz.Style), `"dashed,rounded"`); err != nil {
|
||||
if err := g.AddAttr(subGraphName(s.location), string(gographviz.Style), `"dashed,rounded"`); err != nil {
|
||||
return "", fmt.Errorf("failed to add style to subgraph")
|
||||
}
|
||||
for j := range s.cidrs {
|
||||
if err := g.AddNode(subGraphName(s.Location), graphEscape(s.hostnames[j]), nodeAttrs); err != nil {
|
||||
if err := g.AddNode(subGraphName(s.location), graphEscape(s.hostnames[j]), nodeAttrs); err != nil {
|
||||
return "", fmt.Errorf("failed to add node to subgraph")
|
||||
}
|
||||
var wg net.IP
|
||||
@@ -62,11 +62,11 @@ func (t *Topology) Dot() (string, error) {
|
||||
return "", fmt.Errorf("failed to add rank to node")
|
||||
}
|
||||
}
|
||||
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.Location, s.hostnames[j], s.cidrs[j], s.privateIPs[j], wg)); err != nil {
|
||||
if err := g.Nodes.Lookup[graphEscape(s.hostnames[j])].Attrs.Add(string(gographviz.Label), nodeLabel(s.location, s.hostnames[j], s.cidrs[j], s.privateIPs[j], wg)); err != nil {
|
||||
return "", fmt.Errorf("failed to add label to node")
|
||||
}
|
||||
}
|
||||
meshSubGraph(g, g.Relations.SortedChildren(subGraphName(s.Location)), s.leader)
|
||||
meshSubGraph(g, g.Relations.SortedChildren(subGraphName(s.location)), s.leader)
|
||||
leaders[i] = graphEscape(s.hostnames[s.leader])
|
||||
}
|
||||
meshSubGraph(g, leaders, 0)
|
||||
|
263
pkg/mesh/mesh.go
263
pkg/mesh/mesh.go
@@ -15,6 +15,7 @@
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -43,6 +44,8 @@ const (
|
||||
PrivateKeyPath = KiloPath + "/key"
|
||||
// ConfPath is the filepath where the WireGuard configuration is stored.
|
||||
ConfPath = KiloPath + "/conf"
|
||||
// DefaultKiloPort is the default UDP port Kilo uses.
|
||||
DefaultKiloPort = 51820
|
||||
)
|
||||
|
||||
// Granularity represents the abstraction level at which the network
|
||||
@@ -93,6 +96,17 @@ func (n *Node) Ready() bool {
|
||||
return n != nil && n.ExternalIP != nil && n.Key != nil && n.InternalIP != nil && n.Subnet != nil && time.Now().Unix()-n.LastSeen < int64(resyncPeriod)*2/int64(time.Second)
|
||||
}
|
||||
|
||||
// Peer represents a peer in the network.
|
||||
type Peer struct {
|
||||
wireguard.Peer
|
||||
Name string
|
||||
}
|
||||
|
||||
// Ready indicates whether or not the peer is ready.
|
||||
func (p *Peer) Ready() bool {
|
||||
return p != nil && p.AllowedIPs != nil && len(p.AllowedIPs) != 0 && p.PublicKey != nil
|
||||
}
|
||||
|
||||
// EventType describes what kind of an action an event represents.
|
||||
type EventType string
|
||||
|
||||
@@ -105,24 +119,53 @@ const (
|
||||
UpdateEvent EventType = "update"
|
||||
)
|
||||
|
||||
// Event represents an update event concerning a node in the cluster.
|
||||
type Event struct {
|
||||
// NodeEvent represents an event concerning a node in the cluster.
|
||||
type NodeEvent struct {
|
||||
Type EventType
|
||||
Node *Node
|
||||
}
|
||||
|
||||
// Backend can get nodes by name, init itself,
|
||||
// PeerEvent represents an event concerning a peer in the cluster.
|
||||
type PeerEvent struct {
|
||||
Type EventType
|
||||
Peer *Peer
|
||||
}
|
||||
|
||||
// Backend can create clients for all of the
|
||||
// primitive types that Kilo deals with, namely:
|
||||
// * nodes; and
|
||||
// * peers.
|
||||
type Backend interface {
|
||||
Nodes() NodeBackend
|
||||
Peers() PeerBackend
|
||||
}
|
||||
|
||||
// NodeBackend can get nodes by name, init itself,
|
||||
// list the nodes that should be meshed,
|
||||
// set Kilo properties for a node,
|
||||
// clean up any changes applied to the backend,
|
||||
// and watch for changes to nodes.
|
||||
type Backend interface {
|
||||
type NodeBackend interface {
|
||||
CleanUp(string) error
|
||||
Get(string) (*Node, error)
|
||||
Init(<-chan struct{}) error
|
||||
List() ([]*Node, error)
|
||||
Set(string, *Node) error
|
||||
Watch() <-chan *Event
|
||||
Watch() <-chan *NodeEvent
|
||||
}
|
||||
|
||||
// PeerBackend can get peers by name, init itself,
|
||||
// list the peers that should be in the mesh,
|
||||
// set fields for a peer,
|
||||
// clean up any changes applied to the backend,
|
||||
// and watch for changes to peers.
|
||||
type PeerBackend interface {
|
||||
CleanUp(string) error
|
||||
Get(string) (*Peer, error)
|
||||
Init(<-chan struct{}) error
|
||||
List() ([]*Peer, error)
|
||||
Set(string, *Peer) error
|
||||
Watch() <-chan *PeerEvent
|
||||
}
|
||||
|
||||
// Mesh is able to create Kilo network meshes.
|
||||
@@ -138,7 +181,7 @@ type Mesh struct {
|
||||
kiloIface int
|
||||
key []byte
|
||||
local bool
|
||||
port int
|
||||
port uint32
|
||||
priv []byte
|
||||
privIface int
|
||||
pub []byte
|
||||
@@ -148,23 +191,26 @@ type Mesh struct {
|
||||
table *route.Table
|
||||
tunlIface int
|
||||
|
||||
// nodes is a mutable field in the struct
|
||||
// nodes and peers are mutable fields in the struct
|
||||
// and needs to be guarded.
|
||||
nodes map[string]*Node
|
||||
peers map[string]*Peer
|
||||
mu sync.Mutex
|
||||
|
||||
errorCounter *prometheus.CounterVec
|
||||
nodesGuage prometheus.Gauge
|
||||
peersGuage prometheus.Gauge
|
||||
reconcileCounter prometheus.Counter
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// New returns a new Mesh instance.
|
||||
func New(backend Backend, encapsulate Encapsulate, granularity Granularity, hostname string, port int, subnet *net.IPNet, local bool, logger log.Logger) (*Mesh, error) {
|
||||
func New(backend Backend, encapsulate Encapsulate, granularity Granularity, hostname string, port uint32, subnet *net.IPNet, local bool, logger log.Logger) (*Mesh, error) {
|
||||
if err := os.MkdirAll(KiloPath, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
|
||||
}
|
||||
private, err := ioutil.ReadFile(PrivateKeyPath)
|
||||
private = bytes.Trim(private, "\n")
|
||||
if err != nil {
|
||||
level.Warn(logger).Log("msg", "no private key found on disk; generating one now")
|
||||
if private, err = wireguard.GenKey(); err != nil {
|
||||
@@ -224,6 +270,7 @@ func New(backend Backend, encapsulate Encapsulate, granularity Granularity, host
|
||||
ipTables: ipTables,
|
||||
kiloIface: kiloIface,
|
||||
nodes: make(map[string]*Node),
|
||||
peers: make(map[string]*Peer),
|
||||
port: port,
|
||||
priv: private,
|
||||
privIface: privIface,
|
||||
@@ -240,7 +287,11 @@ func New(backend Backend, encapsulate Encapsulate, granularity Granularity, host
|
||||
}, []string{"event"}),
|
||||
nodesGuage: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "kilo_nodes",
|
||||
Help: "Number of in the mesh.",
|
||||
Help: "Number of nodes in the mesh.",
|
||||
}),
|
||||
peersGuage: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "kilo_peers",
|
||||
Help: "Number of peers in the mesh.",
|
||||
}),
|
||||
reconcileCounter: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "kilo_reconciles_total",
|
||||
@@ -252,8 +303,11 @@ func New(backend Backend, encapsulate Encapsulate, granularity Granularity, host
|
||||
|
||||
// Run starts the mesh.
|
||||
func (m *Mesh) Run() error {
|
||||
if err := m.Init(m.stop); err != nil {
|
||||
return fmt.Errorf("failed to initialize backend: %v", err)
|
||||
if err := m.Nodes().Init(m.stop); err != nil {
|
||||
return fmt.Errorf("failed to initialize node backend: %v", err)
|
||||
}
|
||||
if err := m.Peers().Init(m.stop); err != nil {
|
||||
return fmt.Errorf("failed to initialize peer backend: %v", err)
|
||||
}
|
||||
ipsetErrors, err := m.ipset.Run(m.stop)
|
||||
if err != nil {
|
||||
@@ -285,14 +339,19 @@ func (m *Mesh) Run() error {
|
||||
}()
|
||||
defer m.cleanUp()
|
||||
t := time.NewTimer(resyncPeriod)
|
||||
w := m.Watch()
|
||||
nw := m.Nodes().Watch()
|
||||
pw := m.Peers().Watch()
|
||||
var ne *NodeEvent
|
||||
var pe *PeerEvent
|
||||
for {
|
||||
var e *Event
|
||||
select {
|
||||
case e = <-w:
|
||||
m.sync(e)
|
||||
case ne = <-nw:
|
||||
m.syncNodes(ne)
|
||||
case pe = <-pw:
|
||||
m.syncPeers(pe)
|
||||
case <-t.C:
|
||||
m.checkIn()
|
||||
m.syncEndpoints()
|
||||
m.applyTopology()
|
||||
t.Reset(resyncPeriod)
|
||||
case <-m.stop:
|
||||
@@ -301,9 +360,50 @@ func (m *Mesh) Run() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mesh) sync(e *Event) {
|
||||
// WireGuard updates the endpoints of peers to match the
|
||||
// last place a valid packet was received from.
|
||||
// Periodically we need to syncronize the endpoints
|
||||
// of peers in the backend to match the WireGuard configuration.
|
||||
func (m *Mesh) syncEndpoints() {
|
||||
link, err := linkByIndex(m.kiloIface)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("endpoints").Inc()
|
||||
return
|
||||
}
|
||||
conf, err := wireguard.ShowConf(link.Attrs().Name)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("endpoints").Inc()
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
c := wireguard.Parse(conf)
|
||||
var key string
|
||||
var tmp *Peer
|
||||
for i := range c.Peers {
|
||||
// Peers are indexed by public key.
|
||||
key = string(c.Peers[i].PublicKey)
|
||||
if p, ok := m.peers[key]; ok {
|
||||
tmp = &Peer{
|
||||
Name: p.Name,
|
||||
Peer: *c.Peers[i],
|
||||
}
|
||||
if !peersAreEqual(tmp, p) {
|
||||
p.Endpoint = tmp.Endpoint
|
||||
if err := m.Peers().Set(p.Name, p); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("endpoints").Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mesh) syncNodes(e *NodeEvent) {
|
||||
logger := log.With(m.logger, "event", e.Type)
|
||||
level.Debug(logger).Log("msg", "syncing", "event", e.Type)
|
||||
level.Debug(logger).Log("msg", "syncing nodes", "event", e.Type)
|
||||
if isSelf(m.hostname, e.Node) {
|
||||
level.Debug(logger).Log("msg", "processing local node", "node", e.Node)
|
||||
m.handleLocal(e.Node)
|
||||
@@ -326,9 +426,11 @@ func (m *Mesh) sync(e *Event) {
|
||||
fallthrough
|
||||
case UpdateEvent:
|
||||
if !nodesAreEqual(m.nodes[e.Node.Name], e.Node) {
|
||||
m.nodes[e.Node.Name] = e.Node
|
||||
diff = true
|
||||
}
|
||||
// Even if the nodes are the same,
|
||||
// overwrite the old node to update the timestamp.
|
||||
m.nodes[e.Node.Name] = e.Node
|
||||
case DeleteEvent:
|
||||
delete(m.nodes, e.Node.Name)
|
||||
diff = true
|
||||
@@ -341,6 +443,43 @@ func (m *Mesh) sync(e *Event) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mesh) syncPeers(e *PeerEvent) {
|
||||
logger := log.With(m.logger, "event", e.Type)
|
||||
level.Debug(logger).Log("msg", "syncing peers", "event", e.Type)
|
||||
var diff bool
|
||||
m.mu.Lock()
|
||||
// Peers are indexed by public key.
|
||||
key := string(e.Peer.PublicKey)
|
||||
if !e.Peer.Ready() {
|
||||
level.Debug(logger).Log("msg", "received incomplete peer", "peer", e.Peer)
|
||||
// An existing peer is no longer valid
|
||||
// so remove it from the mesh.
|
||||
if _, ok := m.peers[key]; ok {
|
||||
level.Info(logger).Log("msg", "peer is no longer in the mesh", "peer", e.Peer)
|
||||
delete(m.peers, key)
|
||||
diff = true
|
||||
}
|
||||
} else {
|
||||
switch e.Type {
|
||||
case AddEvent:
|
||||
fallthrough
|
||||
case UpdateEvent:
|
||||
if !peersAreEqual(m.peers[key], e.Peer) {
|
||||
m.peers[key] = e.Peer
|
||||
diff = true
|
||||
}
|
||||
case DeleteEvent:
|
||||
delete(m.peers, key)
|
||||
diff = true
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if diff {
|
||||
level.Info(logger).Log("peer", e.Peer)
|
||||
m.applyTopology()
|
||||
}
|
||||
}
|
||||
|
||||
// checkIn will try to update the local node's LastSeen timestamp
|
||||
// in the backend.
|
||||
func (m *Mesh) checkIn() {
|
||||
@@ -352,7 +491,7 @@ func (m *Mesh) checkIn() {
|
||||
return
|
||||
}
|
||||
n.LastSeen = time.Now().Unix()
|
||||
if err := m.Set(m.hostname, n); err != nil {
|
||||
if err := m.Nodes().Set(m.hostname, n); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", n)
|
||||
m.errorCounter.WithLabelValues("checkin").Inc()
|
||||
return
|
||||
@@ -380,7 +519,7 @@ func (m *Mesh) handleLocal(n *Node) {
|
||||
}
|
||||
if !nodesAreEqual(n, local) {
|
||||
level.Debug(m.logger).Log("msg", "local node differs from backend")
|
||||
if err := m.Set(m.hostname, local); err != nil {
|
||||
if err := m.Nodes().Set(m.hostname, local); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", local)
|
||||
m.errorCounter.WithLabelValues("local").Inc()
|
||||
return
|
||||
@@ -406,31 +545,42 @@ func (m *Mesh) applyTopology() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Ensure all unready nodes are removed.
|
||||
var ready float64
|
||||
for n := range m.nodes {
|
||||
if !m.nodes[n].Ready() {
|
||||
delete(m.nodes, n)
|
||||
var readyNodes float64
|
||||
for k := range m.nodes {
|
||||
if !m.nodes[k].Ready() {
|
||||
delete(m.nodes, k)
|
||||
continue
|
||||
}
|
||||
ready++
|
||||
readyNodes++
|
||||
}
|
||||
m.nodesGuage.Set(ready)
|
||||
// Ensure all unready peers are removed.
|
||||
var readyPeers float64
|
||||
for k := range m.peers {
|
||||
if !m.peers[k].Ready() {
|
||||
delete(m.peers, k)
|
||||
continue
|
||||
}
|
||||
readyPeers++
|
||||
}
|
||||
m.nodesGuage.Set(readyNodes)
|
||||
m.peersGuage.Set(readyPeers)
|
||||
// We cannot do anything with the topology until the local node is available.
|
||||
if m.nodes[m.hostname] == nil {
|
||||
return
|
||||
}
|
||||
t, err := NewTopology(m.nodes, m.granularity, m.hostname, m.port, m.priv, m.subnet)
|
||||
t, err := NewTopology(m.nodes, m.peers, m.granularity, m.hostname, m.port, m.priv, m.subnet)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
conf, err := t.Conf()
|
||||
conf := t.Conf()
|
||||
buf, err := conf.Bytes()
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
}
|
||||
if err := ioutil.WriteFile(ConfPath, conf, 0600); err != nil {
|
||||
if err := ioutil.WriteFile(ConfPath, buf, 0600); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
@@ -443,6 +593,9 @@ func (m *Mesh) applyTopology() {
|
||||
}
|
||||
rules := iptables.MasqueradeRules(private, m.nodes[m.hostname].Subnet, t.RemoteSubnets())
|
||||
rules = append(rules, iptables.ForwardRules(m.subnet)...)
|
||||
for _, p := range m.peers {
|
||||
rules = append(rules, iptables.ForwardRules(p.AllowedIPs...)...)
|
||||
}
|
||||
if err := m.ipTables.Set(rules); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
@@ -450,8 +603,8 @@ func (m *Mesh) applyTopology() {
|
||||
}
|
||||
if m.encapsulate != NeverEncapsulate {
|
||||
var peers []net.IP
|
||||
for _, s := range t.Segments {
|
||||
if s.Location == m.nodes[m.hostname].Location {
|
||||
for _, s := range t.segments {
|
||||
if s.location == m.nodes[m.hostname].Location {
|
||||
peers = s.privateIPs
|
||||
break
|
||||
}
|
||||
@@ -461,6 +614,8 @@ func (m *Mesh) applyTopology() {
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
// If we are handling local routes, ensure the local
|
||||
// tunnel has an IP address.
|
||||
if m.local {
|
||||
if err := iproute.SetAddress(m.tunlIface, oneAddressCIDR(newAllocator(*m.nodes[m.hostname].Subnet).next().IP)); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
@@ -489,14 +644,9 @@ func (m *Mesh) applyTopology() {
|
||||
}
|
||||
// Setting the WireGuard configuration interrupts existing connections
|
||||
// so only set the configuration if it has changed.
|
||||
equal, err := wireguard.CompareConf(conf, oldConf)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
// Don't return here, simply overwrite the old configuration.
|
||||
equal = false
|
||||
}
|
||||
equal := conf.Equal(wireguard.Parse(oldConf))
|
||||
if !equal {
|
||||
level.Info(m.logger).Log("msg", "WireGuard configurations are different")
|
||||
if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
@@ -531,6 +681,7 @@ func (m *Mesh) RegisterMetrics(r prometheus.Registerer) {
|
||||
r.MustRegister(
|
||||
m.errorCounter,
|
||||
m.nodesGuage,
|
||||
m.peersGuage,
|
||||
m.reconcileCounter,
|
||||
)
|
||||
}
|
||||
@@ -558,11 +709,15 @@ func (m *Mesh) cleanUp() {
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := iproute.RemoveInterface(m.kiloIface); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove wireguard interface: %v", err))
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := m.CleanUp(m.hostname); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up backend: %v", err))
|
||||
if err := m.Nodes().CleanUp(m.hostname); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up node backend: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := m.Peers().CleanUp(m.hostname); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := m.ipset.CleanUp(); err != nil {
|
||||
@@ -586,6 +741,32 @@ func nodesAreEqual(a, b *Node) bool {
|
||||
return ipNetsEqual(a.ExternalIP, b.ExternalIP) && string(a.Key) == string(b.Key) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet)
|
||||
}
|
||||
|
||||
func peersAreEqual(a, b *Peer) bool {
|
||||
if !(a != nil) == (b != nil) {
|
||||
return false
|
||||
}
|
||||
if a == b {
|
||||
return true
|
||||
}
|
||||
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint != nil {
|
||||
if !a.Endpoint.IP.Equal(b.Endpoint.IP) || a.Endpoint.Port != b.Endpoint.Port {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(a.AllowedIPs) != len(b.AllowedIPs) {
|
||||
return false
|
||||
}
|
||||
for i := range a.AllowedIPs {
|
||||
if !ipNetsEqual(a.AllowedIPs[i], b.AllowedIPs[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive
|
||||
}
|
||||
|
||||
func ipNetsEqual(a, b *net.IPNet) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
|
@@ -15,41 +15,24 @@
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/squat/kilo/pkg/wireguard"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
confTemplate = template.Must(template.New("").Parse(`[Interface]
|
||||
PrivateKey = {{.Key}}
|
||||
ListenPort = {{.Port}}
|
||||
{{range .Segments -}}
|
||||
{{if ne .Location $.Location}}
|
||||
[Peer]
|
||||
PublicKey = {{.Key}}
|
||||
Endpoint = {{.Endpoint}}:{{$.Port}}
|
||||
AllowedIPs = {{.AllowedIPs}}
|
||||
{{end}}
|
||||
{{- end -}}
|
||||
`))
|
||||
)
|
||||
|
||||
// Topology represents the logical structure of the overlay network.
|
||||
type Topology struct {
|
||||
// Some fields need to be exported so that the template can read them.
|
||||
Key string
|
||||
Port int
|
||||
// key is the private key of the node creating the topology.
|
||||
key []byte
|
||||
port uint32
|
||||
// Location is the logical location of the local host.
|
||||
Location string
|
||||
Segments []*segment
|
||||
location string
|
||||
segments []*segment
|
||||
peers []*Peer
|
||||
|
||||
// hostname is the hostname of the local host.
|
||||
hostname string
|
||||
@@ -69,11 +52,11 @@ type Topology struct {
|
||||
|
||||
type segment struct {
|
||||
// Some fields need to be exported so that the template can read them.
|
||||
AllowedIPs string
|
||||
Endpoint string
|
||||
Key string
|
||||
allowedIPs []*net.IPNet
|
||||
endpoint net.IP
|
||||
key []byte
|
||||
// Location is the logical location of this segment.
|
||||
Location string
|
||||
location string
|
||||
|
||||
// cidrs is a slice of subnets of all peers in the segment.
|
||||
cidrs []*net.IPNet
|
||||
@@ -88,8 +71,8 @@ type segment struct {
|
||||
wireGuardIP net.IP
|
||||
}
|
||||
|
||||
// NewTopology creates a new Topology struct from a given set of nodes.
|
||||
func NewTopology(nodes map[string]*Node, granularity Granularity, hostname string, port int, key []byte, subnet *net.IPNet) (*Topology, error) {
|
||||
// NewTopology creates a new Topology struct from a given set of nodes and peers.
|
||||
func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet) (*Topology, error) {
|
||||
topoMap := make(map[string][]*Node)
|
||||
for _, node := range nodes {
|
||||
var location string
|
||||
@@ -109,7 +92,7 @@ func NewTopology(nodes map[string]*Node, granularity Granularity, hostname strin
|
||||
localLocation = hostname
|
||||
}
|
||||
|
||||
t := Topology{Key: strings.TrimSpace(string(key)), Port: port, hostname: hostname, Location: localLocation, subnet: subnet, privateIP: nodes[hostname].InternalIP}
|
||||
t := Topology{key: key, port: port, hostname: hostname, location: localLocation, subnet: subnet, privateIP: nodes[hostname].InternalIP}
|
||||
for location := range topoMap {
|
||||
// Sort the location so the result is stable.
|
||||
sort.Slice(topoMap[location], func(i, j int) bool {
|
||||
@@ -119,7 +102,7 @@ func NewTopology(nodes map[string]*Node, granularity Granularity, hostname strin
|
||||
if location == localLocation && topoMap[location][leader].Name == hostname {
|
||||
t.leader = true
|
||||
}
|
||||
var allowedIPs []string
|
||||
var allowedIPs []*net.IPNet
|
||||
var cidrs []*net.IPNet
|
||||
var hostnames []string
|
||||
var privateIPs []net.IP
|
||||
@@ -128,37 +111,45 @@ func NewTopology(nodes map[string]*Node, granularity Granularity, hostname strin
|
||||
// - the node's allocated subnet
|
||||
// - the node's WireGuard IP
|
||||
// - the node's internal IP
|
||||
allowedIPs = append(allowedIPs, node.Subnet.String(), oneAddressCIDR(node.InternalIP.IP).String())
|
||||
allowedIPs = append(allowedIPs, node.Subnet, oneAddressCIDR(node.InternalIP.IP))
|
||||
cidrs = append(cidrs, node.Subnet)
|
||||
hostnames = append(hostnames, node.Name)
|
||||
privateIPs = append(privateIPs, node.InternalIP.IP)
|
||||
}
|
||||
t.Segments = append(t.Segments, &segment{
|
||||
AllowedIPs: strings.Join(allowedIPs, ", "),
|
||||
Endpoint: topoMap[location][leader].ExternalIP.IP.String(),
|
||||
Key: strings.TrimSpace(string(topoMap[location][leader].Key)),
|
||||
Location: location,
|
||||
t.segments = append(t.segments, &segment{
|
||||
allowedIPs: allowedIPs,
|
||||
endpoint: topoMap[location][leader].ExternalIP.IP,
|
||||
key: topoMap[location][leader].Key,
|
||||
location: location,
|
||||
cidrs: cidrs,
|
||||
hostnames: hostnames,
|
||||
leader: leader,
|
||||
privateIPs: privateIPs,
|
||||
})
|
||||
}
|
||||
// Sort the Topology so the result is stable.
|
||||
sort.Slice(t.Segments, func(i, j int) bool {
|
||||
return t.Segments[i].Location < t.Segments[j].Location
|
||||
// Sort the Topology segments so the result is stable.
|
||||
sort.Slice(t.segments, func(i, j int) bool {
|
||||
return t.segments[i].location < t.segments[j].location
|
||||
})
|
||||
|
||||
for _, peer := range peers {
|
||||
t.peers = append(t.peers, peer)
|
||||
}
|
||||
// Sort the Topology peers so the result is stable.
|
||||
sort.Slice(t.peers, func(i, j int) bool {
|
||||
return t.peers[i].Name < t.peers[j].Name
|
||||
})
|
||||
|
||||
// Allocate IPs to the segment leaders in a stable, coordination-free manner.
|
||||
a := newAllocator(*subnet)
|
||||
for _, segment := range t.Segments {
|
||||
for _, segment := range t.segments {
|
||||
ipNet := a.next()
|
||||
if ipNet == nil {
|
||||
return nil, errors.New("failed to allocate an IP address; ran out of IP addresses")
|
||||
}
|
||||
segment.wireGuardIP = ipNet.IP
|
||||
segment.AllowedIPs = fmt.Sprintf("%s, %s", segment.AllowedIPs, ipNet.String())
|
||||
if t.leader && segment.Location == t.Location {
|
||||
segment.allowedIPs = append(segment.allowedIPs, ipNet)
|
||||
if t.leader && segment.location == t.location {
|
||||
t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: t.subnet.Mask}
|
||||
}
|
||||
}
|
||||
@@ -169,8 +160,8 @@ func NewTopology(nodes map[string]*Node, granularity Granularity, hostname strin
|
||||
// RemoteSubnets identifies the subnets of the hosts in segments different than the host's.
|
||||
func (t *Topology) RemoteSubnets() []*net.IPNet {
|
||||
var remote []*net.IPNet
|
||||
for _, s := range t.Segments {
|
||||
if s == nil || s.Location == t.Location {
|
||||
for _, s := range t.segments {
|
||||
if s == nil || s.location == t.location {
|
||||
continue
|
||||
}
|
||||
remote = append(remote, s.cidrs...)
|
||||
@@ -184,13 +175,13 @@ func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, encap
|
||||
if !t.leader {
|
||||
// Find the leader for this segment.
|
||||
var leader net.IP
|
||||
for _, segment := range t.Segments {
|
||||
if segment.Location == t.Location {
|
||||
for _, segment := range t.segments {
|
||||
if segment.location == t.location {
|
||||
leader = segment.privateIPs[segment.leader]
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, segment := range t.Segments {
|
||||
for _, segment := range t.segments {
|
||||
// First, add a route to the WireGuard IP of the segment.
|
||||
routes = append(routes, encapsulateRoute(&netlink.Route{
|
||||
Dst: oneAddressCIDR(segment.wireGuardIP),
|
||||
@@ -200,7 +191,7 @@ func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, encap
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
}, encapsulate, t.privateIP, tunlIface))
|
||||
// Add routes for the current segment if local is true.
|
||||
if segment.Location == t.Location {
|
||||
if segment.location == t.location {
|
||||
if local {
|
||||
for i := range segment.cidrs {
|
||||
// Don't add routes for the local node.
|
||||
@@ -239,11 +230,23 @@ func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, encap
|
||||
}, encapsulate, t.privateIP, tunlIface))
|
||||
}
|
||||
}
|
||||
// Add routes for the allowed IPs of peers.
|
||||
for _, peer := range t.peers {
|
||||
for i := range peer.AllowedIPs {
|
||||
routes = append(routes, encapsulateRoute(&netlink.Route{
|
||||
Dst: peer.AllowedIPs[i],
|
||||
Flags: int(netlink.FLAG_ONLINK),
|
||||
Gw: leader,
|
||||
LinkIndex: privIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
}, encapsulate, t.privateIP, tunlIface))
|
||||
}
|
||||
}
|
||||
return routes
|
||||
}
|
||||
for _, segment := range t.Segments {
|
||||
for _, segment := range t.segments {
|
||||
// Add routes for the current segment if local is true.
|
||||
if segment.Location == t.Location {
|
||||
if segment.location == t.location {
|
||||
if local {
|
||||
for i := range segment.cidrs {
|
||||
// Don't add routes for the local node.
|
||||
@@ -282,6 +285,16 @@ func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, encap
|
||||
})
|
||||
}
|
||||
}
|
||||
// Add routes for the allowed IPs of peers.
|
||||
for _, peer := range t.peers {
|
||||
for i := range peer.AllowedIPs {
|
||||
routes = append(routes, &netlink.Route{
|
||||
Dst: peer.AllowedIPs[i],
|
||||
LinkIndex: kiloIface,
|
||||
Protocol: unix.RTPROT_STATIC,
|
||||
})
|
||||
}
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
@@ -293,12 +306,66 @@ func encapsulateRoute(route *netlink.Route, encapsulate Encapsulate, subnet *net
|
||||
}
|
||||
|
||||
// Conf generates a WireGuard configuration file for a given Topology.
|
||||
func (t *Topology) Conf() ([]byte, error) {
|
||||
conf := new(bytes.Buffer)
|
||||
if err := confTemplate.Execute(conf, t); err != nil {
|
||||
return nil, err
|
||||
func (t *Topology) Conf() *wireguard.Conf {
|
||||
c := &wireguard.Conf{
|
||||
Interface: &wireguard.Interface{
|
||||
PrivateKey: t.key,
|
||||
ListenPort: t.port,
|
||||
},
|
||||
}
|
||||
return conf.Bytes(), nil
|
||||
for _, s := range t.segments {
|
||||
if s.location == t.location {
|
||||
continue
|
||||
}
|
||||
peer := &wireguard.Peer{
|
||||
AllowedIPs: s.allowedIPs,
|
||||
Endpoint: &wireguard.Endpoint{
|
||||
IP: s.endpoint,
|
||||
Port: uint32(t.port),
|
||||
},
|
||||
PublicKey: s.key,
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
for _, p := range t.peers {
|
||||
peer := &wireguard.Peer{
|
||||
AllowedIPs: p.AllowedIPs,
|
||||
PersistentKeepalive: p.PersistentKeepalive,
|
||||
PublicKey: p.PublicKey,
|
||||
Endpoint: p.Endpoint,
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// PeerConf generates a WireGuard configuration file for a given peer in a Topology.
|
||||
func (t *Topology) PeerConf(name string) *wireguard.Conf {
|
||||
c := &wireguard.Conf{}
|
||||
for _, s := range t.segments {
|
||||
peer := &wireguard.Peer{
|
||||
AllowedIPs: s.allowedIPs,
|
||||
Endpoint: &wireguard.Endpoint{
|
||||
IP: s.endpoint,
|
||||
Port: uint32(t.port),
|
||||
},
|
||||
PublicKey: s.key,
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
for _, p := range t.peers {
|
||||
if p.Name == name {
|
||||
continue
|
||||
}
|
||||
peer := &wireguard.Peer{
|
||||
AllowedIPs: p.AllowedIPs,
|
||||
PersistentKeepalive: p.PersistentKeepalive,
|
||||
PublicKey: p.PublicKey,
|
||||
Endpoint: p.Endpoint,
|
||||
}
|
||||
c.Peers = append(c.Peers, peer)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// oneAddressCIDR takes an IP address and returns a CIDR
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user