// Copyright 2021 the Kilo authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "context" "encoding/json" "errors" "fmt" "io/ioutil" "net/http" "os" "syscall" "time" "github.com/go-kit/kit/log/level" "github.com/oklog/run" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/cobra" v1 "k8s.io/api/admission/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/serializer" kilo "github.com/squat/kilo/pkg/k8s/apis/kilo/v1alpha1" "github.com/squat/kilo/pkg/version" ) var webhookCmd = &cobra.Command{ Use: "webhook", PreRunE: func(c *cobra.Command, a []string) error { if c.HasParent() { return c.Parent().PreRunE(c, a) } return nil }, Short: "webhook starts a HTTPS server to validate updates and creations of Kilo peers.", RunE: webhook, } var ( certPath string keyPath string metricsAddr string listenAddr string ) func init() { webhookCmd.Flags().StringVar(&certPath, "cert-file", "", "The path to a certificate file") webhookCmd.Flags().StringVar(&keyPath, "key-file", "", "The path to a key file") webhookCmd.Flags().StringVar(&metricsAddr, "listen-metrics", ":1107", "The metrics server will be listening to that address") webhookCmd.Flags().StringVar(&listenAddr, "listen", ":8443", "The webhook server will be listening to that address") } var deserializer = serializer.NewCodecFactory(runtime.NewScheme()).UniversalDeserializer() var ( validationCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "admission_requests_total", Help: "The number of received admission reviews requests", }, []string{"operation", "response"}, ) requestCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "http_requests_total", Help: "The number of received http requests", }, []string{"handler", "method"}, ) errorCounter = prometheus.NewCounter( prometheus.CounterOpts{ Name: "errors_total", Help: "The total number of errors", }, ) ) func validationHandler(w http.ResponseWriter, r *http.Request) { level.Debug(logger).Log("msg", "handling request", "source", r.RemoteAddr) body, err := ioutil.ReadAll(r.Body) if err != nil { errorCounter.Inc() level.Error(logger).Log("err", "failed to parse body from incoming request", "source", r.RemoteAddr) http.Error(w, err.Error(), http.StatusBadRequest) return } var admissionReview v1.AdmissionReview contentType := r.Header.Get("Content-Type") if contentType != "application/json" { errorCounter.Inc() msg := fmt.Sprintf("received Content-Type=%s, expected application/json", contentType) level.Error(logger).Log("err", msg) http.Error(w, msg, http.StatusBadRequest) return } response := v1.AdmissionReview{} _, gvk, err := deserializer.Decode(body, nil, &admissionReview) if err != nil { errorCounter.Inc() msg := fmt.Sprintf("Request could not be decoded: %v", err) level.Error(logger).Log("err", msg) http.Error(w, msg, http.StatusBadRequest) return } if *gvk != v1.SchemeGroupVersion.WithKind("AdmissionReview") { errorCounter.Inc() msg := "only API v1 is supported" level.Error(logger).Log("err", msg) http.Error(w, msg, http.StatusBadRequest) return } response.SetGroupVersionKind(*gvk) response.Response = &v1.AdmissionResponse{ UID: admissionReview.Request.UID, } rawExtension := admissionReview.Request.Object var peer kilo.Peer if err := json.Unmarshal(rawExtension.Raw, &peer); err != nil { errorCounter.Inc() msg := fmt.Sprintf("could not unmarshal extension to peer spec: %v:", err) level.Error(logger).Log("err", msg) http.Error(w, msg, http.StatusBadRequest) return } if err := peer.Validate(); err == nil { level.Debug(logger).Log("msg", "got valid peer spec", "spec", peer.Spec, "name", peer.ObjectMeta.Name) validationCounter.With(prometheus.Labels{"operation": string(admissionReview.Request.Operation), "response": "allowed"}).Inc() response.Response.Allowed = true } else { level.Debug(logger).Log("msg", "got invalid peer spec", "spec", peer.Spec, "name", peer.ObjectMeta.Name) validationCounter.With(prometheus.Labels{"operation": string(admissionReview.Request.Operation), "response": "denied"}).Inc() response.Response.Result = &metav1.Status{ Message: err.Error(), } } res, err := json.Marshal(response) if err != nil { errorCounter.Inc() msg := fmt.Sprintf("failed to marshal response: %v", err) level.Error(logger).Log("err", msg) http.Error(w, msg, http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") if _, err := w.Write(res); err != nil { level.Error(logger).Log("err", err, "msg", "failed to write response") } } func metricsMiddleWare(path string, next func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { requestCounter.With(prometheus.Labels{"method": r.Method, "handler": path}).Inc() next(w, r) } } func webhook(_ *cobra.Command, _ []string) error { if printVersion { fmt.Println(version.Version) os.Exit(0) } registry.MustRegister( errorCounter, validationCounter, requestCounter, ) ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() }() var g run.Group g.Add(run.SignalHandler(ctx, syscall.SIGINT, syscall.SIGTERM)) { mm := http.NewServeMux() mm.Handle("/metrics", promhttp.HandlerFor(registry, promhttp.HandlerOpts{})) msrv := &http.Server{ Addr: metricsAddr, Handler: mm, } g.Add( func() error { level.Info(logger).Log("msg", "starting metrics server", "address", msrv.Addr) err := msrv.ListenAndServe() level.Info(logger).Log("msg", "metrics server exited", "err", err) return err }, func(err error) { var serr run.SignalError if ok := errors.As(err, &serr); ok { level.Info(logger).Log("msg", "received signal", "signal", serr.Signal.String(), "err", err.Error()) } else { level.Error(logger).Log("msg", "received error", "err", err.Error()) } level.Info(logger).Log("msg", "shutting down metrics server gracefully") ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer func() { cancel() }() if err := msrv.Shutdown(ctx); err != nil { level.Error(logger).Log("msg", "failed to shut down metrics server gracefully", "err", err.Error()) msrv.Close() } }, ) } { mux := http.NewServeMux() mux.HandleFunc("/validate", metricsMiddleWare("/validate", validationHandler)) srv := &http.Server{ Addr: listenAddr, Handler: mux, } g.Add( func() error { level.Info(logger).Log("msg", "starting webhook server", "address", srv.Addr) err := srv.ListenAndServeTLS(certPath, keyPath) level.Info(logger).Log("msg", "webhook server exited", "err", err) return err }, func(err error) { var serr run.SignalError if ok := errors.As(err, &serr); ok { level.Info(logger).Log("msg", "received signal", "signal", serr.Signal.String(), "err", err.Error()) } else { level.Error(logger).Log("msg", "received error", "err", err.Error()) } level.Info(logger).Log("msg", "shutting down webhook server gracefully") ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer func() { cancel() }() if err := srv.Shutdown(ctx); err != nil { level.Error(logger).Log("msg", "failed to shut down webhook server gracefully", "err", err.Error()) srv.Close() } }, ) } err := g.Run() var serr run.SignalError if ok := errors.As(err, &serr); ok { return nil } return err }