""" KeyEncryptionService — Fernet-based encryption for BYO API keys. Uses MultiFernet to support key rotation: - primary key: active encryption key (all new values encrypted with this) - previous key: optional previous key (supports decryption during rotation window) The PLATFORM_ENCRYPTION_KEY environment variable must be a valid URL-safe base64-encoded 32-byte key, as generated by `Fernet.generate_key()`. Usage: from shared.crypto import KeyEncryptionService svc = KeyEncryptionService(primary_key=settings.platform_encryption_key) ciphertext = svc.encrypt("sk-my-secret-key") plaintext = svc.decrypt(ciphertext) new_cipher = svc.rotate(old_ciphertext) # re-encrypts with primary key """ from __future__ import annotations from cryptography.fernet import Fernet, MultiFernet class KeyEncryptionService: """ Encrypt and decrypt BYO API keys using Fernet symmetric encryption. Fernet guarantees: - AES-128-CBC with PKCS7 padding - HMAC-SHA256 authentication - Random IV per encryption call (produces different ciphertext each time) - Timestamp in token (can enforce TTL if desired) MultiFernet supports key rotation: - Encryption always uses the first (primary) key - Decryption tries all keys in order until one succeeds - rotate() decrypts with any key, re-encrypts with the primary key """ def __init__(self, primary_key: str, previous_key: str = "") -> None: """ Initialise the service with one or two Fernet keys. Args: primary_key: Active key for encryption and decryption. Must be a URL-safe base64-encoded 32-byte value (Fernet key). previous_key: Optional previous key retained only for decryption during a rotation window. Pass "" to omit. """ keys: list[Fernet] = [Fernet(primary_key.encode())] if previous_key: keys.append(Fernet(previous_key.encode())) self._multi = MultiFernet(keys) def encrypt(self, plaintext: str) -> str: """ Encrypt a plaintext string. Returns a URL-safe base64-encoded Fernet token (str). Calling encrypt() twice with the same plaintext produces different ciphertexts due to the random IV embedded in each Fernet token. """ return self._multi.encrypt(plaintext.encode()).decode() def decrypt(self, ciphertext: str) -> str: """ Decrypt a Fernet token back to the original plaintext. Raises: cryptography.fernet.InvalidToken: if the ciphertext is invalid, tampered, or cannot be decrypted by any of the known keys. """ return self._multi.decrypt(ciphertext.encode()).decode() def rotate(self, ciphertext: str) -> str: """ Re-encrypt an existing ciphertext with the current primary key. Useful for key rotation: after adding a new primary key and keeping the old key as previous_key, call rotate() on each stored ciphertext to migrate it to the new key. Once all values are rotated, the old key can be removed. Returns a new Fernet token encrypted with the primary key. Raises InvalidToken if the ciphertext cannot be decrypted. """ return self._multi.rotate(ciphertext.encode()).decode()