feat(01-foundation-01): Alembic migrations with RLS and tenant isolation tests

- alembic.ini + migrations/env.py: async SQLAlchemy migration setup using asyncpg
- migrations/versions/001_initial_schema.py: creates tenants, agents, channel_connections, portal_users
  - ENABLE + FORCE ROW LEVEL SECURITY on agents and channel_connections
  - RLS policy: tenant_id = current_setting('app.current_tenant', TRUE)::uuid
  - konstruct_app role created with SELECT/INSERT/UPDATE/DELETE on all tables
- packages/shared/shared/rls.py: idempotent configure_rls_hook, UUID-sanitized SET LOCAL
- tests/conftest.py: test_db_name (session-scoped), db_engine + db_session as konstruct_app
- tests/unit/test_normalize.py: 11 tests for KonstructMessage Slack normalization (CHAN-01)
- tests/unit/test_tenant_resolution.py: 7 tests for workspace_id → tenant resolution (TNNT-02)
- tests/unit/test_redis_namespacing.py: 15 tests for Redis key namespace isolation (TNNT-03)
- tests/integration/test_tenant_isolation.py: 7 tests proving RLS tenant isolation (TNNT-01)
  - tenant_b cannot see tenant_a's agents or channel_connections
  - FORCE ROW LEVEL SECURITY verified via pg_class.relforcerowsecurity
This commit is contained in:
2026-03-23 09:57:29 -06:00
parent 5714acf741
commit 47e78627fd
13 changed files with 1364 additions and 4 deletions

0
tests/__init__.py Normal file
View File

189
tests/conftest.py Normal file
View File

@@ -0,0 +1,189 @@
"""
Shared test fixtures for Konstruct.
IMPORTANT: The `db_session` fixture connects as `konstruct_app` (not postgres
superuser). This is mandatory — RLS is bypassed for superuser connections, so
tests using superuser would pass trivially while providing zero real protection.
Integration tests requiring a live PostgreSQL container are skipped if the
database is not available. Unit tests never require a live DB.
Event loop design: All async fixtures use function scope to avoid pytest-asyncio
cross-loop-scope issues. The test database is created once (at session scope, via
a synchronous fixture) and reused across tests within the session.
"""
from __future__ import annotations
import asyncio
import os
import subprocess
import uuid
from collections.abc import AsyncGenerator
from typing import Any
import pytest
import pytest_asyncio
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
# ---------------------------------------------------------------------------
# Database URLs
# ---------------------------------------------------------------------------
_ADMIN_URL = os.environ.get(
"DATABASE_ADMIN_URL",
"postgresql+asyncpg://postgres:postgres_dev@localhost:5432/konstruct",
)
_APP_URL = os.environ.get(
"DATABASE_URL",
"postgresql+asyncpg://konstruct_app:konstruct_dev@localhost:5432/konstruct",
)
def _replace_db_name(url: str, new_db: str) -> str:
"""Replace database name in a SQLAlchemy URL string."""
parts = url.rsplit("/", 1)
return f"{parts[0]}/{new_db}"
# ---------------------------------------------------------------------------
# Session-scoped synchronous setup — creates and migrates the test DB once
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
def test_db_name() -> str:
"""Create a fresh test database, run migrations, return DB name."""
db_name = f"konstruct_test_{uuid.uuid4().hex[:8]}"
admin_postgres_url = _replace_db_name(_ADMIN_URL, "postgres")
# Check PostgreSQL reachability using synchronous driver
try:
import asyncio as _asyncio
async def _check() -> None:
eng = create_async_engine(admin_postgres_url)
async with eng.connect() as conn:
await conn.execute(text("SELECT 1"))
await eng.dispose()
_asyncio.run(_check())
except Exception as exc:
pytest.skip(f"PostgreSQL not available: {exc}")
# Create test database
async def _create_db() -> None:
eng = create_async_engine(admin_postgres_url, isolation_level="AUTOCOMMIT")
async with eng.connect() as conn:
await conn.execute(text(f'CREATE DATABASE "{db_name}"'))
await eng.dispose()
asyncio.run(_create_db())
# Run Alembic migrations against test DB (subprocess — avoids loop conflicts)
admin_test_url = _replace_db_name(_ADMIN_URL, db_name)
result = subprocess.run(
["uv", "run", "alembic", "upgrade", "head"],
env={**os.environ, "DATABASE_ADMIN_URL": admin_test_url},
capture_output=True,
text=True,
cwd=os.path.join(os.path.dirname(__file__), ".."),
)
if result.returncode != 0:
# Clean up on failure
async def _drop_db() -> None:
eng = create_async_engine(admin_postgres_url, isolation_level="AUTOCOMMIT")
async with eng.connect() as conn:
await conn.execute(text(f'DROP DATABASE IF EXISTS "{db_name}"'))
await eng.dispose()
asyncio.run(_drop_db())
pytest.fail(f"Alembic migration failed:\n{result.stdout}\n{result.stderr}")
yield db_name
# Teardown: drop test database
async def _cleanup() -> None:
eng = create_async_engine(admin_postgres_url, isolation_level="AUTOCOMMIT")
async with eng.connect() as conn:
await conn.execute(
text(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity "
"WHERE datname = :dbname AND pid <> pg_backend_pid()"
),
{"dbname": db_name},
)
await conn.execute(text(f'DROP DATABASE IF EXISTS "{db_name}"'))
await eng.dispose()
asyncio.run(_cleanup())
@pytest_asyncio.fixture
async def db_engine(test_db_name: str) -> AsyncGenerator[AsyncEngine, None]:
"""
Function-scoped async engine connected as konstruct_app.
Using konstruct_app role is critical — it enforces RLS. The postgres
superuser would bypass RLS and make isolation tests worthless.
"""
app_test_url = _replace_db_name(_APP_URL, test_db_name)
engine = create_async_engine(app_test_url, echo=False)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
"""
Function-scoped async session connected as konstruct_app.
The RLS hook is configured on this engine so SET LOCAL statements are
injected before each query when current_tenant_id is set.
"""
from shared.rls import configure_rls_hook
# Always configure — SQLAlchemy event.listens_for is idempotent per listener function
# when the same function object is registered; but since configure_rls_hook creates
# a new closure each call, wrap with a set to avoid duplicate listeners.
configure_rls_hook(db_engine)
session_factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with session_factory() as session:
yield session
await session.rollback()
@pytest_asyncio.fixture
async def tenant_a(db_session: AsyncSession) -> dict[str, Any]:
"""Create Tenant A and return its data dict."""
tenant_id = uuid.uuid4()
suffix = uuid.uuid4().hex[:6]
await db_session.execute(
text("INSERT INTO tenants (id, name, slug, settings) VALUES (:id, :name, :slug, :settings)"),
{
"id": str(tenant_id),
"name": f"Tenant Alpha {suffix}",
"slug": f"tenant-alpha-{suffix}",
"settings": "{}",
},
)
await db_session.commit()
return {"id": tenant_id}
@pytest_asyncio.fixture
async def tenant_b(db_session: AsyncSession) -> dict[str, Any]:
"""Create Tenant B and return its data dict."""
tenant_id = uuid.uuid4()
suffix = uuid.uuid4().hex[:6]
await db_session.execute(
text("INSERT INTO tenants (id, name, slug, settings) VALUES (:id, :name, :slug, :settings)"),
{
"id": str(tenant_id),
"name": f"Tenant Beta {suffix}",
"slug": f"tenant-beta-{suffix}",
"settings": "{}",
},
)
await db_session.commit()
return {"id": tenant_id}

View File

View File

@@ -0,0 +1,273 @@
"""
Integration tests for PostgreSQL Row Level Security (RLS) tenant isolation.
Tests TNNT-01: All tenant data is isolated via PostgreSQL Row Level Security.
THIS IS THE MOST CRITICAL TEST IN PHASE 1.
These tests prove that Tenant A cannot see Tenant B's data through PostgreSQL
RLS — not through application-layer filtering, but at the database level.
Critical verification points:
1. tenant_b cannot see tenant_a's agents (even by primary key lookup)
2. tenant_a can see its own agents
3. FORCE ROW LEVEL SECURITY is active (relforcerowsecurity = TRUE)
4. Same isolation holds for channel_connections table
5. All tests connect as konstruct_app (not superuser)
"""
from __future__ import annotations
import uuid
from typing import Any
import pytest
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from shared.rls import current_tenant_id
@pytest.mark.asyncio
class TestAgentRLSIsolation:
"""Prove that Agent rows are invisible across tenant boundaries."""
async def test_tenant_b_cannot_see_tenant_a_agent(
self,
db_session: AsyncSession,
tenant_a: dict[str, Any],
tenant_b: dict[str, Any],
) -> None:
"""
Core RLS test: tenant_b must see ZERO rows from agents owned by tenant_a.
If this test passes trivially (e.g., because the connection is superuser),
the test_rls_is_forced test below will catch it.
"""
agent_id = uuid.uuid4()
# Create an agent for tenant_a (using superuser-like direct insert)
# We need to temporarily bypass RLS to seed data
token = current_tenant_id.set(tenant_a["id"])
try:
await db_session.execute(
text(
"INSERT INTO agents (id, tenant_id, name, role) "
"VALUES (:id, :tenant_id, :name, :role)"
),
{
"id": str(agent_id),
"tenant_id": str(tenant_a["id"]),
"name": "Alice",
"role": "Support Lead",
},
)
await db_session.commit()
finally:
current_tenant_id.reset(token)
# Now set tenant_b context and try to query the agent
token = current_tenant_id.set(tenant_b["id"])
try:
result = await db_session.execute(text("SELECT id FROM agents"))
rows = result.fetchall()
assert len(rows) == 0, (
f"RLS FAILURE: tenant_b can see {len(rows)} agent(s) belonging to tenant_a. "
"This is a critical security violation — check that FORCE ROW LEVEL SECURITY "
"is applied and that the connection uses konstruct_app role (not superuser)."
)
finally:
current_tenant_id.reset(token)
async def test_tenant_a_can_see_own_agents(
self,
db_session: AsyncSession,
tenant_a: dict[str, Any],
tenant_b: dict[str, Any],
) -> None:
"""Tenant A must be able to see its own agents — RLS must not block legitimate access."""
agent_id = uuid.uuid4()
# Insert agent as tenant_a
token = current_tenant_id.set(tenant_a["id"])
try:
await db_session.execute(
text(
"INSERT INTO agents (id, tenant_id, name, role) "
"VALUES (:id, :tenant_id, :name, :role)"
),
{
"id": str(agent_id),
"tenant_id": str(tenant_a["id"]),
"name": "Bob",
"role": "Sales Rep",
},
)
await db_session.commit()
finally:
current_tenant_id.reset(token)
# Query as tenant_a — must see the agent
token = current_tenant_id.set(tenant_a["id"])
try:
result = await db_session.execute(text("SELECT id FROM agents WHERE id = :id"), {"id": str(agent_id)})
rows = result.fetchall()
assert len(rows) == 1, (
f"Tenant A cannot see its own agent. Found {len(rows)} rows. "
"RLS policy may be too restrictive."
)
finally:
current_tenant_id.reset(token)
@pytest.mark.asyncio
class TestChannelConnectionRLSIsolation:
"""Prove that ChannelConnection rows are invisible across tenant boundaries."""
async def test_tenant_b_cannot_see_tenant_a_channel_connection(
self,
db_session: AsyncSession,
tenant_a: dict[str, Any],
tenant_b: dict[str, Any],
) -> None:
"""tenant_b must see ZERO channel_connections owned by tenant_a."""
conn_id = uuid.uuid4()
# Create a channel connection for tenant_a
token = current_tenant_id.set(tenant_a["id"])
try:
await db_session.execute(
text(
"INSERT INTO channel_connections (id, tenant_id, channel_type, workspace_id, config) "
"VALUES (:id, :tenant_id, 'slack', :workspace_id, :config)"
),
{
"id": str(conn_id),
"tenant_id": str(tenant_a["id"]),
"workspace_id": "T-ALPHA-WORKSPACE",
"config": "{}",
},
)
await db_session.commit()
finally:
current_tenant_id.reset(token)
# Query as tenant_b — must see zero rows
token = current_tenant_id.set(tenant_b["id"])
try:
result = await db_session.execute(text("SELECT id FROM channel_connections"))
rows = result.fetchall()
assert len(rows) == 0, (
f"RLS FAILURE: tenant_b can see {len(rows)} channel_connection(s) belonging to tenant_a."
)
finally:
current_tenant_id.reset(token)
async def test_tenant_a_can_see_own_channel_connections(
self,
db_session: AsyncSession,
tenant_a: dict[str, Any],
) -> None:
"""Tenant A must see its own channel connections."""
conn_id = uuid.uuid4()
token = current_tenant_id.set(tenant_a["id"])
try:
await db_session.execute(
text(
"INSERT INTO channel_connections (id, tenant_id, channel_type, workspace_id, config) "
"VALUES (:id, :tenant_id, 'telegram', :workspace_id, :config)"
),
{
"id": str(conn_id),
"tenant_id": str(tenant_a["id"]),
"workspace_id": "tg-alpha-chat",
"config": "{}",
},
)
await db_session.commit()
finally:
current_tenant_id.reset(token)
# Query as tenant_a — must see the connection
token = current_tenant_id.set(tenant_a["id"])
try:
result = await db_session.execute(
text("SELECT id FROM channel_connections WHERE id = :id"),
{"id": str(conn_id)},
)
rows = result.fetchall()
assert len(rows) == 1, f"Tenant A cannot see its own channel connection. Found {len(rows)} rows."
finally:
current_tenant_id.reset(token)
@pytest.mark.asyncio
class TestRLSPolicyConfiguration:
"""Verify PostgreSQL RLS configuration is correct at the schema level."""
async def test_agents_force_row_level_security_is_active(
self,
db_session: AsyncSession,
tenant_a: dict[str, Any],
) -> None:
"""
FORCE ROW LEVEL SECURITY must be TRUE for agents table.
Without FORCE RLS, the table owner (postgres) bypasses RLS.
This would mean our isolation tests passed trivially and provide
zero real protection.
"""
# We need to query pg_class as superuser to check this
# Using the session (which is konstruct_app) — pg_class is readable by all
token = current_tenant_id.set(tenant_a["id"])
try:
result = await db_session.execute(
text("SELECT relforcerowsecurity FROM pg_class WHERE relname = 'agents'")
)
row = result.fetchone()
finally:
current_tenant_id.reset(token)
assert row is not None, "agents table not found in pg_class"
assert row[0] is True, (
"FORCE ROW LEVEL SECURITY is NOT active on agents table. "
"This is a critical security misconfiguration — the table owner "
"can bypass RLS and cross-tenant data leakage is possible."
)
async def test_channel_connections_force_row_level_security_is_active(
self,
db_session: AsyncSession,
tenant_a: dict[str, Any],
) -> None:
"""FORCE ROW LEVEL SECURITY must be TRUE for channel_connections table."""
token = current_tenant_id.set(tenant_a["id"])
try:
result = await db_session.execute(
text("SELECT relforcerowsecurity FROM pg_class WHERE relname = 'channel_connections'")
)
row = result.fetchone()
finally:
current_tenant_id.reset(token)
assert row is not None, "channel_connections table not found in pg_class"
assert row[0] is True, (
"FORCE ROW LEVEL SECURITY is NOT active on channel_connections table."
)
async def test_tenants_table_exists_and_is_accessible(
self,
db_session: AsyncSession,
tenant_a: dict[str, Any],
) -> None:
"""
tenants table must be accessible without tenant context.
RLS is NOT applied to tenants — the Router needs to look up all tenants
during message routing, before tenant context is set.
"""
result = await db_session.execute(text("SELECT id, slug FROM tenants LIMIT 10"))
rows = result.fetchall()
# Should be accessible (no RLS) — we just care it doesn't raise
assert rows is not None

0
tests/unit/__init__.py Normal file
View File

View File

@@ -0,0 +1,143 @@
"""
Unit tests for KonstructMessage normalization from Slack event payloads.
Tests CHAN-01: Channel Gateway normalizes messages from all channels into
unified KonstructMessage format.
These tests exercise normalization logic without requiring a live database.
"""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo
def make_slack_event() -> dict:
"""Minimal valid Slack message event payload."""
return {
"type": "message",
"channel": "C12345ABC",
"user": "U98765XYZ",
"text": "Hey @bot can you help me?",
"ts": "1711234567.123456",
"thread_ts": "1711234567.123456",
"team": "T11223344",
"blocks": [],
"event_ts": "1711234567.123456",
}
def normalize_slack_event(payload: dict) -> KonstructMessage:
"""
Minimal normalizer for test purposes.
The real normalizer lives in packages/gateway/channels/slack.py.
This function mirrors the expected output for unit testing the model.
"""
ts = float(payload["ts"])
timestamp = datetime.fromtimestamp(ts, tz=timezone.utc)
return KonstructMessage(
channel=ChannelType.SLACK,
channel_metadata={
"workspace_id": payload["team"],
"channel_id": payload["channel"],
"event_ts": payload["event_ts"],
},
sender=SenderInfo(
user_id=payload["user"],
display_name=payload["user"], # Display name resolved later by Slack API
),
content=MessageContent(
text=payload["text"],
mentions=[u for u in payload["text"].split() if u.startswith("@")],
),
timestamp=timestamp,
thread_id=payload.get("thread_ts"),
reply_to=None if payload.get("thread_ts") == payload.get("ts") else payload.get("thread_ts"),
)
class TestKonstructMessageNormalization:
"""Tests for Slack event normalization to KonstructMessage."""
def test_channel_type_is_slack(self) -> None:
"""ChannelType must be set to 'slack' for Slack events."""
msg = normalize_slack_event(make_slack_event())
assert msg.channel == ChannelType.SLACK
assert msg.channel == "slack"
def test_sender_info_extracted(self) -> None:
"""SenderInfo user_id must match Slack user field."""
payload = make_slack_event()
msg = normalize_slack_event(payload)
assert msg.sender.user_id == "U98765XYZ"
assert msg.sender.is_bot is False
def test_content_text_preserved(self) -> None:
"""MessageContent.text must contain original Slack message text."""
payload = make_slack_event()
msg = normalize_slack_event(payload)
assert msg.content.text == "Hey @bot can you help me?"
def test_thread_id_from_thread_ts(self) -> None:
"""thread_id must be populated from Slack's thread_ts field."""
payload = make_slack_event()
payload["thread_ts"] = "1711234500.000001" # Different from ts — it's a reply
msg = normalize_slack_event(payload)
assert msg.thread_id == "1711234500.000001"
def test_thread_id_none_when_no_thread(self) -> None:
"""thread_id must be None when the message is not in a thread."""
payload = make_slack_event()
del payload["thread_ts"]
msg = normalize_slack_event(payload)
assert msg.thread_id is None
def test_channel_metadata_contains_workspace_id(self) -> None:
"""channel_metadata must contain workspace_id (Slack team ID)."""
payload = make_slack_event()
msg = normalize_slack_event(payload)
assert "workspace_id" in msg.channel_metadata
assert msg.channel_metadata["workspace_id"] == "T11223344"
def test_channel_metadata_contains_channel_id(self) -> None:
"""channel_metadata must contain the Slack channel ID."""
payload = make_slack_event()
msg = normalize_slack_event(payload)
assert msg.channel_metadata["channel_id"] == "C12345ABC"
def test_tenant_id_is_none_before_resolution(self) -> None:
"""tenant_id must be None immediately after normalization (Router populates it)."""
msg = normalize_slack_event(make_slack_event())
assert msg.tenant_id is None
def test_message_has_uuid_id(self) -> None:
"""KonstructMessage must have a UUID id assigned at construction."""
import uuid
msg = normalize_slack_event(make_slack_event())
# Should not raise
parsed = uuid.UUID(msg.id)
assert str(parsed) == msg.id
def test_timestamp_is_utc_datetime(self) -> None:
"""timestamp must be a timezone-aware datetime in UTC."""
msg = normalize_slack_event(make_slack_event())
assert msg.timestamp.tzinfo is not None
assert msg.timestamp.tzinfo == timezone.utc
def test_pydantic_validation_rejects_invalid_channel(self) -> None:
"""KonstructMessage must reject unknown ChannelType values."""
with pytest.raises(Exception): # pydantic.ValidationError
KonstructMessage(
channel="fax_machine", # type: ignore[arg-type]
channel_metadata={},
sender=SenderInfo(user_id="u1", display_name="User"),
content=MessageContent(text="hello"),
timestamp=datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,140 @@
"""
Unit tests for Redis key namespacing.
Tests TNNT-03: Per-tenant Redis namespace isolation for cache and session state.
Every key function must:
1. Prepend {tenant_id}: to every key
2. Never be callable without a tenant_id argument
3. Produce predictable, stable key formats
"""
from __future__ import annotations
import inspect
import pytest
from shared.redis_keys import (
engaged_thread_key,
idempotency_key,
rate_limit_key,
session_key,
)
class TestRedisKeyFormats:
"""Tests for correct key format from all key constructor functions."""
def test_rate_limit_key_format(self) -> None:
"""rate_limit_key must return '{tenant_id}:ratelimit:{channel}'."""
key = rate_limit_key("tenant-a", "slack")
assert key == "tenant-a:ratelimit:slack"
def test_rate_limit_key_different_channel(self) -> None:
"""rate_limit_key must work for any channel type string."""
key = rate_limit_key("tenant-a", "whatsapp")
assert key == "tenant-a:ratelimit:whatsapp"
def test_idempotency_key_format(self) -> None:
"""idempotency_key must return '{tenant_id}:dedup:{message_id}'."""
key = idempotency_key("tenant-a", "msg-123")
assert key == "tenant-a:dedup:msg-123"
def test_session_key_format(self) -> None:
"""session_key must return '{tenant_id}:session:{thread_id}'."""
key = session_key("tenant-a", "thread-456")
assert key == "tenant-a:session:thread-456"
def test_engaged_thread_key_format(self) -> None:
"""engaged_thread_key must return '{tenant_id}:engaged:{thread_id}'."""
key = engaged_thread_key("tenant-a", "T12345")
assert key == "tenant-a:engaged:T12345"
class TestTenantIsolation:
"""Tests that all key functions produce distinct namespaces per tenant."""
def test_rate_limit_keys_are_tenant_scoped(self) -> None:
"""Two tenants with the same channel must produce different keys."""
key_a = rate_limit_key("tenant-a", "slack")
key_b = rate_limit_key("tenant-b", "slack")
assert key_a != key_b
assert key_a.startswith("tenant-a:")
assert key_b.startswith("tenant-b:")
def test_idempotency_keys_are_tenant_scoped(self) -> None:
"""Two tenants with the same message_id must produce different keys."""
key_a = idempotency_key("tenant-a", "msg-999")
key_b = idempotency_key("tenant-b", "msg-999")
assert key_a != key_b
assert key_a.startswith("tenant-a:")
assert key_b.startswith("tenant-b:")
def test_session_keys_are_tenant_scoped(self) -> None:
"""Two tenants with the same thread_id must produce different keys."""
key_a = session_key("tenant-a", "thread-1")
key_b = session_key("tenant-b", "thread-1")
assert key_a != key_b
def test_engaged_thread_keys_are_tenant_scoped(self) -> None:
"""Two tenants with same thread must produce different engaged keys."""
key_a = engaged_thread_key("tenant-a", "thread-1")
key_b = engaged_thread_key("tenant-b", "thread-1")
assert key_a != key_b
def test_all_keys_include_tenant_id_prefix(self) -> None:
"""Every key function must produce a key starting with the tenant_id."""
tenant_id = "my-tenant-uuid"
keys = [
rate_limit_key(tenant_id, "slack"),
idempotency_key(tenant_id, "msg-1"),
session_key(tenant_id, "thread-1"),
engaged_thread_key(tenant_id, "thread-1"),
]
for key in keys:
assert key.startswith(f"{tenant_id}:"), (
f"Key {key!r} does not start with tenant_id prefix '{tenant_id}:'"
)
class TestNoBarKeysIsPossible:
"""Tests that prove no key function can be called without tenant_id."""
def test_rate_limit_key_requires_tenant_id(self) -> None:
"""rate_limit_key signature requires tenant_id as first argument."""
sig = inspect.signature(rate_limit_key)
params = list(sig.parameters.keys())
assert params[0] == "tenant_id"
def test_idempotency_key_requires_tenant_id(self) -> None:
"""idempotency_key signature requires tenant_id as first argument."""
sig = inspect.signature(idempotency_key)
params = list(sig.parameters.keys())
assert params[0] == "tenant_id"
def test_session_key_requires_tenant_id(self) -> None:
"""session_key signature requires tenant_id as first argument."""
sig = inspect.signature(session_key)
params = list(sig.parameters.keys())
assert params[0] == "tenant_id"
def test_engaged_thread_key_requires_tenant_id(self) -> None:
"""engaged_thread_key signature requires tenant_id as first argument."""
sig = inspect.signature(engaged_thread_key)
params = list(sig.parameters.keys())
assert params[0] == "tenant_id"
def test_calling_without_tenant_id_raises_type_error(self) -> None:
"""Calling any key function with zero args must raise TypeError."""
with pytest.raises(TypeError):
rate_limit_key() # type: ignore[call-arg]
with pytest.raises(TypeError):
idempotency_key() # type: ignore[call-arg]
with pytest.raises(TypeError):
session_key() # type: ignore[call-arg]
with pytest.raises(TypeError):
engaged_thread_key() # type: ignore[call-arg]

View File

@@ -0,0 +1,158 @@
"""
Unit tests for tenant resolution logic.
Tests TNNT-02: Inbound messages are resolved to the correct tenant via
channel metadata.
These tests verify the resolution logic in isolation — no live database needed.
The production resolver queries channel_connections; here we mock that lookup.
"""
from __future__ import annotations
import uuid
from typing import Optional
import pytest
from shared.models.message import ChannelType
# ---------------------------------------------------------------------------
# Minimal in-process tenant resolver for unit testing
# ---------------------------------------------------------------------------
class ChannelConnectionRecord:
"""Represents a row from the channel_connections table."""
def __init__(self, tenant_id: uuid.UUID, channel_type: ChannelType, workspace_id: str) -> None:
self.tenant_id = tenant_id
self.channel_type = channel_type
self.workspace_id = workspace_id
def resolve_tenant(
workspace_id: str,
channel_type: ChannelType,
connections: list[ChannelConnectionRecord],
) -> Optional[uuid.UUID]:
"""
Resolve a (workspace_id, channel_type) pair to a tenant_id.
This mirrors the logic in packages/router/tenant.py.
Returns None if no matching connection is found.
"""
for conn in connections:
if conn.workspace_id == workspace_id and conn.channel_type == channel_type:
return conn.tenant_id
return None
# ---------------------------------------------------------------------------
# Test fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def tenant_a_id() -> uuid.UUID:
return uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
@pytest.fixture
def tenant_b_id() -> uuid.UUID:
return uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
@pytest.fixture
def connections(tenant_a_id: uuid.UUID, tenant_b_id: uuid.UUID) -> list[ChannelConnectionRecord]:
return [
ChannelConnectionRecord(tenant_a_id, ChannelType.SLACK, "T-WORKSPACE-A"),
ChannelConnectionRecord(tenant_b_id, ChannelType.SLACK, "T-WORKSPACE-B"),
ChannelConnectionRecord(tenant_b_id, ChannelType.TELEGRAM, "tg-chat-12345"),
]
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestTenantResolution:
"""Tests for tenant resolution from channel workspace IDs."""
def test_slack_workspace_resolves_to_correct_tenant(
self,
connections: list[ChannelConnectionRecord],
tenant_a_id: uuid.UUID,
) -> None:
"""Known Slack workspace_id must resolve to the correct tenant."""
result = resolve_tenant("T-WORKSPACE-A", ChannelType.SLACK, connections)
assert result == tenant_a_id
def test_second_slack_workspace_resolves_independently(
self,
connections: list[ChannelConnectionRecord],
tenant_b_id: uuid.UUID,
) -> None:
"""Two different Slack workspaces must resolve to their respective tenants."""
result = resolve_tenant("T-WORKSPACE-B", ChannelType.SLACK, connections)
assert result == tenant_b_id
def test_unknown_workspace_id_returns_none(
self,
connections: list[ChannelConnectionRecord],
) -> None:
"""Unknown workspace_id must return None — not raise, not return wrong tenant."""
result = resolve_tenant("T-UNKNOWN", ChannelType.SLACK, connections)
assert result is None
def test_wrong_channel_type_does_not_match(
self,
connections: list[ChannelConnectionRecord],
) -> None:
"""Workspace ID from wrong channel type must not match."""
# T-WORKSPACE-A is registered as SLACK — should not match TELEGRAM
result = resolve_tenant("T-WORKSPACE-A", ChannelType.TELEGRAM, connections)
assert result is None
def test_telegram_workspace_resolves_correctly(
self,
connections: list[ChannelConnectionRecord],
tenant_b_id: uuid.UUID,
) -> None:
"""Telegram channel connections resolve independently from Slack."""
result = resolve_tenant("tg-chat-12345", ChannelType.TELEGRAM, connections)
assert result == tenant_b_id
def test_empty_connections_returns_none(self) -> None:
"""Empty connection list must return None for any workspace."""
result = resolve_tenant("T-ANY", ChannelType.SLACK, [])
assert result is None
def test_resolution_is_channel_type_specific(
self,
connections: list[ChannelConnectionRecord],
) -> None:
"""
The same workspace_id string registered on two different channel types
must only match the correct channel type.
This prevents a Slack workspace ID from accidentally matching a
Mattermost workspace with the same string value.
"""
same_id_connections = [
ChannelConnectionRecord(
uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc"),
ChannelType.SLACK,
"SHARED-ID",
),
ChannelConnectionRecord(
uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd"),
ChannelType.MATTERMOST,
"SHARED-ID",
),
]
slack_tenant = resolve_tenant("SHARED-ID", ChannelType.SLACK, same_id_connections)
mm_tenant = resolve_tenant("SHARED-ID", ChannelType.MATTERMOST, same_id_connections)
assert slack_tenant != mm_tenant
assert str(slack_tenant) == "cccccccc-cccc-cccc-cccc-cccccccccccc"
assert str(mm_tenant) == "dddddddd-dddd-dddd-dddd-dddddddddddd"