diff --git a/cmd/kg/main.go b/cmd/kg/main.go index 7e84fd8..c1b4870 100644 --- a/cmd/kg/main.go +++ b/cmd/kg/main.go @@ -15,6 +15,7 @@ package main import ( + "context" "errors" "fmt" "net" @@ -275,15 +276,16 @@ func runRoot(_ *cobra.Command, _ []string) error { } { + ctx, cancel := context.WithCancel(context.Background()) // Start the mesh. g.Add(func() error { logger.Log("msg", fmt.Sprintf("Starting Kilo network mesh '%v'.", version.Version)) - if err := m.Run(); err != nil { + if err := m.Run(ctx); err != nil { return fmt.Errorf("error: Kilo exited unexpectedly: %v", err) } return nil }, func(error) { - m.Stop() + cancel() }) } diff --git a/cmd/kgctl/main.go b/cmd/kgctl/main.go index aa0a316..9a6c58a 100644 --- a/cmd/kgctl/main.go +++ b/cmd/kgctl/main.go @@ -71,7 +71,7 @@ var ( topologyLabel string ) -func runRoot(_ *cobra.Command, _ []string) error { +func runRoot(c *cobra.Command, _ []string) error { if opts.port < 1 || opts.port > 1<<16-1 { return fmt.Errorf("invalid port: port mus be in range [%d:%d], but got %d", 1, 1<<16-1, opts.port) } @@ -99,11 +99,11 @@ func runRoot(_ *cobra.Command, _ []string) error { return fmt.Errorf("backend %s unknown; posible values are: %s", backend, availableBackends) } - if err := opts.backend.Nodes().Init(make(chan struct{})); err != nil { + if err := opts.backend.Nodes().Init(c.Context()); err != nil { return fmt.Errorf("failed to initialize node backend: %w", err) } - if err := opts.backend.Peers().Init(make(chan struct{})); err != nil { + if err := opts.backend.Peers().Init(c.Context()); err != nil { return fmt.Errorf("failed to initialize peer backend: %w", err) } return nil diff --git a/pkg/k8s/backend.go b/pkg/k8s/backend.go index 9b5da36..e94a664 100644 --- a/pkg/k8s/backend.go +++ b/pkg/k8s/backend.go @@ -128,7 +128,7 @@ func New(c kubernetes.Interface, kc kiloclient.Interface, ec apiextensions.Inter } // CleanUp removes configuration applied to the backend. -func (nb *nodeBackend) CleanUp(name string) error { +func (nb *nodeBackend) CleanUp(ctx context.Context, name string) error { patch := []byte("[" + strings.Join([]string{ fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(endpointAnnotationKey, "/", jsonPatchSlash, 1))), fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(internalIPAnnotationKey, "/", jsonPatchSlash, 1))), @@ -138,7 +138,7 @@ func (nb *nodeBackend) CleanUp(name string) error { fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(discoveredEndpointsKey, "/", jsonPatchSlash, 1))), fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(granularityKey, "/", jsonPatchSlash, 1))), }, ",") + "]") - if _, err := nb.client.CoreV1().Nodes().Patch(context.TODO(), name, types.JSONPatchType, patch, metav1.PatchOptions{}); err != nil { + if _, err := nb.client.CoreV1().Nodes().Patch(ctx, name, types.JSONPatchType, patch, metav1.PatchOptions{}); err != nil { return fmt.Errorf("failed to patch node: %v", err) } return nil @@ -155,9 +155,9 @@ func (nb *nodeBackend) Get(name string) (*mesh.Node, error) { // Init initializes the backend; for this backend that means // syncing the informer cache. -func (nb *nodeBackend) Init(stop <-chan struct{}) error { - go nb.informer.Run(stop) - if ok := cache.WaitForCacheSync(stop, func() bool { +func (nb *nodeBackend) Init(ctx context.Context) error { + go nb.informer.Run(ctx.Done()) + if ok := cache.WaitForCacheSync(ctx.Done(), func() bool { return nb.informer.HasSynced() }); !ok { return errors.New("failed to sync node cache") @@ -212,7 +212,7 @@ func (nb *nodeBackend) List() ([]*mesh.Node, error) { } // Set sets the fields of a node. -func (nb *nodeBackend) Set(name string, node *mesh.Node) error { +func (nb *nodeBackend) Set(ctx context.Context, name string, node *mesh.Node) error { old, err := nb.lister.Get(name) if err != nil { return fmt.Errorf("failed to find node: %v", err) @@ -253,7 +253,7 @@ func (nb *nodeBackend) Set(name string, node *mesh.Node) error { if err != nil { return fmt.Errorf("failed to create patch for node %q: %v", n.Name, err) } - if _, err = nb.client.CoreV1().Nodes().Patch(context.TODO(), name, types.StrategicMergePatchType, patch, metav1.PatchOptions{}); err != nil { + if _, err = nb.client.CoreV1().Nodes().Patch(ctx, name, types.StrategicMergePatchType, patch, metav1.PatchOptions{}); err != nil { return fmt.Errorf("failed to patch node: %v", err) } return nil @@ -431,7 +431,7 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer { } // CleanUp removes configuration applied to the backend. -func (pb *peerBackend) CleanUp(name string) error { +func (pb *peerBackend) CleanUp(_ context.Context, _ string) error { return nil } @@ -446,14 +446,14 @@ func (pb *peerBackend) Get(name string) (*mesh.Peer, error) { // Init initializes the backend; for this backend that means // syncing the informer cache. -func (pb *peerBackend) Init(stop <-chan struct{}) error { +func (pb *peerBackend) Init(ctx context.Context) error { // Check the presents of the CRD peers.kilo.squat.ai. - if _, err := pb.extensionsClient.ApiextensionsV1().CustomResourceDefinitions().Get(context.TODO(), strings.Join([]string{v1alpha1.PeerPlural, v1alpha1.GroupName}, "."), metav1.GetOptions{}); err != nil { + if _, err := pb.extensionsClient.ApiextensionsV1().CustomResourceDefinitions().Get(ctx, strings.Join([]string{v1alpha1.PeerPlural, v1alpha1.GroupName}, "."), metav1.GetOptions{}); err != nil { return fmt.Errorf("CRD is not present: %v", err) } - go pb.informer.Run(stop) - if ok := cache.WaitForCacheSync(stop, func() bool { + go pb.informer.Run(ctx.Done()) + if ok := cache.WaitForCacheSync(ctx.Done(), func() bool { return pb.informer.HasSynced() }); !ok { return errors.New("failed to sync peer cache") @@ -512,7 +512,7 @@ func (pb *peerBackend) List() ([]*mesh.Peer, error) { } // Set sets the fields of a peer. -func (pb *peerBackend) Set(name string, peer *mesh.Peer) error { +func (pb *peerBackend) Set(ctx context.Context, name string, peer *mesh.Peer) error { old, err := pb.lister.Get(name) if err != nil { return fmt.Errorf("failed to find peer: %v", err) @@ -542,7 +542,7 @@ func (pb *peerBackend) Set(name string, peer *mesh.Peer) error { p.Spec.PresharedKey = peer.PresharedKey.String() } p.Spec.PublicKey = peer.PublicKey.String() - if _, err = pb.client.KiloV1alpha1().Peers().Update(context.TODO(), p, metav1.UpdateOptions{}); err != nil { + if _, err = pb.client.KiloV1alpha1().Peers().Update(ctx, p, metav1.UpdateOptions{}); err != nil { return fmt.Errorf("failed to update peer: %v", err) } return nil diff --git a/pkg/mesh/backend.go b/pkg/mesh/backend.go index 562d32b..203661d 100644 --- a/pkg/mesh/backend.go +++ b/pkg/mesh/backend.go @@ -15,6 +15,7 @@ package mesh import ( + "context" "net" "time" @@ -146,11 +147,11 @@ type Backend interface { // clean up any changes applied to the backend, // and watch for changes to nodes. type NodeBackend interface { - CleanUp(string) error + CleanUp(context.Context, string) error Get(string) (*Node, error) - Init(<-chan struct{}) error + Init(context.Context) error List() ([]*Node, error) - Set(string, *Node) error + Set(context.Context, string, *Node) error Watch() <-chan *NodeEvent } @@ -160,10 +161,10 @@ type NodeBackend interface { // clean up any changes applied to the backend, // and watch for changes to peers. type PeerBackend interface { - CleanUp(string) error + CleanUp(context.Context, string) error Get(string) (*Peer, error) - Init(<-chan struct{}) error + Init(context.Context) error List() ([]*Peer, error) - Set(string, *Peer) error + Set(context.Context, string, *Peer) error Watch() <-chan *PeerEvent } diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index bc9b43d..ea405d3 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -19,6 +19,7 @@ package mesh import ( "bytes" + "context" "fmt" "io/ioutil" "net" @@ -69,7 +70,6 @@ type Mesh struct { pub wgtypes.Key resyncPeriod time.Duration iptablesForwardRule bool - stop chan struct{} subnet *net.IPNet table *route.Table wireGuardIP *net.IPNet @@ -180,7 +180,6 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit resyncPeriod: resyncPeriod, iptablesForwardRule: iptablesForwardRule, local: local, - stop: make(chan struct{}), subnet: subnet, table: route.NewTable(), errorCounter: prometheus.NewCounterVec(prometheus.CounterOpts{ @@ -208,8 +207,8 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit } // Run starts the mesh. -func (m *Mesh) Run() error { - if err := m.Nodes().Init(m.stop); err != nil { +func (m *Mesh) Run(ctx context.Context) error { + if err := m.Nodes().Init(ctx); err != nil { return fmt.Errorf("failed to initialize node backend: %v", err) } // Try to set the CNI config quickly. @@ -221,14 +220,14 @@ func (m *Mesh) Run() error { level.Warn(m.logger).Log("error", fmt.Errorf("failed to get node %q: %v", m.hostname, err)) } } - if err := m.Peers().Init(m.stop); err != nil { + if err := m.Peers().Init(ctx); err != nil { return fmt.Errorf("failed to initialize peer backend: %v", err) } - ipTablesErrors, err := m.ipTables.Run(m.stop) + ipTablesErrors, err := m.ipTables.Run(ctx.Done()) if err != nil { return fmt.Errorf("failed to watch for IP tables updates: %v", err) } - routeErrors, err := m.table.Run(m.stop) + routeErrors, err := m.table.Run(ctx.Done()) if err != nil { return fmt.Errorf("failed to watch for route table updates: %v", err) } @@ -238,7 +237,7 @@ func (m *Mesh) Run() error { select { case err = <-ipTablesErrors: case err = <-routeErrors: - case <-m.stop: + case <-ctx.Done(): return } if err != nil { @@ -257,11 +256,11 @@ func (m *Mesh) Run() error { for { select { case ne = <-nw: - m.syncNodes(ne) + m.syncNodes(ctx, ne) case pe = <-pw: m.syncPeers(pe) case <-checkIn.C: - m.checkIn() + m.checkIn(ctx) checkIn.Reset(checkInPeriod) case <-resync.C: if m.cni { @@ -269,18 +268,18 @@ func (m *Mesh) Run() error { } m.applyTopology() resync.Reset(m.resyncPeriod) - case <-m.stop: + case <-ctx.Done(): return nil } } } -func (m *Mesh) syncNodes(e *NodeEvent) { +func (m *Mesh) syncNodes(ctx context.Context, e *NodeEvent) { logger := log.With(m.logger, "event", e.Type) level.Debug(logger).Log("msg", "syncing nodes", "event", e.Type) if isSelf(m.hostname, e.Node) { level.Debug(logger).Log("msg", "processing local node", "node", e.Node) - m.handleLocal(e.Node) + m.handleLocal(ctx, e.Node) return } var diff bool @@ -348,7 +347,7 @@ func (m *Mesh) syncPeers(e *PeerEvent) { // checkIn will try to update the local node's LastSeen timestamp // in the backend. -func (m *Mesh) checkIn() { +func (m *Mesh) checkIn(ctx context.Context) { m.mu.Lock() defer m.mu.Unlock() n := m.nodes[m.hostname] @@ -358,7 +357,7 @@ func (m *Mesh) checkIn() { } oldTime := n.LastSeen n.LastSeen = time.Now().Unix() - if err := m.Nodes().Set(m.hostname, n); err != nil { + if err := m.Nodes().Set(ctx, m.hostname, n); err != nil { level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", n) m.errorCounter.WithLabelValues("checkin").Inc() // Revert time. @@ -368,7 +367,7 @@ func (m *Mesh) checkIn() { level.Debug(m.logger).Log("msg", "successfully checked in local node in backend") } -func (m *Mesh) handleLocal(n *Node) { +func (m *Mesh) handleLocal(ctx context.Context, n *Node) { // Allow the IPs to be overridden. if !n.Endpoint.Ready() { e := wireguard.NewEndpoint(m.externalIP.IP, m.port) @@ -399,7 +398,7 @@ func (m *Mesh) handleLocal(n *Node) { } if !nodesAreEqual(n, local) { level.Debug(m.logger).Log("msg", "local node differs from backend") - if err := m.Nodes().Set(m.hostname, local); err != nil { + if err := m.Nodes().Set(ctx, m.hostname, local); err != nil { level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", local) m.errorCounter.WithLabelValues("local").Inc() return @@ -584,11 +583,6 @@ func (m *Mesh) RegisterMetrics(r prometheus.Registerer) { ) } -// Stop stops the mesh. -func (m *Mesh) Stop() { - close(m.stop) -} - func (m *Mesh) cleanUp() { if err := m.ipTables.CleanUp(); err != nil { level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up IP tables: %v", err)) @@ -604,13 +598,21 @@ func (m *Mesh) cleanUp() { m.errorCounter.WithLabelValues("cleanUp").Inc() } } - if err := m.Nodes().CleanUp(m.hostname); err != nil { - level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up node backend: %v", err)) - m.errorCounter.WithLabelValues("cleanUp").Inc() + { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := m.Nodes().CleanUp(ctx, m.hostname); err != nil { + level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up node backend: %v", err)) + m.errorCounter.WithLabelValues("cleanUp").Inc() + } } - if err := m.Peers().CleanUp(m.hostname); err != nil { - level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err)) - m.errorCounter.WithLabelValues("cleanUp").Inc() + { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := m.Peers().CleanUp(ctx, m.hostname); err != nil { + level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err)) + m.errorCounter.WithLabelValues("cleanUp").Inc() + } } if err := m.enc.CleanUp(); err != nil { level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up encapsulator: %v", err))