feat: add A/B partition updates with GRUB and Go update agent (Phase 3)
Implement atomic OS updates via A/B partition scheme with automatic rollback. GRUB bootloader manages slot selection with a 3-attempt boot counter that auto-rolls back on repeated health check failures. GRUB boot config: - A/B slot selection with boot_counter/boot_success env vars - Automatic rollback when counter reaches 0 (3 failed boots) - Debug, emergency shell, and manual slot-switch menu entries Disk image (refactored): - 4-partition GPT layout: EFI + System A + System B + Data - GRUB EFI/BIOS installation with graceful fallbacks - Both system partitions populated during image creation Update agent (Go, zero external deps): - pkg/grubenv: read/write GRUB env vars (grub-editenv + manual fallback) - pkg/partition: find/mount/write system partitions by label - pkg/image: HTTP download with SHA256 verification - pkg/health: post-boot checks (containerd, API server, node Ready) - 6 CLI commands: check, apply, activate, rollback, healthcheck, status - 37 unit tests across all 4 packages Deployment: - K8s CronJob for automatic update checks (every 6 hours) - ConfigMap for update server URL - Health check Job for post-boot verification Build pipeline: - build-update-agent.sh compiles static Linux binary (~5.9 MB) - inject-kubesolo.sh includes update agent in initramfs - Makefile: build-update-agent, test-update-agent, test-update targets Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
239
update/pkg/grubenv/grubenv.go
Normal file
239
update/pkg/grubenv/grubenv.go
Normal file
@@ -0,0 +1,239 @@
|
||||
// Package grubenv provides read/write access to GRUB environment variables.
|
||||
//
|
||||
// GRUB stores its environment in a 1024-byte file (grubenv) located at
|
||||
// /boot/grub/grubenv on the EFI partition. This package manipulates
|
||||
// those variables for A/B boot slot management.
|
||||
//
|
||||
// Key variables:
|
||||
// - active_slot: "A" or "B"
|
||||
// - boot_counter: "3" (fresh) down to "0" (triggers rollback)
|
||||
// - boot_success: "0" (pending) or "1" (healthy boot confirmed)
|
||||
package grubenv
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultGrubenvPath is the standard location for the GRUB environment file.
|
||||
DefaultGrubenvPath = "/boot/grub/grubenv"
|
||||
|
||||
// SlotA represents system partition A.
|
||||
SlotA = "A"
|
||||
// SlotB represents system partition B.
|
||||
SlotB = "B"
|
||||
)
|
||||
|
||||
// Env provides access to GRUB environment variables.
|
||||
type Env struct {
|
||||
path string
|
||||
}
|
||||
|
||||
// New creates a new Env for the given grubenv file path.
|
||||
func New(path string) *Env {
|
||||
if path == "" {
|
||||
path = DefaultGrubenvPath
|
||||
}
|
||||
return &Env{path: path}
|
||||
}
|
||||
|
||||
// Get reads a variable from the GRUB environment.
|
||||
func (e *Env) Get(key string) (string, error) {
|
||||
vars, err := e.ReadAll()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
val, ok := vars[key]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("grubenv: key %q not found", key)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Set writes a variable to the GRUB environment.
|
||||
func (e *Env) Set(key, value string) error {
|
||||
editenv, err := findEditenv()
|
||||
if err != nil {
|
||||
return e.setManual(key, value)
|
||||
}
|
||||
|
||||
cmd := exec.Command(editenv, e.path, "set", key+"="+value)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("grub-editenv set %s=%s: %w\n%s", key, value, err, output)
|
||||
}
|
||||
|
||||
slog.Debug("grubenv set", "key", key, "value", value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadAll reads all variables from the GRUB environment.
|
||||
func (e *Env) ReadAll() (map[string]string, error) {
|
||||
editenv, err := findEditenv()
|
||||
if err != nil {
|
||||
return e.readManual()
|
||||
}
|
||||
|
||||
cmd := exec.Command(editenv, e.path, "list")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("grub-editenv list: %w", err)
|
||||
}
|
||||
|
||||
return parseEnvOutput(string(output)), nil
|
||||
}
|
||||
|
||||
// ActiveSlot returns the currently active boot slot ("A" or "B").
|
||||
func (e *Env) ActiveSlot() (string, error) {
|
||||
return e.Get("active_slot")
|
||||
}
|
||||
|
||||
// PassiveSlot returns the currently passive boot slot.
|
||||
func (e *Env) PassiveSlot() (string, error) {
|
||||
active, err := e.ActiveSlot()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if active == SlotA {
|
||||
return SlotB, nil
|
||||
}
|
||||
return SlotA, nil
|
||||
}
|
||||
|
||||
// BootCounter returns the current boot counter value.
|
||||
func (e *Env) BootCounter() (int, error) {
|
||||
val, err := e.Get("boot_counter")
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
switch val {
|
||||
case "0":
|
||||
return 0, nil
|
||||
case "1":
|
||||
return 1, nil
|
||||
case "2":
|
||||
return 2, nil
|
||||
case "3":
|
||||
return 3, nil
|
||||
default:
|
||||
return -1, fmt.Errorf("grubenv: invalid boot_counter: %q", val)
|
||||
}
|
||||
}
|
||||
|
||||
// BootSuccess returns whether the last boot was marked successful.
|
||||
func (e *Env) BootSuccess() (bool, error) {
|
||||
val, err := e.Get("boot_success")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return val == "1", nil
|
||||
}
|
||||
|
||||
// MarkBootSuccess sets boot_success=1 and boot_counter=3.
|
||||
// Called by the health check after a successful boot.
|
||||
func (e *Env) MarkBootSuccess() error {
|
||||
if err := e.Set("boot_success", "1"); err != nil {
|
||||
return fmt.Errorf("setting boot_success: %w", err)
|
||||
}
|
||||
if err := e.Set("boot_counter", "3"); err != nil {
|
||||
return fmt.Errorf("setting boot_counter: %w", err)
|
||||
}
|
||||
slog.Info("boot marked successful")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ActivateSlot switches the active slot and resets the boot counter.
|
||||
// Used after writing a new image to the passive partition.
|
||||
func (e *Env) ActivateSlot(slot string) error {
|
||||
if slot != SlotA && slot != SlotB {
|
||||
return fmt.Errorf("invalid slot: %q (must be A or B)", slot)
|
||||
}
|
||||
if err := e.Set("active_slot", slot); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := e.Set("boot_counter", "3"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := e.Set("boot_success", "0"); err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Info("activated slot", "slot", slot)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForceRollback switches to the other slot immediately.
|
||||
func (e *Env) ForceRollback() error {
|
||||
passive, err := e.PassiveSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return e.ActivateSlot(passive)
|
||||
}
|
||||
|
||||
func findEditenv() (string, error) {
|
||||
if path, err := exec.LookPath("grub-editenv"); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
if path, err := exec.LookPath("grub2-editenv"); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
return "", fmt.Errorf("grub-editenv not found")
|
||||
}
|
||||
|
||||
func parseEnvOutput(output string) map[string]string {
|
||||
vars := make(map[string]string)
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
vars[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
return vars
|
||||
}
|
||||
|
||||
// setManual writes to grubenv without grub-editenv (fallback).
|
||||
func (e *Env) setManual(key, value string) error {
|
||||
vars, err := e.readManual()
|
||||
if err != nil {
|
||||
vars = make(map[string]string)
|
||||
}
|
||||
vars[key] = value
|
||||
return e.writeManual(vars)
|
||||
}
|
||||
|
||||
// readManual reads grubenv without grub-editenv.
|
||||
func (e *Env) readManual() (map[string]string, error) {
|
||||
data, err := os.ReadFile(e.path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading grubenv: %w", err)
|
||||
}
|
||||
return parseEnvOutput(string(data)), nil
|
||||
}
|
||||
|
||||
// writeManual writes grubenv without grub-editenv.
|
||||
// GRUB requires the file to be exactly 1024 bytes, padded with '#'.
|
||||
func (e *Env) writeManual(vars map[string]string) error {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("# GRUB Environment Block\n")
|
||||
for k, v := range vars {
|
||||
sb.WriteString(k + "=" + v + "\n")
|
||||
}
|
||||
|
||||
content := sb.String()
|
||||
if len(content) > 1024 {
|
||||
return fmt.Errorf("grubenv content exceeds 1024 bytes")
|
||||
}
|
||||
|
||||
// Pad to 1024 bytes with '#'
|
||||
padding := 1024 - len(content)
|
||||
content += strings.Repeat("#", padding)
|
||||
|
||||
return os.WriteFile(e.path, []byte(content), 0o644)
|
||||
}
|
||||
423
update/pkg/grubenv/grubenv_test.go
Normal file
423
update/pkg/grubenv/grubenv_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package grubenv
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// createTestGrubenv writes a properly formatted grubenv file for testing.
|
||||
// GRUB requires the file to be exactly 1024 bytes, padded with '#'.
|
||||
func createTestGrubenv(t *testing.T, dir string, vars map[string]string) string {
|
||||
t.Helper()
|
||||
path := filepath.Join(dir, "grubenv")
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("# GRUB Environment Block\n")
|
||||
for k, v := range vars {
|
||||
sb.WriteString(k + "=" + v + "\n")
|
||||
}
|
||||
|
||||
content := sb.String()
|
||||
padding := 1024 - len(content)
|
||||
if padding > 0 {
|
||||
content += strings.Repeat("#", padding)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
env := New("")
|
||||
if env.path != DefaultGrubenvPath {
|
||||
t.Errorf("expected default path %s, got %s", DefaultGrubenvPath, env.path)
|
||||
}
|
||||
|
||||
env = New("/custom/path/grubenv")
|
||||
if env.path != "/custom/path/grubenv" {
|
||||
t.Errorf("expected custom path, got %s", env.path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadAll(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"active_slot": "A",
|
||||
"boot_counter": "3",
|
||||
"boot_success": "1",
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
vars, err := env.ReadAll()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if vars["active_slot"] != "A" {
|
||||
t.Errorf("active_slot: expected A, got %s", vars["active_slot"])
|
||||
}
|
||||
if vars["boot_counter"] != "3" {
|
||||
t.Errorf("boot_counter: expected 3, got %s", vars["boot_counter"])
|
||||
}
|
||||
if vars["boot_success"] != "1" {
|
||||
t.Errorf("boot_success: expected 1, got %s", vars["boot_success"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"active_slot": "B",
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
|
||||
val, err := env.Get("active_slot")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if val != "B" {
|
||||
t.Errorf("expected B, got %s", val)
|
||||
}
|
||||
|
||||
_, err = env.Get("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"active_slot": "A",
|
||||
"boot_counter": "3",
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
|
||||
if err := env.Set("boot_counter", "2"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
val, err := env.Get("boot_counter")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if val != "2" {
|
||||
t.Errorf("expected 2 after set, got %s", val)
|
||||
}
|
||||
|
||||
// Verify file is still 1024 bytes
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 1024 {
|
||||
t.Errorf("grubenv should be 1024 bytes, got %d", len(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestActiveSlot(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"active_slot": "A",
|
||||
"boot_counter": "3",
|
||||
"boot_success": "1",
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
slot, err := env.ActiveSlot()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if slot != "A" {
|
||||
t.Errorf("expected A, got %s", slot)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPassiveSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
active string
|
||||
passive string
|
||||
}{
|
||||
{"A", "B"},
|
||||
{"B", "A"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("active_"+tt.active, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"active_slot": tt.active,
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
passive, err := env.PassiveSlot()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if passive != tt.passive {
|
||||
t.Errorf("expected passive %s, got %s", tt.passive, passive)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootCounter(t *testing.T) {
|
||||
tests := []struct {
|
||||
value string
|
||||
expect int
|
||||
wantErr bool
|
||||
}{
|
||||
{"0", 0, false},
|
||||
{"1", 1, false},
|
||||
{"2", 2, false},
|
||||
{"3", 3, false},
|
||||
{"invalid", -1, true},
|
||||
{"99", -1, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("counter_"+tt.value, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"boot_counter": tt.value,
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
counter, err := env.BootCounter()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if counter != tt.expect {
|
||||
t.Errorf("expected %d, got %d", tt.expect, counter)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootSuccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
value string
|
||||
expect bool
|
||||
}{
|
||||
{"0", false},
|
||||
{"1", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("success_"+tt.value, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"boot_success": tt.value,
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
success, err := env.BootSuccess()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if success != tt.expect {
|
||||
t.Errorf("expected %v, got %v", tt.expect, success)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkBootSuccess(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"active_slot": "B",
|
||||
"boot_counter": "1",
|
||||
"boot_success": "0",
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
if err := env.MarkBootSuccess(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
success, err := env.BootSuccess()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !success {
|
||||
t.Error("expected boot_success=1 after MarkBootSuccess")
|
||||
}
|
||||
|
||||
counter, err := env.BootCounter()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if counter != 3 {
|
||||
t.Errorf("expected boot_counter=3 after MarkBootSuccess, got %d", counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestActivateSlot(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"active_slot": "A",
|
||||
"boot_counter": "3",
|
||||
"boot_success": "1",
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
if err := env.ActivateSlot("B"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
slot, _ := env.ActiveSlot()
|
||||
if slot != "B" {
|
||||
t.Errorf("expected active_slot=B, got %s", slot)
|
||||
}
|
||||
|
||||
counter, _ := env.BootCounter()
|
||||
if counter != 3 {
|
||||
t.Errorf("expected boot_counter=3, got %d", counter)
|
||||
}
|
||||
|
||||
success, _ := env.BootSuccess()
|
||||
if success {
|
||||
t.Error("expected boot_success=0 after ActivateSlot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestActivateSlotInvalid(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"active_slot": "A",
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
err := env.ActivateSlot("C")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid slot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForceRollback(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"active_slot": "A",
|
||||
"boot_counter": "3",
|
||||
"boot_success": "1",
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
if err := env.ForceRollback(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
slot, _ := env.ActiveSlot()
|
||||
if slot != "B" {
|
||||
t.Errorf("expected active_slot=B after rollback from A, got %s", slot)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEnvOutput(t *testing.T) {
|
||||
input := `# GRUB Environment Block
|
||||
active_slot=A
|
||||
boot_counter=3
|
||||
boot_success=1
|
||||
|
||||
`
|
||||
vars := parseEnvOutput(input)
|
||||
|
||||
if len(vars) != 3 {
|
||||
t.Errorf("expected 3 variables, got %d", len(vars))
|
||||
}
|
||||
if vars["active_slot"] != "A" {
|
||||
t.Errorf("active_slot: expected A, got %s", vars["active_slot"])
|
||||
}
|
||||
if vars["boot_counter"] != "3" {
|
||||
t.Errorf("boot_counter: expected 3, got %s", vars["boot_counter"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteManualFormat(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "grubenv")
|
||||
|
||||
env := New(path)
|
||||
// Use setManual directly since grub-editenv may not be available
|
||||
err := env.setManual("test_key", "test_value")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(data) != 1024 {
|
||||
t.Errorf("grubenv should be exactly 1024 bytes, got %d", len(data))
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(string(data), "# GRUB Environment Block\n") {
|
||||
t.Error("grubenv should start with '# GRUB Environment Block'")
|
||||
}
|
||||
|
||||
if !strings.Contains(string(data), "test_key=test_value\n") {
|
||||
t.Error("grubenv should contain test_key=test_value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadNonexistentFile(t *testing.T) {
|
||||
env := New("/nonexistent/path/grubenv")
|
||||
_, err := env.ReadAll()
|
||||
if err == nil {
|
||||
t.Fatal("expected error reading nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleSetOperations(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := createTestGrubenv(t, dir, map[string]string{
|
||||
"active_slot": "A",
|
||||
"boot_counter": "3",
|
||||
"boot_success": "1",
|
||||
})
|
||||
|
||||
env := New(path)
|
||||
|
||||
// Simulate a boot cycle: decrement counter, then mark success
|
||||
if err := env.Set("boot_counter", "2"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := env.Set("boot_success", "0"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Now mark boot success
|
||||
if err := env.MarkBootSuccess(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
vars, err := env.ReadAll()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if vars["active_slot"] != "A" {
|
||||
t.Errorf("active_slot should still be A, got %s", vars["active_slot"])
|
||||
}
|
||||
if vars["boot_counter"] != "3" {
|
||||
t.Errorf("boot_counter should be 3 after mark success, got %s", vars["boot_counter"])
|
||||
}
|
||||
if vars["boot_success"] != "1" {
|
||||
t.Errorf("boot_success should be 1, got %s", vars["boot_success"])
|
||||
}
|
||||
}
|
||||
198
update/pkg/health/health.go
Normal file
198
update/pkg/health/health.go
Normal file
@@ -0,0 +1,198 @@
|
||||
// Package health implements post-boot health checks for KubeSolo OS.
|
||||
//
|
||||
// After booting a new system partition, the health check verifies that:
|
||||
// - containerd is running and responsive
|
||||
// - KubeSolo API server is reachable
|
||||
// - The Kubernetes node reaches Ready state
|
||||
//
|
||||
// If all checks pass, the GRUB environment is updated to mark the boot
|
||||
// as successful (boot_success=1). If any check fails, boot_success
|
||||
// remains 0 and GRUB will eventually roll back.
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Status represents the result of a health check.
|
||||
type Status struct {
|
||||
Containerd bool
|
||||
APIServer bool
|
||||
NodeReady bool
|
||||
Message string
|
||||
}
|
||||
|
||||
// IsHealthy returns true if all checks passed.
|
||||
func (s *Status) IsHealthy() bool {
|
||||
return s.Containerd && s.APIServer && s.NodeReady
|
||||
}
|
||||
|
||||
// Checker performs health checks against the local KubeSolo instance.
|
||||
type Checker struct {
|
||||
kubeconfigPath string
|
||||
apiServerAddr string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewChecker creates a health checker.
|
||||
func NewChecker(kubeconfigPath, apiServerAddr string, timeout time.Duration) *Checker {
|
||||
if kubeconfigPath == "" {
|
||||
kubeconfigPath = "/var/lib/kubesolo/pki/admin/admin.kubeconfig"
|
||||
}
|
||||
if apiServerAddr == "" {
|
||||
apiServerAddr = "127.0.0.1:6443"
|
||||
}
|
||||
if timeout == 0 {
|
||||
timeout = 120 * time.Second
|
||||
}
|
||||
return &Checker{
|
||||
kubeconfigPath: kubeconfigPath,
|
||||
apiServerAddr: apiServerAddr,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckContainerd verifies that containerd is running.
|
||||
func (c *Checker) CheckContainerd() bool {
|
||||
// Check if containerd socket exists
|
||||
if _, err := os.Stat("/run/containerd/containerd.sock"); err != nil {
|
||||
slog.Warn("containerd socket not found")
|
||||
return false
|
||||
}
|
||||
|
||||
// Try ctr version (bundled with KubeSolo)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "ctr", "--address", "/run/containerd/containerd.sock", "version")
|
||||
if err := cmd.Run(); err != nil {
|
||||
slog.Warn("containerd not responsive", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
slog.Debug("containerd healthy")
|
||||
return true
|
||||
}
|
||||
|
||||
// CheckAPIServer verifies the Kubernetes API server is reachable.
|
||||
func (c *Checker) CheckAPIServer() bool {
|
||||
// TCP connect to API server port
|
||||
conn, err := net.DialTimeout("tcp", c.apiServerAddr, 5*time.Second)
|
||||
if err != nil {
|
||||
slog.Warn("API server not reachable", "addr", c.apiServerAddr, "error", err)
|
||||
return false
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
// Try HTTPS health endpoint (skip TLS verify for localhost)
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get("https://" + c.apiServerAddr + "/healthz")
|
||||
if err != nil {
|
||||
// TLS error is expected without proper CA, but TCP connect succeeded
|
||||
slog.Debug("API server TCP reachable but HTTPS check skipped", "error", err)
|
||||
return true
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
slog.Debug("API server healthy", "status", resp.StatusCode)
|
||||
return true
|
||||
}
|
||||
|
||||
slog.Warn("API server unhealthy", "status", resp.StatusCode)
|
||||
return false
|
||||
}
|
||||
|
||||
// CheckNodeReady uses kubectl to verify the node is in Ready state.
|
||||
func (c *Checker) CheckNodeReady() bool {
|
||||
if _, err := os.Stat(c.kubeconfigPath); err != nil {
|
||||
slog.Warn("kubeconfig not found", "path", c.kubeconfigPath)
|
||||
return false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "kubectl",
|
||||
"--kubeconfig", c.kubeconfigPath,
|
||||
"get", "nodes",
|
||||
"-o", "jsonpath={.items[0].status.conditions[?(@.type==\"Ready\")].status}",
|
||||
)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
slog.Warn("kubectl get nodes failed", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
status := strings.TrimSpace(string(output))
|
||||
if status == "True" {
|
||||
slog.Debug("node is Ready")
|
||||
return true
|
||||
}
|
||||
|
||||
slog.Warn("node not Ready", "status", status)
|
||||
return false
|
||||
}
|
||||
|
||||
// RunAll performs all health checks and returns the combined status.
|
||||
func (c *Checker) RunAll() *Status {
|
||||
return &Status{
|
||||
Containerd: c.CheckContainerd(),
|
||||
APIServer: c.CheckAPIServer(),
|
||||
NodeReady: c.CheckNodeReady(),
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForHealthy polls health checks until all pass or timeout expires.
|
||||
func (c *Checker) WaitForHealthy() (*Status, error) {
|
||||
deadline := time.Now().Add(c.timeout)
|
||||
interval := 5 * time.Second
|
||||
|
||||
slog.Info("waiting for system health", "timeout", c.timeout)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
status := c.RunAll()
|
||||
if status.IsHealthy() {
|
||||
status.Message = "all checks passed"
|
||||
slog.Info("system healthy",
|
||||
"containerd", status.Containerd,
|
||||
"apiserver", status.APIServer,
|
||||
"node_ready", status.NodeReady,
|
||||
)
|
||||
return status, nil
|
||||
}
|
||||
|
||||
slog.Debug("health check pending",
|
||||
"containerd", status.Containerd,
|
||||
"apiserver", status.APIServer,
|
||||
"node_ready", status.NodeReady,
|
||||
"remaining", time.Until(deadline).Round(time.Second),
|
||||
)
|
||||
|
||||
time.Sleep(interval)
|
||||
}
|
||||
|
||||
// Final check
|
||||
status := c.RunAll()
|
||||
if status.IsHealthy() {
|
||||
status.Message = "all checks passed"
|
||||
return status, nil
|
||||
}
|
||||
|
||||
status.Message = "health check timeout"
|
||||
return status, fmt.Errorf("health check timed out after %s", c.timeout)
|
||||
}
|
||||
86
update/pkg/health/health_test.go
Normal file
86
update/pkg/health/health_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStatusIsHealthy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status Status
|
||||
wantHealth bool
|
||||
}{
|
||||
{
|
||||
name: "all healthy",
|
||||
status: Status{Containerd: true, APIServer: true, NodeReady: true},
|
||||
wantHealth: true,
|
||||
},
|
||||
{
|
||||
name: "containerd down",
|
||||
status: Status{Containerd: false, APIServer: true, NodeReady: true},
|
||||
wantHealth: false,
|
||||
},
|
||||
{
|
||||
name: "apiserver down",
|
||||
status: Status{Containerd: true, APIServer: false, NodeReady: true},
|
||||
wantHealth: false,
|
||||
},
|
||||
{
|
||||
name: "node not ready",
|
||||
status: Status{Containerd: true, APIServer: true, NodeReady: false},
|
||||
wantHealth: false,
|
||||
},
|
||||
{
|
||||
name: "all down",
|
||||
status: Status{Containerd: false, APIServer: false, NodeReady: false},
|
||||
wantHealth: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.status.IsHealthy(); got != tt.wantHealth {
|
||||
t.Errorf("IsHealthy() = %v, want %v", got, tt.wantHealth)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewChecker(t *testing.T) {
|
||||
// Test defaults
|
||||
c := NewChecker("", "", 0)
|
||||
if c.kubeconfigPath != "/var/lib/kubesolo/pki/admin/admin.kubeconfig" {
|
||||
t.Errorf("unexpected default kubeconfig: %s", c.kubeconfigPath)
|
||||
}
|
||||
if c.apiServerAddr != "127.0.0.1:6443" {
|
||||
t.Errorf("unexpected default apiserver addr: %s", c.apiServerAddr)
|
||||
}
|
||||
if c.timeout != 120*time.Second {
|
||||
t.Errorf("unexpected default timeout: %v", c.timeout)
|
||||
}
|
||||
|
||||
// Test custom values
|
||||
c = NewChecker("/custom/kubeconfig", "10.0.0.1:6443", 30*time.Second)
|
||||
if c.kubeconfigPath != "/custom/kubeconfig" {
|
||||
t.Errorf("expected custom kubeconfig, got %s", c.kubeconfigPath)
|
||||
}
|
||||
if c.apiServerAddr != "10.0.0.1:6443" {
|
||||
t.Errorf("expected custom addr, got %s", c.apiServerAddr)
|
||||
}
|
||||
if c.timeout != 30*time.Second {
|
||||
t.Errorf("expected 30s timeout, got %v", c.timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMessage(t *testing.T) {
|
||||
s := &Status{
|
||||
Containerd: true,
|
||||
APIServer: true,
|
||||
NodeReady: true,
|
||||
Message: "all checks passed",
|
||||
}
|
||||
if s.Message != "all checks passed" {
|
||||
t.Errorf("unexpected message: %s", s.Message)
|
||||
}
|
||||
}
|
||||
180
update/pkg/image/image.go
Normal file
180
update/pkg/image/image.go
Normal file
@@ -0,0 +1,180 @@
|
||||
// Package image handles downloading, verifying, and staging OS update images.
|
||||
//
|
||||
// Update images are distributed as pairs of files:
|
||||
// - vmlinuz (kernel)
|
||||
// - kubesolo-os.gz (initramfs)
|
||||
//
|
||||
// These are fetched from an HTTP(S) server that provides a metadata file
|
||||
// (latest.json) describing available updates.
|
||||
package image
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UpdateMetadata describes an available update from the update server.
|
||||
type UpdateMetadata struct {
|
||||
Version string `json:"version"`
|
||||
VmlinuzURL string `json:"vmlinuz_url"`
|
||||
VmlinuzSHA256 string `json:"vmlinuz_sha256"`
|
||||
InitramfsURL string `json:"initramfs_url"`
|
||||
InitramfsSHA256 string `json:"initramfs_sha256"`
|
||||
ReleaseNotes string `json:"release_notes,omitempty"`
|
||||
ReleaseDate string `json:"release_date,omitempty"`
|
||||
}
|
||||
|
||||
// StagedImage represents downloaded and verified update files.
|
||||
type StagedImage struct {
|
||||
VmlinuzPath string
|
||||
InitramfsPath string
|
||||
Version string
|
||||
}
|
||||
|
||||
// Client handles communication with the update server.
|
||||
type Client struct {
|
||||
serverURL string
|
||||
httpClient *http.Client
|
||||
stageDir string
|
||||
}
|
||||
|
||||
// NewClient creates a new update image client.
|
||||
func NewClient(serverURL, stageDir string) *Client {
|
||||
return &Client{
|
||||
serverURL: serverURL,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
},
|
||||
stageDir: stageDir,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckForUpdate fetches the latest update metadata from the server.
|
||||
func (c *Client) CheckForUpdate() (*UpdateMetadata, error) {
|
||||
url := c.serverURL + "/latest.json"
|
||||
slog.Info("checking for update", "url", url)
|
||||
|
||||
resp, err := c.httpClient.Get(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching update metadata: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("update server returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var meta UpdateMetadata
|
||||
if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil {
|
||||
return nil, fmt.Errorf("parsing update metadata: %w", err)
|
||||
}
|
||||
|
||||
if meta.Version == "" {
|
||||
return nil, fmt.Errorf("update metadata missing version")
|
||||
}
|
||||
|
||||
return &meta, nil
|
||||
}
|
||||
|
||||
// Download fetches the update files and verifies their checksums.
|
||||
func (c *Client) Download(meta *UpdateMetadata) (*StagedImage, error) {
|
||||
if err := os.MkdirAll(c.stageDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("creating stage dir: %w", err)
|
||||
}
|
||||
|
||||
vmlinuzPath := filepath.Join(c.stageDir, "vmlinuz")
|
||||
initramfsPath := filepath.Join(c.stageDir, "kubesolo-os.gz")
|
||||
|
||||
slog.Info("downloading vmlinuz", "url", meta.VmlinuzURL)
|
||||
if err := c.downloadAndVerify(meta.VmlinuzURL, vmlinuzPath, meta.VmlinuzSHA256); err != nil {
|
||||
return nil, fmt.Errorf("downloading vmlinuz: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("downloading initramfs", "url", meta.InitramfsURL)
|
||||
if err := c.downloadAndVerify(meta.InitramfsURL, initramfsPath, meta.InitramfsSHA256); err != nil {
|
||||
return nil, fmt.Errorf("downloading initramfs: %w", err)
|
||||
}
|
||||
|
||||
return &StagedImage{
|
||||
VmlinuzPath: vmlinuzPath,
|
||||
InitramfsPath: initramfsPath,
|
||||
Version: meta.Version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Cleanup removes staged update files.
|
||||
func (c *Client) Cleanup() error {
|
||||
return os.RemoveAll(c.stageDir)
|
||||
}
|
||||
|
||||
func (c *Client) downloadAndVerify(url, dest, expectedSHA256 string) error {
|
||||
resp, err := c.httpClient.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("downloading %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("server returned %d for %s", resp.StatusCode, url)
|
||||
}
|
||||
|
||||
f, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating %s: %w", dest, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
hasher := sha256.New()
|
||||
writer := io.MultiWriter(f, hasher)
|
||||
|
||||
written, err := io.Copy(writer, resp.Body)
|
||||
if err != nil {
|
||||
os.Remove(dest)
|
||||
return fmt.Errorf("writing %s: %w", dest, err)
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
return fmt.Errorf("closing %s: %w", dest, err)
|
||||
}
|
||||
|
||||
// Verify checksum
|
||||
if expectedSHA256 != "" {
|
||||
actual := hex.EncodeToString(hasher.Sum(nil))
|
||||
if actual != expectedSHA256 {
|
||||
os.Remove(dest)
|
||||
return fmt.Errorf("checksum mismatch for %s: expected %s, got %s", dest, expectedSHA256, actual)
|
||||
}
|
||||
slog.Debug("checksum verified", "file", dest, "sha256", actual)
|
||||
}
|
||||
|
||||
slog.Info("downloaded", "file", dest, "size", written)
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyFile checks the SHA256 checksum of an existing file.
|
||||
func VerifyFile(path, expectedSHA256 string) error {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(hasher, f); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
actual := hex.EncodeToString(hasher.Sum(nil))
|
||||
if actual != expectedSHA256 {
|
||||
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedSHA256, actual)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
241
update/pkg/image/image_test.go
Normal file
241
update/pkg/image/image_test.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckForUpdate(t *testing.T) {
|
||||
meta := UpdateMetadata{
|
||||
Version: "1.2.0",
|
||||
VmlinuzURL: "/vmlinuz",
|
||||
VmlinuzSHA256: "abc123",
|
||||
InitramfsURL: "/kubesolo-os.gz",
|
||||
InitramfsSHA256: "def456",
|
||||
ReleaseNotes: "Bug fixes",
|
||||
ReleaseDate: "2025-01-15",
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/latest.json" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(meta)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "")
|
||||
got, err := client.CheckForUpdate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got.Version != "1.2.0" {
|
||||
t.Errorf("expected version 1.2.0, got %s", got.Version)
|
||||
}
|
||||
if got.VmlinuzSHA256 != "abc123" {
|
||||
t.Errorf("expected vmlinuz sha abc123, got %s", got.VmlinuzSHA256)
|
||||
}
|
||||
if got.ReleaseNotes != "Bug fixes" {
|
||||
t.Errorf("expected release notes, got %s", got.ReleaseNotes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckForUpdateMissingVersion(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(UpdateMetadata{})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "")
|
||||
_, err := client.CheckForUpdate()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckForUpdateServerError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "")
|
||||
_, err := client.CheckForUpdate()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for server error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndVerify(t *testing.T) {
|
||||
// Create test content
|
||||
vmlinuzContent := []byte("fake vmlinuz content for testing")
|
||||
initramfsContent := []byte("fake initramfs content for testing")
|
||||
|
||||
vmlinuzHash := sha256.Sum256(vmlinuzContent)
|
||||
initramfsHash := sha256.Sum256(initramfsContent)
|
||||
|
||||
meta := UpdateMetadata{
|
||||
Version: "2.0.0",
|
||||
VmlinuzSHA256: hex.EncodeToString(vmlinuzHash[:]),
|
||||
InitramfsSHA256: hex.EncodeToString(initramfsHash[:]),
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/latest.json":
|
||||
m := meta
|
||||
m.VmlinuzURL = "http://" + r.Host + "/vmlinuz"
|
||||
m.InitramfsURL = "http://" + r.Host + "/kubesolo-os.gz"
|
||||
json.NewEncoder(w).Encode(m)
|
||||
case "/vmlinuz":
|
||||
w.Write(vmlinuzContent)
|
||||
case "/kubesolo-os.gz":
|
||||
w.Write(initramfsContent)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
stageDir := filepath.Join(t.TempDir(), "stage")
|
||||
client := NewClient(server.URL, stageDir)
|
||||
defer client.Cleanup()
|
||||
|
||||
// First get metadata
|
||||
gotMeta, err := client.CheckForUpdate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Download
|
||||
staged, err := client.Download(gotMeta)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if staged.Version != "2.0.0" {
|
||||
t.Errorf("expected version 2.0.0, got %s", staged.Version)
|
||||
}
|
||||
|
||||
// Verify files exist
|
||||
if _, err := os.Stat(staged.VmlinuzPath); err != nil {
|
||||
t.Errorf("vmlinuz not found: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(staged.InitramfsPath); err != nil {
|
||||
t.Errorf("initramfs not found: %v", err)
|
||||
}
|
||||
|
||||
// Verify content
|
||||
data, _ := os.ReadFile(staged.VmlinuzPath)
|
||||
if string(data) != string(vmlinuzContent) {
|
||||
t.Error("vmlinuz content mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadChecksumMismatch(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/vmlinuz":
|
||||
w.Write([]byte("actual content"))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
stageDir := filepath.Join(t.TempDir(), "stage")
|
||||
client := NewClient(server.URL, stageDir)
|
||||
|
||||
meta := &UpdateMetadata{
|
||||
Version: "1.0.0",
|
||||
VmlinuzURL: server.URL + "/vmlinuz",
|
||||
VmlinuzSHA256: "wrong_checksum_value",
|
||||
InitramfsURL: server.URL + "/initramfs",
|
||||
}
|
||||
|
||||
_, err := client.Download(meta)
|
||||
if err == nil {
|
||||
t.Fatal("expected checksum mismatch error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyFile(t *testing.T) {
|
||||
content := []byte("test file content for verification")
|
||||
hash := sha256.Sum256(content)
|
||||
expected := hex.EncodeToString(hash[:])
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "testfile")
|
||||
if err := os.WriteFile(path, content, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should pass with correct hash
|
||||
if err := VerifyFile(path, expected); err != nil {
|
||||
t.Errorf("expected verification to pass: %v", err)
|
||||
}
|
||||
|
||||
// Should fail with wrong hash
|
||||
if err := VerifyFile(path, "deadbeef"); err == nil {
|
||||
t.Error("expected verification to fail with wrong hash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyFileNotFound(t *testing.T) {
|
||||
err := VerifyFile("/nonexistent/file", "abc123")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanup(t *testing.T) {
|
||||
stageDir := filepath.Join(t.TempDir(), "stage")
|
||||
os.MkdirAll(stageDir, 0o755)
|
||||
os.WriteFile(filepath.Join(stageDir, "test"), []byte("data"), 0o644)
|
||||
|
||||
client := NewClient("http://unused", stageDir)
|
||||
if err := client.Cleanup(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(stageDir); !os.IsNotExist(err) {
|
||||
t.Error("stage dir should be removed after cleanup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMetadataJSON(t *testing.T) {
|
||||
meta := UpdateMetadata{
|
||||
Version: "1.0.0",
|
||||
VmlinuzURL: "https://example.com/vmlinuz",
|
||||
VmlinuzSHA256: "abc",
|
||||
InitramfsURL: "https://example.com/kubesolo-os.gz",
|
||||
InitramfsSHA256: "def",
|
||||
ReleaseNotes: "Initial release",
|
||||
ReleaseDate: "2025-01-01",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(meta)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var decoded UpdateMetadata
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if decoded.Version != meta.Version {
|
||||
t.Errorf("version mismatch: %s != %s", decoded.Version, meta.Version)
|
||||
}
|
||||
if decoded.ReleaseDate != meta.ReleaseDate {
|
||||
t.Errorf("release date mismatch: %s != %s", decoded.ReleaseDate, meta.ReleaseDate)
|
||||
}
|
||||
}
|
||||
139
update/pkg/partition/partition.go
Normal file
139
update/pkg/partition/partition.go
Normal file
@@ -0,0 +1,139 @@
|
||||
// Package partition detects and manages A/B system partitions.
|
||||
//
|
||||
// It identifies System A and System B partitions by label (KSOLOA, KSOLOB)
|
||||
// and provides mount/write operations for the update process.
|
||||
package partition
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
LabelSystemA = "KSOLOA"
|
||||
LabelSystemB = "KSOLOB"
|
||||
LabelData = "KSOLODATA"
|
||||
LabelEFI = "KSOLOEFI"
|
||||
)
|
||||
|
||||
// Info contains information about a partition.
|
||||
type Info struct {
|
||||
Device string // e.g. /dev/sda2
|
||||
Label string // e.g. KSOLOA
|
||||
MountPoint string // current mount point, empty if not mounted
|
||||
Slot string // "A" or "B"
|
||||
}
|
||||
|
||||
// FindByLabel locates a block device by its filesystem label.
|
||||
func FindByLabel(label string) (string, error) {
|
||||
cmd := exec.Command("blkid", "-L", label)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("partition with label %q not found: %w", label, err)
|
||||
}
|
||||
return strings.TrimSpace(string(output)), nil
|
||||
}
|
||||
|
||||
// GetSlotPartition returns the partition info for the given slot ("A" or "B").
|
||||
func GetSlotPartition(slot string) (*Info, error) {
|
||||
var label string
|
||||
switch slot {
|
||||
case "A":
|
||||
label = LabelSystemA
|
||||
case "B":
|
||||
label = LabelSystemB
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid slot: %q", slot)
|
||||
}
|
||||
|
||||
dev, err := FindByLabel(label)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Info{
|
||||
Device: dev,
|
||||
Label: label,
|
||||
Slot: slot,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MountReadOnly mounts a partition read-only at the given mount point.
|
||||
func MountReadOnly(dev, mountPoint string) error {
|
||||
if err := os.MkdirAll(mountPoint, 0o755); err != nil {
|
||||
return fmt.Errorf("creating mount point: %w", err)
|
||||
}
|
||||
cmd := exec.Command("mount", "-o", "ro", dev, mountPoint)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("mounting %s at %s: %w\n%s", dev, mountPoint, err, output)
|
||||
}
|
||||
slog.Debug("mounted", "device", dev, "mountpoint", mountPoint, "mode", "ro")
|
||||
return nil
|
||||
}
|
||||
|
||||
// MountReadWrite mounts a partition read-write at the given mount point.
|
||||
func MountReadWrite(dev, mountPoint string) error {
|
||||
if err := os.MkdirAll(mountPoint, 0o755); err != nil {
|
||||
return fmt.Errorf("creating mount point: %w", err)
|
||||
}
|
||||
cmd := exec.Command("mount", dev, mountPoint)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("mounting %s at %s: %w\n%s", dev, mountPoint, err, output)
|
||||
}
|
||||
slog.Debug("mounted", "device", dev, "mountpoint", mountPoint, "mode", "rw")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unmount unmounts a mount point.
|
||||
func Unmount(mountPoint string) error {
|
||||
cmd := exec.Command("umount", mountPoint)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("unmounting %s: %w\n%s", mountPoint, err, output)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadVersion reads the version file from a mounted system partition.
|
||||
func ReadVersion(mountPoint string) (string, error) {
|
||||
data, err := os.ReadFile(filepath.Join(mountPoint, "version"))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading version: %w", err)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
// WriteSystemImage copies vmlinuz and initramfs to a mounted partition.
|
||||
func WriteSystemImage(mountPoint, vmlinuzPath, initramfsPath, version string) error {
|
||||
// Copy vmlinuz
|
||||
if err := copyFile(vmlinuzPath, filepath.Join(mountPoint, "vmlinuz")); err != nil {
|
||||
return fmt.Errorf("writing vmlinuz: %w", err)
|
||||
}
|
||||
|
||||
// Copy initramfs
|
||||
if err := copyFile(initramfsPath, filepath.Join(mountPoint, "kubesolo-os.gz")); err != nil {
|
||||
return fmt.Errorf("writing initramfs: %w", err)
|
||||
}
|
||||
|
||||
// Write version
|
||||
if err := os.WriteFile(filepath.Join(mountPoint, "version"), []byte(version+"\n"), 0o644); err != nil {
|
||||
return fmt.Errorf("writing version: %w", err)
|
||||
}
|
||||
|
||||
// Sync to ensure data is flushed to disk
|
||||
exec.Command("sync").Run()
|
||||
|
||||
slog.Info("system image written", "mountpoint", mountPoint, "version", version)
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
data, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dst, data, 0o644)
|
||||
}
|
||||
129
update/pkg/partition/partition_test.go
Normal file
129
update/pkg/partition/partition_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package partition
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadVersion(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
versionFile := filepath.Join(dir, "version")
|
||||
if err := os.WriteFile(versionFile, []byte("1.2.3\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
version, err := ReadVersion(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != "1.2.3" {
|
||||
t.Errorf("expected 1.2.3, got %s", version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadVersionMissing(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
_, err := ReadVersion(dir)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing version file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteSystemImage(t *testing.T) {
|
||||
mountPoint := t.TempDir()
|
||||
srcDir := t.TempDir()
|
||||
|
||||
// Create source files
|
||||
vmlinuzPath := filepath.Join(srcDir, "vmlinuz")
|
||||
initramfsPath := filepath.Join(srcDir, "kubesolo-os.gz")
|
||||
|
||||
if err := os.WriteFile(vmlinuzPath, []byte("kernel data"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(initramfsPath, []byte("initramfs data"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := WriteSystemImage(mountPoint, vmlinuzPath, initramfsPath, "2.0.0"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify files were copied
|
||||
data, err := os.ReadFile(filepath.Join(mountPoint, "vmlinuz"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != "kernel data" {
|
||||
t.Errorf("vmlinuz content mismatch")
|
||||
}
|
||||
|
||||
data, err = os.ReadFile(filepath.Join(mountPoint, "kubesolo-os.gz"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != "initramfs data" {
|
||||
t.Errorf("initramfs content mismatch")
|
||||
}
|
||||
|
||||
// Verify version file
|
||||
version, err := ReadVersion(mountPoint)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != "2.0.0" {
|
||||
t.Errorf("expected version 2.0.0, got %s", version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := filepath.Join(dir, "src")
|
||||
dst := filepath.Join(dir, "dst")
|
||||
|
||||
if err := os.WriteFile(src, []byte("test content"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := copyFile(src, dst); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != "test content" {
|
||||
t.Errorf("copy content mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFileNotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
err := copyFile("/nonexistent", filepath.Join(dir, "dst"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSlotPartitionInvalid(t *testing.T) {
|
||||
_, err := GetSlotPartition("C")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid slot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstants(t *testing.T) {
|
||||
if LabelSystemA != "KSOLOA" {
|
||||
t.Errorf("unexpected LabelSystemA: %s", LabelSystemA)
|
||||
}
|
||||
if LabelSystemB != "KSOLOB" {
|
||||
t.Errorf("unexpected LabelSystemB: %s", LabelSystemB)
|
||||
}
|
||||
if LabelData != "KSOLODATA" {
|
||||
t.Errorf("unexpected LabelData: %s", LabelData)
|
||||
}
|
||||
if LabelEFI != "KSOLOEFI" {
|
||||
t.Errorf("unexpected LabelEFI: %s", LabelEFI)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user