*: add peer VPN support
This commit adds support for defining arbitrary peers that should have access to the VPN. In k8s, this is accomplished using the new Peer CRD.
This commit is contained in:
263
pkg/mesh/mesh.go
263
pkg/mesh/mesh.go
@@ -15,6 +15,7 @@
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -43,6 +44,8 @@ const (
|
||||
PrivateKeyPath = KiloPath + "/key"
|
||||
// ConfPath is the filepath where the WireGuard configuration is stored.
|
||||
ConfPath = KiloPath + "/conf"
|
||||
// DefaultKiloPort is the default UDP port Kilo uses.
|
||||
DefaultKiloPort = 51820
|
||||
)
|
||||
|
||||
// Granularity represents the abstraction level at which the network
|
||||
@@ -93,6 +96,17 @@ func (n *Node) Ready() bool {
|
||||
return n != nil && n.ExternalIP != nil && n.Key != nil && n.InternalIP != nil && n.Subnet != nil && time.Now().Unix()-n.LastSeen < int64(resyncPeriod)*2/int64(time.Second)
|
||||
}
|
||||
|
||||
// Peer represents a peer in the network.
|
||||
type Peer struct {
|
||||
wireguard.Peer
|
||||
Name string
|
||||
}
|
||||
|
||||
// Ready indicates whether or not the peer is ready.
|
||||
func (p *Peer) Ready() bool {
|
||||
return p != nil && p.AllowedIPs != nil && len(p.AllowedIPs) != 0 && p.PublicKey != nil
|
||||
}
|
||||
|
||||
// EventType describes what kind of an action an event represents.
|
||||
type EventType string
|
||||
|
||||
@@ -105,24 +119,53 @@ const (
|
||||
UpdateEvent EventType = "update"
|
||||
)
|
||||
|
||||
// Event represents an update event concerning a node in the cluster.
|
||||
type Event struct {
|
||||
// NodeEvent represents an event concerning a node in the cluster.
|
||||
type NodeEvent struct {
|
||||
Type EventType
|
||||
Node *Node
|
||||
}
|
||||
|
||||
// Backend can get nodes by name, init itself,
|
||||
// PeerEvent represents an event concerning a peer in the cluster.
|
||||
type PeerEvent struct {
|
||||
Type EventType
|
||||
Peer *Peer
|
||||
}
|
||||
|
||||
// Backend can create clients for all of the
|
||||
// primitive types that Kilo deals with, namely:
|
||||
// * nodes; and
|
||||
// * peers.
|
||||
type Backend interface {
|
||||
Nodes() NodeBackend
|
||||
Peers() PeerBackend
|
||||
}
|
||||
|
||||
// NodeBackend can get nodes by name, init itself,
|
||||
// list the nodes that should be meshed,
|
||||
// set Kilo properties for a node,
|
||||
// clean up any changes applied to the backend,
|
||||
// and watch for changes to nodes.
|
||||
type Backend interface {
|
||||
type NodeBackend interface {
|
||||
CleanUp(string) error
|
||||
Get(string) (*Node, error)
|
||||
Init(<-chan struct{}) error
|
||||
List() ([]*Node, error)
|
||||
Set(string, *Node) error
|
||||
Watch() <-chan *Event
|
||||
Watch() <-chan *NodeEvent
|
||||
}
|
||||
|
||||
// PeerBackend can get peers by name, init itself,
|
||||
// list the peers that should be in the mesh,
|
||||
// set fields for a peer,
|
||||
// clean up any changes applied to the backend,
|
||||
// and watch for changes to peers.
|
||||
type PeerBackend interface {
|
||||
CleanUp(string) error
|
||||
Get(string) (*Peer, error)
|
||||
Init(<-chan struct{}) error
|
||||
List() ([]*Peer, error)
|
||||
Set(string, *Peer) error
|
||||
Watch() <-chan *PeerEvent
|
||||
}
|
||||
|
||||
// Mesh is able to create Kilo network meshes.
|
||||
@@ -138,7 +181,7 @@ type Mesh struct {
|
||||
kiloIface int
|
||||
key []byte
|
||||
local bool
|
||||
port int
|
||||
port uint32
|
||||
priv []byte
|
||||
privIface int
|
||||
pub []byte
|
||||
@@ -148,23 +191,26 @@ type Mesh struct {
|
||||
table *route.Table
|
||||
tunlIface int
|
||||
|
||||
// nodes is a mutable field in the struct
|
||||
// nodes and peers are mutable fields in the struct
|
||||
// and needs to be guarded.
|
||||
nodes map[string]*Node
|
||||
peers map[string]*Peer
|
||||
mu sync.Mutex
|
||||
|
||||
errorCounter *prometheus.CounterVec
|
||||
nodesGuage prometheus.Gauge
|
||||
peersGuage prometheus.Gauge
|
||||
reconcileCounter prometheus.Counter
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// New returns a new Mesh instance.
|
||||
func New(backend Backend, encapsulate Encapsulate, granularity Granularity, hostname string, port int, subnet *net.IPNet, local bool, logger log.Logger) (*Mesh, error) {
|
||||
func New(backend Backend, encapsulate Encapsulate, granularity Granularity, hostname string, port uint32, subnet *net.IPNet, local bool, logger log.Logger) (*Mesh, error) {
|
||||
if err := os.MkdirAll(KiloPath, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
|
||||
}
|
||||
private, err := ioutil.ReadFile(PrivateKeyPath)
|
||||
private = bytes.Trim(private, "\n")
|
||||
if err != nil {
|
||||
level.Warn(logger).Log("msg", "no private key found on disk; generating one now")
|
||||
if private, err = wireguard.GenKey(); err != nil {
|
||||
@@ -224,6 +270,7 @@ func New(backend Backend, encapsulate Encapsulate, granularity Granularity, host
|
||||
ipTables: ipTables,
|
||||
kiloIface: kiloIface,
|
||||
nodes: make(map[string]*Node),
|
||||
peers: make(map[string]*Peer),
|
||||
port: port,
|
||||
priv: private,
|
||||
privIface: privIface,
|
||||
@@ -240,7 +287,11 @@ func New(backend Backend, encapsulate Encapsulate, granularity Granularity, host
|
||||
}, []string{"event"}),
|
||||
nodesGuage: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "kilo_nodes",
|
||||
Help: "Number of in the mesh.",
|
||||
Help: "Number of nodes in the mesh.",
|
||||
}),
|
||||
peersGuage: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "kilo_peers",
|
||||
Help: "Number of peers in the mesh.",
|
||||
}),
|
||||
reconcileCounter: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "kilo_reconciles_total",
|
||||
@@ -252,8 +303,11 @@ func New(backend Backend, encapsulate Encapsulate, granularity Granularity, host
|
||||
|
||||
// Run starts the mesh.
|
||||
func (m *Mesh) Run() error {
|
||||
if err := m.Init(m.stop); err != nil {
|
||||
return fmt.Errorf("failed to initialize backend: %v", err)
|
||||
if err := m.Nodes().Init(m.stop); err != nil {
|
||||
return fmt.Errorf("failed to initialize node backend: %v", err)
|
||||
}
|
||||
if err := m.Peers().Init(m.stop); err != nil {
|
||||
return fmt.Errorf("failed to initialize peer backend: %v", err)
|
||||
}
|
||||
ipsetErrors, err := m.ipset.Run(m.stop)
|
||||
if err != nil {
|
||||
@@ -285,14 +339,19 @@ func (m *Mesh) Run() error {
|
||||
}()
|
||||
defer m.cleanUp()
|
||||
t := time.NewTimer(resyncPeriod)
|
||||
w := m.Watch()
|
||||
nw := m.Nodes().Watch()
|
||||
pw := m.Peers().Watch()
|
||||
var ne *NodeEvent
|
||||
var pe *PeerEvent
|
||||
for {
|
||||
var e *Event
|
||||
select {
|
||||
case e = <-w:
|
||||
m.sync(e)
|
||||
case ne = <-nw:
|
||||
m.syncNodes(ne)
|
||||
case pe = <-pw:
|
||||
m.syncPeers(pe)
|
||||
case <-t.C:
|
||||
m.checkIn()
|
||||
m.syncEndpoints()
|
||||
m.applyTopology()
|
||||
t.Reset(resyncPeriod)
|
||||
case <-m.stop:
|
||||
@@ -301,9 +360,50 @@ func (m *Mesh) Run() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mesh) sync(e *Event) {
|
||||
// WireGuard updates the endpoints of peers to match the
|
||||
// last place a valid packet was received from.
|
||||
// Periodically we need to syncronize the endpoints
|
||||
// of peers in the backend to match the WireGuard configuration.
|
||||
func (m *Mesh) syncEndpoints() {
|
||||
link, err := linkByIndex(m.kiloIface)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("endpoints").Inc()
|
||||
return
|
||||
}
|
||||
conf, err := wireguard.ShowConf(link.Attrs().Name)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("endpoints").Inc()
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
c := wireguard.Parse(conf)
|
||||
var key string
|
||||
var tmp *Peer
|
||||
for i := range c.Peers {
|
||||
// Peers are indexed by public key.
|
||||
key = string(c.Peers[i].PublicKey)
|
||||
if p, ok := m.peers[key]; ok {
|
||||
tmp = &Peer{
|
||||
Name: p.Name,
|
||||
Peer: *c.Peers[i],
|
||||
}
|
||||
if !peersAreEqual(tmp, p) {
|
||||
p.Endpoint = tmp.Endpoint
|
||||
if err := m.Peers().Set(p.Name, p); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("endpoints").Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mesh) syncNodes(e *NodeEvent) {
|
||||
logger := log.With(m.logger, "event", e.Type)
|
||||
level.Debug(logger).Log("msg", "syncing", "event", e.Type)
|
||||
level.Debug(logger).Log("msg", "syncing nodes", "event", e.Type)
|
||||
if isSelf(m.hostname, e.Node) {
|
||||
level.Debug(logger).Log("msg", "processing local node", "node", e.Node)
|
||||
m.handleLocal(e.Node)
|
||||
@@ -326,9 +426,11 @@ func (m *Mesh) sync(e *Event) {
|
||||
fallthrough
|
||||
case UpdateEvent:
|
||||
if !nodesAreEqual(m.nodes[e.Node.Name], e.Node) {
|
||||
m.nodes[e.Node.Name] = e.Node
|
||||
diff = true
|
||||
}
|
||||
// Even if the nodes are the same,
|
||||
// overwrite the old node to update the timestamp.
|
||||
m.nodes[e.Node.Name] = e.Node
|
||||
case DeleteEvent:
|
||||
delete(m.nodes, e.Node.Name)
|
||||
diff = true
|
||||
@@ -341,6 +443,43 @@ func (m *Mesh) sync(e *Event) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mesh) syncPeers(e *PeerEvent) {
|
||||
logger := log.With(m.logger, "event", e.Type)
|
||||
level.Debug(logger).Log("msg", "syncing peers", "event", e.Type)
|
||||
var diff bool
|
||||
m.mu.Lock()
|
||||
// Peers are indexed by public key.
|
||||
key := string(e.Peer.PublicKey)
|
||||
if !e.Peer.Ready() {
|
||||
level.Debug(logger).Log("msg", "received incomplete peer", "peer", e.Peer)
|
||||
// An existing peer is no longer valid
|
||||
// so remove it from the mesh.
|
||||
if _, ok := m.peers[key]; ok {
|
||||
level.Info(logger).Log("msg", "peer is no longer in the mesh", "peer", e.Peer)
|
||||
delete(m.peers, key)
|
||||
diff = true
|
||||
}
|
||||
} else {
|
||||
switch e.Type {
|
||||
case AddEvent:
|
||||
fallthrough
|
||||
case UpdateEvent:
|
||||
if !peersAreEqual(m.peers[key], e.Peer) {
|
||||
m.peers[key] = e.Peer
|
||||
diff = true
|
||||
}
|
||||
case DeleteEvent:
|
||||
delete(m.peers, key)
|
||||
diff = true
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if diff {
|
||||
level.Info(logger).Log("peer", e.Peer)
|
||||
m.applyTopology()
|
||||
}
|
||||
}
|
||||
|
||||
// checkIn will try to update the local node's LastSeen timestamp
|
||||
// in the backend.
|
||||
func (m *Mesh) checkIn() {
|
||||
@@ -352,7 +491,7 @@ func (m *Mesh) checkIn() {
|
||||
return
|
||||
}
|
||||
n.LastSeen = time.Now().Unix()
|
||||
if err := m.Set(m.hostname, n); err != nil {
|
||||
if err := m.Nodes().Set(m.hostname, n); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", n)
|
||||
m.errorCounter.WithLabelValues("checkin").Inc()
|
||||
return
|
||||
@@ -380,7 +519,7 @@ func (m *Mesh) handleLocal(n *Node) {
|
||||
}
|
||||
if !nodesAreEqual(n, local) {
|
||||
level.Debug(m.logger).Log("msg", "local node differs from backend")
|
||||
if err := m.Set(m.hostname, local); err != nil {
|
||||
if err := m.Nodes().Set(m.hostname, local); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", local)
|
||||
m.errorCounter.WithLabelValues("local").Inc()
|
||||
return
|
||||
@@ -406,31 +545,42 @@ func (m *Mesh) applyTopology() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Ensure all unready nodes are removed.
|
||||
var ready float64
|
||||
for n := range m.nodes {
|
||||
if !m.nodes[n].Ready() {
|
||||
delete(m.nodes, n)
|
||||
var readyNodes float64
|
||||
for k := range m.nodes {
|
||||
if !m.nodes[k].Ready() {
|
||||
delete(m.nodes, k)
|
||||
continue
|
||||
}
|
||||
ready++
|
||||
readyNodes++
|
||||
}
|
||||
m.nodesGuage.Set(ready)
|
||||
// Ensure all unready peers are removed.
|
||||
var readyPeers float64
|
||||
for k := range m.peers {
|
||||
if !m.peers[k].Ready() {
|
||||
delete(m.peers, k)
|
||||
continue
|
||||
}
|
||||
readyPeers++
|
||||
}
|
||||
m.nodesGuage.Set(readyNodes)
|
||||
m.peersGuage.Set(readyPeers)
|
||||
// We cannot do anything with the topology until the local node is available.
|
||||
if m.nodes[m.hostname] == nil {
|
||||
return
|
||||
}
|
||||
t, err := NewTopology(m.nodes, m.granularity, m.hostname, m.port, m.priv, m.subnet)
|
||||
t, err := NewTopology(m.nodes, m.peers, m.granularity, m.hostname, m.port, m.priv, m.subnet)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
conf, err := t.Conf()
|
||||
conf := t.Conf()
|
||||
buf, err := conf.Bytes()
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
}
|
||||
if err := ioutil.WriteFile(ConfPath, conf, 0600); err != nil {
|
||||
if err := ioutil.WriteFile(ConfPath, buf, 0600); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
@@ -443,6 +593,9 @@ func (m *Mesh) applyTopology() {
|
||||
}
|
||||
rules := iptables.MasqueradeRules(private, m.nodes[m.hostname].Subnet, t.RemoteSubnets())
|
||||
rules = append(rules, iptables.ForwardRules(m.subnet)...)
|
||||
for _, p := range m.peers {
|
||||
rules = append(rules, iptables.ForwardRules(p.AllowedIPs...)...)
|
||||
}
|
||||
if err := m.ipTables.Set(rules); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
@@ -450,8 +603,8 @@ func (m *Mesh) applyTopology() {
|
||||
}
|
||||
if m.encapsulate != NeverEncapsulate {
|
||||
var peers []net.IP
|
||||
for _, s := range t.Segments {
|
||||
if s.Location == m.nodes[m.hostname].Location {
|
||||
for _, s := range t.segments {
|
||||
if s.location == m.nodes[m.hostname].Location {
|
||||
peers = s.privateIPs
|
||||
break
|
||||
}
|
||||
@@ -461,6 +614,8 @@ func (m *Mesh) applyTopology() {
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
return
|
||||
}
|
||||
// If we are handling local routes, ensure the local
|
||||
// tunnel has an IP address.
|
||||
if m.local {
|
||||
if err := iproute.SetAddress(m.tunlIface, oneAddressCIDR(newAllocator(*m.nodes[m.hostname].Subnet).next().IP)); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
@@ -489,14 +644,9 @@ func (m *Mesh) applyTopology() {
|
||||
}
|
||||
// Setting the WireGuard configuration interrupts existing connections
|
||||
// so only set the configuration if it has changed.
|
||||
equal, err := wireguard.CompareConf(conf, oldConf)
|
||||
if err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
// Don't return here, simply overwrite the old configuration.
|
||||
equal = false
|
||||
}
|
||||
equal := conf.Equal(wireguard.Parse(oldConf))
|
||||
if !equal {
|
||||
level.Info(m.logger).Log("msg", "WireGuard configurations are different")
|
||||
if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil {
|
||||
level.Error(m.logger).Log("error", err)
|
||||
m.errorCounter.WithLabelValues("apply").Inc()
|
||||
@@ -531,6 +681,7 @@ func (m *Mesh) RegisterMetrics(r prometheus.Registerer) {
|
||||
r.MustRegister(
|
||||
m.errorCounter,
|
||||
m.nodesGuage,
|
||||
m.peersGuage,
|
||||
m.reconcileCounter,
|
||||
)
|
||||
}
|
||||
@@ -558,11 +709,15 @@ func (m *Mesh) cleanUp() {
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := iproute.RemoveInterface(m.kiloIface); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove wireguard interface: %v", err))
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := m.CleanUp(m.hostname); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up backend: %v", err))
|
||||
if err := m.Nodes().CleanUp(m.hostname); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up node backend: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := m.Peers().CleanUp(m.hostname); err != nil {
|
||||
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err))
|
||||
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
||||
}
|
||||
if err := m.ipset.CleanUp(); err != nil {
|
||||
@@ -586,6 +741,32 @@ func nodesAreEqual(a, b *Node) bool {
|
||||
return ipNetsEqual(a.ExternalIP, b.ExternalIP) && string(a.Key) == string(b.Key) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet)
|
||||
}
|
||||
|
||||
func peersAreEqual(a, b *Peer) bool {
|
||||
if !(a != nil) == (b != nil) {
|
||||
return false
|
||||
}
|
||||
if a == b {
|
||||
return true
|
||||
}
|
||||
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
|
||||
return false
|
||||
}
|
||||
if a.Endpoint != nil {
|
||||
if !a.Endpoint.IP.Equal(b.Endpoint.IP) || a.Endpoint.Port != b.Endpoint.Port {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(a.AllowedIPs) != len(b.AllowedIPs) {
|
||||
return false
|
||||
}
|
||||
for i := range a.AllowedIPs {
|
||||
if !ipNetsEqual(a.AllowedIPs[i], b.AllowedIPs[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive
|
||||
}
|
||||
|
||||
func ipNetsEqual(a, b *net.IPNet) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
|
Reference in New Issue
Block a user